From 7afe4073db4b703dd4a9ddc13dd16426fee3bd6f Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Wed, 24 May 2023 17:03:33 +0800 Subject: [PATCH 01/42] add quantization 2.0 document --- .../quantization_2_0_tutotial.rst | 523 ++++++++++++++++++ 1 file changed, 523 insertions(+) create mode 100644 prototype_source/quantization_2_0_tutotial.rst diff --git a/prototype_source/quantization_2_0_tutotial.rst b/prototype_source/quantization_2_0_tutotial.rst new file mode 100644 index 00000000000..09a5ee7f579 --- /dev/null +++ b/prototype_source/quantization_2_0_tutotial.rst @@ -0,0 +1,523 @@ +(prototype) PyTorch Quantization 2.0 Tutorial +========================================== + +Today we have `FX Graph Mode +Quantization `__ +which uses symbolic_trace to capture the model into a graph, and then +perform quantization transformations on top of the captured model. In a +similar way, for Quantization 2.0 flow, we will now use the PT2 Export +workflow to capture the model into a graph, and perform quantizations +transformations on top of the ATen dialect graph. This is expected to +have significantly higher model coverage, better programmability, and +a simplified UX. + +Suppose we are a backend developer and we wish to integrate our backend +with PyTorch's quantization 2.0 flow. We only need to define the backend +specific quantizer. An existing quantizer object defined for +QNNPack/XNNPack is here +`QNNPackQuantizer `__. +Taking QNNPackQuantizer as example, the overall Quantization 2.0 flow could be: + +:: + + import torch + import torch._dynamo as torchdynamo + from torch.ao.quantization._quantize_pt2e import convert_pt2e, prepare_pt2e + import torch.ao.quantization._pt2e.quantizer.qnnpack_quantizer as qq + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 10) + + def forward(self, x): + return self.linear(x) + + example_inputs = (torch.randn(1, 5),) + model = M().eval() + + # Step 1: Trace the model into an FX graph of flattened ATen operators + exported_graph_module, guards = torchdynamo.export( + model, + *copy.deepcopy(example_inputs), + aten_graph=True, + ) + + # Step 2: Insert observers or fake quantize modules + quantizer = qq.QNNPackQuantizer() + operator_config = qq.get_symmetric_quantization_config(is_per_channel=True) + quantizer.set_global(operator_config) + prepared_graph_module = prepare_pt2e(exported_graph_module, quantizer) + + # Step 3: Quantize the model + convered_graph_module = convert_pt2e(prepared_graph_module) + + # Step 4: Lower Reference Quantized Model into the backend + +Inside the Quantizer, we will use the `QuantizationAnnotation API `__ +to convey user's intent for what quantization spec to use and how to +observe certain tensor values in the prepare step. Now, we will have a step by step +tutorial for how to use the `QuantizationAnnotation API` to create a quantizer. + +1. Define QuantizationConfig +-------------------------------------------------------- + +QuantizationConfig defines the data type and qscheme for activation, weight and bias. Suppose +we want to define: + +- Activation: `int8` data type, `per_tensor_affine` quantization, `HistogramObserver` +- Weight : `int8` data type, `per_channel_symmetric` quantization, `PerChannelMinMaxObserver` +- Bias : `float` data type, `PlaceholderObserver` + +:: + + def get_symmetric_quantization_config(): + act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = \ + HistogramObserver + act_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12), + ) + + weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PerChannelMinMaxObserver + extra_args: Dict[str, Any] = {"eps": 2**-12} + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-127, + quant_max=127, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + is_dynamic=False, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(**extra_args), + ) + + bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PlaceholderObserver + bias_quantization_spec = QuantizationSpec( + dtype=torch.float, + observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr + ) + quantization_config = QuantizationConfig( + act_quantization_spec, weight_quantization_spec, bias_quantization_spec, is_qat + ) + return quantization_config + +2. Define the BackendQuantizer +-------------------------------------------------------- + +Then we will define the skeleton of a BackendQuantizer. The annotatation methods for each operation will be +defined later. + +:: + + class BackendQuantizer(Quantizer): + + def __init__(self): + super().__init__() + self.global_config: QuantizationConfig = None # type: ignore[assignment] + self.operator_type_config: Dict[str, Optional[QuantizationConfig]] = {} + + def set_global(self, quantization_config: QuantizationConfig): + self.global_config = quantization_config + return self + + def set_config_for_operator_type( + self, operator_type: str, quantization_config: QuantizationConfig + ): + self.operator_type_config[operator_type] = quantization_config + return self + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + """just handling global spec for now""" + global_config = self.global_config + self.annotate_symmetric_config(model, global_config) + + return model + + def annotate_symmetric_config( + self, model: torch.fx.GraphModule, config: QuantizationConfig + ) -> torch.fx.GraphModule: + for node in reversed(model.graph.nodes): + # The annotation methods for each op will defined later + pass + return model + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + @classmethod + def get_supported_operators(cls) -> List[OperatorConfig]: + return [] + +3. Annotate common operator patterns +-------------------------------------------------------- + +Now we will start to define the annotatation methods inside quantizer. For common operators like `conv2d`, we can use `QuantizationSpec` to +annotate the input, weight, bias and output. + +:: + + def _annotate_conv2d( + self, node: Node, quantization_config: QuantizationConfig + ) -> None: + if ( + conv_node.op == "call_function" + and conv_node.target == torch.ops.aten.convolution.default + ): + input_qspec_map = {} + input_act = conv_node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = get_act_qspec(quantization_config) + + weight = conv_node.args[1] + assert isinstance(weight, Node) + input_qspec_map[weight] = get_weight_qspec(quantization_config) + + bias = conv_node.args[2] + if isinstance(bias, Node): + input_qspec_map[bias] = get_bias_qspec(quantization_config) + + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=get_act_qspec(quantization_config), + _annotated=True + ) + +4. Annotate sharing qparams operators +-------------------------------------------------------- + +For operator such as `add` and `cat`, which we want the two inputs sharing +quantization parameters, we can use the `SharedQuantizationSpec` to make the two inputs +sharing the same quantization parameters. + +:: + + def _annotate_add( + self, node: Node, quantization_config: QuantizationConfig + ) -> None: + if add_node.op == "call_function" and add_node.target in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.add_.Tensor, + ]: + act_qspec = get_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = add_node.args[0] + input_act1 = add_node.args[1] + + share_qparams_with_input_act0_qspec = SharedQuantizationSpec((input_act0, add_node)) + + input_qspec_map = {input_act0: act_qspec, input_act1: share_qparams_with_input_act0_qspec} + + add_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=act_qspec, + _annotated=True, + ) + +5. Annotate fixed qparams operators +-------------------------------------------------------- + +For operator such as `sigmoid`, whose quantization parameters are known before, +we want to use fixed parameters for it. + +**TODO(leslie)** `FixedQParamsQuantizationSpec` has not been implemented yet. +Will add example of `FixedQParamsQuantizationSpec` with `sigmoid` after implementation. + +:: + + def _annotate_sigmoid( + self, node: Node, quantization_config: QuantizationConfig + ) -> None: + if sigmoid_node.op == "call_function" and sigmoid_node.target in [ + torch.ops.aten.sigmoid.default, + ]: + act_qspec = get_act_qspec(quantization_config) + + input_qspec_map = {} + input_act0 = sigmoid_node.args[0] + + fixed_params_qspec = FixedQParamsQuantizationSpec... + +6. Annotate bias for linear +-------------------------------------------------------- + +`DerivedQuantizationSpec` is the quantization spec for the Tensors whose quantization parameters are derived from other Tensors. +**TODO(leslie)** `DerivedQuantizationSpec` has not been implemented yet. +Will add example of `DerivedQuantizationSpec` with `linear`. + +7. A Toy Example with Resnet18 +-------------------------------------------------------- + +After above annotation methods defined with `QuantizationAnnotation API`, we can now put them together for the BackendQuantizer +to run a example with Torchvision Resnet18. + +.. code:: ipython3 + + import copy + import functools + import operator + from typing import Callable, Dict, List, Optional, Set, Any + + import torch + import torch._dynamo as torchdynamo + from torch.ao.quantization._pt2e.quantizer.utils import ( + get_act_qspec, + get_weight_qspec, + get_bias_qspec, + ) + + from torch.fx import Node + + from torch.fx.passes.utils.source_matcher_utils import get_source_partitions + + from torch.ao.quantization._pt2e.quantizer.quantizer import ( + OperatorConfig, + QuantizationConfig, + QuantizationSpec, + Quantizer, + QuantizationAnnotation, + _annotate_input_qspec_map, + _annotate_output_qspec, + ) + from torch.ao.quantization.observer import ( + HistogramObserver, + PerChannelMinMaxObserver, + PlaceholderObserver, + ) + from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor + import torchvision + from torch.ao.quantization._quantize_pt2e import ( + convert_pt2e, + prepare_pt2e_quantizer, + ) + + def _mark_nodes_as_annotated(nodes: List[Node]): + for node in nodes: + if node is not None: + if "quantization_annotation" not in node.meta: + node.meta["quantization_annotation"] = QuantizationAnnotation() + node.meta["quantization_annotation"]._annotated = True + + def _is_annotated(nodes: List[Node]): + annotated = False + for node in nodes: + annotated = annotated or ( + "quantization_annotation" in node.meta + and node.meta["quantization_annotation"]._annotated + ) + return annotated + + class BackendQuantizer(Quantizer): + + def __init__(self): + super().__init__() + self.global_config: QuantizationConfig = None # type: ignore[assignment] + self.operator_type_config: Dict[str, Optional[QuantizationConfig]] = {} + + def set_global(self, quantization_config: QuantizationConfig): + self.global_config = quantization_config + return self + + def set_config_for_operator_type( + self, operator_type: str, quantization_config: QuantizationConfig + ): + self.operator_type_config[operator_type] = quantization_config + return self + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + """just handling global spec for now""" + global_config = self.global_config + self.annotate_symmetric_config(model, global_config) + + return model + + def annotate_symmetric_config( + self, model: torch.fx.GraphModule, config: QuantizationConfig + ) -> torch.fx.GraphModule: + self._annotate_linear(model, config) + for node in reversed(model.graph.nodes): + self._annotate_conv2d(node, config) + self._annotate_maxpool2d(node, config) + return model + + def _annotate_conv2d( + self, node: Node, quantization_config: QuantizationConfig + ) -> None: + conv_node = node + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.convolution.default + ): + return + # skip annotation if it is already annotated + if _is_annotated([conv_node]): + return + + input_qspec_map = {} + input_act = conv_node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = get_act_qspec(quantization_config) + + weight = conv_node.args[1] + assert isinstance(weight, Node) + input_qspec_map[weight] = get_weight_qspec(quantization_config) + + bias = conv_node.args[2] + if isinstance(bias, Node): + input_qspec_map[bias] = get_bias_qspec(quantization_config) + + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=get_act_qspec(quantization_config), + _annotated=True + ) + + def _annotate_linear( + self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig + ) -> None: + module_partitions = get_source_partitions( + gm.graph, [torch.nn.Linear, torch.nn.functional.linear] + ) + act_qspec = get_act_qspec(quantization_config) + weight_qspec = get_weight_qspec(quantization_config) + bias_qspec = get_bias_qspec(quantization_config) + for module_or_fn_type, partitions in module_partitions.items(): + if module_or_fn_type == torch.nn.Linear: + for p in partitions: + act_node = p.input_nodes[0] + output_node = p.output_nodes[0] + weight_node = None + bias_node = None + for node in p.params: + weight_or_bias = getattr(gm, node.target) # type: ignore[arg-type] + if weight_or_bias.ndim == 2: # type: ignore[attr-defined] + weight_node = node + if weight_or_bias.ndim == 1: # type: ignore[attr-defined] + bias_node = node + if weight_node is None: + raise ValueError("No weight found in Linear pattern") + # find use of act node within the matched pattern + act_use_node = None + for node in p.nodes: + if node in act_node.users: # type: ignore[union-attr] + act_use_node = node + break + if act_use_node is None: + raise ValueError( + "Could not find an user of act node within matched pattern." + ) + if _is_annotated([act_use_node]) is False: # type: ignore[list-item] + _annotate_input_qspec_map( + act_use_node, + act_node, + act_qspec, + ) + if bias_node and _is_annotated([bias_node]) is False: + _annotate_output_qspec(bias_node, bias_qspec) + if _is_annotated([weight_node]) is False: # type: ignore[list-item] + _annotate_output_qspec(weight_node, weight_qspec) + if _is_annotated([output_node]) is False: + _annotate_output_qspec(output_node, act_qspec) + nodes_to_mark_annotated = list(p.nodes) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + + # TODO: move to `_pt2e/_propagate_annotation.py` after we have + # decided on the how we want to use pattern matching for annotation + def _annotate_maxpool2d( + self, node: Node, quantization_config: QuantizationConfig + ) -> None: + if ( + node.op != "call_function" + or node.target != operator.getitem + or node.args[1] != 0 + ): + return + getitem_node = node + maxpool_node = getitem_node.args[0] + assert isinstance(maxpool_node, Node) + if ( + maxpool_node.op != "call_function" + or maxpool_node.target != torch.ops.aten.max_pool2d_with_indices.default + ): + return + if _is_annotated([getitem_node, maxpool_node]): + return + + input_act = maxpool_node.args[0] + assert isinstance(input_act, Node) + + act_qspec = get_act_qspec(quantization_config) + maxpool_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: act_qspec, + }, + _annotated=True, + ) + getitem_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=act_qspec, + _input_output_share_observers=True, + _annotated=True, + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + @classmethod + def get_supported_operators(cls) -> List[OperatorConfig]: + return [] + + def get_symmetric_quantization_config(): + act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = \ + HistogramObserver + act_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12), + ) + + weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PerChannelMinMaxObserver + extra_args: Dict[str, Any] = {"eps": 2**-12} + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-127, + quant_max=127, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + is_dynamic=False, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(**extra_args), + ) + + bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PlaceholderObserver + bias_quantization_spec = QuantizationSpec( + dtype=torch.float, + observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr + ) + quantization_config = QuantizationConfig( + act_quantization_spec, weight_quantization_spec, bias_quantization_spec + ) + return quantization_config + + if __name__ == "__main__": + example_inputs = (torch.randn(1, 3, 224, 224),) + m = torchvision.models.resnet18().eval() + m_copy = copy.deepcopy(m) + # program capture + m, guards = torchdynamo.export( + m, + *copy.deepcopy(example_inputs), + aten_graph=True, + ) + quantizer = BackendQuantizer() + operator_config = get_symmetric_quantization_config() + quantizer.set_global(operator_config) + m = prepare_pt2e_quantizer(m, quantizer) + after_prepare_result = m(*example_inputs) + m = convert_pt2e(m) + print("converted module is: {}".format(m), flush=True) From 0beaa6fe8291c5e99da0c632638f3e12b991fff5 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Fri, 26 May 2023 10:22:14 +0800 Subject: [PATCH 02/42] add example of DerivedQuantizationSpec --- .../quantization_2_0_tutotial.rst | 42 ++++++++++++++++++- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/prototype_source/quantization_2_0_tutotial.rst b/prototype_source/quantization_2_0_tutotial.rst index 09a5ee7f579..8adba426a0b 100644 --- a/prototype_source/quantization_2_0_tutotial.rst +++ b/prototype_source/quantization_2_0_tutotial.rst @@ -246,8 +246,46 @@ Will add example of `FixedQParamsQuantizationSpec` with `sigmoid` after implemen -------------------------------------------------------- `DerivedQuantizationSpec` is the quantization spec for the Tensors whose quantization parameters are derived from other Tensors. -**TODO(leslie)** `DerivedQuantizationSpec` has not been implemented yet. -Will add example of `DerivedQuantizationSpec` with `linear`. +For example, we want to define the scale, zp for bias derived from activation and weight of convolution node. + +:: + + def _annotate_conv2d_derived_bias( + self, node: Node, quantization_config: QuantizationConfig + ) -> None: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.convolution.default + ): + input_act = node.args[0] + weight = node.args[1] + bias = node.args[2] + act_qspec = get_act_qspec(quantization_config) + weight_qspec = get_weight_qspec(quantization_config) + + def derive_qparams_fn(obs_or_fqs: List[ObserverOrFakeQuantize]) -> Tuple[Tensor, Tensor]: + assert len(obs_or_fqs) == 2, \ + "Expecting two obs/fqs, one for activation and one for weight, got: {}".format(len(obs_or_fq)) + act_obs_or_fq = obs_or_fqs[0] + weight_obs_or_fq = obs_or_fqs[1] + act_scale, act_zp = act_obs_or_fq.calculate_qparams() + weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams() + return torch.tensor([act_scale * weight_scale]).to(torch.float32), torch.tensor([0]).to(torch.int32) + + bias_qspec = DerivedQuantizationSpec( + derived_from=[(input_act, node), (weight, node)], + derive_qparams_fn=derive_qparams_fn, + dtype=torch.int32, + quant_min=-2**31, + quant_max=2**31 - 1, + qscheme=torch.per_tensor_symmetric, + ) + input_qspec_map = {input_act: act_qspec, weight: weight_qspec, bias: bias_qspec} + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=act_qspec, + _annotated=True, + ) 7. A Toy Example with Resnet18 -------------------------------------------------------- From 927ce7d06e8b10ee65d40e16d9d21c81eb10fe51 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Mon, 29 May 2023 08:54:09 +0800 Subject: [PATCH 03/42] add example of FixedQParamsQuantizationSpec --- .../quantization_2_0_tutotial.rst | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/prototype_source/quantization_2_0_tutotial.rst b/prototype_source/quantization_2_0_tutotial.rst index 8adba426a0b..21c238a11b3 100644 --- a/prototype_source/quantization_2_0_tutotial.rst +++ b/prototype_source/quantization_2_0_tutotial.rst @@ -221,28 +221,36 @@ sharing the same quantization parameters. 5. Annotate fixed qparams operators -------------------------------------------------------- -For operator such as `sigmoid`, whose quantization parameters are known before, -we want to use fixed parameters for it. - -**TODO(leslie)** `FixedQParamsQuantizationSpec` has not been implemented yet. -Will add example of `FixedQParamsQuantizationSpec` with `sigmoid` after implementation. +For operator such as `sigmoid`, which has predefined and fixed scale/zero_point, +we can use fixed parameters for it with `FixedQParamsQuantizationSpec`. :: def _annotate_sigmoid( self, node: Node, quantization_config: QuantizationConfig ) -> None: - if sigmoid_node.op == "call_function" and sigmoid_node.target in [ + if node.op == "call_function" and node.target in [ torch.ops.aten.sigmoid.default, ]: - act_qspec = get_act_qspec(quantization_config) - - input_qspec_map = {} - input_act0 = sigmoid_node.args[0] - - fixed_params_qspec = FixedQParamsQuantizationSpec... + input_act = node.args[0] + assert isinstance(input_act, Node) + act_qspec = FixedQParamsQuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + scale=2.0 / 256.0, + zero_point=128, + ) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: act_qspec, + }, + output_qspec=act_qspec, + _annotated=True, + ) -6. Annotate bias for linear +6. Annotate tensor with derived quantization parameters -------------------------------------------------------- `DerivedQuantizationSpec` is the quantization spec for the Tensors whose quantization parameters are derived from other Tensors. @@ -463,8 +471,6 @@ to run a example with Torchvision Resnet18. nodes_to_mark_annotated = list(p.nodes) _mark_nodes_as_annotated(nodes_to_mark_annotated) - # TODO: move to `_pt2e/_propagate_annotation.py` after we have - # decided on the how we want to use pattern matching for annotation def _annotate_maxpool2d( self, node: Node, quantization_config: QuantizationConfig ) -> None: From 66a5bde77aff0f1ea20fd10c33c0e31261558f50 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Wed, 31 May 2023 08:56:14 +0800 Subject: [PATCH 04/42] add customcarditem to Quantization 2.0 --- prototype_source/prototype_index.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/prototype_source/prototype_index.rst b/prototype_source/prototype_index.rst index cfdb2ffcca3..43ebd2e80e5 100644 --- a/prototype_source/prototype_index.rst +++ b/prototype_source/prototype_index.rst @@ -68,6 +68,13 @@ Prototype features are not available as part of binary distributions like PyPI o :link: ../prototype/numeric_suite_tutorial.html :tags: Debugging,Quantization +.. customcarditem:: + :header: PyTorch Quantization 2.0 Tutorial + :card_description: Learn how to use the PyTorch Quantization 2.0 stack. + :image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png + :link: ../prototype/quantization_2_0_tutotial.html + :tags: Quantization + .. Mobile .. customcarditem:: From 89bccc99df7572d3de60eb2dc3124d0bfcef3384 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Wed, 31 May 2023 09:07:18 +0800 Subject: [PATCH 05/42] add more into prototype_index --- prototype_source/prototype_index.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/prototype_source/prototype_index.rst b/prototype_source/prototype_index.rst index 43ebd2e80e5..b004cc2f365 100644 --- a/prototype_source/prototype_index.rst +++ b/prototype_source/prototype_index.rst @@ -200,6 +200,7 @@ Prototype features are not available as part of binary distributions like PyPI o prototype/fx_graph_mode_ptq_dynamic.html prototype/fx_graph_mode_ptq_static.html prototype/graph_mode_dynamic_bert_tutorial.html + prototype/quantization_2_0_tutotial.html prototype/ios_gpu_workflow.html prototype/nnapi_mobilenetv2.html prototype/tracing_based_selective_build.html From 212e82033802c366b7ddfce78e14ace0ffbde975 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Wed, 31 May 2023 09:53:43 +0800 Subject: [PATCH 06/42] add quantization 2.0 diagram --- .../pytorch_quantization_2_0_diagram.png | Bin 0 -> 40506 bytes prototype_source/quantization_2_0_tutotial.rst | 10 +++++++--- 2 files changed, 7 insertions(+), 3 deletions(-) create mode 100644 _static/img/quantization/pytorch_quantization_2_0_diagram.png diff --git a/_static/img/quantization/pytorch_quantization_2_0_diagram.png b/_static/img/quantization/pytorch_quantization_2_0_diagram.png new file mode 100644 index 0000000000000000000000000000000000000000..2a58a0bffa927f42247a46a5797bde499155a672 GIT binary patch literal 40506 zcmd43cTkgEv^N?RL_vy*q96jVB2Ah!B~la-=_Q2TlwLv+2ptqeL?l?~E%cJm2?0VA z5do!1Cm{huS|IcU>G#3+eBV7Y_ssp~{&DUcW|$QA^E_*>wf5R;|9)!`t*4{L%*e?I z0)d#-pFA=Efljr8KqrUJoCdxTxhatjyqxehP%F!Je4ulDp#1;@s*GdWe@YL0 zKI{3!)Efl4%zylQLfzoz8VHopr~c@Hp}!U7fZ+u@b^7S2B`#uy^V-3}n<(}NZ{O0p zIejjYeET@JvfU{sbi=qZ^kZ@ypXw*wp(k6&qUg|Aex6w3729wb8jB#K;G1y6MbB%q zETNl$tBrB>PibTlaiw?j!#gQ3@FUd_)a^>Le+~~k2M7ZKN%nu^Id%NzJ@pJQ5fJFb z7%$!5FP=rP{T=d#`Ju$|A9!W|Ck^NvUtBLd0dm4Pvx=`3zWA2%Tq8w^9;CdPee=q( zv}|0jY>r=*-PpoU9KWf2@Yh73ucx_fW!Sh|K#Os`zGe(zJU2Vy} z`l`G4PGA_|;zN;a=jirVIzGXQHRPW|Wv#<#C&F|Bg(mBI-6y$6GnV>}RXf1ScV(@| zpTgs?9?shHN8Ry{_rPFc57$BQmcdZ=r>pX0Jt@Btk-6&n)m`tG-Lqo%7xx+i+n*5} zOokCMtGJjrOqf*8!o&B8AM&lcCAW1Nh7?D#m|)%xjdn<;fp$@aa{JtcCTj57Kgs5W zGuK_8U}3-%*wW933+S~z5P-<5L(X1HU7;jhubJ_>s9Qb5>-UT6x=+=Qt~oJ=YhG2{ zEk|Ss{G8pX5Mkgz6V+#F0 zO*$O-!>y%Fzo3kUI;ehgE27OJXm(0fq7)JV$RN0gA4D>NKwr+ZUj3ZGzDV~yV`@u; z(2=R7KpmY-T@t_dJyAAm-SB(v1ugCw(y7-w-hN&biq8bmG0UFHnQ>+*{1HU_aos?! z*5Pf2xcgrbz8ih*SM8W0C(9J}(vYOxocgKMh8~&PW}9Z&oqzTwHgzYScIa=_K9JE2 zeg0!}xZGPtd#lzw%rEuf($G|02Pf?su0@Fh1bTHoT&lCx8eFj9T`Az={R83-83C z9>uLAvYS64|8%f-`r!Wfw(73PT59g?c)z>%i)((ovU#ZsbtJMDy0niZOh)E*>^6w5 ztg?VWPNuJ>&J|qv@u}IbG{Y#`4x)MK*K%8JAz?jJg$thc;i)a6NbG#zOg#hT?#gLq zrA`D#cR&Z?rIy@omaPk&sTXX@wR)AWV-F(3yl!2)6;lrfVv*TpzPoKgukJ04maP@K zz`Ju0h8Af?zGWBkE}W~m$AS%j4wLYKyhc#t(U|8$Y|CT(W0mht``*g!-~^*&(c4UA zkUm%QT-4C%ZSAFV%{r?T)WxNmX@pyjeBOD>NAIsU{aCK!dmVdk0x|qpZ$eREY3lKW zS3j-cz>DCyR8u(K!TAN5bg{;xrbZi{v0Jz8TOP3cGVPx@uQhMiCsE1+)st56`+bu< zAQ|69QU^kc+WZv+vXg)4`8tMC;sK)+>(Ne_)>E%&(V8=vK0=*>o<05ti{yDrbK9m{ zs0{SOPxyN0tD@ztLbo+aa#uf3AEsL>BDb4E%6rmA2=>yK>2KFKkL;CNH5UbDFvhz? z`E^jcJZf`cKjn!_TZi~0B`t-)*K(zD6CYyb^Ug9%+i*?T_;&2#%__T)8dv1yy|gf* z`uUrIO0XELF@?6faxW)vkAG~XZu>sV_e%H*PeXJQF zUamkTYd|rYiQTNMY*heQ^Zf&W-;0f{3`EB9`Znjlon2bU_^P)LbHd-HZyZ<~DV1Zt zaUA`6z}3m_W4+s93BRhJ1#!AiqXu6Z#3ULM>O3-zn7n_VRjT+w=L}|Fv0#uKB_3pE zwD|tG#r3(b*(WuF!>SO=XIh+#>O zuJAN%=Wx3oJsKV&`WU9QTrEEuM0C76ddfe@%SwLEdGQtZXOpmXD+$de#Y zPb^z{E4!VguPOZ}&omkR3W(b6r{#xBieMuNMh4LY9o%T6;=$3fiIJG4O$Ygp;8;Xe zL5KD(A|sPJCk4M&;E6)me<2i&3Y*j`FE}-))xw7jz$}AU25^SHP#J-xX_B>&WNix; zxE9vTjed2w*I9a|V(9FtLK5k;)H{>(@s!sEU$lf}@^cPa;5N^*jt4PwVN<%4 zoY4}q!^opEAdvJWzPk}cGyF9?!&&eYY+xnFBlf!%5@*x=H&_pfQw>uEyTX&1vRbpM z?Y(DC_2>&&v0WW-ox%6;U#y|fM{+Nli6 z7h1#hDG%2{+wYo2uZ*<`5Gv7!w^zOnL6c6uSHvaWTA54}Mn&faMWZd7_R!_P>_}5-pIjCNEt_JpMw1tR2p}Lf$~1VB!1i zS764-X6pqjdSU1)v_mW@L%t|bPrNITGRQh%m@Dvm_wMqT_p-^74)kdsyM} zOr^F(Ojy;pxB*dkIrTH?7_(G;w;2wh??# zFWV^8$-ngEmgOQ%i=1JWKRLZNX^4xA59~LpgR04muaKOR@OyeoCp9%`R@bFxl^;o zqT!j@EQzun8e4%S^9ICBEWc1ATH2&Zu<$L@&`N0QxE3^<8lvau+?YOw^IF6;kz3i_XNXptW8%>XwYIsLOfmV7M?7x=obrxrr8`vb)G zRAa<;+~uQP>Cc!bjz194z;Ha|`@Rt7=Aq)ZmmKCzNAoWIVZl%^JQan?D6k~R7KDI3 zI)6w8kbJBfEIiU}=jQ`9H#eVCTnU&i$$MrA8Sb`iYXea8?orm|4<>xTmi~rqLOK2S ziN!d!3Hd*J$@D&tr=o@?WJxNk>pC8f^+U^qgj^ioGI?uve8Cn-w3G@tO@3n!%*JLN zls(A43>Z%Bj61K+#)~cBm3u#=`@V2Cv9c{~O^5z%;|p{aq4Z{-5B0T9JEf7YkDcT9 zt;?rDyqav`TyH+z=RG(dZVsG32MrM@@mfikx3^Y8l0Wo2pO}aS-VGL(fxQAEW#?4z-rIPNMuC!e*R&m3^OUnMay? z^goG%whb(gEgo1$$W5Bx9|eneq%CRBVVe1Ej3o@`p(6Hv$Tl^uE>s|OV6{}OkmeAVsp?>}`f!k# z2EZjU=l|-#j`d*cqChXg9Xjxh30ZHYH{4Zoh;Cs^fPybHH3YWSG*QkJ#tOM0U4ueJ z#oS!+dz%-VT3qTM=~i$ODqMB@KT!Vcja6+mpf*C&@+~*CO_9$A%hj5NC|V2jpe(Uh zQ>TC3WKHRat0p~ZYH}&DKA?T)w(^9zGHyKA%I0v2iNUA-+Q$}nN8MlbO5;>$CSz{} za_ey&D$0sm8V)Fpac(c^+t5h45sTJl@BJ?T$6_wNx|esmRB(E?hT}$Kp_X-}dC{+g$*2i;Ik{=~Qi?**B|H-p z1j<@J)5`o+{kmBybA2ct_GyORay8F{wbqy-(M|#U?iaiu#?PR{JT;zOXbmgH-6RQp z|8oGCvCnQSTIfMOM{tnekJ|jEh~3MCNd*w-3G=`9o%^NXqbkrEXkl722)~eTV4mdVm!L ztgP(m3kqVju(Udcm0s>|M#d}TNTMQ1y*$vHGN-*nQ&zrJ9mW#@<+LGeidi&|myB2R zyyN-V1C|DB3#@J{VZthfmdxF7@VpSfb46o;d%m){nJXN-dr#<8k-`cMSogfT+y*r= zAvlh)AU{;R@rf&)kIVb4Y@01ipBe?c)IF+NsI4el0Hcje>EfiAy1c2j6o%Q|Ch;j0m!aehNn;{%xAVmhN&%?Q)W zA4VRMF-GpPnk{%2StGaFmZp;7((z^K~4-680_7UQOO zk5$Oqp5JYODf;&ObT$BB0s{Q!Osk}=Za)>7ajRAvL8-4nxkL#GxjSk!{7wrcGHKQ7 z1YqMwgBX=BkblEw@uhQUU#lk;_3b4Vwe2O77$P$Tkr|c6Dh!)K$47%YZ-G2a-a+4$`$eMy z;n1m`T9{u$LjP;lCy~Oda@fEF+3FFdHdk3UIgwGf%0aFJiDD*QsFnZ7^JrInlxcQF z$)p35rEk?L+6qq?qx_tV{LT&Tt^z{qbMYZGh9I;AvHQln`s`cWf(roCDF^;TC#UCo zhK9iFto|;F)M)-Ax+MG)tv~swLZzY z)qglx7@ht|H15%l)2m9+*nsKH;2af4IC<* zX$Fq!XMnZSwDZ%YI#4Cwv~b{UPlCe_tngqNEGN}1U*W8G)L^DP|E-_oEoY4De1`Pu z#8U#C@g!ZuDx`-zc9xWK3i5tO(EQ6O}G0)_si7hBaEzFO)7rLKdPkUxufva{2Jzi^_ zIsoof(YE)>XDPPF%7&DK^{|o{!RU+`Qp{{?q6xT=sBtUPga!5Cy|6S3Vh=Xfd4^vj zq;ofK*tFZoWBgivw|qoZrg4`~&#YUN)Jr(L;iTKM-}2|Y^DE30&yy~muq0&LpzR50 zie-ZZmnYa?MLp_1BzJ{cr|#&*~hlplP&tv;0vX8mwK$QK5PS?G$yI94VSC4)!SNwP$}Z{508J z#;$|(i4*-uPS@5W)T&`lv=yc!?`72GY3S>WKX-Dmqh@n>RrX4-^E?wsd7mx(<*nxe zqhv#oCr}~Q>pxRyZOS6W=#27k@ftbc6wUKc;{C%eJS!R}y?vfO6{!(m37)iBg9?dL z0F}Q0E_cZJ__PLmF7wUOZ}aWz1$2y0xJP@9Ta78y<0pVzRL-1EP0{JK?1~59<3MSi z(uSCWtw+7Vt^cB>FE1OK`#LAE#c>+`>N~;KZfW=yP!vm=9A0`;EEel#U<|ennLi77 zHZSx{t8%J&vqKDXTU=8W=CQ9c2is8b<@e4`#}hMS;=9txixNilvz3Wy?>)6UIoi98 zZ4t+C_~hVsE*uMWeDkXs!)iQX_kKF-z1MgR{rEUpS>Cf%&J;Y^6-BM8;VfD?##R}D zTX~k5IUD*eez50;3Er~F$}b)X7whlu1+kp$6fK6anKX1a5U@5-C=^c;lj5fK2Vh|#m;q)y{^$l)486x`qW-Z=q`5JG9@ zD*9BuI$rs_X#l%(^C5!cO+>A^Fc7Owr*}osa)8WB@Xx@p+;$=lC0@7)FGBuAy7L^*!P%IAx6(9>7P!IXFKeE5k+8b+4*1Qv*;Jh+bFvbYK7 zP_$UfKQ>_J<5xBy-?u0+3hJAIrTpe?>*{u^i8HmE(CDMiD#URd1ewONoy#JzXR^P} zMQNgk^fFxd6p90@m#*HQ=r!iJbG*y&iUC9#r|PiNaHup2CtYiLh~Rflf?m`CW_==5!^2{nWOQak!T!?Q zv(`PiW$tK{OMm@TxeZV%Mi83Bl4OA+4;YSXmto<}KoqXcZ>F;bPOCUTOBlvT zn*=ryHU?;YT9s0a~e3u)lc>~sZ~fRUQ-qUA?m1zd8ZVzj}>7Zw10xegF2c4-Z_6+0Ii58a%ccuPUIQfsgJxxf0i!i62oX(adhLhDra* z_TzZsL0u?Kk(;=RWFa!JPHqVmrs1yT6xlX>Q!F8KrM&maRW6X^GZE$B!Cq zvd~-qojc|}Alkd%^6vpNS|TqD+fZlbvZx|OlX4-I#OA5JdicL{OjJ8_`u(~E!7KrL zy7-mMu~+k(o}kixsO@=^3k)Xk;pp(PMD>wgzm2C&ZcjYQrkMXjGj<7hK(nY0v9CH! zL1mUk*<0U^IG@fBKR#iez0>SUBZz(#euW-NI`&_Hgam|k6{>r^T8m~Y31SuLm-nTA zTBFi@(A5X)T1hBmuHyfZig1}<8qxi$!mUSuTvQA7(v`L)lwcG;(-#1D6&rg^05m9P zGH{GS91md=L7fbliP*f+R)Ny820Z*oKZkiOp|%HMk@)w)Dbc`#jSPiLY|^?-KkcSe z$E{eqf{cpGzuY;FioBbc^$rjllgo{3?TO1Yp*u)Xq8n+~UwZ&_o9rJ00-p0eMeuB| z^5Au=DcuuJMY_!scKF%3S911AXAt z|4Ndm^)x74qTd9fBeYrgAAAuYwpqtqX@W!BoJMui^C;$iOU;daN02l-fK~LH6|pdb zsCgeJGcz`|TUH{E6X_S{RR0&L`#%NF_MGb_>7t53+T9GFZd-JxIYfNmiB+2A1ul&R}Fu z$mlV_9MlTLKwgO6h4hz#XCH@5coFLzR!WS!*Bvcji2KtFhZA&>NdhbNy|=B_vW0F7 zODd%Ywk#^xBp==8)a;;+xaw99$|*@tBC1me){%QmW%UFk(-13vE9k|@Lv;x@l^_=4 zq{Z5&23XBP&+Jl9MD5Ynu;bE;PWN~^M_!F^eV&>GI&iCOiPE3-42QAg8{Nm;_kDnq z9l^t)!$roWCU#hlJhzaW08SIZCDPP;)n6ls_rpRh_J&v}|R-F)}# z6SP>b`o7u0|7P+N$zW;e1Mpds19|#gTb-t%l?bPCxy+z`_rehX5a+n5z=AfujV0Xg z16!F{VTid+ww_<*dZkpk<;Mejou8l$Rtf(!$9fD3guBD}-w`k%I8+So!}`eECo(W+zoF)_SohVKibdz1@c zeEnvEXHP%V;)6VRqL{Xq$<4|gQUow9Lf2jk1=rhLWpQ6^*^S>J#_33xYagK(m7MaJ zwakntahFhELbqqSX!n|uT%|+;_oIp(A2+wuZF`ju?iIOgBSS-)Q>1428?5N-!hd}v z54bCY0!-(=%yJIEg_h6;+`$-PpF(I_{o%*laA5nn)uJb_Xt5!@=LK%Z{IZpue_93`07p zYydS8$*Y#W=2W;-eI&Bcqt>>vmXd0nRD(`Tt+fvRsuQ-}b#Ct}UDPGb%gL>K@2dMW zkqR_f55*(7z3Tx1J2dYB+2XH$d$iKI^dIh$84Xlg5{a7PVjWIE&onw8xCB}F1c&bx zg&b~r^wtgvLltL^{Bns(O*(`-dlW~?UTs%TX(epM#I@gVs%nX+y8P#S{_2~N?A0h* zt=sA;w=|oc!S|zzlh#b{Cmysl09Y#sb6X?)WwOW{!5YXV(|}E&Ea0xw+)^~TcJF%C zUg-<=#G#S9CYOwcqg5Y3Ufy0vHMg#S#zzdt$|$sk#eOl3U)x)dVGKu{^sLyeDS(;@ z1EhV0Z3;ky4nwgmzK4CTZc!s4z47MDm@oD!ni0^CK*a_h?~_%(IsWK@#d&018)342 z?q@@&L#=xL*XfUNwG($6n*x_qXU-ksiS1Ic?g)Qz+NP_6yBHn&;_rbdz4 zJ|FW~at&1Z%!hM)%T1xDM%dhM5GDetJj=Q@h#8%hfy6rU$%yo=CReUY_ZxgxZ9DY& z7xF_+?*K*ttpjyC&w8O5MAxsa@a1<3!ByN15(rAv5y8B1gNVwWniujvDAiWkIdkBi=g z+FY`&$7Mw{>VR^=YiFH>Y9f0#w4<#)rB!#>`XL=^2$w9`pNvSsVMARhTSV_y+zh2A z7eam%*p&B^saJvMPdFd`@=DO}URR0<#j0)lcY@caX4?)N|4+5mCZ|FOl(8f;UUfAa zUR-=;gn9PMuWmVHkW17xDmO5zPFY}9uMMlvL2W2%&(AX) z^53Ft88mOyz(j{{jnM~2_N%Szr)0p{U#zCUqce;oV2Jl+h;N2ukxBU@4tVfJfKn|o zS}|=D8S=T8X;2}yrm1y?);;Nt<);Qxrj%{`6kzBfV#sR*qK2?~c5=U&TkbVhrcM-) zW)bi{yLw<}sKB)eF^{GFvgB%F?HHKekaNR)mqpwt~l zQLIq1E5Yn-cq!^xOn5#b7OdlfUPgQfcUGyVzGA3}fM?>A&cBFLnjw?lm(B!D_?2M> zb&n=vKTI3sf5Th(O_VkKd9r8X^CN%CG0y;Q+_OuX)UEtMmRT`b@{YxY7=#o>dHR{( zsuh)g9b5bzNLTv}00&LHT5TN7eaz%_IA(@^+}V9~fC6XWmN$O%Ft;3iVh&3>+V%3D z7-gyY0bZf@ikG+i)JnP2u7j>FS=&rKoE!B^xc^uojCQr~PaIsKyuT1s2HR=f?yDgf zRXR8A%viFfH5~3@m^)i%xnY_^d-b#uW9x%l&L0i)#Pz#7X==m%he4AHZTELqx6x?6 zB@<`8KMJ#ZEz;Wx8{zltCcjb=$QRroO9(%I&uvB{E?0$-y=zxc&&rg^Vv3DhaSi+R z4O;hiqUIQsng@L-EJ6W4f6v!QdUz;)cM~jk+nDj<&XkGV z378w=(UDS0-u3@Jigb=M|D<7dN%|5zcV+OL-AAFq_~TWIxef6C|Ip>iW;mwXV?E3q zp0kAYyG{Ho?TU*b?I(7PH}~#4w%x3r^fE3CK1iX}0ZhmTAXor7&PVX@aUR#Wvu+l~ zIA%GF$Z&9+20`>Cz9wVawL|y&7z$JP(ZX7zO}f&Sfy~-FyHXS{*TB^T0hsUHDS_Z0 zJ}VlTwkS?w`zKv;L z2q-id{G0pIDQva&9I}I!SB5?$?LeFuHU96?(HAcYCoS9=02SB?pj1vNfBq_BRXyYA z!RgTf>i-CwO|oPSqjZ@6(-p_+`g;1mFXjK=^{oHpBVJtg?7TiHU5IGxzBabHba)L^ z`tol*FD46KA8^P{zLVpV-sG>BJo*!PO4_nlCr*hOR4KRqc69mAc-l9DBnXoirC2wP z)L>=2zUR-E=lWC}GSU1=X*3xIp#0xTU3+qTH@geb4u8B|4V>p+^tLT;k$Zl0Z*KBn z2A>|tGfOgTD~-6}alQNAB;w)3EA@!JI~f}d?gHId$ z)!hqD_9rY0;I7xVBTG3r?%QRiZaSJ9pr?7fSF8BF4a_Dc6FqO`8W>+4g2e;-!N>1K9CR}Xg;NqoG6k zQ%hlrK|_fsjia0S?@L7!kf&Z)I$b;mI)OQRx{81Ym@k|*LA!lKxZ1QWQ3DK*6TvP^ z%#s?(9i$3-0*m-pigW^$6K!}$vPoqS+R!)ZzZ)C7$Mul~77pEA!2fKnTtGUk;{xFF zDHo%x?M{`2kD13D&Gwa!ANp?6Vjc8Gf08PTrTc=P$1!^#9+sjuaakOBs4CRS`?iak zvPHslYNAOcrKRSzkGRdZaclD16scAp5wbb)`gjdrX7vUouR~>tY@z4Eh`&VWKz(b# zKFi@Vu_$Q;vcw z)dOB5UOZ#=kErW|>jys6Zca)faYH{G)=7NcN@2pv#pPVsoT86dqt*3Ep|tDvdKaWp zw}*oqO1A444ydWME$LSC^ZgS}7t=Z8flBxn?+W^N$E$kghykuW)vIOFtSy+K{l((G zo+HJz;+>yJ#XB4%>A~P2GT31tp*#gr*bEExoH^&cSjv;k6V!LGK?;~FXz7s)3x1Zp z^Fu18d(vuuV1Fidc5Y|8<&H-%=394q3#r=@^D}sN_EvR&P@JlWY#P5cH2IP%Hnn^N zCgUpWF2yBljE9X?m40Vx1#1^@h@V^2RK4U+_Wn|?cayQGt4Kv80Vn)h)z=Wp!*e^R z{q%YI3z0*Hov^`PB!2eK)5+SGmM_((wIdcP!(Hm~pilM7_3NM}KIQrqhL>jfh})Y< z?=GMF%otXN_dqE!jF!opHaNoy>CiV^dE%}HBbZYI73*}(j`Zy(X-N|q_*){6;GIi$ z)-#GyJ&?~n-)lXJLaoAK88WZvYs)8?rdsV~>I@DfLX5XayZaMChLjRaoQ%@ApZa8M z+v!ZZT4GZHmbTP+8_=*a2*4Wx&8nZK?QDpMpDUSgYqo$sO$_4mX-^e(Xg3pomFTgZ zgfdsx2M+?J5s>)F5q6A+EhAw*aK-{+nAH)z__gt?cO<;tV)fPC4z5P!=>q}D3|o9% zM}73Xv~C>X_svY#cRqq=;rvlc*pfSgO7(rnv~&@oVY0rtxqt=v>g@#PL~{KOBPW0z zpuhH(LRvj{3&TpqVzrFEa0d4k*)C~37*;eH&LuWNXEnR`NI^KN$b#@6q2I5Aa7GJ| zG3e*OL4vsP$wdu=+SM^LBHGs<+;-w_9mz0+W#!82xY!wu%S6e`u+h*EgmlbEPwCGa zex5;XO}-48@EY+lj7i7zo0kHF`5YSm{?({qYU7Mz)KpMipY>$Yc4DL6%(L(>9)^@x zQ+|*8@>Lc+fQ`BM0kAEDo^YWw?$5`>&MnFrCm$WzB2#`Gt!1a|57u_qh8-g6f5q{n zI>+6Xh-<_QHSV^$Y%V6v%7W&Bye4Stw)7tJid{3-+HbS;#%}Uzm@TU}bupr+JWOI@ zBDCFUbE?L+r^Vxn{)og@9?DMnN*SPhpS`g&MlyNBGB)LT)rhfZ6yOwP$pJZlrj(R1k_ud zaSsORaFqS_+{9Ah7969l<>C4;b{lrRN7EIO>{^$c*e$lH5gN8*df)}=hU@g)P*eEO z`A=V-Rp81ClH8I#3JLR<$0`R)D{8s>4#-4Ks3(;Zs-y!q7l1{C7cj+7n?I2mo!8@G z43TGWBfXI3R#lJR7dneQ%tTOfF*9O}l>UWvFqvbMb zYr&hpU4L{(Az>0}!oslri6CVL@2xR|%_duhWMrny*NMPI!OgWkRtepMT^la9Fz;5+ z&B@8DZen*sS-nvl4gxEa#AT(tPeMzUcDTgf0n^eQVa9m|Rs`CiA<^M>Kq~@Tm9(6w zUZYNjs)yq+{>l;ZFAbt|xSA!m?wgkyXFE^h#~Xa_Ub4^SNdCg!Y_U*M8h=L%3WHXa zuvrlf;<$47kSY+j(urkVRqvN_MfZZ>BkixI^BKmSLlVV;V^DTKK8v-v1|Pb6!ZI{t z_Qt*5pM0*O2mzR8z~h|sU99(#!bGbm&s%&cggM;aT#DLJ^-%DB`{P_AGy9OYjUYnX zYg2kcQ~^RJ?r=kSb5HYes)#O^osJT!{9dq+oD@*iv1r&Bl%|vS_^DO25)^FwPjZR2 z!}~7}Cq5+XXq@OA#F=Be zCpLUy8NHDYPuKDeW})W4FP~vMhae7=Y%f-Z2;gJR?wzg=yNs!EiQyhqg{DGJqNj4KLb+5Fm}4}& zhjnTchi;8A<3`KyAhiS#8n*bV3If|RgHHOs?1~&QG%>IIV#Ia1rV{GxVB#IdQuC80 zxQ2q2AXFsolD#!Kc9$(>7W|siSNez<_=^lmM{A!I-pu@U!PK4+7<3LRz>x|GrD5y$ zQdH-v^Jis;LcN1^H_@9kS=}Rt^c`#Q{mk{6X4y5=MPZDXv2a1Zv9YghWWLT0k+Ox_ z5U13Mp~_G$=z&_haw2|gBnnlGV;di%wd)U=eq@583QEi=by<153Gco523GmMV-vfH#u#R7r zncL@bNbRt)a`E8y=_h5t)}Yk67+qm0hQK&NO@4+uPQN($Vb;4v?>ugSTvQRGms#kF z@#>EEQh^Iv3`9O`i9O$FZ_q)rJgC>6L9v(P8;9I&En8jGdD&28PnBXh=UP_bmn<68 zc{XUNyr~ebL88*LM-2`Rr??G9qMVmbzHC<0-&VRjdfwc&#kSMxl??!gJjb=H^c~&l zIA-#U{0`SsB6`G|KqXT#TXKD@1}K&itN6BIU+Ng~<6}&)O2o%NNaDBnoVO>*$0^t1 zx3U{oYsQ&Zs!zHbrrA)cMLvvdPG2xjxIJn#YrdCt{f%nJoe5JDVMPYxoDyF;m2#Ro zw9*xV=c}|5&^J*092RfZ1YTLx9o7a4VKMe|RY{b)-v_ACG3+Q=9_qaX|E0QrLZ24~ z?mL6tH0>3LV`r-4%63?lUc9cS{_N}i2;|{sqqgphrr-NYvO43tC+LaJ9KlotS2|B( zKl=zRJ8P;j_z^~z$L)Efik1YWFdZfYE)TCPJN@l1ma~zq6sY;Zz372j7OUimI$xMw zT&t_BxLk+^?aof9xe;-GqWy#4-E%F>&6`Q_|axilbSvPgVr#kKPS-uN0cv>iFeS}sYVxZgS-V)wnLWybM z*sfVE4*9thwS9k#2^m6h#RUH`8Ks&siz|#F+K?(y`@QfZ3{M~u$Y!buSWoSo@4-WX z9kevNB;zi>q;P`*^3CBU10HB0GF zcKA#oZv1=7c>mm^&(gJHo||w#|FKgj;1aTxWL4AYerg4}GVs^IGs-ov;F}^t_v4eu z$1k4FvosX!aA?Ix*Eo{7cp{}}m&;!BBVnUx2$KCB%B4fIT2Ub9d>jYTOn3B8VH4Ld znGs3<(AG)Rq`;H%fVodWZqY!6e)-QV=cZobWnsa_44inSRrxmO$P6gV{UP8?u!fOx z`w7KjSovH70s|-qnl*gcWVH}cvBlYBv9C zkyW)d1bLK2@Oi{^%SE4)%GrtH1xPO64%{{sSJ;6)Gk{Ej>=*#@*F;3j z$#o^Id2Hg(R51ANN@XS>sxl405Z*M!=v)E7+LaRZ12;WDR%Q6#Z`W_&QU}B+2WkyY zng4t63jvgLvH=i8UMzn)Cf`memt+JG8~zqWndN`3wN^}JJ10HxpI`z|9uoi$eCD{# zT)+-=H}HZ18vXM;{`BDG<4-_A&Iw-S{MPuBfH-#Ic)A-)|FU_~y+FNcfu(ub`5fnh zbb~qF_KPq8msosTeIIm&5A=6*@7A$m|5K0ae~AhF_YU~Iy!B^L*(jN|b`t0vHS0Tu zREqsm{o7sP`^(xhlq=NDeg?Tc2&kd~qGmOxn=J%TSI1KScY!3OWU3=i#j$ZU|8@(% zEcv&ez{LXSc69(+T#tT=0sN*5%iUjgf2q-Xo$LUGC+pF?_C9616@A>!N@)zW-%I|< zz-`NNf5Q+0w6kVJ&25iQ`1O>IW;zs#15BbrU-lkh(M{sG&pg5hOKN{G z0P8Ab0Ii*K;jSy#w&NDk(7*kik%hHZ{+&8ELJZ?QfDTHt2A34(W@F-7{J}Nie8B!y zW5PT&z(7Whn&Sd_vU;Me+t;tfF{&*8V80%)({X`?^gxBAT5`E|bEi(r`a}f$-Ig4{ zWXk>#4yNZ0_`(XruHZj3AMo+op=E1JROz4ybZFtiM>WRhu#Sy8**oytfw3A~eT}~z zHV2ees-PaT%bJD5PAd0jYQmB^n&vGV3?nP}NK3b_g&n%*KNG*dBVcZ6Owh+Wz7Lr5 zkRyzuUuljjP`9p(f;H0JK>qwP-L%Rp1|Cd*w0AKo4*g)RB6#q?@oO%y*I zvIy(ahu0rQJjPSWbR?`J^`XW4hls?GKYd5GM&j0n2a0A^<>jRZ=lA&XL<-zViYe@(26&j)J8K8D1J{iE%k^7`tBWw zhBN0g#s9K|weK#vaft?sQ<9qF0*?xL>biPa2PFr#wUE`W5Yd$&E;pX?`H9V6|Ivk= z#ztyxJ{2`;ViobwCH~uC;xVUuu#71Q0yvJ&Pypv%!q}--=rsH0l&*R1l=E!QkK8`N zlyvx3BzhpaRwE`NztTM;P|SpL(dtHEi+90-J#R5mD~2|5^}!RtEcE z^GJfK#dZG+VO~Rp-(pvU@g5^GuoSO~Hmkd6;Q+OW~_LhNGJZ zvDbnNrzdY*YRP8`eeU;3vDWUF#f3%y+VkFzo2(d&r;|Ge0QxQ}>YgqDnlJtF;VYer z*gZ3u4yrpETG zT8oZKH(yZjAmSzOLkT@y#w20=fh=s6Z~&A#2yn({xVFq3EmiNm1uVi*@9Wx(VTrg}RrVkfYDd6rYsebLn@U`#Q}r zJ}={Bw)wGC7t?!cFNAmv{Ck2t>YOb7r-bA&$G?-M@rUbyTK&VStC`=OIN>Wa!Yy85`U;YCbSMM0KJmbU+G*rf1 zh`=GRNR{E<@8rJ6(z<=XzYIuoo}f>%<8LxY3oq9N2mLy!I0Y0V->wX$J_j3_c=~Yh zM{@Q+9CiQp@!DkdCv}BGvwVLPy=FXZlRzH5TrrS=?~;v|*BXp`FFAuCh)&2kg!DW> zR=WS>*B>7G`CJ%ll(Ma8>Ug+2k%Yx}nG-58ojvU2;#t|9EV3TKK;O3ppV@^bAo(PLN z=2K3UW^#?A=O)iY7XEBvk)o%(?ZJtnyOrih&;y)+BCe+yPBND9_<30l=55w2j8){Y zOz!tnMiSrUi;wrUg~;v(&zEO9qu_>kd&Cw0s-y4SpC+;WNYvua4lBNQ?r?KF_wX7o zz~x{h60V~I3#1Dmg()72kQrI&UV=FP;pc5}F6WM4?yh-=21AXMx3svuYwrc2PHw*B z*9OBD^dl0!{ru#w#o2q8SU&4>zn~Wq{JvMNXy4JpXM$F4dJv_`9U(zf6uIq^S(geD zK#p#4GWX1W0GHPwKj4XKP>2WFV(<7_-20~ECz#h`k27g>&or$RS_ktC!3<*!6pdvJ`&{p- z@iz*uC&*ns7(`^0e@{w=l6a>F%u97onB>PCqup*Kqv|2;^5GYWH6E)smgjTcgAWE< zQgw}c){x_~2mImy>!Ic2RK%O!1I;~-F#U`cx;U2@&Xo$a6UG8tE^S<@P!-?4F&p)w zMC0i4<)F@W__eK{LE^iRvt}eJilC?R`1-SmsE6@brE8s+MxB3GC-V(s7`MCXRxx7FudT;!0@T0jcMqZ75g6WmPLyQ# ze^K}5;ZVO}->{MvMg1xfq2(u`P_~e@MG>-$Z7hlGW674i2+39nS;oH3*qIsIm{5uA z%vgp&ip&`M7>wn)#(h8ceLR1>?|VGQ`^Wpd{;J_S-|f22>%7kMvsh-JI__+=B+q;Y z{g!0X*6U3I$}J6aWNcU$sUcPl1-R<|>n3Kdiglv_@o_VGt^R4w-Mr8C3*1Rs4@UbZ zuv_OC&iCv6W`EklSoA9(7f**@2_+Q=C+-@mm&dVEHwN%h z)LId9nl{KA7k)0 z&9?g}js?do$2Y{)L3QrdYSty4#EjtU6R;TM;p=AKvEswu>uYZJwNHUs+Bmi4-h1&k zMZc};xo9XG4PqYcDv^^uau1^6Lw!$cJUnz*-pO?6fqB~{vge6NPt2Y|SHf1GUh`@wIf_|z1=gP+QO3>b=_|*$N(nQec@U5 zsxo1F_}>NrKMNMUx#Ew)@>OExv}NyNex0t&~k_I zt?SadFx@nDZqBhCpYecN8wAXo!N2R`6aWuJ5f>)udT{IF8x^(Fhcraov&!Sl!7M^Lfgds2Go!MaB(y78{LrMvWqU3ua>LbNI z^_Pfc@3T0g4gE_Hb#dx%4#pxN+-Na05dr%r(MfYd=Jc2{6=|6;B#YakEuflVkI@`) zis!WCu1l5&@A`%;lsSEGrvAhIqh_1ivcm9K9l9;?d;QM)&cRoP5ln|!pSc!_;!YC5-VVb#xX%0>p= z9(0a<8mvadKWhr~^P28_UX|HJ7Hwp2GLIuR$GZ`=VdXPgDz{-KbMmhsdtUQ2{~gQb zMaZCiE?qy-xsBf!U-j2$xZ!(nV9l?tAN?BCoq97tfo_4@k3h%#mkG4Zk5xI zgm-$^*)0V9V_V)f*t?ooG6JEYBb5fe)wziN>nl2wTqd$^W^Ntp%3mID%+RS>f6~KO zNm7}C8>#Z!!;Ah0e%&H!thx#%H`42`w!DyU$PmS<&ra4~3k}>(1kh7`TJZo!L|}lt(6M!D@RwI(}G{9+o)3dru$-^G+8I5OE=_>R0LR?#3VZRUe=&C^$6NU?qBjb$+7yrc0?%Cmp?0@{VU*2B& z;Wn~p4^r@4&Mp7gp4)g%UedQX@8Y_S<3&VK-_NB)-(i+|Kdr8ycj2y-XQJOOw!9NZvb5e^ z-g>e6_or4DcF)pfA#m&X>W`45)R4x?G37l#Jw%gN+DH_YRRYR|c8;M8sNyixp9A^f zcEhipc}|kxU3ohJpAF*vI(IAjypjj4oEgfIPpPbt-Z#S-T+A+Hnk?+-mmBwmJBn*u zrf2O`3$3r>EVAdut+O1H)9iIStSoC8Z1+s7mZb*DZ7Y=8CvsOwmXnl8wuo#ylX%wo z)y&e(FU0Q_m@gEATR+yw6&jcS{{H=WhoK{#8^G_nJ||B)0qW9Cz6^A@-E+ULQ8o5U z{G+~=&6iC=OS&;N%0u5h3028me9;-QnI--czdy?s1bA5SU6XaO0IXY6Sn`LvhXV%| za$8o9TRz=>`3GtT+99dkY)X%lTfUB;`S)dD0xT?7HJ3j67dw=ridyzyZL(J2^1?Jz z@I9FoKxIh6oKaaS+Jx$0Q2iH6{eZ@TW60U3Tsktkt?P(V9}EAJsMaJ2LU-kB%^v3v zjNi|-YntiM6?&)I*dq?y|6GfVB(&l+? zZ^b;XqbotH<@@BK<}LiQxDVkZ_m-RNx31S%Se2wXi1n{SC^|Ek696 zgpMvCwv%a%$x$x2UH&zoEH|Zk1=i3Q7fz68HV~A(7uE^C?B^SW_Bf}NI?qUs z=|#UbE5<)}l(nbgD_tEY8R7bdm2ozr>hU&aS`xw>fue8|wZIDK8TOQCkJayR#0z{FGNq$GSBNp~3X zOS*QgX?=6TS4_{aIQ5|JliJ_9EkaS=iHF1*i+?LL!pq*=tZkayy`Sih*h2c3CK-M@ zgmgl#iq0CQo_;uaOE&<$n40$ zteIrOJbQin6TeEBD>i%&NIv_oQ|lZitAby^Q2hq~U7SSUx~vAv12R*@(c!^>+Xc6b zP3rB{K{w7@y5_!Ferw4`QgPb+&H}YQC3W=Fpf0Q6n({gdIjD=@I7T+4(Iu~r)~`1c z5&htEW{8{HYVP6*+2nL{yKWWNSDW2KTd9C9D6Jvb+j>E^3ON$;&|WGqJ`TO(3cx05 zDMKz}c8XR0ZX|`(u?DfOR*4%oL;RhxH>s?dse*fRHlEF~Qk_vljn8$O)?uqway`;2 zIQz1?)_~nPxump{=UgfvRUa(8K9-ZRzp+KVw)~-e?jiobd$AG7BZ7hFu8UTiA#cIO zl3BylS(kOi0euI^)8j`zKb^A^lsGDP1hVYvSj_ocOIBL7o==eL`rO4iOk8?7!9tHSWBGR_9&ZI`f3hVxGs!Q;te zE3f5e$a0&lClsB2Z{7TnLT##6S5x{G;J?b5i5B(97 z_Of0@c&B*vD$MilVP7#m`sidS?3S9rdY$;t*jA`TlL#i`;61o)@z>3zzVXgE4H^Ip z$^O{h{rCY7L4~Y_VPP*bYC}#M%!E79vNxBx`MQ|9;j#v|3~XkF`m)lVnXT5R4;fz3 z0(uSRLrq5c4desHW2WN7mbnA`qUyMu(7vec9_kqW&Aw4(H`9B_> z6pb~g&{KuD^{PoFH9l?gR~~M{Y~|r9d~LrN1Xw3-IKi7G|7xD(LYvDt_N}P}mv@=@ zdWK(O3ousapJ#)f^)c`1wRRP((Ip=n?(?B2e*w2CVqQO;w_e)3!^Ph-S)XouV0^q= zL#VL1M}6&l(GL1Lv7$GQPd&3*S&SOFF5Fm}1LOa71cqX8`p*2;Tm;H@JpL$HeTXYe zft;lAI*8z$iqv+miax66b-|n?D&?modiokjc-)7USDP49J-TgoS99aU`UNo$)d}qc z=)VVkTwgf%XvBfI&Tsd9n>It0ckUDuEV=0{cU56fpE&ClxA=TfE`2DDG9Q~Qnya0k zK(jfVY10Z{DCC9o|8ZT+NeJ|X@?dqW+Y@6%OTr*Urj_N}!=|!NsPb)V$HIQA|@~MN%?%XG3U7_}^ zRvXx}{+xU-^=ZrQ3<^nCsDfD#=}urmSi?RMis`*CBqRb$wsKyDej+MksAlns8Kb*P z)P%wIPw470;?lJuk&=2CM{vQysz=Q@d-qf{*8^o7{-)0#vNe9SHagP~=Em4OVPzlnsc*Nt>lA zwCexuJr-Pr$XE=pR==eBv(BlHy_n2l)@&-g45GwRMYcOo$*^&nIKOb;xZL0e$PC92)pt*_NaRItUsX)q$5A8asNW&tdh@EAAj=G-of>`7WJji&Wbe>u8agA+ECGw zf|emm;e8weXab`?Y{B8iY;8~H3Y1l-$hb_)!X5WaiqWF?=`l`f+QM=P)t6CHNzy8h z+Pc82HE|ynO(?c0*nw7heAt^8pXzxsh7OO+QH?4ml)vZOEujUc7A~3_%Fr|y>7i>6 z9!XYlf{sSNnr!kR;nlGKv92}Wrdv|V7WT-;d+Ry>pbMP`twce~EJEM0$^V3J{;$*5 zr|xZKfxKwH^(?ZC+VTH9VR?Hj-Tk3NuGsuZ?C>3!>Vdo@O)^BZ>|?Mo`M9kV6ZHnD z_Q#TZXUb4O_(Khn0@>N>8^P$a$CT_t8blzy%1njOi#yEN0Dk^kYx@6#KK%c@hJ8vf ztEu%q$02JfBiOSJ2%0}Lj7?;^v(|)k`ki0Tpu9t%wS_xN&b2DMgm06)^LJx`XdjCE z_m$+nDtjtiotTI1mSj~pDfNx2m;yQtoduaXsUUrsg@hAGLri90qu1y>`3GtUd?#Rsc%w(=E z$m#}qrKMD^2`O4Nm*zkBmoed&7f7GH zbAy$oD6ORoA*;(9n4}_XUBdl}+0J@R-DC@qvPzU#&V01tg+HMf@~z|fu;?A5U z#j2rfv{`itsj>5p%AT5H>8v5%RM~PlDFyq>ThW3WVqk=y^M)Fw08r9w&-i_7=py~w zsQgUnx4o6wtE2paV|yk8*+KTul@f9b1s(XDWU{7P?MFNe<*4@sFyOh=vlk#A$qp8b zb>$W>@oO@mMM1rTY0_<+Yu)H+Oqr zf@oK2Ft&SrC28pX36D0}?Tk@l{Gs0P_e;L2#QwtREb=t6=jk*7H52x{eBP(>D>Wfl zcy#Zz#e>1ypeL1KioIbyMen0KegDWk&z5T2Cc;0*2rmHKan*UaOmQZl#egl)WfoS| z(`CPnFB!i!JYd(veb1aFNvUDprj1>j#_{DBLIl$lwab@YVmy)C(FDT?4S`DQ2WQW~ z9|FR|vvO_|4zmp(^(OI5J0r+A z(fMYUlISbVi?0Ys*f9aCAXl?RYgS!NvqUHi?tS25;I_2x)a~mV%B~C)&Og2py}7IZ z_6)##4J-P)6fc=d;6fN@b4Ow5Fe!zjOJPG})W|xoAFtLGczn zz}01ZaljuXj?qyvdSV#mZ!lNgS6|b;6n>=3xOkS?!m`Qyh{b8XqMa9WKv2=nKTD^P zL822)%~mgbr&uoBX%dwo_|r_;F!w3=D*rBJQhH^EH5I(t$5$AiR%7Bv}Rr-!D;WKaP#I=p)XeDNoy&UBKZ!^nFON&Vwm!YY1tr zMZdzV3~@`W({X}p(EO3_=Qyp3RCG&GXn{kCY_D0eL!)hLpdwYXmZJ3rVyB*-G)ltxKFT3)0Nm@m5#EHcmf0gf6U66zT-({Y;25u0>+|Z|Za+oA{QCOQUzUmt!-pBkJ!8w#z?lhsGtf0dbbYgkgvJZ9pvS7XeDTZ21 z$(1;@`Gg=g1E-E$g5_%Ra((5sd1b|6%sB6wOs@TCYuv$e&fDBDRn%F`X#U)`?3PeU zd#Ked_7wkEKJ)g1@1k8gM)|ipu_2$2CyL45Ku>Qmxi3ZlQSdk5wDP&pt1yfMU)UuQ zhmIl5j{9ZUM+nL|(J{(DRyhoo_-67ph=8voO#b6}ikmuU7T7S#b`4kc<}%cii*78+ zMP%4(mF0#S2G!h~^EMDELgp4tW;u0;!i=sNIl62-d!@PN{XrG`vY}VLhOxU_-S-hd z@bdvc9sLHhDy^K_)D{HVPM=t6R`cq6YVl1|aoj}Qh1tqxS*q)Qs`>iB#njR72(x1K zcM6yEUn0bwx(?Tm#z7IRmlGA76!8lXhy_F+34`#&9Z!5$kLVSR|s0Zzgcw;@J zSlMC{gq+d5i#!ESh6rA;s8eKE!X>2e0v;vpNL>YGDcSM`t)&f2;FV@LD=?KfELr+0{<{5+7WsWo_nLd-wVcAiwQeJm%7-20yO3Q*2=f zAExn^(Amp$2S!z}xTF*YK44dB?n-U!%1G)->abqg_@wgOhOqOQdH+}=w{A1>hwc2qXej8g4__>-oLnRX4wIu9P|^{QI@r^<>c zoWWt+D5pa5P6%LefF0pj{H6P?Curad#I#4b9><>>tZcSipps}R3!?YAvJ{w+tC@?! zjBh0*oV#TAO9p^o!v1(T>uL1Qv2d( zIkBpv(_|&oLdOtdOw6X^UzqJJWgH2}UbkQb)?sw9iYIDdDPa^ofO|Ih4NK50ny^?& zH{~a9IKG@~_PIbaCDY8T+|&3I3+UvENO)Nf{mYH==tHcECovSCFyez_0tQsi5Q%L^ zzK;5snsU!J+>18Zfa_=HRcMye*NGc-O0xd3q@Ey71*0IX!zrQY`*z(O7;2T7?-7Zo zN9b-{Gk*#^)lMqnw5)Zj9f|(c)<2#2-ZK(Ulv$Rep%lW)dp#aLBM7*nQ(*jBdXm)w zEP=px&U-$LI8kmZ!H{>#q>TtO1oEXY@eaK26hX51{&bn`ZE{81$EYtTo4rd;xO-6Lo z$TkBRAy}wE*#A3I9qF`pr4T^KW|`aP-|8>C%7#-n%0hyz+khn9@%^1yox#<$n*cR+ zXOn!(lSZxXIxc0j*#na4dO)aWx%lV7Ngy)BSdbj3)@n*wp7YWYXmqNP2g02NFxNQ0f#Y9PXrW8R1`{_i%IGhLe^C<^ck)3U`uu)s z`hK}UsqstZ!B+&ieTi{E_X2Pjs2ucJzS}QyWeA&7?zSa9uldqhES;XL*(nKpZ~NL~aGF z_djtsysi_Y==)7e@AG?K-twqtT2cscGAX2-gp{EF0gqMmxJ12e{oVDAg1PaV%b18oT_8?f z2)OR1z%b|{9Ozq@IkoAmw*S(eseYc;)sp{*R6BNktpEa7NZD?bMYEQb@vlCf<)l6? zYFh}z|MFMq?x}zb7tLq@tVVzd{)}-|5e_=v3a8wVqt8gt zl~7ZFg9%C(IlItSl;`_XRB9)I0iM)6*czP&+nVv&!kO)n%xIVUM|HrCuL@`zGKB{V zyHh!coyk_Po6*LuvQ}jX$)7+nSB&6tQ;Yx@wjCIT7QXcB$yTN!DRH`G9rjqQ>3iM! za@nT;*!Qzr@t)YFbmI5VXi|vu`1drzuyVKd9pjy0K4ht<4Y=opBDKPlvT&@+#?FhK z!u?G}=A$QXF95!pL2ucx_-dyx-PSOw;&`)Y-|lQF?DyRQ=uP#$|Fnvu}ihj zKe-jxSikIwD7B-0r`nAdXu36am5kSn#|`LHAk>W4#sZLl|XiZ{AI;{aB{#xdvkQAQKc>Wh$o6lYOqwUF^D|}=IIpi zv{abK=>CC?qAIt=bpiIeoVFf`4QTAedG(5e*tpnHs2g(ZS*9m3EBeTeB1fsRz?&&B zSqmTh*HUds(LJC_QFM689l~Ce8E@1$r<*O#9<`&nUNx|4gargCQ&?fbNMlg$c}xEa zILXM=I*G}F#b?wSM=H&G4KC}h-5&pT7^{n|pq`j(!ih|vNsXw#Dff!jGtwCNhJr#& z8EC;iexxSH81`6XV&jxEQ1j1MzvmQo36s#X|Jj7jkpoR=`u#9M-2elW>`ozG75nj` zjbSK?2Zh6JA`~xZD1o#dPb%*Ost=FT(q%~jIQ^i;tDIYLA}gUwK;xiOYIRp1yWN5k z3~2MM;As2iCxrGQ>Maq_+ks9I_&to0!TV!SQ}1Y+UyfB@mqI#-4_&N z3LK~7K|$nI%v4L>Aoz`O$A~y)jt)1%+awl5XvsoGMam9IKh7_83z5e@06J|v4KTjR z@P!#k^y0rz_L6?6Z==^GHKBU{#l67wBKYhzuZB|-l|jaeoCb5A7E~lez1$M%ovr#) ztOj|{l0UWBOTdAHxel;=Rj-n*DmFCswqolp>MdKxbls==aQQ&Yw#l<0^jY44hE`QY zqo>APvrAc)sBe&1&f`hjg!|c6LL*a>Cu$Dwmt}6HYD!aD6T8chCePcn^Ve(GKa1i| zM6`M3^c#7b3r&$pW%tavkDT>t4MW_F7R%QO=Op<<5_ui8_piY={i*k_kN@Jm!$Ki$ zLvE$m9rvK3=|X1BoH&9ol;RN>X#cFNG5Q`Ui)(kil>gDxn?1?piT6xF|FQzux}ghU z7BNqt&uH!(|MP-Aie+D?KA}^KBD^Mg;vhn|P59+aB}-xtqFU0OZ5ZEkLI_7==bnL* zLjdwJbyd+ynbHF@Z3ndT0H>^cu%MVQCH|NsgH9%9R$S5hJwH=H&hESO-cqt^feH3m zCDRORewKUty4A}S-2$e*n+VH~Kry)+Xq&{Z1S27{Y*fw5S8z*+*3vsO4JP#$rBxUz z;8Sv{OcqU}entuvBhbfpne1z-C2l1gH!h0R@fL=>vC3msVfjA0V%MctN7rzvzz-`e zJC_sH`io64l_^YkxafAIu*JnE?;Cwb^LNt;xfhWzznzgI+wwA9Da#)>@!!sL0;N4W z(e0ipK)qp@Dl=W$6w7qW(M~MAXWee>8hbk~9S9O0fzsbNk~Wl@9JOE;4ji5o9*}=3 zo(tF;Q!5<#IB0uzA+*_4O`kps<77vS7Hl!$06v-CR=*~9-i4hdjW-GYz9Goj;$On} z7kRJllGll&xQF&*!eC`&9Fj{vSaC3c>5rAVXc>$g*{e81b^klmBMUSFQ+#(~L-2%$U2C>u0w`dt`T#7j7(j9E*#G zM-unTWd*)%N-&;J9sI=kAm1SuOIJZfcISHCa@U5uAu1q_^1Y2-(R6t`5EK$dpTbj6 zF(d=Js6emd=ODPIB@xwLy_ics5uDM!FiWXr{NpxOkN&RYN3%4;5R3`W^Ty?<$(l9Z z<>N&((d+RG-k;$lfXv{2s4?PQYyEU6Xg=soeoocc`svc3zMw#qm07Sue5H>mEP>QE zZfq)T^;2SEW7D-wKf;UH1C(J1q`Q9W@8oKQ-Pt`dCkFxh@a@e+_j;DV_56X(+NeUJ zqee6R$C==#Oohh+PYh|-JWnbpg4^5^!_=5m5mw90)qL}ZX+(FE#g-Ma3Te|`#j+IX z>w54TLV!larF}?Acjh2*`7=pl_%!M1u`*3fRE*|*uD3KGp+XoTZU+9soG)7=aII~J z;zqnFO*k`43c+Ta^<}kHeX8i<4ZmigRN zbaoAD4*8T$G=_XIdcrwizedI_NJ&|Sj4BgmZ`lrBnt5NI2?+}9liXK><@H|$ch#6WopR0GaHCmz--YMMe!C-@Z6 z{Ep$&RRYGZoUSoP(pt}8QiBcxxM=nts*mI_2QF~D=-wRSjTY}$NMgk zF{N=(EA}37>BC`Z*Div4dR*+~3nUu1dggN^X|2Wdp`+vMwq-Tm3!Pugoo$~#|T zeM;!Z4?or8a*&@37+NrjwIJr$WEGHdbb?}z%II+@Ai7|;%Cwzo z$6@cGSBE3eWj?6((fWGXV)pW~BNXTE@t(ugPZvA&7EJ|3K|uytYT@OHKBhxJ4`B<7 zW}OPlhS)?{_N>V%>n3<~BSS#C26mhLAPwX`{Xl0I%rX7G?=^}$E7c=_#xA<8KN6Y~ z=QDOTOhe#Vn`pjaqJwu%{*ywOt`WZ?Ux^iVERGNQg8#%9C#$jdIyK-WRGkMR?UlzS zZui9DEo;!Cee8zyk!7`_m7+@8J`ICVG38arxURvkF?8N!uj+%(*r_QNbK^*>g^i|tA6HOF=>in57(gf!b8Mu9NDNag zE7pinpv-dG#7C-h-FG;ikUaSyx?~7P89O`pA5(3 z`Eze*CXAeqzrQ1HY{(^P)_PwEBO-$a1EMH93XiAcHSe+l6gd(Kt&8SYQ?M(uxCwZz z1-XFN$_@+V)5%`gg^D9NE$YKR_`9E@vCks}D6-0B+|y5cPoomKFpUiy`W0sizN3%% zm@7WjYyY#4qz`K$=2OK*7i>EKgnogC@@a9ep9!gfCvDXnJLlW;%r~&6yqR%P?i0}p zmg`?bkfjVC!pI=TQ7u&s6XoL4UKIH+OykD783OlwBcKYKji0GunCD#ZnU|MwP0V|b zJat`f;y_me5c-?HGbqSpi$7GUGXV02@0Us2FH2Rzw>q3N-PQr-*I+b2ci^hr&~ek~ z&Z=*j!(WWi^aC@)1LAVR@PJZ_?z!ILhVwy7l?4byzoCnn+d02%?VrhBA~MrggWA%) z=cETv8eQ%{j!S=)35D)0pdW|cUQ}9bk1S95Z=v3*{AUKR=Y!SPXXEEb0ah6PL|4$J0Pu{i}fS_K_ zgxv<6a{KfW3?$=9!xr7&x zS*l_eYW{JSWA4xSlxSu>tN$ePiGrX1Zu8@_KR}8&W0QU>Fr1DK6?5aWr(ctML5)ff zF#fZ+{&4#&6@mkLPlFh01RSbW6__^Fuh#A|MVvtF_&>^8`+uvc-h`)MQsh)DZIpMP z<(h8e?Lx%C4OJc6?**5SS+_2V36|Y*>o&D1WXiw`ULSrs`%eUTd{tWisX16#oYUWF zRzVsA3jC(P=cb#f&v18{KvjGk-%ai6KtZY7Rm%76REe7~(Yv$5wl|0jQ@ z(l{xV%lY&rOpR86XMZ-fE3Z=Q|NshV*fn$&(T(JP)a;_CC}`e`(N<6 zXe1_n24W8muDHPcf*<%u7Gbpua)gd9`~fy94vIT$|GJV?s9W}$KGmyQyh%pHGIajfX|{a zFY}JAA%Ja54!VISPY^YnviYL)$_|ge*bH5P{kCtxWC!cgzeon^^KwUiH}UK;kWuC4J5@LccP z*qCovBX_RgOThKz4^5bcO@<-OA6hgz^A zd?nLCk{9^LWSJI8%kTt4z8MQ4pJSnzdQaKj@I&M3>SO#ZV-_Xou-A)BSdd zVymC_w1qhuTR|$4_vzs8!G_sX-~?uCxZ!>F7wTlawn2eQ!0i(gFZCa*#PvTS!g-E%vX!< zC~f%!gfN0@l)6rBbq=Q0Ac_X;z-a#p5&+dwue97D6*B|=V93WJ=6G{46?FCeBx0y- z21D_OThoTMnSmiJ|W8S3Ec}s`1)_~b<=>(%k!kjz}u+BZTx66>UiW7 zp9HV<1CAU`;Q9CC0tgMyRIL|)(})_&gM_)@mL8Hc$?{1bwR%GbEMt$sLi3(&5ZK*p zi_YsP`4*{g$7bv60D{7AFj>KS4LaJade~qCOI?hz%($OT1wJKL8nYpDms?^_cFP@! z|LYPb044VTbBBi~f!KEIwWrtCB3c*0fE+;}4n3`r9gonGIdc!6PaFRijAyWJ9zA^P z?ZSWpO~TXp4b!hoD&hP3#qGg|1}p^oz_ar+ofHykI)Rg_y>Wk( zIU|^ssE2RGXohhG9>e|kw^_0z;P}96KG3eY0nXo5)Kz}4bwqD|g=!OkT>DTqO z?ceI2eT#M@`nT8)B!1!`Tg*$mPAce8Wi?9*KEl20TCho16x)_(MhGB<+^gvoKfxT> z7Awfp6~H-hCg;pyeMk*g@xLZiBgpA75dBrSyR>l!*q$o%f*=Q6!%x(FEcWe@Detm` zwJrjML>IT7HxmTl7;3S06%5Z*zDqm~LWskqr%H$y{z3O2xpZ6UUicKU3Cqp>4h+nM zzN z9)Q0^a!jxJ8XIA(2Vrr8Q;DP7`k{&PUTFfLo{~{<1cH}(b+AJGb}?cN*vIs}29d4^ z0KgnB!{1D7Om36=EyMTirEFRxdK(RizwAR&41$&);p8p8Tx7mpTQK0xh-~Wsn7{1W zi+C%b`t|IS*f(xw{CWr?X5On%$^0f&3K-~2`K97nPSgN+1?oo+lbH)8I!tW5|ETJg zr?m1@5O92yW-(*>$H9r|v9Rf#OWTN+vvg0s2LD^I<4MV&iNF-cz zd)9bP*eZ{Almu!e!rDM+`vqE8hyFc|6XqwR$>MHi2i!r%(t-5fj8u>QR$x#40+uW= zPGejJ7k^3+3_;Zc4-c43Y0xkH5x^~b&Ui15>Le~S*r{mlRDTwKskGa1>pz#SQ_KmY zD*TmQ6+D^ChuV%?gUoLO6u3XA8x-ou$ji%f@b0c8<2K66EM~jmSFAE!dRJ&SVWTxa zh5&j+hIZ3^@(c5ygOPEjgU-%;Ya>PMdC0@(3v|l84hH_Mt^8144l}{H!f{EzGhY6R z`Y}Y1HTM~LzKe90vu=6AJT!23b9NV*5HOdYeAx<0bYG~O-6}5||J|N&#mb5(={fUw z58v*XQUs24fRgVr&e66;5A(Oy|E+%O4}*m%uWt4FUL(^`99w0C`xBeYzl%99^*zL! zn!38^uq2a9C))<^Vy$Pek>aviDqDp?yV>?CH>oR%h7!3rYg%7=h(;R4EyQLoj-cVW zLv`!M!dEm~AV}SvsL4{DZF^0mM(>Hb zncNU7dr^$*N+uZWBN3-!N(z@0H{NP>UYb>TmgGut(|WA}tse&H)MEfeBMniyc2h(NoNCtHk#D%d!s;N7uy}!Y3>dLHS6^Oxmh1N@*oZ!K z1IjkJ!%GcMv2|`+&!DHTC>+bLmwH9WLpK{n0*gP8bbj~;q>JZc{R4d91D|{x-G~gt z*$90rmN`wVW-?9u^1rMs@2?(he(0<>WvrxxY!z#fc;!`yvnJ>1+AyQnMJ4bZ>tjJs zu3+Ifsw-%bYi_W;{|7AMU2=c5dBF|_bgoAYj+oqXdJ@IY+BMgq(k=UQ$n~b2rF7`~ z&`(k9t%$A{guLKD(2;krUgM0dnN+BIm$RhB@<3~(CcmF>=UREu$}Ed^VHT2E{;kBD-z{J%Xl6u-I0+^BCi(~C&IU_tk3Z%M_QqKyzADllgC)n`;)YB1TCubI(hd&9i7l6NkJn{{3q? z#|)e@4&X`rs~P&?R()3_;*K-REqqbcX}(+U*kA0(t5Ou0p&baY8`IOWK28nUmsf(Y zmo)|UjR_*#B(O!o&gM5v*J5d8d z!&~roj;h?(@Xw>di(o@Uf_&ks1V!uk3&U<%Krzl-~BUe(gvMNrxqsk3VZYsFZ$% z?;6wQQ(vmcw=%_l;R$o=xcn_B)LLX4@op8Em7&89O6Ie@&DDmAcD<8XzOpF&E{pd_v7P3zv}kpMuS?sPeFlkjPAd2M+|(esbB8!b zI(8IQkC*|!y&rJ$a^Z*jUg|)DxIZAeE9tgp--EYdb~l+6F~_RP@hBxhb`$hZgvrkydPJ{D^Yrs&f0B8zEz zEji*Gb06Er;#0^{a`8f5!2$D!>SJ$OD|V?*X3g|!O2ljX?F3lfKPsZ9-ai`%IXHi? z;%}vnzM(?Ns}mPfHL;;WO`>c@M?6%=YQfA4rls|M%euCOnQ5Af?i$2(0@qcR_veo` zo8A+?LY|Ux*?wC#GdscmJ<&gl(>N%=xVT&sQ*=1?wgFGm&Pv}SMp7@E5t}tW=d-_@ z8V}2Rd2nQlk=1I%wfk&!>7A^!tcfUbUOn~`gE~|6W#gpH_|rjpU`AX#b1Jm@Rap!2 z?CtY3Ad~8l6%1ew@%m4*ILpr7e%R&qaop(yND9^4OiWdUuZR(@Do7s7{PfA#aAX;{ zI>L@|SGDdtcQ|mSVT@8+*2lyeIDgB=LW$>lQ~BH?RZZ`y)>N**~Ec-alG$ zPD^x!MfHS+7$;_o+z@2ZqheZ2+U10%&bfnh3ap!)I_;{6xM+NA=&IK>w{ zvNK0Ris#iJHY{&3OlP+w0y(EhJg)wDUTPs2&Tp@b0{=^Ig`37)&-$V^7lZHAZN0I; z-Wei9+-_{tS#XAu!e195I{e`oT1L_8<;RY^d-~ZOpEcX<|4q3Gc=;fb2Z9eJx|!Qs zfX4HQ(oG}gY>cSoskF*AuH zw_XbPmi`lZ*PLMuH;K?au{I3fKCG+Yt&TV|G#OQ?qwLhsVL8-Z<8Z!f=oXm40q|+U zpV~HFAm5UZz9jqi%{5bo13b|AJXWWb*^3u+zF?RcbUaOJI{z@6pl#-gc7x8Sm~3ge zvT4YFiU2Yzdl`YKazTJr{BS}^V|dfa|B`E(^qPR=VD}lmCuJ%_Of$VvnD5WG>qAYtA1AcF&BXfb zovon;_z~4n&WGxlx4so;xgC9^3rJ5su~?m}Zw^HZ|)}`$XRJA`DKhN?$_Gt5^ z>8ly8y5`=ew}BKSAYj~C#iUWUnc|X)+5C31B1>~`428=UzM7N|wMXj2wC1bmXC^p$ z2&3B(GUXPi!MHGCk!J8gyv!#r>RgYuOPFSHC^3slIM&x1vYVNg&$_~F&N$04)b869 zHZjiS*{l^>Wm+vxhRx?KW(3O>2JR)}h38eUwEF=$WfNaYfN|o;Y@Ff_yP}o0#BXNzDqiNVTeK%oWrQ5@ zrRG9sgfVyAzUwNqiYLsg_cE~uZT?o+2gHim-W}?luhF~8c629cWiN1ThzyO zyOlo{+dxZKlIe*ECsehI*V~@4_>5ovHNnNATB5#`vne?Iv0&-cZkv^GSx$mB+#yA z6Wag;?>rl*{wi3vPpd%bj(ptXtsr-BIqvs02L1N@rtm_=^mLPF{*I1N?&AA~zborn z@uQWyLACRPYU}ESe;W{Ot`CjwwuO7#FvzSZ4)F~`g{+O*atE(1)+GIdH++GL8NZSdbe!ns(^D3X1`kIW5PAyl2 za`6+@0R@yxt<4R%`2}XT+)SuYM^v|fPmD*Pg_spI^r*H`d1r6__TvMPOIHG4;D>{p zA1n+z3FS9^>~ z;uq4h-C*V&_g9B7MuXB$D}dJ10d`;;noQLnyel(4YglY%5H&>82R1-|Oy4n5xKALljh zNKYzjP!CYXzB(y)=O31R7ZpLI&j`~^Y90V{Hf5;wbggw$c*H6Va`e>^-g`sixFZ!E zHF`z#hZ8vWECqe1X1JbcegPnmj#{{oaGp7?u_t?QPR0b4V!mE}lXJ6;j;_>ew0D;9 zZ_Th{B+YpYyGj>RHIdJnxYVd4vh3pzb}lXq5g_nA*}JNB-+*@d0kl-t_25!~I0;;O zvrv}6UOfw>-tbpdua06244A$@2nF#Q@0rom%`|B8w}LF{gN}#0y{?~-=PHJ(#D|v^ z&R9fnKmR9O=>FOgYVI#gi$PIZVi-nd3va!_k)K5U_wCLJ0L?yT1LRZU)Q zT%(QBz%f>q(wKMmp%`e;$$&EfeBI_O$y-4BG#IlnH}~7GP_e7P6aj1wmlz`{$DKcg zZl)Y5U3DYUkBB8*4*fPcY2o-lBcT4X05hQtirV3tXefO)F|*F=Me)dPVZ5=xnr=bT zb&fi;A*$U86QG2(4BfK}gpsyV^7I<7N3XsklfFPugj@@t{!jm}#Zacb7c15ctJ;7< zVulK9yS#nBje6c-Br~zJoa><~8Dz6r^VF-HRvh?LXa3ZsI@*{SEZm?-IxzX$@9FX9 zN-=h!kejpu>trfMVi<1yqTMAk1p(;s6Ts3QjNGVe`=QSN6Sn>K6a0Z%*abRla_)ME z1(JX#&Q>%)`_YiBv{t1#kscXl*Ic7Cd%o<@hkIVZR6+~>Mg<8cGt5zJZ1ULd)!Q*4 zOAg5xi#H=nUw+}Qx4!j$i7FP&pdOLPr$yxCeAD&^?Acuo(z{6MiE4T|a~g-gtV}K* z;5NxZq7xku_O6Tmg^_FfDif>8mBw-@xA5T=aygM--d+$ndjZbDYIO$|9TOj>4?Nl@ zbZl{yJ%QjkS8{9pBiGWuga)amYKTY(yaX_x!%ORSn zc2udoFb)aQE|>K&&ysCdPJP_*X`!qP+Hb>hTw0b9zLmQVCMRN^)`YtIC==!@L!Gny zY61&=W{Wlp2E`Y!q^&(B^tUyVd#J_a1RSO+5Z+uIQU)$z^r#mk9Q+O z^-ZFpPVzIUQ8uZnRh}7*AE|p-j5y z?gvpyfzNEgC^VRaEbDP1Gxs*>wU^Tc^2#Aq7|jSP_P&P-BS9Mx-o1jBsf1DBw`Q*I z@WgRScup34M{GTddQ?!t%Dsi++)o^ySVTjS@*63xlWZ4^AnBjf={KBH^mA@Y`tu?R zo!s18$4y$G>jiO_%asfT6BP~CcY1GQ)Lxa_{zCfRi+}mHScMoJ^1yFZ-3kJUc2=|3 zV6pyoaau7;z!=@deYJLFlE>^1TH)3w()hZX#+=aXaLrTb9Gi<5Q|)tB^kdZpqSw@t z#8nHuH|V7a(M!36PIO%6g}_fs57F9F^5^2W&?<>uqX`VAp{%+mVMG-#L$Hir!$BQ_ z!I}9NPd{$Tkh%&sj!f1Z6gp-4IVBcYxAh8QTlSnSjFy^aBd!hzSiJE3vuP2;MHky{ zIz-3UqZ}`4rX563GsJf$$E7=mwU4EFT>RkpdyhgUX!^Wu8i-B( zI6MoB87@%vIHal`tl1KIgW7(*Ig43NZSP9=pG4@QtultSd%15-9b0coGLsglM&RTg z|4?V9qCK2|EAMg@NnQsYs&i&`jK`CDi#NIQRrdxozFT|?4fr&lHOP!Yy>T?T6*9p0 zT2V89EqOM^$Eo(5-5VCN239YW`yVBgP-|o35K&8luzhkUf3S8GXrmFqoKI=P9e}XP6Ds1+~5LKVJ zA_)*+EAH()F{%mQn&D#z_#)r$I%smn+BLDi^;m7XRp!B9V}nf$^NPDm>qcuUB6Zvr z213$5R_O7z|0LB4X{H5|T<&li4h9d`vX%1kcp)JRJ~yyCbJsGa|I7$iXjRHHyA7ZS zF=G08onNz+vw*=#o!>+yY5P+GI=%*9LMsF8B19v)LE~a-U!;iO?_4X}osRpNDTbYy zP2T=`+;B3x&z@n#P!g;z%{ELiIMbHsl%qY1miv#MWE8gyR5h4_0Yvf}Ye;>ZP6sHX z8Zi88z<_WR0#0$v-i=tm{C}B$4@j^7BozJ!AcD<8fRE9eSw=J4O)hvdDe*d;@1p>NFXr^Wi31xU3*a_C?*FoL@<`qv>!7{bq|n@Jsp^gA@b*Bt JGhHub{s~{{LF)hj literal 0 HcmV?d00001 diff --git a/prototype_source/quantization_2_0_tutotial.rst b/prototype_source/quantization_2_0_tutotial.rst index 21c238a11b3..1db08d61074 100644 --- a/prototype_source/quantization_2_0_tutotial.rst +++ b/prototype_source/quantization_2_0_tutotial.rst @@ -13,8 +13,12 @@ a simplified UX. Suppose we are a backend developer and we wish to integrate our backend with PyTorch's quantization 2.0 flow. We only need to define the backend -specific quantizer. An existing quantizer object defined for -QNNPack/XNNPack is here +specific quantizer. The high level arch of quantization 2.0 with quantizer could be: + +.. image:: /_static/img/quantization/pytorch_quantization_2_0_diagram.png + :width: 300 px + +An existing quantizer object defined for QNNPack/XNNPack is here `QNNPackQuantizer `__. Taking QNNPackQuantizer as example, the overall Quantization 2.0 flow could be: @@ -47,7 +51,7 @@ Taking QNNPackQuantizer as example, the overall Quantization 2.0 flow could be: quantizer = qq.QNNPackQuantizer() operator_config = qq.get_symmetric_quantization_config(is_per_channel=True) quantizer.set_global(operator_config) - prepared_graph_module = prepare_pt2e(exported_graph_module, quantizer) + prepared_graph_module = prepare_pt2e_quantizer(exported_graph_module, quantizer) # Step 3: Quantize the model convered_graph_module = convert_pt2e(prepared_graph_module) From 042fbbcf0ba91f96c5349c3f345ca52128011276 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Wed, 31 May 2023 09:57:55 +0800 Subject: [PATCH 07/42] Modify the arch diagram --- .../pytorch_quantization_2_0_diagram.png | Bin 40506 -> 39873 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/_static/img/quantization/pytorch_quantization_2_0_diagram.png b/_static/img/quantization/pytorch_quantization_2_0_diagram.png index 2a58a0bffa927f42247a46a5797bde499155a672..e00ebf90276de8ca39a020606a8f21ba5ddd5d35 100644 GIT binary patch literal 39873 zcmdqJc|4Ti`#(A=2`PzG3~f>hMaWJRjdd`_PT7~Sgps`xDp`_U_I1X-%vgtLkt|~$ zW-Li&7{(H2?9QXl_nh-Pf1KCr{BzFt`#OK{yyltbxn1{tU)ObC*ZX~4-spz#1y-pfA^d^qN)s-p@5RmHOJK4Atv zvwCWpdV@gcgbsh09vfU+0fEwuAFHbxJ-4FmvAp1!%{e$wuB>zmD-Z{JeYm0tt$M-1 zabD_CztcVJd7Yv3R!i@*JyQ8j&up)5<6&plq>{c)=6*lsx(L?i*5_oJ$KBQ%3!e(g z%y3Kvlikv5nu?mO<+?w~dXu$fo8+_B5}YxEwtyfY(3>2L@}bzE?EVwLBtW1SqyJN` zX0-LqK@e!DD?x*@WvDV^frOxxnL$6e6D*D#O8WHb|IIZ|Ez7L$33J@`9CyC`%FY?D z)f#qCYGl0D;ae(humAsZRT+7{HLsl7r3ZErXxic3=Ldn}ZUggE88O8L#X6P6mzKPGh{iAhUWhRM&-A?VqQdXvRkJyQlU@z{LUod?HA8=V+){ zT+2AG+_pl6Rm%@-r_qLCPt*SItF-0!H@06r?D`SUT+@N7@I;eo?82mh1Cq;<)XB=LKUy=jRZ z`TJ+~-0N4VGd-U!c#Mi$t;+?!+$fsLY$Yv>CqJ6C=UcqrjbY&kX~_3=@f^41_+}Zf z>CFUkG6m$A5tH_L@>`65J88*Q+#qv{wA{my%x~Wc)|f#VJrHu2`;OZ_y+K*}O_ZFk zZ}QHz-?(ULBQn>)u~lAsH3L7JeI|0WN({Xf;2d2#Q#>+8I6bq_>s`<6Sr4)I%H=<=Ht9S4@Bk`lMK*8l?5-fp*lC%4I0 z1hc+xu}Tu|8J1D9JsR^0FL`=*(wZ&$E@eh$@%hl`%;NnIXYmc#ZT`6js~XBpq+WPm zC}UP*(Z|C4XV)%jJr7-#hqEol|Do*PWgf_AF>3sk7O0&Z-?Gnd8?!m$k1WLpc_Kg0LKt&S%>K4_P9Z(`@1&9Zq_^gm#x+5? z?Kf1e_jT4mKVoVoK0jd@TWvO7Cf$u&dIa{-j(%~YJlYFESN!ul6kXW1zmRl01D)aaTq! zM;!NkmD6!i9<798u$LR+UCr~*#>W{t3nyE2mucukDrWlB{am>L&T#Mf$&#TMzNd3@ zUuY(X%C9u<+!)!fU=~CxNwAKi&e|_6xb$*R%ctVtk}_J+!uzkdiohQ4)06!E@53l% zm%*u;u`PjPtNfv%prN4#a}Rl4!j8)kP#@>(u6tZOCtq=M3GA&Mq`b;~x+5C6^7Ijt zs;UFg91l}oCO)G|L1`y=9u384b~*WT=&C*$mfJpgo{S%xEs+%*V-TC2ah8v4p`Orm zC~k^s&t5`6WfvlK{U+izvFC{A$_8<|%N4;}Gk+uzHwFT}fIW+^MtgofiS;^Z99U9O zHk-bS5>a~u1D{1`897%^>XCA3u%=?Az*7~IA*|Si#o^F58NBNoyLwj{qEO0}0A<33 zxAOZJ*VA~jn)61ao1*7`g?0jvsyvK5g1vPjAvjg2xAsFj( zw{Ua*Nf6v`-1f)JHgCdDk7eh9wx~6sRb6jh#kapH#ub@o6=Ir+0>PAB>(r}yn;$H% zYm##M{>WkL_cKq^cFWAA$7}-W1%`{wD{>f)w~)?S(Y9WR^34guVc@x&?($(^XNA>+ z5c`haZ7G;g3!6ps>U?CyXcb-0$Z||a1A|JVk0{fguh)Bwt$Gz%JJ&&{bzafrI&|+( z^s*Z`m3dIEu=%l(Ie0f2S2^41vO!)y*KP-KSIka9U?#3O|(YKW{`@XH4#pwB}v#pLrj?HV$;K-ZnD7E)1(o_uFZ80a=v}JOcCFX0)ms%PyKO^cK zR&ALv(|LMmw0E^5%cDaC%5K)U;vI^fHp+lAX~)R=!lg#TVibmBrryr|acty0B?JvD zZM?xbQs;dp<8+VVTibfuGr0XvUG6o5Y^%_MWDfs|OT0s03K~_U-!=-xwybzrr)*GV zMq|Rx_3n*k5g%5VV%%OHoB+L8e^9RtrQUf&Tf0|0kzVFNwfwd4VA66WQs@3w$G(w{ z$X>iUQwMXC&5P0MhE>(+;->(i z`~8)5Nu_a+S8w~s9TUaftw;|{D|~}FrO7w8K2114Y?N%-syw7{X}7>6K^^fE@V$=# z-&+vH?z^3xDQ{o-J zm;ZFRge_``zgVsUq6dpMftiQ-DEWk3GYQxlVTbJ)5k&Rk%N$!(_qkt?bH$%<-x)49 zZ5>=M$S#@pBn!z9HgRRn@Y!*UKfyJ`d-XY~e0)#o?QisMiB3ac7Ex~E2q;_jjF{oj zs^jk0@$DZL^2Ncw$CCG+)KT+wIuoDGr*H68D7DBPb1_@)Dmyh;u5`IRgXVgHAhW&zT8!I4x;n z`ChpCo&f8JkLS8jVhFy(A*|(x0gvgP6)fqNoDDVjYw(2=AQi2v z77icO!R5zFyk=*4iB%9@-|5^JDE4c-M~21pGLAgN&`P7wZyJ2(t}M!gZe`0g&tJKf zG5;G1oAOab*%pbovP}~SHart#ee;c>Y8@Cwj+%?98kf6y_o1A=u-$gX*?lmXtl?{L z9W9khaV_~>K6R9Oqi4@*F!>cmw1P3f3pX>izT%R;zViSvBPEXhD4rPLBFhj62=+dR zVFFz`({6uG|MLga!F}6~C*{0m!C^~~s}t4R(pJ2hOAj9L>Oy%HHm2R9b9+5~&Wwyw zab7-DVpkSKiO?1&NB6}2ly*cFEJQyGcoma0l&K|jP2zsv{w=Y(UgXw8V=PdtS5;-vAV&^sUvhJ63Nd!eDtY>*mGH8p?Y1TCib_O;^HVwSgWw#HWp(` zi+GW@y%1Ic)@0svD?2999^<{B5@ARH?v_E81#jJWPf#cLAexN}kRyKO#3F z?No{|kS`!}mgbK1(JG=K!7xba3`+H~7<7J-N#m1S2ivaP>5J_zD|5Kd56^#`JKyA8 zpb8GZ;p4KSt5h)!OFQcYD}zxdxP+pGiYu;Ft3WeuD!9I44JB1c5%Xw&V zv9ubQnd~mzn;S)apQ^VQSL4I7EP?Gvu(%Tcs^A*6OkZYP55BbjX_{Wt`D+O=W}sRk zJFj(|S&Mb@k^rz+V{bj*nmuH`<9pQ_eEsiN|==2b!ZHZbLeQl0E;z(;b0l z&ETWJYfWdw&dk>jR+PGE@KVn^tQDfI^fxT3Nw=hh!=ENvq2a@JUaG?r%Mu?^kW!`7 z)ITU}`q#r31N(CHMCU4xXa@VR4M=S6y{X?EV}8-C)(Iqvz}M34|AVhiaT)S{0S8XG z)T+B>yYtiwN-|Tr{vW3)OrX^KL z?GAW@Z4s(;+?j}VMi4O@5#z?!3F*Cv^qFY$v8{&5-R5!r9pXrtd8)BlVum|HjT{d% z@0K~eXoYzIQu%(#V&qkn@THA2V&=zN1tnkXZ|_Wt7hxEEfz5;I(3E7`ckHCJ+*kaA zIoZ)Sf4JKIC1IqQtxUQ->Rk0rj)5+{@ib)@%X~4UQSWwPH$+=m1Ug>zC4FUokvE_X z%2e&@RX+cl0~|#RP%FRr6OS@G4|?(F|N1u#jc&ng5eVX+T8Tufcq5cu5gnm}7xORl zs0$jz8Ojf~WJkB&0f9On)|WE#>ZzRE?#UJV!nGISJp2>M_SAJ`o)A}NTNHklfbrHk z_MdqezGkT97{4#WWl8Ow@D+7;4DjocRFtQjxyA(QiTNjTTt>4st32)(#j==~0tfGY zGA@jMkSF7m-+^{^L@uuH6MpqD&f!pCnt!q$AmmqGI0II zy`up5taq>GEkp~Iv+-V>y;QwnD3 z?MPXy@Q>VA(uEnyt!NodQTVWA{FBD=MnT!^x?_?`tUKsj9~V@=z6`O08jVCKzv z^^{|v{B!O0qV>-oY_Hnp$`~xiUR)3k3jB6ToD+L@tD#K$B`35 zr^q9ZB=*X)k^Em)0m>rQeK$k?4-sqpa0`L%P?1>Oxx2OQT`K-60NE)MxFLEkkbb+; z8U%!qVkeR@zkaj!^h9VVKg7DvGHWhBH9TDhU%K;<_Dnl67h724SA{ZX;pP?5#7cpi z@4}${v4PH)!S5FFiXhPZe{{0X$bvP0?zchgEfGIrFT&_& zalFRH`e+aP2a$%pvSTld+Lt@NJI{tX=VxsUScN<2PhF|eh&^&(3IYxAO;?|@P`K>e zEGW5=hPE=OgBuZ%(T&ax0e(hr7KbL;6Rh4CrE6RQuk2Lq54jsCHep>!#FKNTHoz){ ziJ}0*7~~XlZToiiVmX>+L>hr!ny&YbtQt0MdVaaO&hIX265JFp*vP>W*SomfVRJ^o zEwFQ|OFCVO51Z=sr70lNp=f??Z};>(+zztCohk-)Qf3AHaC_Zl2EY$*(Aaqm4Z?`a zay{-@Q70azA&1ajcHg2;dW$3)RUVV*H9f!3f$c{;`Qe#%CkCUF<;1Q4=51=I{k~t?zt%W$TgZ#^EW*HuHub|V=f)9!~ zRm9bKBbG6jGT!o}^2w*?%17OjO2=KKlRWB_PxB;xUmS(Z_Bo0QK$37Px_zY^dw$h z3`Rz5nIt;pLrZbh62cXMu4VV4j7uk6jLX;B3GYN5zBP(E>=n|mJvl#jEa$D{ponl^ zV8v!9TW*`y3@0SQH}U8Bu63MRcDo)kubX`gWGD4+TB(#8WoG_GlOZ0CW97O__v zX|#8KZbU$_@rEXO4Ew<{%K0kuHMIhUBsO4Ax@Lqu)m7R}=GKT?)ga$qT%ov+2u>e8 zi*A{L>VU%a_uQlh>tyM{!>-=0O^Vw%{VLl?(?*vlQ*yX{Vy|2Qj=g8}Nk&iowB;gT z=YYRM%UJ|yI2HNz_iHE1XkH0!HuvX!D;LT46Pr#@Hl-=sb_)oFJ7f)9@$P=9sLjyz zr?-3nrCE<9L2BG*km7FclD7yjj&1*mMjV>MHl2SXx3ssKTc+eemJ5#&CYT-~1K)cS zh&S3Z4-+6Nd#5?UV7GG65lsv|7-^o(!_uGj%#9+yZJ9mLVQ;k=OfD-xnoDmnnXJL& zWW5fSjzK`6D}chT_7;`&P&s%UGFmLBES_N9O70;@MzG}4N?w_zzOj;I^yAw58PQk- zE~-C6-p#z2zJ;I#Uup9Di9zH$h>zM??K{~Ix3>VBZLf9t(W%>X3P_V@*8CRkNu-QX z%3s7qcv{co42RYE339mJ9#2i{%MHvUXO$D_B&6KLD;7BkEWCmTj+irifxG{;yQ6^qeIb-iyH%hfkZ~kSv z#Eh3NGY{mw3+tk>WeO|@6earg+loKuw7%3;qLLTw*Eg(X|I+W}C`zpyHxY@O)c}qu z3oXKdUNiQqm3nVV-DH-n_G1bXo0cBnzu2F;)5nzR{(XFdT&71zk#Q`S?jpEh)-1B~ z!9ZZap;oVbCDENqa^JX8KbdZY3n}4!fb%yOapT>9s5icCDL(oT_6^OYsWtjd`L>*$ z%Td@o_#VvpR=2CIT7G~N>}~!LJ4cW4(MdDdp3dd9w}fOPYm0Ys^C@=<(`5s+-zu}H zT<^p04l&@P`ltbjhVt4~hwBYa$wz=d@$*8fctz#evXO0?C140j6#< zk=M*e<}<)OB-%B9y;1rj`h|kZ^=ul#zBZKTPIHZ6TwbSVV<3UJJ{k>fF0ONqOi78M z7GMnJ`2>dhV+88D5Ed^=#S^V=I8oarHF?w7(k>p=RNlswWGxF@Tf4NXvl$1%beN9- zi?9mhcHMy86-Vv#SaD9|Vjo)l8mbk0ZM0`>i;*QZ21kxrK~0Biav7lu{1eet7-yQ0 zREB;A%=cNp-pec1&Pq?6|1rb_L(@4;L51=Ula1jna8A_gBX}0CyJ)N^(yLIeN{4N9p15vM7-bb>RkIyEk8!i%NZ zzZ#m$LkJ%wUjoXL^sxR|nkHh{0hTNK%%%#^GH!;wj*&dXw>$5p831V=zTJ1y2w|)$ z`P}?0Y{Jl)y-tb=2}R|%#IeF93HT_Y#NeSC4jqp4;nv&!aPJlDb))o$(j{f!0*J7l zTrKb6Of`=JGY!9nmM%Hf2dFM zZ#`M_IVVn7{j-OmV(A63=HUyRaqdnU02+_HWO3vs1a{o_u0Xl8){b`mQ9{h$n9%Mz z>jZU7-ZlOI8vCNYRJ*Ws33c(BrckQL*7OxH>NE3Cb=nPK@W^1&fApbSz5dw1n-Z(@ zFa%QFnT1cf!BfNA@McrXX7-_Rk3|D=+$0PZ0-Kaeyqy--GFJ(Ph94<_^mm+E(3S@r zXEoXXEi$2ArqF)$sKyvYF^->J0d!4jGyz6kKCa8 zGS-U^nEvf!w$h1Yl?a1Nm0%`ZORXH|iD3t*U^%#;rT9&w=&8f?Q#lX3d<0n6;5n@m zhb#Bu`xT2Lrr#r2#GQ_NxaOvh-MX(WS1Ua1Q1kxCKiyK1;1qj#yCD8%_`70Kn)dm@ zXqTD`+Drx%Ynau>Lp4B$z_Sm9PgYvr*9VLObnnO}DOn|z9?lHFuy^ur66%Z*8Vd=Bj>!=(oGGY6{uhO@{&cLRE z(zpR3!iv=w!x?a*R~$rQ-n4eB0hz$IKON(Lxb2{{v%nL1s9=YBVrhH9(%a(7$HmSU zeKC1P+4+du|M(9rP}Rf4Ox9tCj~X4Y4*0ND*h>SZ%tKqa#KS4ZbPbkEN-Mr*D3mrn zlD`jFBcSV(f&c!z!3ALOk5J5*=i~GeX|3U&uP&|{vQM;S%ps3~NCJe4@MO*IylH_} zH&m4zRRR{Osos_wP+0Jq)1~)EK`)FR)gS9CEtYQ5%0(uispw3$OF{aEq z7-1joQf~CnW1Z_sL&GC61_d95!NjP4(utlCW2(JBVWxIjG6*m>Rv6zG?P6)ILj{YQ z2be_zNUlo_o3+^Mf~ZC!6m59*&`N%A16y-*VQN`iM!dPT`Vq!7{BC%)QJHj!!DI6Y z8SWeZB#;0D68!WAPXT*arpIbidC|qNZ=2@@prJq9tylJdW%xIk4=|YXH;kpxyiFYB z*x88@b(kT~?o#v91mz-k5#Lf#P7xKC)c-P{h^a!hS|GVTT$W3SVI&~4s_(StU&R_) zHCv#}9>G(-LYl3dEstT{H03_6KW$u)@K05W0n5I0^#*6eGB3Ie5(L(Z338?)BD zHeccP_z$dzqgE&rQPR>Bi`3gRU!=!FQKc76V<@NIJ_~Gic#Y2Gy zfPJ;Q0}IE^=QMn{YzT7{mwWQ!pcoxJ3-jxKZ;^DkydYnI^Za2V^3BLb`wyUSJ!2Y- z5-AgOg2Ckk6+pA}n@0A}gCt^zHjk==J^>x-TP zT`A75)+kaKtTSh711u<8;C~9k{NE+VKW1A~)leXVqQ~@hLT1)Ll2%69)oAh&op(GW z8tE4#+ujQ>L_@{mpLD)5C4Vo8We7gxQ>5SbrrU?E2r`W=R2xF}Zh};<=lub^_92C7 z^5of%A8I9$J?SH)KL05}F-d}TP(FvcW{l{p{eA+ckD0t_P z?#y~8Zy{>!jl=KC)VHlgQA8>Q{Cyfym)5-Z4bx3|L%LoVWy? zynFw*IVH874yH9X-pl_Z@qBT@>*O=aGQfEe-)DFRMEjf7SsncT=y!em;0}VschERR z>X2FVKZ7G@U)3wK3e{gR=?V+oeXA%fCsuT@$r#?bFnh@KHYTiZ2r~8GXgKy_QKeHT zPKbMw`_7-PF+t*L95j+x^&sj2Jmo<&yW6zNdOsgN#XKY+*Wa5urQQ1SsvBp2waXN$ zk_*{|zdlA;w<_1kom&+Y&bqL*|9i3npa%0qT0*yJTMi?&a9RM*`@d1|O0*m|Hi4CFZSFp6`NG?EY5h;;IO#zSp|3u!0<9XT1;;{oQ{ne0YH~_}uy^hd(oo z76d@jPjsyr2J*{|i$4AXcoc{E~?f}fa3ON>IGjNO}pY7U+Vr_sSp zsd%M;Z*C?{?Fs7_Ou&~u5i(oxVxgZmms4;Ium`0lYb|GM_oEELWWL`cuHg^-bh=us zEQuevC|g+lw8?URLEMoz@Z}Qol*KG#ngsV>c1thMKmwy+K$$PTw0tAls=4{i_B>*qF)F`zrUl&K7lA{bl1 z_k-o@O`!o9SbK7y_$b38L? z|58}B1tm%~X|eW!xtf+Z?C*I2?#i(B5JxKmSlTvFx!}Lk^h04k*?JUVfO4p}8--S$ z^r#SM3jVaV%7?s-X1f)jQ_pR*Kv?)wh6=5lxswQ?ny^f<0|!A15HTk4{UZ_ec}ml$ z5Eh=Nxv%gymVKu=b34!?W1NST7>511lSP&xACzG{*S|E<=R7*l-&M2=204O8uMX@cFSS|>bDa`HZV`QMKT(W zaxtX8pe{+Gksop^Gm4AGiN)U6=W8gQ@~6yN?r)FWXtR1c%nrnuDiUXa$a-~jf-qJ^ z@-k2#rFL#E%g$#NAz?HV3y;>C!o7}|FG>up^5I@ds^npG!vX~V zo+59Da2zmZm{4R#*4EjB)U-#jiX$;RW)_}a8GF~F(6h(2@U@7ZVe$a8eRxyH(tdab zlKb{@##=l(Ot(2UKH-g**)Dn&i`s?3XOHrqg%XkH9*V=vbbP*ma^@;~8Wmr4Y z3NZJH_Hv44B3G05+?+|%p1=W3kmzOWWFAHWf}meOT5<9(J?b>fFClN5=*_5+`3BE6 zK#jx-E5}uSg{)!TKZrUWUTU|&kB&oM%BR&&$LB3#Z?3qI(xQ&!q5r%)k{%N+uj|52 zjrnlITCHK@H48@4H`6~`#6%$UyB_*WNzAg>RMSsgWd@L2psEa7Z2o)N#Z=x>4tGaB+u?-_F@OQMrg{)Bs>15D|cq!DX=L)KtK0m6Ftz4t$9l&#lOZl%W*cj zgoc?j?((RW0_n_hIkdbkTVm1f3W0upSVo_Pbh{A{+pj_k^JLf-hku!a{>~O4Hq^Sk zXuaFeqzzOEkPl0TeCGS{-YE`k#uWaJ>+#kLqPWVVs;6E z5~+eBwtyQih(OS6vNV~r6q2mT(H0X1+%vcU-p;=-x6TC0mj{0TJ9UBE|HJ9C{|EA6 zjpY^h?-}&{D~AMNjiP@vDChc*jyfP1uWT~ZuvpS7I1Ld4f;yns6bQfV$|#ej4lKZ@ zuNoK=kg02Xq6cx6T_(217UZOB6veF1hyM1jma>{_8i>pRE9LudxoTFc(KoTOzF%j_ zpIfHZDWKF?pjr_~wm)EHfs!4nM6+Ef$1HOLcvn_arVJUd>kRrKa<)8ZFvhyx{}cGA zig&=D8S_2VP&WH^aC6j-7_*`dv?*ra!WIj2zspxwF2zphxocgTM#h&tF7V!y=fDh> z$cRgc_XPjYjohH!ALSokHmPx5kXysc0zqtvHos5OdUhan!9c$hg@DD_{piMHM;{IDP}E7Z7C!>_YdagN)eJ! z2;&R=nZlG2k?tOj+tZK|5@9`APC_&thdK?@1*KN*?c9fbQ#_4Y8f~KxKj-hg&jpDt z{DhC)9XNfsg_ok{9~`_;I*WJ?edBLGd6k@M8j2N)x|o@<$7>#cP3YmC+gpn(B0>o- zEj14PB3QPJlyXK9sARnNiG<{c#Wzpfo~rHZh2JNYEI7$$xWxMl7+B_tU&!2EKV_b) z?7d~a&!5lN4jiB==L~tY!)2K-hNmRjN9{SH-6pYZ#!!n&WKwV`I`b`r2ad3M`X%Rx zf$bOkUmfCoc%GOHmT1*AjuMBFElDgMGwcJD?vo))2I3__NMw5CPbx1l}Baj-? zFG+0j$k?vB#XI(s&D4&`N_)~Naemr8p2UkK++{8Sk%hPXvMN@G(l0XPukjEq(6!wd z!=S)v<(^z$??s<0!l#;k9M@p}eb{)p!LyU(bKt}cN&BUPQkLujt%3+s$r+t@5fz9(~qeYFlBG)`w*KdKWv5_Og zH(S5{sS~Y^g=Kgnw%FFj{`#TSp5)Segc8`|vqVD)NdrO)3iZg6>8+SR%S}ZKdHmAD zVAsWu#OE$!zht@n%6&fhjL-yU&w4t9OhV9!1S}C$Bc@(ot=W`Y+M@+amJ&Z%&qD}XJR3&Jej$2 zqo~-}&-47s9?sgFZ zrpw)$T8$q_p;C>FWRwWY{Ox!AD7H4^k{oD^G)OSiG#WuyT*_%bB;#4?Or& z>kd<#kaHyXL~oUurH(}JqLp|cVwR5yjiM#pL-=I|32t(rg#+Cs|#p-u9Goc+LiA2<10lKr%zvSn+^w?&G&jS2~QMSSJ zlZ>E}I|ZDmUw>ojtjoW4C+;GqsxXboq!naDn7t>iQ||QA`H0x)8A$v1-E#;+2rl*> z;4{nn)jT`{zu*|hwXR^q%Y6(xrr$6LlPI}#p`e&A)PPAZ%H>UgrQL* zoBDTDVQsWTtD8@-LfSgW(jm;=%H8#l$F7;mjv@2y_)1R6=GJXZ*U|4N%Ld17(PAS~ zZo4L0`!Ic zFiZpd|l3Yzy-U|)OvVC4BSIJL1XLr7~f5v+k-rZN5H+P4XgV>zS|<+6lHR{mdDO0|Juj89-+zS zKl8kM20g^2vVt1g+(arWncSE6X+8b6knYnlG>?#w5gYHFf~!4wtqPwN)<|Ipo81ZK zFWsS`KqmfcQ?VOmrTM8cRHo+OS4?3oJZJtla~n{Yv$g!r?k8ffHw#Uw{U%PTg5+{d zq~%M)RPC`DVmGN%*}bktN5~N#9I|-hWux<{tw^#vcvgLEj`_o#@Qb`;ms>D)!$hsb z*COL1$b;Spryu@rY+b)Ogo>oG$E)g$4VFCEu8a>`tnDpe$K;|LB(m2kFH)vB%0-k}SY$<7`GJ zBGj3`Yg~^7zS&k;i`06QQliV}q^pdMl?^h2o9B*JIWeTYHDg|lKMh4A%a7fHrq^8x zzYoJFe4k9aX7YO|&ivzzzD89bo35et+`OOhNT8h$HH%)@1Szt+XRyCqC6ZCP+|>Ud}9xMy6$+B+y&xMKV&<*jW!ZvNdCNw8~we@f9zcChkKS5a?xcXca1LTS7b%!WNJJth;?Bx z5p?mMKR$8G0C77Ha(fu*5j2m&?&YGr8ZZ5Bw~;6nwRSIweMh{5UKcGz;nfheifr*N zfyFbNC+*}sP!A}Ru*l$p1=Rld1PbS{vKz_rf6U+}*1hDOpMv7lr@v8*Zbjfou@<{o z^JJAY6Oc0Xsjkj{$SaXGqA@-_Xqygz(eNnW$9?~ zMw@~ixb&lyk!xV0?<(c{3_7Wt*!p$k%ePV$hLHi#2l23mOd**65(BgvA zg>|}z9zdUMP%Dd+VKQJx7$I?9%erRfhSmr?JQS7{BX?SEQm)*|O!gy?lz{9KcZ$TS zXKG^0w^(>ozKrj0COfR6Mai$#0@7UZ{BD4cUC&ZQC~=H)-!&_MD0dGbC~@lE+ybN9N~ zWXbmT^~P^+T(#zJfM&6sl3mgv_Q|IQ1DE=pm#HzP?aKxF{WW_6$VR0O;{qF^JI>LW za>hZW=m|1(2gz?If^?Hact7(a_3BIpHWG^A1N-1eBF)%Q&5}Lvg4FHIx^3Ggx1qcJ zA-+GC-6;#tJ&4#zw^0LFqSOV7>O!AuEiD)~Q%3XN$&K$@&yZyTe71uUnn`jR>Do8#M3zZFZZ1>M0#r>n(8wY(6OH3lnj52 zpzLxrX?yn3U0bErS9)9NdA1=klKG_bfx4LYy|L{z5NF4-=lF8?xZ`!|S_<2^m$vX> zAKVL*$#mRE3|=1V43X_yHxe~MoEk1cKS*BiGPK?3Zm^-D_IiLgL%|iGE9Unl=N{P!6pSL`~J!!%wYnuh!8eDwVfhF|29T1pmmmycN@RNwQ<3dazlapI3svT z4QXw`V_Nr~-m`I@zW|HBda-a=cjM?Okg3!ElD$TM!5hpJe@XDk6M1L5tzFBzm>ajU z*-fI|OFK=uEdNRF_p?xuVEt!668Rzk+Bu<28Mh5I^es?x3i-s6FLIc+fG(=nv7~Nq z?Ik2(?(`lt;Q0S8^HmucHL`lLzqkLF+Q!;DC;#%%NT+*3A;Zz z+@`Lf;A4h%7I>a%Vr@7vE z{+1=Av6tgIP7x;etGD{`OV~9Y$Ye0USVRv&t*~ubXlEL_Kw94_ct(=0(75KmdqKGh zzYt{bP;x65a!>$2$RTeBD%GkT_9-834t~A$JwIx3{gB}6&yi00?1ySfIWM5VVbr_Z^GBNwiE(@fScCdOp=cb1Wd z9S@$C*!&zz)UU^tQ=`?}DNG?gjV7mG^|r5r!lv0QLLtPn**g zU{~@9;N9p!Hx)L+#EA<4Oinrvu#YLrRxRc`?M>y%1mmsY{o-TG4XsmlQ?}z@CAh1U`^lPnGH<=eCj)&N|D7&qiA*LPQf1hjWs4Hr|bcWal$=x#X@~+i&imYDnX`O z?&N?j7BZgCSTg_)RMAn;RsJ|Yc=G@4sfiyk)~9F*O}7Pui0L0vMJLj2vaSm&HH#{1c-Yf^zxLnmb}Ys&+ac$*Us#ZyKJeyQf4<3 zYYs+_EXH`JD7A_pPPr`N;}teLvPSn}%{=&P8}fVXZTrYvDhR;T?JECuAUd5MM>cRI zHQEp=k;7aE7iypBJxH5NBF*;6vJs;DRn&Jp{((RkQfnKm4<$%#nELKi&o zuJRW0Ml|#41dmDJz;QYFyT>qy8whlrvIvA@CcH_P0oG|Tq8Vsl3qkyphLx2>@k+FF zTX9?QJYN?uE23M*(t?(^>TAX2go*Hst2m8hfxSgX%{5EVxPg^JM+`e-daF8w_Dmkx zb-xW;l*GKwG=<1?;K4FNj@`L{!Ran>l~m z5{{}xALs(KfRxqQ-P_?>fniw-h9R4-=HEuVf={Q(K$m?|4e*q5XY-u2bxd7p%GP9aC45r+dHUp~w zs+jyK>yuIJ5V2q34N=p-9=1-ZUDi_aKVX(W7k&x0#_kbDZt`y%7`kK`7a3Mi{I2Z5 zYhwEGqbW`}jg03nGmKIC5;to1p8q{6btG=G17R)!IXMfE;?=g`sg)3vNa?5bumzW6 zTK!%_RRecTTk7eaM!bxMzGSAhwGvwYUCrwR`aS+f80GTUeTe}P^XzYJUJPN;)be|< zKp67_xWt8^<@?H7+ip^@%5J2eG9wWzcqjiO_^2q#{E0`PYRKP}s6E+#giSNGtZ$BEFoutF(3b`LwiEZcHqN^7Qbu?M)C)v78kVzMzmSQu%=7JLLdtDXOW*09 z3+}5qjJHH2V?UVA@7|#!WKx}6{JKnuTi=2Yn)u^k_JTeOG zBxl={(G%P75$s*4-#PLXRkp*&PGm9)w@n{>!O6ImH7p(UNVEpq#B9y2`4(zW6xO1= zBT}G8qxOXQK1C$IFB3xt*7N3&vek`M}7#l3M-Tn@84EmP^es% z>8T3CmxW|RybbFm(tcaCs1K6r7|#O*#$^fF&OL+PsV;)cH4`OjC~)2^z$2aj z^5L_EsxTecte8bXWzP6^ZiL-91CQjK31`MM7yO8okI>JN=Wf!$1zMdjcLSF+-pKEgnsl1`v9O!4Y3< zzPonT+x_vh@9U{HS>jVIonp(2!@jk$Zl50)YN_gd97IQ3#zu^PQjpifKV;Sp@5qC+ zx^DYjba_YgEe)B#{uC{=wSo^St9C-0J_Ut}NX|cRd@;ery~#soL0+)OtIhVlqk3OW zE27lCPtuWDih8du^wIoQ7p>kca2;3CuK*m<7tLA&Qo##X#`|dy?sNC_GKO}^;(|uj zp0~6MYHVO86KMrG0IXZkhha`ODnh>DD2T~iXZ|u*)j>O###hr)nC)F^KCNlwYkx26 zZ&@Pw3!aeLird7L(F4ZyrX-XPx*W-Z9-@-2GVU@!KhB-00|(``AAgQ^hi2UWT-(*N zYW+6p?_0t1JnQX39~(KvZ=zky6@%wpk&7p9$^MAw9ZpY9>u{WO-L~;XTzqrPIUQZUkX!53nc6RD$BhJeZF0HXdo_F~B;{FmI3x*K$8s;9vAk z)CyG)*QR@;dzZ(D2X$s+*qXY}9Mz>;r-u9`a&H<+;qEamlOB(Rwgio5g(ax~^~P}! zd6wcLPUQH|%9K0mNi+&6Iq~o`{VD70%BXrVxy8RBn{)Xst+md1F1hkXL zi3pVKWhh0GW1_;d@;$Zwh9;A?Qozgwd3Uxp1J8|9eCK>mw)>=~RTlgIhLr0miHEQT z8Tv0SEiB3!G?9y`KLMNPM9{t?q&1cHvxLasBU|q9PAW_U?WFm2>*4gb2o&B?mkO$3 zh|>X`{GK(q=fF~FVBB}uNq(Z`g){>7w7*f-VXF|GwFVGKBR3MC!H#V-CwUuPL2*I+ z0rrJ9YHErdE&Y3(HCfsvm@P?!BOOS!9M-UD1?mB1$M01H1I!mbs(l8)uN_)UqI=?$$295|sp0MGX+QHE`?o%I_BCh2Xzsp|oKYzPn@UKWt zLhxzX$(S&uy%nBN^t&WKoSXE}FW?_;um^;nMTFj;6FKG_GOml{4`*j?`{gwdSWi>l z#jkkDkf?g>~H{lGQDCM&SFal#3OE z*v6P5uC3h}bniR$ieXA`a8kUe$UWiTN`u>fe*w`JP~v3yDXYK@>pqB-bnL&rMtlz= z=OnjE%P1-kECO~q$npf7Bfz<|aK!uV+qR@`@3TpIJWi#3a$HuFOF8T_DOYr@^k0Ot zq>#ieJkh$vpc?+E;6ai1Dy%Ep9=wFAs&NZ&U%6+EXGug%&FNh*9F<- zAkW)$SeItHcX-viY2k&*sB-stZm>A3T znm*s(`Q6t!_qqPK??0|{U7gd>A@7-atR;Nj?u7b$7^fYHdv5OGNs}Tc z#c;RJU2Rt6-rP61qwr@Nzcl8%SQmv_aXh>N(#qiv)CTQO*n7~EdGooZj02Qq6pJ(A zZc$UEsmmy!*|+INi^k*OhENMrF>~jRw zW(F=rC}+9G+ot;Z6k-@!-PgSCMy}!Q6Rr2cNJW4y_Js5c+B$)iu8)aAw|R9w{_n;( zbX>At@h2eWTmN8^PMmiQvRKF=oen_kPvbj>GcFrY@Gwp&K5l%cV@@@UR2b}8!A_2r zkIi-chDf|{40F-{4IA%ajkX72Stkr`vhwpQ`p(VOZR_DiuKI~`JdvM8)O0N%RxW6I z@^JiQA^9S$ieldmiYDs)D@8xF~t~gbL>d0WPd{!ByC}DasZo;vXU2ms~EQnUxa} zgYK#kg!9VzO2|V*x>Z-(AKtM68BHNT<8}Y>IOu0J>rGXjySuS{%_FTNZm-49Hl%@E zhnPJkmUha`bmXSz&ryQ^^Lg6-AheM(*0D(q`|`SA14g<3H|OXcm~0ri@MQBm+TLyL zm6zkk_`O}za~o4jdZZavYCAN@gO-al#r<Mkt>fG$N z>qYnV_|IA6q3RTiT-&}JyJauL)`VZPVv@YCF>&!r8zOGKJsQ4jJ=v#mDR2+@B2FrC z!s&4;CIlNf>tbCOLVcv*x$CE0RE}M)bxzI6J6)3%rEa|<$F-w9UsrLkVKpQzQ}Jca zt6#}Zr(DLP=e0UMD`v{B#kM7QHZ-Uc)4IK!XX+coPu-tXm|>G=eKM!j20!S~-JCO7 zXeBrYfqD!@2}<*^rTl)q6A(Zvh$K6i+nQFKIDL$7o0qozW}`-OJ=on8 z(h4^bzvWF)P2ZF+SU$MJDvsGgrKgT8W2(FHMii#sH@?#D^6S~o%=PP z6lh!GwB2+G!U&`^E1-hVNN18x7Ue!EDr+IO=6^g%wzg*-YpO}TUD&kUJB95D=p@_r zDi&sK&oTf+i{BXklejSHh5Y7bZWCC?Z$q5@MT{ON z%&zW#{&0QbPr-t@!dEMW)d3d|vHNvY(NB>nIpe1*`$8LL?L2t=#$8cxU^^8w^lX^+&Ea460E0|BqJS`QoA68WK^ zJXe8VLid#%=J(*?tXF!to10;pp6i>j?5nvpa#ikVY;V=|4J9WxX;aFcKv439d8SgP zS>=g2?KxFaoz1vAZFF!)?iWIOgV4R1Jb@5pCZc7`MWjvnO*1;m!?8LB@G7dsF$niFOIACkpiM+mUnUo%Vg^JcCI(t%^Xc^SyJ5G-nM@Pjo}N9%JxMv zzoO$j_Pd=QW$^G`-c#TVtQZJ*LR(R4H6!!_YU$U#e))Z3+>hT;9lRSps^0_$&8oFT zjnusp(fO~5n(ftzB2#Ynh2_j*MH)VM&ZPT}U68BuaFgSlA$n=(V5o)BlI7fTL2*rf zu1;Fl-lip3J`sil%hpd;?dPGk-nXHXY#Wg#9@XVT{00m`8zrI3CWi_Q1(>3;Js%{# z)M{CTG#CpRK5+Aww~Tt|8QiuP4Jj~^x=(p5@=d70m{WsnT3;*rnv|usn&Xd)pxy`(lFr*CiNxvRNpJgpTY=+{6z7fTz}qPeJnS& z;oiAh?px|$jF%oUDs7KfeK&u};$lwAW{Ux%KWux2DrIsaL8$WOi-OW9e;qqvh5C;l z-u%e65jH3Y`-u`=Qlw1SLLO*E5Q!sKhwiOjz~7sbId@$>4gpUOQ+ika%YUDkDmBq` zyA`#S6mpj1amNcyoF(9btJf4DKdmwb1Jmthx9?ZtQ{6Xnh|tLz5YQUKfC1guA;Sx0}B0sRz}5@N#757nWmAOBI0>R&!$&($fxIREjh;WlRBPKzbeg++3lns z$a(2`>K0OEgo0*b}hX> z^DE!NDE}@9Sjlj391e@Q2fOz+6Q@<$y*E?XYWxnir2nAua`alpZ1#W88j>RR5=ctYox8ta-o-M*v@MFG<`V~J}3IXibvc#>%z|=>Mzj| ziC!pA#?o6H?-PU#is5T1jdqs&ZG)W`$9)mRo}G8 zt4ZriqsmJ*%BNT1=4`V>mL_zfvjV+;`A8G$OXV5+W>U2GZ6OGtq0TvEw(JXt zd9q5(J*~+a_0ZQKfr|RfW~cFbhLZ%HNzL0aUAD3c5(9V{eO`GKxJ?}Yv|k{prwV*a9Pb1GyzhlPl3-?1}t zeR8J6(_}q=b9Iv3$mvQ~Q=XNVUlOHDT-|PJyOMkkI*C60W$K4$r36W6dY9Ui!t2-< zA_G}1n^N+z>oaM3zP!q-HT+7Ua(rq02nz#t%%wYKT{!c%4EM!WOs=FlleTyri9p~0 z`^c?(iNE5On~VwKNItq?l8-R-!OuPP9VuxQ#d8;WYOX)4KLlGx^L)bzx-0T)KJB;NxTCgvJLJ4Pa1`NgIVP1d? zHP<@g5%m*Ig4MKV6B9K5{6WF8n`U04jXFER*JDV-qxzg7V&PF8M?V+`D~#bV zv8K4daaljo*}_4TTr>^nL>gCl-mRtn0UT7Tj zkI`x+bAIFP;mC5J{qQKP>^imRr@=Y> zHGP#Qq%zn*&nl;6Z3TBb?tC=bpbN6;+#dqjEfcGNydAfCRKw*EF2?la2G^Li&eFO1 z1lzg0t5qxxEA>YXe?aR!SM=|$<0W!o_M&ijR-+sZ5;sB@e8UPGiMG6BmZSgt`a|xb z4Un?_9$Kp9UulL@tgf|@`&$c)>HvQ0sQ&DN}u_E`FZ|dcnr=LLXx=W-o|-QiWqGCiel@>vNZYs*K*g+ z@pS40jNRX6(qqTlaMIwgo>mh~SB8#*1Z!ka`WhiiGno8joqRh3EjteQGGN2|04jz1 zi2=~;s4GzC*5wPSx}mseGD5OurVL=Fe=ORa z<;`C`6sQZenyHytO-t>XtR6xDR3bFT6626ahZpUSS#mt4gs=?vim!_9Br9W3mUC@; zu~HxvuNEvoNjJi~G&i=|L(_Ul(j$wNSh@E*T7-O+!USavkbe>W0r_?mz5g)RKnY%9 z`l-G&I}&gQ3{l4^Z7$ywcp`6{d%Kp#Nbl3o7(XTL-1OIpZ;v|Gl{EvFnzbIXM@e(oKs z#lg^TbuBi>Lyy+pz{;O21*)dU=u8@ZvEX2g*u^oWg*i2va`wg=cjsy}ey9W^czsuC z$3n+^w|ejAv zt=e@mU#6T&`iUw*yKwHw-!>4-(nt&;z_ntja!jm*!i{)E-O4_XiI#z0KfxgGrK+Rh z+!}3$ONI{(1-!am4Hq$_BV&-N@5AETtL71DMFiekKkR98JkPx4iE}>g9G)B>eYD@@ z$f+e(RNKwbYv1@d&hKHg{LZKe>pB0`gu`U%Ms#)V-e^fIs=LjmSd2F;i?Pg<#HP?9 z=V7*LdstWVVSy>^0cUAPr``sKTu0dIG}@Y0Lj4MRHzyWes9Jv{9vaJM^jK34v%yc! z)eji7#qMe_FkCZ08W&G$L|gw(Mj3vX?1@DfOyU67ML*!_EEx|wQwNJh>us2B7d!@^ zD#Av{1WnhW&)6gTbt{2ugiv3Ds$5xT0ptTeV_tJs}(6!0O!UpTD&LA^Ix z!Z$F(qlZfzC~3jX)Yx2fzS&j(-C8~a7svSa!U>HqQU#lRFdT|{QwS({@rYbaG2yx+ z&(gTrAiSpQ56zE1ND4!f$^LE*4ZQM1G!j)8CBz|oq=;B_@N~QW0C_cWcZEQ^AM~cZVARkMHyLZ4ao`^V`v+4hf6P`Cc{DBy zHTSyY>Qa4WG2sZt8Xt;y>hqzW4y3Nut^bQ@8_8fvgb1QUj3b~cBP)7R8B8G``=wF# z+KffFc4+upby-usp|K*{J_V0$4La2eLaWhFx8cuS^Q_kaJ6te#t&mXPmiWn~bdaD1 z)#@0dMK)BTpAKD;FOLK104(t_?1j>`DZ3olpl>w#*16d6A8Jr<#2s-VgG4uIMk10h zf5E*(z(QJz4)f!<{ufKw zW^>_!n(b4|ZD#rLhoYZyL7SSi`YIrFmtNa}g;VhJwuUO(j;&IeV)42gAS%j}qdPLIxqW4d^ z6NY$Gdog;aNa|kzqW^yb5P_)ZCU&!$k=o#X!SYuTJGA2D#H68MIAF0%l*MGor5e;5 zj~M7>L?MmIf4ge;ZSS=~@7x!Bcn5MKy`FI1H-plkoYP(rR(Z(Sl4{REiAR}xyb2UQ zg^A<9=uJt7uPD7Vm>Yy?;DQ?W5_g`mj;+`68ew8qtg_)*;A*Ki*HDNWE;3I>8rLXw z%0AVmZ1eGjYhs}f@tIJUJ{m~ob5S2p=bI|bw* zUp2a#AX(-RX_A+I(Jou4D;HD*fqk`z4O<$j-O^jFaB!^H?GL?*ID#=~vDk>>m8WCE zukjv$>9Z88UL1#ZF)dWJoVl--1vpt)+ad&`$jgPd)fu~)q8eSkwjVx0KF^=Y7-B^ufZ zZp?Iyha@P#N{UF|5a(T&BSG>uOhT8H-MU{_s~1|PKQ#gsUk?8}9x8LTTW9$WH@%U( z+Zgv-bn{p$;D4eyy0CBn1Ruh)M5s4>F0>25EY2MBBlXq7@kBs#_~LuA#%qbVn5jaY zuf?3Y^`dbF;1@2q-ipX5T)iLBU)q@kTi%Iy)BX#SD8uIAuHe8@AH z2)D%~fG?-0dZZiY?zH{n>RHr}lBEk){+Tvh>HPNGrb@+WpV0H?Gxc5UZEfjA~^;bp}#2<`f9%CAT)W+tF^ z{t$Psxl`1_FC<6jI{jO@pIN) zO__#L&So3!qgq0x!w@! z*Obs$KZDpELo;D+EvC3;P0|8Ur-&;EZnsbzqkb+m#Bn&QB(6l>BbVwa`0mcso)N<6 zMhVq$jT0(090+wVg?f$4_U}+0ZcP4JbZK(8Aj(i>2ZQFZwTg0)23eP{aKS|1q5;}a zAhCZoK|1+VA{;OtCB*mI;GpM3u_Q3w^W%M7tA-u{*$(bF|#vfT$>5GY}RAB(%%RBeBzMyBy)Y;=RnFMU$! zX6J67wUWeY2zuBWXt$mE`)ITmC}1{K@Hyzpl_tBq8~WZn&D50@#JGV2R7_GV?|<{}bE@G>U1(Cg zY|zv_i6s$_hxOmoEF*k`ob7gex@3ZzxKbv@*R?Bv?s2|%=E6Wj#==9~ur*}wA5k8U zit;?2*4P(1G~Wzuz zED>_iLfT>`He+@}J-wUGa^2SY-t3i{x~fxR2D=o{PzLL~V=4FRFYz*#jIH9j4J8(L z`dZoI>k&&P`z40#Q`ug0p0$x7u20u=1$lFZBs01&CLMsk<83~3b_0+X|MG>WeRBF0 z%m9ALB(*qeaa8B?2?z8(GTU`->NRUZgxcf1Lwv% znX;9O!)fWepeLU?r`#%2 zrg_t{-uT2aAUI6-siQ%D)$)JP1hJrV&)MixR1CF!^hymFx1)?Qa4eDME1{+ z(c1v90>`u^=DB}W#7q!+2|h@m$0ZLrJA-F*oV z7I(~mq)`gMyLvjctYt*H&eFW!o&*wL=YiZnK3luPVli{pf0)}ZxU{AIGPf&j>vpPS z+(*$JCWJC1t{&hh3;m2&YU4^~2_KZa4^36e{)cZbWt%u4U>a;(kW43_XSr!l#=htQ z$C`fG~I)WhiEZJQ?D=VM#1P^0BHtgV0!ZFpy{FE4qH z0>lR-gI`mtS7(NodZV5BE=mLd^qK@1)%$`kwIn&f2#tkr0cckbi0G?^qI@>OhW-p1 zk$o!POPu?e8{4{c^P0tzbMAA+0)@}cxleBcO1RsRZkyxRoz3l&ywc$as>JBD-eDJYnfas7jJ-fhe4j)A-`iK! zlA~qk?5ac2T+*oO!D``hj#{hLTg>Ryf-a`GnZArH4#Ohh4AENnx86>)Upttb-2nEbrtUX`iz*XGwz(HJ+Q0P}8~n z=bmQN_)MGn9+2w3VV36O37=2SD|GW^=}agsKFyvQbGn^TyVVJ<2NlgH7j+g%hf75g z-O||)T8dF&sKHjx2Ol)&J|&KX7yGQyY4pzgC^lUTY!7>r!~uYk&RRF1#QE6dLyejO zl!=4wbualogA~M6{oN)}V2&_2*hs8EsXfAuK%Th;gxe+xlw7`=c=hrI*aV=u;>A{X zN4u~Q6$}AjR@*uccL->8o+%#ek!jh91gD-A`PfKRSS6}4;uxh`r2Mt}!^-?>QRX2? zz^O=Wd!@hCyuk>~AB!D{SIb4L9Xh>=vAJRCnp0sp*AtOG`n)Um>xmU2ZP0m>XD3gi zdLiY;Qs;?5OVC6!`_q4a+t94i%*89rHT_N-BxMORJfT#hgtP6T7re>mpv+t+9g68h zMwq{G`gb*qE1-MHNoYzU&D<>HtmuLXTg@f$73j3Z-)H|7;&}FcP7|j|Bm|R^+s3zQ z?4YMknP};7)Vt5TO>_0n zj+j%Z4V@P-yAvHJNmrj^#q7E#q;AdD?6ckHDoGX-?^E=zQ4DSK-d-OH6?fnvY-=&9 z?3h=%yyY$VRf$GbF~0z^7358n*&!XA9P^Xty<$tbSwDlZi`^`2?px1?5TG-K%uWQwjLnkVpthA2vvDcbLGpQxQUc|HKiOQ+L}Q6szDZ+@HA z&;3#M%y;s=)}HJQ4Mqdf`Y_As9dFk(DKc~;J-Ra;7$$YRr_EKnJx~r3&7%8&S<2*_ zmpmT>&pJZ%(j4D68QnONWzIoCyUbTuuebGGA!DtkR$kI!ECjBy&F^7$n+Mq`K zsg!)G$fj7}5?#J$mS#Ez-%?&ylfSUjQ=?-`i=Ev&P;luy9AO-MI1M&jZuWR0Xy(>X zmG|i&3$wvpS^l&><$$60ti0?2N-Zhd=uv}5wsiFWEHJ64CIyXD&83WuDbN}LU@3(a zXtT8x(Ee1Y7=agc7FzKh^|S^_`)Vgr_pwha=-V|<#$`l|>dWTN z2VW+*TRHN}Gv8_K%ZGd963C+g{5+B@-qEnzgTW#%sBdvgb#Eq&sKjLebqNMdU=q5w zZZ(g3GP6&On&&gNCn*EBG@s5)0InqFOCn^2_LcN4Ey2o0Q+r$>r>X`;p<8!#tf7M? z@G@C$>}GbC(w0(+p_PGSPz=6hUgnUvGPrCq1mgj>zXEl`!T}-r6I@rddJNU~8H&!(j5f)I@3l3U< z9FMv>{^VtX=Iy`5WP!534jbV+`aD|DaH#MwXKV0ambo(ODs2T2$mkMYB^H`@Y^%;I zu9#bWl!akD^8`vzn9i)v?=kE4>?3?9*D3E&;R5>& z5DWueIHy9GjRSW-!kR*CSmkOI1t9F1f*nl41wiEtY$1F zvwYLQ4#WEbo1+ByN#r1|P~JOo*35%)6U3usrk$#QpCspdD4vny3RJU*&k7pDp)v%d zK3$>{x^|3i8y$ad^UXTAp6zCMytimmS>vW|n~tC-vnKOn)7_Qg#W{;_1@jFX`?T^! zs_`{RxOe_0!6X!uQacM*5@8FIxrrpp}d}#H4cl2RchCb zCMh$S)ZtGi$u+?SWSeHu$-W}n>SUn?wYvuYVf);FKsO*1wkvTMPF5{uv1?Pg!pz6) z#4HSnq3Z?iOEd}P@6vfY?WJL{BD^0Iv99B}sGRkQk~orAzfbF8A*%c}Nwn-50Pc_w zfKfR+FXl0xKrNoF6{wwfib6v?oJz`3Foo;`8MO2H3*_;!WME&Y9 zh&PTZR4rvGwhp$!-<<$neJ~e$fB(gaCGv*ox&Jt-Z48EJ@P<;3ia=m7S_N_CyfZFb z6E>Z9yd8~$Pps`;4t=HI%Q9y|5&DDw5Vw5pzcdgH>b^!S0+l{c2o!2=Wb3gJ29=C` zm(IK|pzl!a;0=|TgV-8@%U(3Ufl|=;*R~3wiF?TWbFuz`@n>ns$WlsFtO2gdyN|jo z^EDPkA&b${$@k4`8{AX6^vY1IJov&}{}5)y#>#Jfy}|CD_t3QIvl{G#Wx8^>Dj|j$^DwS!7s%_o9cxrtk6 z5qUj=3<4xDxArGYUlUS%gRMhlhq8#44c4zytVh~nr_V|LJR@&XoQC>>+Taftgazfg zHl*6@Euzv5WzhLj=dR;JYA$9R&>KuFKEBBPf;kU#FGp!7EL=6S2=EdeMeL`9<9G8c` z@~kW^$-CtDRGChU04BHh&jO`KL1WJ0#5a30p<*Lss|aRAR7A0P()C9-NxepTc`H`Y zQ?z@oW@MU$lYBYuR{5<+y=uE~6HH8b4ZjBL#CXe?{7l05p+Y`4Yvc9NhjmPgUksy4 zW9>YH>f!*S5<3+k{QjPcb0vr0Qv}O-s-S9oY^oauP@ntH)Pid~x5uGF_`0GeX$Eh5 z55sH_SDtI7YRPFlF4{qH0hO;J+Pc@tSK^{9%SblLvdE9kiVX*IP*G~;ojPsRDh|B6$1A~r*8H|Cx?D;@mO zk?{X?Tf*jR2HAh|=4+pQns~Wm<$i}P+DT7mgrw(%^(T(wwtXi8L6KSAapI@Ae1%YJ%&By9 z%<9c?`8o-aVR+q!768FLh}@FYh!d`g11F)rGU#BRCr1&IQm$E9KTfJBW{v}z390W` z$qv9@2j1Jbf0sugYq#e)9mL^!UE?d5A?I^I8S$$#sJD%Nyj@*8V&T)y;62Bg*Zm-p zQZ1N#Cyn~Cj+RbW3_6U|ybT>JUZYD*G=jMn$h!e38YWS^GR`}pBL|M$w-W2c~wEOjG|<6Ja`{Jz`6Xa zHb=i$*B2LIXpBcJul1s-rj!+;J~NUzYH_|O2k$ufp@G;k5pIUvZ)e}ijseg93YFW2jVISb^xuhthj)T@~=G1qs%Aej`c<9a#ggNb5KohGro9!O-1+RRm_pdTF zGqaoa6Fvi;^yKra9C*t8OVDvaAUFTaT`f|H!7xnoDqgk^dLT2D^)AzuCCO(OJm#<5 zr`da!L?h*g2IFwA^@h9MbgLns!hY;+L>C9S0!2ZwUj)Od4ce>ympH0%8zn2}i(Sq? zGIPHQtS@4X%5)qon3MgK_7Y*~3unZE#l%%0mf`HX_~+^cNMG6rW|RcU1bzL&(&oXj zP;yAmSx5X;Bhc~>jiis0I3rkjHBD9wPA<;`fRBEM}wgB zIUw7Y^)Ap5daE_}*A(ibWSmXO|0gxgN9qghyr;Y7G zp~PYIS@H%$f-EFOHB%PS7WIo0#Ku80kxeVe+x)Fh&SJGI%aV{i>X1~1 z&>5^Bdr-iu#Pu3dQQd{QHWt|?L4(?pt2SSa7X-20V3FGY)mw&IL&=D-HPQVG+B@4# zTlEUt>4$5s8oPxGrnjhpo!66Br+hR%hvRCdKP;+&zWFQ*3g2QmAnZ(9i1~arUGPf4 zh!7CkbJklz#U!~V#K*_L9yTW<>|Lr3u2;%K)A2X7HwBMeWCuqY2mbN%lX?0J#Nh@= zJaZ!5Sgg}Ha!in?o0^-i4b_32u7QK-dy%}I#Vd;v+-qUxjjj1P5H@_)D*e_e$ zi)Kz`|KgK}1q%-eB;UT?WY$6r-$w$ZaJBDLe87-;N6**i#$`tt7zXHK4d=r1!ED`b zW2;&SNule3gc#(IaBHbxb=Cny+kJ_}0V5cM^U}2}`A=EVU-XMVXo_-X{>7}lq9bol zgnyUcH^48I2%U83=s9>x=j}_PSL$C z?nElsHP=AM*52E^e^_2=q2KDIX1)S_QT(;DxtVAM(Px9a)eXC7ME#OKEeu*Shk@15 z9ap1bmftmTD|3B};*m|g0jacZDWNXs-RfW3{ZJd zk&1l6_}T0}#m);1&$EL0lRG{-hzA9l?wUcI#B{2WL*)$emm&EO^vP(RccwCz zj*jwS=D!D^iX+$sJvQL5?Mas~CF#rh`a6-p;;xaZHLq{2&NdiR^fd{T`n4)NV%5kk{jJ|#VcG7*Au;Nh?mFxp|3Qdv( zlfM9V@TYth+eA$(M5i`=5dn95$@eU&t-JgBfeFBY>=Y0Kzh{VSUY%rtbzTAEhpLd6 zEvu+I!mcvT28~2m{{m|pr`+gMLmn+%P;@#_B zh-RcuZSix@6OzwVXXyHn9apHxBW4jTzq|^1Rl%(+AO&{8QNxe<@7bIKrs8Gc{#+8< zV;6n7!FB>JQfI{NRsu=4;^4f81c2mdagbf?g@?<1K!B(RL@j?fAatjgqgnw_KB1AN^QuGr|;68kB$7NtV%RJQFM;!$j&($SOW&VYYL?DXG?4urv8^v0mt z8x}on7fUs`k|JYy+37z-EDdkFJyoBzro@te< z0$-%o>*`L;Wd@ZWEE09nB!$;0b*F}k^r+SuDM?ZEr=ts$YCQ}pDq#J$N~3$#ac*ur zKD|qKvvQCGg7kci8CURZo|QT2X5gkXc*kzvC-d3}Z-^y%+cT#l(oXX?j9lN|%-Y)4 zu=+Dr+v8@9eG?Li*D%`WgP|t(V)hix_;n~vt-_&kiyP-D^bfVWa_pJ=y?V8Ku0!Qt zyC>yF&~<@7>jXj6c+-(g80vO9XlZK>I%t?33bpotpyBITAl&l;vt_^mImYW1CHrf) zvpWH{=UzMmjBsduhdzq(n17r5zlqj*uF8 zTksiu1M0G>N;`KP#8PYqP2Jz9-Dv0>K3X#CG)S)R0VdZ(>!dZjywT1PK`{l3!|Z3_ z0k>}Dk$$0_!8@Bl>iDHq|3D={aERfaCurp>CK@mP^Z+;QZ2sM#e}oJ?m`^hQ!_YRx+<@TH_>FoG=2X;^ifnLv`R15iH_QynTha6T zyMui-z-t*B^ZZ;K+N#Er5!M#M-F;Pd3=EYlg30$ENgMZ?d#r#3pLF4s=n8U+Ia{!p zEQo*1$S?wYl!P^M3^SaOR8vCnJ-R`lVu9u$Y~_8dAQ{V>kLSpU>epc_;k)3A;wR)#n)qV?Gl!87dVLn9RgrFVZe( z_*!n|?c2lAbP5S+9fgZZW77dS5gGr*!HY�?;`TUj^G(Meedi?t^-oA?oE|xF(UC@#9Ll94ve9knZtRhG#xs z%I-7R-++=#7IrcK9|<(w_Zn2X44QlpPiN5^Nhw3-c1?}TEsOdrewQ`8X`Aaf5c9OpCPnttA%ey!5bpc!)n!+oVh z`$k7_;^)h?3ofd|LU9XfAfUu)OqLsEoMA}ILiudaxESI!k0&EJbJ!RR<3jjLdOYKF z1035^;Is3$1oweeL+yojLDv~koA2L6bii6r6=cX~2>xAU<-ogVl?i-S!6B8c{*S1^ zHH1)(N9Tz|hIq4fY!@4d>wpvR2{+&`>GPMC4O~gUX_|=tJP?xmkpkIwwRvCCf!75F)*BW`gZ!IIX+&nUagHjKWg_rO7NlkK-Izo@w zXG-y4QUkQUFHeSn=eu4QIZHJBTB1kl%3{fu z{;TVEDZM^6)PntlCK&N=D)riSuJZIW(cN6rps7kM*&RJZ*9wefT-I$fWR$oM{j?o+ zboA;4aSG;c#v%2A>YYpPq`xsdQ$w<+7IFwdA1{Oz%vbIIDUmd4V^@((96X{W1a&Qr zEqmrZ9heM7;`&~;*hl5IN>TMJ?XMmn(8|Xco{j$_I`n@3Ye{EoCD{Ba)*+!nlh+R; z>UWx&G=-tucM%WmCBnPGC7D=BZfU7E{%Bl$@GIze3d1w!mtB(fCAR|ZzHi^XAe!(z zu)i;=jZ8Z|^v!tj{ZiGV9oG3^N<`#|B}g z3;sGL#KSlJW=ofp$H=QO??I46ay!KE{D*F^^TBUr zL5#6~nO>POyACQZ`Dx-Hq%1iW{_Tf=aU2uFZ2-`}2>yOcWL7rXy>KzZVCS@ePF#mreswl@CmIJrG@xIU_R|tzUEM)(#*%-P{ z$Fo@ta6FG%peM3arTe*n4t#by%(Pr+zIjftAFu^WMmj-M#p3VL)%wUh9+lFazQ5^DGe0;1G zTr)EdGh`PJ03o}YL8b~+!fR1~kL}U?MHXAij+r@6&)RyUb2%))?=3IX0a`KS9i_QP zNm~mKT)Q+5jm%J-M$rCjE$Nn6>C*l&MtfxhY%eI|M|{gnyVPOnB@svh5XuQ!kTPyn zPf@boc|&h9ch_PvC79HkTvuMrF31M&*Sep>TY)W{fS($y>%%db|2bXG*A)Hz$L9~C zn?85WHUGfzQ;02`aZ;?PzLxNwv@O>chnU3mMhh*uI%d*`O|6f#%2GULld>|;1u&NOMI~8FuCfnd zeE4bHmXdDU(5|z8wD7+)H}g=Xjp%HHg8iLcx?hWwD6T1Kv!vyIvQiK$W&_G(Jtplk zj}Lv7ac!&pAi}#xB9w+5_r<(%4kgpB{0R`AY#vz2I#e1g=3$Nui58%44?3x34Mc}h z;v2TMl1tncEr9>a?6OwTNMEs1sG0OXDC{AIpWhcfWqsQhw!-wek}H>%qv9IsO)W z?*Y(tY-ahKA7m-+ zjeS_}ur{?mez3b`y3AA=a@4VTqq)_dGq^3MFKYhh_UlO{7i(`yOkawq;!Y}bzwCZO zTk!KMZRDYct1^MhURw>o6l$>7l=G?V?+(R|mK<<*ars8K2E2R!u#b8jFK{yYGMezyQybGk-~!tMaF)J*9JpF8?^U za2f}U;#|a@hHln0m5AL&ye(l-zEMJ`N8AqGyx5PieJN7j!OE<;o(%}m7a|eU) zrU|`Oo*BA6;SP(aJwF%3qyCo=jdpvvplGFfdb-6gaZg(?w=bS1K8-cy$Jp%!*YDz< zZ?UKiI6>paVHe5?~H)bT=u(Z?qTHTWQa-oN9|K6dbjo%W|&3tql&SN0yc`%jUGc6bti2%Ljs-&YVFe+&VQu2`KRk>rM>K->DUHYk zA^R4j{$It@f_Jg)IU%RJ=t5RrFwMx=D!aWK@0)j3Y3GNyUgb_w|D$Z#lDD|b{Dhdf z3$!uiu&sP0FO^=a=Cu64C8xl#D8^~m(s`B?A)?>k@i3v)v2$F4($s0$RD!n|W|oOi zAf|?BehKO^J=g3y3if}D?R*^@kGQP&=WXjJQAmdI!t;FRu-%|aQ`L5_}8}AcLw}Ssca$hMws#l*( zSz8U+i?P7E+amBjr%>N6GY3umkM7R?E2%VrdLfc%ao$@;+Z5R*D_O6@ZH2T zRL}`E9ne-CGh5NdH_Df`(n_1su}LKa1r)SK#Gz4+HKVSb&A~t&l+LoJ@=RU%RTo#&$-X@efiK>`cqdW(~9ZM-^!G09P_Ham9GAGm*rSeu#T6S z;0m)|{mH1i3x~7~wwYHM(jb%jGAAiNCJe=fIgvlcym+PrlA2*#1L%HS>MCU#PC4dl z@#_y2GH>PAahAL76Ta(URSuDp7)J>w)QNE8{utzvof+ftmGbaGz)9olxY!jnhU9%^ zbUK7D|4VJlA~rs5bs<~{k4x2nZ%3Y`OXTj-0~8NNS!;2Dnpx>{cjDoJ__pqdSZ9${ zy(PPwv=96}A8~%d%mz)jhFHAwwZ#+e5!9vT@E|!Y($9=xo=7`N=)7K-@Igs!C@8PA$Hh31ZG1@Gt_&aNfylfmN#&D-IX(NYX?3v?a9B< z>#Z3VR2u-CcH;G7M)SSHv(@!1??Yy^oNaXu6y(a(boQecyr*EN$bA=>M`aQfr}j|e zVwBW&bBv(M-i@q>&0zMT_CLS3?&t#@J$p>d4YQ9cb~}*K*dv4VBbRSTzn@0<5Edfn zHa2*Eok^P|Aw;y|vH+%cXWOlrx|Wmp?|!zZtuoKt_ymG8NurNXpEqf+#ApsusJt51 zu!XH?($-=!yw-ZRFKyTZC}=r$gp_~gtSAM~e*O39-l=X;tke!)pR?|jOoxNq``$9DhAf9blE@bAKI8PD12`{4H;XzZBN-_w|2M%S;hmszGqh`_r`=>s?%ww ze7(v3Z=$6*y>;|F4sy)$Bt`E2lLz*f=?>Nic|t3FwiJIERqsTnT;u$vFz1WfqY81M zlPXyO$2JCy)_y_o7-xwZZr+TF<&)s~W|&oSagXHZX4GCv&Kw8PU{G+-$F*Bt8{*lY z`|17193QbNv97Hf{XD5%L5;F>rZVbf`rs=RCUNjM|C;i8@sHGlJLMTpMGxr6;M0OC z=PAYk;9`5*j|%@hkqu%dKbhRq4B|m*>Jwp0GT-*wNLuZ`V6`HY}A``0jpf z`t!D#SQ*=t>kJF)B|uliOf5U+Q+ThJ`i=hoqZa;?ZAI84qCW!TSYl z_`xt?s(kA|=vFB?BkY5D0^f6vl)uh4=EIQZ>|&W=)atdO_j!9M>GOSi3OyLZ#H?Jb zeWI
d~!EsjX>45cQ&8#GnXb~CMo^H!Sh)=^J+85YJ64&@4*a25bPd7S*_B9%cAJ8Q?f zfRRB5@!_Z#29{`jq_EA{x%lbwOwGS0juK|GtNBZy?X^FNaUErEazu6eGCpqa?Xi{z zMpJ{DsP74O^Fgpqm|-(X7dkF>%rDMVb^CXjJPn+e0)!qn=18Wv;D{^j>b}I~Xv>b~wdyxw6+CpL--|2GI5U;3S=0b6dGOgz7r4|<3~R9@DR&b1 z)Yw-%nrz}cJBvUhX;RZAc@-izu4@7^tSzmohr&L5NE57!U5!D z8KCYMERsH;`~I)4f1n^gi~?tbF9_G#3+eBV7Y_ssp~{&DUcW|$QA^E_*>wf5R;|9)!`t*4{L%*e?I z0)d#-pFA=Efljr8KqrUJoCdxTxhatjyqxehP%F!Je4ulDp#1;@s*GdWe@YL0 zKI{3!)Efl4%zylQLfzoz8VHopr~c@Hp}!U7fZ+u@b^7S2B`#uy^V-3}n<(}NZ{O0p zIejjYeET@JvfU{sbi=qZ^kZ@ypXw*wp(k6&qUg|Aex6w3729wb8jB#K;G1y6MbB%q zETNl$tBrB>PibTlaiw?j!#gQ3@FUd_)a^>Le+~~k2M7ZKN%nu^Id%NzJ@pJQ5fJFb z7%$!5FP=rP{T=d#`Ju$|A9!W|Ck^NvUtBLd0dm4Pvx=`3zWA2%Tq8w^9;CdPee=q( zv}|0jY>r=*-PpoU9KWf2@Yh73ucx_fW!Sh|K#Os`zGe(zJU2Vy} z`l`G4PGA_|;zN;a=jirVIzGXQHRPW|Wv#<#C&F|Bg(mBI-6y$6GnV>}RXf1ScV(@| zpTgs?9?shHN8Ry{_rPFc57$BQmcdZ=r>pX0Jt@Btk-6&n)m`tG-Lqo%7xx+i+n*5} zOokCMtGJjrOqf*8!o&B8AM&lcCAW1Nh7?D#m|)%xjdn<;fp$@aa{JtcCTj57Kgs5W zGuK_8U}3-%*wW933+S~z5P-<5L(X1HU7;jhubJ_>s9Qb5>-UT6x=+=Qt~oJ=YhG2{ zEk|Ss{G8pX5Mkgz6V+#F0 zO*$O-!>y%Fzo3kUI;ehgE27OJXm(0fq7)JV$RN0gA4D>NKwr+ZUj3ZGzDV~yV`@u; z(2=R7KpmY-T@t_dJyAAm-SB(v1ugCw(y7-w-hN&biq8bmG0UFHnQ>+*{1HU_aos?! z*5Pf2xcgrbz8ih*SM8W0C(9J}(vYOxocgKMh8~&PW}9Z&oqzTwHgzYScIa=_K9JE2 zeg0!}xZGPtd#lzw%rEuf($G|02Pf?su0@Fh1bTHoT&lCx8eFj9T`Az={R83-83C z9>uLAvYS64|8%f-`r!Wfw(73PT59g?c)z>%i)((ovU#ZsbtJMDy0niZOh)E*>^6w5 ztg?VWPNuJ>&J|qv@u}IbG{Y#`4x)MK*K%8JAz?jJg$thc;i)a6NbG#zOg#hT?#gLq zrA`D#cR&Z?rIy@omaPk&sTXX@wR)AWV-F(3yl!2)6;lrfVv*TpzPoKgukJ04maP@K zz`Ju0h8Af?zGWBkE}W~m$AS%j4wLYKyhc#t(U|8$Y|CT(W0mht``*g!-~^*&(c4UA zkUm%QT-4C%ZSAFV%{r?T)WxNmX@pyjeBOD>NAIsU{aCK!dmVdk0x|qpZ$eREY3lKW zS3j-cz>DCyR8u(K!TAN5bg{;xrbZi{v0Jz8TOP3cGVPx@uQhMiCsE1+)st56`+bu< zAQ|69QU^kc+WZv+vXg)4`8tMC;sK)+>(Ne_)>E%&(V8=vK0=*>o<05ti{yDrbK9m{ zs0{SOPxyN0tD@ztLbo+aa#uf3AEsL>BDb4E%6rmA2=>yK>2KFKkL;CNH5UbDFvhz? z`E^jcJZf`cKjn!_TZi~0B`t-)*K(zD6CYyb^Ug9%+i*?T_;&2#%__T)8dv1yy|gf* z`uUrIO0XELF@?6faxW)vkAG~XZu>sV_e%H*PeXJQF zUamkTYd|rYiQTNMY*heQ^Zf&W-;0f{3`EB9`Znjlon2bU_^P)LbHd-HZyZ<~DV1Zt zaUA`6z}3m_W4+s93BRhJ1#!AiqXu6Z#3ULM>O3-zn7n_VRjT+w=L}|Fv0#uKB_3pE zwD|tG#r3(b*(WuF!>SO=XIh+#>O zuJAN%=Wx3oJsKV&`WU9QTrEEuM0C76ddfe@%SwLEdGQtZXOpmXD+$de#Y zPb^z{E4!VguPOZ}&omkR3W(b6r{#xBieMuNMh4LY9o%T6;=$3fiIJG4O$Ygp;8;Xe zL5KD(A|sPJCk4M&;E6)me<2i&3Y*j`FE}-))xw7jz$}AU25^SHP#J-xX_B>&WNix; zxE9vTjed2w*I9a|V(9FtLK5k;)H{>(@s!sEU$lf}@^cPa;5N^*jt4PwVN<%4 zoY4}q!^opEAdvJWzPk}cGyF9?!&&eYY+xnFBlf!%5@*x=H&_pfQw>uEyTX&1vRbpM z?Y(DC_2>&&v0WW-ox%6;U#y|fM{+Nli6 z7h1#hDG%2{+wYo2uZ*<`5Gv7!w^zOnL6c6uSHvaWTA54}Mn&faMWZd7_R!_P>_}5-pIjCNEt_JpMw1tR2p}Lf$~1VB!1i zS764-X6pqjdSU1)v_mW@L%t|bPrNITGRQh%m@Dvm_wMqT_p-^74)kdsyM} zOr^F(Ojy;pxB*dkIrTH?7_(G;w;2wh??# zFWV^8$-ngEmgOQ%i=1JWKRLZNX^4xA59~LpgR04muaKOR@OyeoCp9%`R@bFxl^;o zqT!j@EQzun8e4%S^9ICBEWc1ATH2&Zu<$L@&`N0QxE3^<8lvau+?YOw^IF6;kz3i_XNXptW8%>XwYIsLOfmV7M?7x=obrxrr8`vb)G zRAa<;+~uQP>Cc!bjz194z;Ha|`@Rt7=Aq)ZmmKCzNAoWIVZl%^JQan?D6k~R7KDI3 zI)6w8kbJBfEIiU}=jQ`9H#eVCTnU&i$$MrA8Sb`iYXea8?orm|4<>xTmi~rqLOK2S ziN!d!3Hd*J$@D&tr=o@?WJxNk>pC8f^+U^qgj^ioGI?uve8Cn-w3G@tO@3n!%*JLN zls(A43>Z%Bj61K+#)~cBm3u#=`@V2Cv9c{~O^5z%;|p{aq4Z{-5B0T9JEf7YkDcT9 zt;?rDyqav`TyH+z=RG(dZVsG32MrM@@mfikx3^Y8l0Wo2pO}aS-VGL(fxQAEW#?4z-rIPNMuC!e*R&m3^OUnMay? z^goG%whb(gEgo1$$W5Bx9|eneq%CRBVVe1Ej3o@`p(6Hv$Tl^uE>s|OV6{}OkmeAVsp?>}`f!k# z2EZjU=l|-#j`d*cqChXg9Xjxh30ZHYH{4Zoh;Cs^fPybHH3YWSG*QkJ#tOM0U4ueJ z#oS!+dz%-VT3qTM=~i$ODqMB@KT!Vcja6+mpf*C&@+~*CO_9$A%hj5NC|V2jpe(Uh zQ>TC3WKHRat0p~ZYH}&DKA?T)w(^9zGHyKA%I0v2iNUA-+Q$}nN8MlbO5;>$CSz{} za_ey&D$0sm8V)Fpac(c^+t5h45sTJl@BJ?T$6_wNx|esmRB(E?hT}$Kp_X-}dC{+g$*2i;Ik{=~Qi?**B|H-p z1j<@J)5`o+{kmBybA2ct_GyORay8F{wbqy-(M|#U?iaiu#?PR{JT;zOXbmgH-6RQp z|8oGCvCnQSTIfMOM{tnekJ|jEh~3MCNd*w-3G=`9o%^NXqbkrEXkl722)~eTV4mdVm!L ztgP(m3kqVju(Udcm0s>|M#d}TNTMQ1y*$vHGN-*nQ&zrJ9mW#@<+LGeidi&|myB2R zyyN-V1C|DB3#@J{VZthfmdxF7@VpSfb46o;d%m){nJXN-dr#<8k-`cMSogfT+y*r= zAvlh)AU{;R@rf&)kIVb4Y@01ipBe?c)IF+NsI4el0Hcje>EfiAy1c2j6o%Q|Ch;j0m!aehNn;{%xAVmhN&%?Q)W zA4VRMF-GpPnk{%2StGaFmZp;7((z^K~4-680_7UQOO zk5$Oqp5JYODf;&ObT$BB0s{Q!Osk}=Za)>7ajRAvL8-4nxkL#GxjSk!{7wrcGHKQ7 z1YqMwgBX=BkblEw@uhQUU#lk;_3b4Vwe2O77$P$Tkr|c6Dh!)K$47%YZ-G2a-a+4$`$eMy z;n1m`T9{u$LjP;lCy~Oda@fEF+3FFdHdk3UIgwGf%0aFJiDD*QsFnZ7^JrInlxcQF z$)p35rEk?L+6qq?qx_tV{LT&Tt^z{qbMYZGh9I;AvHQln`s`cWf(roCDF^;TC#UCo zhK9iFto|;F)M)-Ax+MG)tv~swLZzY z)qglx7@ht|H15%l)2m9+*nsKH;2af4IC<* zX$Fq!XMnZSwDZ%YI#4Cwv~b{UPlCe_tngqNEGN}1U*W8G)L^DP|E-_oEoY4De1`Pu z#8U#C@g!ZuDx`-zc9xWK3i5tO(EQ6O}G0)_si7hBaEzFO)7rLKdPkUxufva{2Jzi^_ zIsoof(YE)>XDPPF%7&DK^{|o{!RU+`Qp{{?q6xT=sBtUPga!5Cy|6S3Vh=Xfd4^vj zq;ofK*tFZoWBgivw|qoZrg4`~&#YUN)Jr(L;iTKM-}2|Y^DE30&yy~muq0&LpzR50 zie-ZZmnYa?MLp_1BzJ{cr|#&*~hlplP&tv;0vX8mwK$QK5PS?G$yI94VSC4)!SNwP$}Z{508J z#;$|(i4*-uPS@5W)T&`lv=yc!?`72GY3S>WKX-Dmqh@n>RrX4-^E?wsd7mx(<*nxe zqhv#oCr}~Q>pxRyZOS6W=#27k@ftbc6wUKc;{C%eJS!R}y?vfO6{!(m37)iBg9?dL z0F}Q0E_cZJ__PLmF7wUOZ}aWz1$2y0xJP@9Ta78y<0pVzRL-1EP0{JK?1~59<3MSi z(uSCWtw+7Vt^cB>FE1OK`#LAE#c>+`>N~;KZfW=yP!vm=9A0`;EEel#U<|ennLi77 zHZSx{t8%J&vqKDXTU=8W=CQ9c2is8b<@e4`#}hMS;=9txixNilvz3Wy?>)6UIoi98 zZ4t+C_~hVsE*uMWeDkXs!)iQX_kKF-z1MgR{rEUpS>Cf%&J;Y^6-BM8;VfD?##R}D zTX~k5IUD*eez50;3Er~F$}b)X7whlu1+kp$6fK6anKX1a5U@5-C=^c;lj5fK2Vh|#m;q)y{^$l)486x`qW-Z=q`5JG9@ zD*9BuI$rs_X#l%(^C5!cO+>A^Fc7Owr*}osa)8WB@Xx@p+;$=lC0@7)FGBuAy7L^*!P%IAx6(9>7P!IXFKeE5k+8b+4*1Qv*;Jh+bFvbYK7 zP_$UfKQ>_J<5xBy-?u0+3hJAIrTpe?>*{u^i8HmE(CDMiD#URd1ewONoy#JzXR^P} zMQNgk^fFxd6p90@m#*HQ=r!iJbG*y&iUC9#r|PiNaHup2CtYiLh~Rflf?m`CW_==5!^2{nWOQak!T!?Q zv(`PiW$tK{OMm@TxeZV%Mi83Bl4OA+4;YSXmto<}KoqXcZ>F;bPOCUTOBlvT zn*=ryHU?;YT9s0a~e3u)lc>~sZ~fRUQ-qUA?m1zd8ZVzj}>7Zw10xegF2c4-Z_6+0Ii58a%ccuPUIQfsgJxxf0i!i62oX(adhLhDra* z_TzZsL0u?Kk(;=RWFa!JPHqVmrs1yT6xlX>Q!F8KrM&maRW6X^GZE$B!Cq zvd~-qojc|}Alkd%^6vpNS|TqD+fZlbvZx|OlX4-I#OA5JdicL{OjJ8_`u(~E!7KrL zy7-mMu~+k(o}kixsO@=^3k)Xk;pp(PMD>wgzm2C&ZcjYQrkMXjGj<7hK(nY0v9CH! zL1mUk*<0U^IG@fBKR#iez0>SUBZz(#euW-NI`&_Hgam|k6{>r^T8m~Y31SuLm-nTA zTBFi@(A5X)T1hBmuHyfZig1}<8qxi$!mUSuTvQA7(v`L)lwcG;(-#1D6&rg^05m9P zGH{GS91md=L7fbliP*f+R)Ny820Z*oKZkiOp|%HMk@)w)Dbc`#jSPiLY|^?-KkcSe z$E{eqf{cpGzuY;FioBbc^$rjllgo{3?TO1Yp*u)Xq8n+~UwZ&_o9rJ00-p0eMeuB| z^5Au=DcuuJMY_!scKF%3S911AXAt z|4Ndm^)x74qTd9fBeYrgAAAuYwpqtqX@W!BoJMui^C;$iOU;daN02l-fK~LH6|pdb zsCgeJGcz`|TUH{E6X_S{RR0&L`#%NF_MGb_>7t53+T9GFZd-JxIYfNmiB+2A1ul&R}Fu z$mlV_9MlTLKwgO6h4hz#XCH@5coFLzR!WS!*Bvcji2KtFhZA&>NdhbNy|=B_vW0F7 zODd%Ywk#^xBp==8)a;;+xaw99$|*@tBC1me){%QmW%UFk(-13vE9k|@Lv;x@l^_=4 zq{Z5&23XBP&+Jl9MD5Ynu;bE;PWN~^M_!F^eV&>GI&iCOiPE3-42QAg8{Nm;_kDnq z9l^t)!$roWCU#hlJhzaW08SIZCDPP;)n6ls_rpRh_J&v}|R-F)}# z6SP>b`o7u0|7P+N$zW;e1Mpds19|#gTb-t%l?bPCxy+z`_rehX5a+n5z=AfujV0Xg z16!F{VTid+ww_<*dZkpk<;Mejou8l$Rtf(!$9fD3guBD}-w`k%I8+So!}`eECo(W+zoF)_SohVKibdz1@c zeEnvEXHP%V;)6VRqL{Xq$<4|gQUow9Lf2jk1=rhLWpQ6^*^S>J#_33xYagK(m7MaJ zwakntahFhELbqqSX!n|uT%|+;_oIp(A2+wuZF`ju?iIOgBSS-)Q>1428?5N-!hd}v z54bCY0!-(=%yJIEg_h6;+`$-PpF(I_{o%*laA5nn)uJb_Xt5!@=LK%Z{IZpue_93`07p zYydS8$*Y#W=2W;-eI&Bcqt>>vmXd0nRD(`Tt+fvRsuQ-}b#Ct}UDPGb%gL>K@2dMW zkqR_f55*(7z3Tx1J2dYB+2XH$d$iKI^dIh$84Xlg5{a7PVjWIE&onw8xCB}F1c&bx zg&b~r^wtgvLltL^{Bns(O*(`-dlW~?UTs%TX(epM#I@gVs%nX+y8P#S{_2~N?A0h* zt=sA;w=|oc!S|zzlh#b{Cmysl09Y#sb6X?)WwOW{!5YXV(|}E&Ea0xw+)^~TcJF%C zUg-<=#G#S9CYOwcqg5Y3Ufy0vHMg#S#zzdt$|$sk#eOl3U)x)dVGKu{^sLyeDS(;@ z1EhV0Z3;ky4nwgmzK4CTZc!s4z47MDm@oD!ni0^CK*a_h?~_%(IsWK@#d&018)342 z?q@@&L#=xL*XfUNwG($6n*x_qXU-ksiS1Ic?g)Qz+NP_6yBHn&;_rbdz4 zJ|FW~at&1Z%!hM)%T1xDM%dhM5GDetJj=Q@h#8%hfy6rU$%yo=CReUY_ZxgxZ9DY& z7xF_+?*K*ttpjyC&w8O5MAxsa@a1<3!ByN15(rAv5y8B1gNVwWniujvDAiWkIdkBi=g z+FY`&$7Mw{>VR^=YiFH>Y9f0#w4<#)rB!#>`XL=^2$w9`pNvSsVMARhTSV_y+zh2A z7eam%*p&B^saJvMPdFd`@=DO}URR0<#j0)lcY@caX4?)N|4+5mCZ|FOl(8f;UUfAa zUR-=;gn9PMuWmVHkW17xDmO5zPFY}9uMMlvL2W2%&(AX) z^53Ft88mOyz(j{{jnM~2_N%Szr)0p{U#zCUqce;oV2Jl+h;N2ukxBU@4tVfJfKn|o zS}|=D8S=T8X;2}yrm1y?);;Nt<);Qxrj%{`6kzBfV#sR*qK2?~c5=U&TkbVhrcM-) zW)bi{yLw<}sKB)eF^{GFvgB%F?HHKekaNR)mqpwt~l zQLIq1E5Yn-cq!^xOn5#b7OdlfUPgQfcUGyVzGA3}fM?>A&cBFLnjw?lm(B!D_?2M> zb&n=vKTI3sf5Th(O_VkKd9r8X^CN%CG0y;Q+_OuX)UEtMmRT`b@{YxY7=#o>dHR{( zsuh)g9b5bzNLTv}00&LHT5TN7eaz%_IA(@^+}V9~fC6XWmN$O%Ft;3iVh&3>+V%3D z7-gyY0bZf@ikG+i)JnP2u7j>FS=&rKoE!B^xc^uojCQr~PaIsKyuT1s2HR=f?yDgf zRXR8A%viFfH5~3@m^)i%xnY_^d-b#uW9x%l&L0i)#Pz#7X==m%he4AHZTELqx6x?6 zB@<`8KMJ#ZEz;Wx8{zltCcjb=$QRroO9(%I&uvB{E?0$-y=zxc&&rg^Vv3DhaSi+R z4O;hiqUIQsng@L-EJ6W4f6v!QdUz;)cM~jk+nDj<&XkGV z378w=(UDS0-u3@Jigb=M|D<7dN%|5zcV+OL-AAFq_~TWIxef6C|Ip>iW;mwXV?E3q zp0kAYyG{Ho?TU*b?I(7PH}~#4w%x3r^fE3CK1iX}0ZhmTAXor7&PVX@aUR#Wvu+l~ zIA%GF$Z&9+20`>Cz9wVawL|y&7z$JP(ZX7zO}f&Sfy~-FyHXS{*TB^T0hsUHDS_Z0 zJ}VlTwkS?w`zKv;L z2q-id{G0pIDQva&9I}I!SB5?$?LeFuHU96?(HAcYCoS9=02SB?pj1vNfBq_BRXyYA z!RgTf>i-CwO|oPSqjZ@6(-p_+`g;1mFXjK=^{oHpBVJtg?7TiHU5IGxzBabHba)L^ z`tol*FD46KA8^P{zLVpV-sG>BJo*!PO4_nlCr*hOR4KRqc69mAc-l9DBnXoirC2wP z)L>=2zUR-E=lWC}GSU1=X*3xIp#0xTU3+qTH@geb4u8B|4V>p+^tLT;k$Zl0Z*KBn z2A>|tGfOgTD~-6}alQNAB;w)3EA@!JI~f}d?gHId$ z)!hqD_9rY0;I7xVBTG3r?%QRiZaSJ9pr?7fSF8BF4a_Dc6FqO`8W>+4g2e;-!N>1K9CR}Xg;NqoG6k zQ%hlrK|_fsjia0S?@L7!kf&Z)I$b;mI)OQRx{81Ym@k|*LA!lKxZ1QWQ3DK*6TvP^ z%#s?(9i$3-0*m-pigW^$6K!}$vPoqS+R!)ZzZ)C7$Mul~77pEA!2fKnTtGUk;{xFF zDHo%x?M{`2kD13D&Gwa!ANp?6Vjc8Gf08PTrTc=P$1!^#9+sjuaakOBs4CRS`?iak zvPHslYNAOcrKRSzkGRdZaclD16scAp5wbb)`gjdrX7vUouR~>tY@z4Eh`&VWKz(b# zKFi@Vu_$Q;vcw z)dOB5UOZ#=kErW|>jys6Zca)faYH{G)=7NcN@2pv#pPVsoT86dqt*3Ep|tDvdKaWp zw}*oqO1A444ydWME$LSC^ZgS}7t=Z8flBxn?+W^N$E$kghykuW)vIOFtSy+K{l((G zo+HJz;+>yJ#XB4%>A~P2GT31tp*#gr*bEExoH^&cSjv;k6V!LGK?;~FXz7s)3x1Zp z^Fu18d(vuuV1Fidc5Y|8<&H-%=394q3#r=@^D}sN_EvR&P@JlWY#P5cH2IP%Hnn^N zCgUpWF2yBljE9X?m40Vx1#1^@h@V^2RK4U+_Wn|?cayQGt4Kv80Vn)h)z=Wp!*e^R z{q%YI3z0*Hov^`PB!2eK)5+SGmM_((wIdcP!(Hm~pilM7_3NM}KIQrqhL>jfh})Y< z?=GMF%otXN_dqE!jF!opHaNoy>CiV^dE%}HBbZYI73*}(j`Zy(X-N|q_*){6;GIi$ z)-#GyJ&?~n-)lXJLaoAK88WZvYs)8?rdsV~>I@DfLX5XayZaMChLjRaoQ%@ApZa8M z+v!ZZT4GZHmbTP+8_=*a2*4Wx&8nZK?QDpMpDUSgYqo$sO$_4mX-^e(Xg3pomFTgZ zgfdsx2M+?J5s>)F5q6A+EhAw*aK-{+nAH)z__gt?cO<;tV)fPC4z5P!=>q}D3|o9% zM}73Xv~C>X_svY#cRqq=;rvlc*pfSgO7(rnv~&@oVY0rtxqt=v>g@#PL~{KOBPW0z zpuhH(LRvj{3&TpqVzrFEa0d4k*)C~37*;eH&LuWNXEnR`NI^KN$b#@6q2I5Aa7GJ| zG3e*OL4vsP$wdu=+SM^LBHGs<+;-w_9mz0+W#!82xY!wu%S6e`u+h*EgmlbEPwCGa zex5;XO}-48@EY+lj7i7zo0kHF`5YSm{?({qYU7Mz)KpMipY>$Yc4DL6%(L(>9)^@x zQ+|*8@>Lc+fQ`BM0kAEDo^YWw?$5`>&MnFrCm$WzB2#`Gt!1a|57u_qh8-g6f5q{n zI>+6Xh-<_QHSV^$Y%V6v%7W&Bye4Stw)7tJid{3-+HbS;#%}Uzm@TU}bupr+JWOI@ zBDCFUbE?L+r^Vxn{)og@9?DMnN*SPhpS`g&MlyNBGB)LT)rhfZ6yOwP$pJZlrj(R1k_ud zaSsORaFqS_+{9Ah7969l<>C4;b{lrRN7EIO>{^$c*e$lH5gN8*df)}=hU@g)P*eEO z`A=V-Rp81ClH8I#3JLR<$0`R)D{8s>4#-4Ks3(;Zs-y!q7l1{C7cj+7n?I2mo!8@G z43TGWBfXI3R#lJR7dneQ%tTOfF*9O}l>UWvFqvbMb zYr&hpU4L{(Az>0}!oslri6CVL@2xR|%_duhWMrny*NMPI!OgWkRtepMT^la9Fz;5+ z&B@8DZen*sS-nvl4gxEa#AT(tPeMzUcDTgf0n^eQVa9m|Rs`CiA<^M>Kq~@Tm9(6w zUZYNjs)yq+{>l;ZFAbt|xSA!m?wgkyXFE^h#~Xa_Ub4^SNdCg!Y_U*M8h=L%3WHXa zuvrlf;<$47kSY+j(urkVRqvN_MfZZ>BkixI^BKmSLlVV;V^DTKK8v-v1|Pb6!ZI{t z_Qt*5pM0*O2mzR8z~h|sU99(#!bGbm&s%&cggM;aT#DLJ^-%DB`{P_AGy9OYjUYnX zYg2kcQ~^RJ?r=kSb5HYes)#O^osJT!{9dq+oD@*iv1r&Bl%|vS_^DO25)^FwPjZR2 z!}~7}Cq5+XXq@OA#F=Be zCpLUy8NHDYPuKDeW})W4FP~vMhae7=Y%f-Z2;gJR?wzg=yNs!EiQyhqg{DGJqNj4KLb+5Fm}4}& zhjnTchi;8A<3`KyAhiS#8n*bV3If|RgHHOs?1~&QG%>IIV#Ia1rV{GxVB#IdQuC80 zxQ2q2AXFsolD#!Kc9$(>7W|siSNez<_=^lmM{A!I-pu@U!PK4+7<3LRz>x|GrD5y$ zQdH-v^Jis;LcN1^H_@9kS=}Rt^c`#Q{mk{6X4y5=MPZDXv2a1Zv9YghWWLT0k+Ox_ z5U13Mp~_G$=z&_haw2|gBnnlGV;di%wd)U=eq@583QEi=by<153Gco523GmMV-vfH#u#R7r zncL@bNbRt)a`E8y=_h5t)}Yk67+qm0hQK&NO@4+uPQN($Vb;4v?>ugSTvQRGms#kF z@#>EEQh^Iv3`9O`i9O$FZ_q)rJgC>6L9v(P8;9I&En8jGdD&28PnBXh=UP_bmn<68 zc{XUNyr~ebL88*LM-2`Rr??G9qMVmbzHC<0-&VRjdfwc&#kSMxl??!gJjb=H^c~&l zIA-#U{0`SsB6`G|KqXT#TXKD@1}K&itN6BIU+Ng~<6}&)O2o%NNaDBnoVO>*$0^t1 zx3U{oYsQ&Zs!zHbrrA)cMLvvdPG2xjxIJn#YrdCt{f%nJoe5JDVMPYxoDyF;m2#Ro zw9*xV=c}|5&^J*092RfZ1YTLx9o7a4VKMe|RY{b)-v_ACG3+Q=9_qaX|E0QrLZ24~ z?mL6tH0>3LV`r-4%63?lUc9cS{_N}i2;|{sqqgphrr-NYvO43tC+LaJ9KlotS2|B( zKl=zRJ8P;j_z^~z$L)Efik1YWFdZfYE)TCPJN@l1ma~zq6sY;Zz372j7OUimI$xMw zT&t_BxLk+^?aof9xe;-GqWy#4-E%F>&6`Q_|axilbSvPgVr#kKPS-uN0cv>iFeS}sYVxZgS-V)wnLWybM z*sfVE4*9thwS9k#2^m6h#RUH`8Ks&siz|#F+K?(y`@QfZ3{M~u$Y!buSWoSo@4-WX z9kevNB;zi>q;P`*^3CBU10HB0GF zcKA#oZv1=7c>mm^&(gJHo||w#|FKgj;1aTxWL4AYerg4}GVs^IGs-ov;F}^t_v4eu z$1k4FvosX!aA?Ix*Eo{7cp{}}m&;!BBVnUx2$KCB%B4fIT2Ub9d>jYTOn3B8VH4Ld znGs3<(AG)Rq`;H%fVodWZqY!6e)-QV=cZobWnsa_44inSRrxmO$P6gV{UP8?u!fOx z`w7KjSovH70s|-qnl*gcWVH}cvBlYBv9C zkyW)d1bLK2@Oi{^%SE4)%GrtH1xPO64%{{sSJ;6)Gk{Ej>=*#@*F;3j z$#o^Id2Hg(R51ANN@XS>sxl405Z*M!=v)E7+LaRZ12;WDR%Q6#Z`W_&QU}B+2WkyY zng4t63jvgLvH=i8UMzn)Cf`memt+JG8~zqWndN`3wN^}JJ10HxpI`z|9uoi$eCD{# zT)+-=H}HZ18vXM;{`BDG<4-_A&Iw-S{MPuBfH-#Ic)A-)|FU_~y+FNcfu(ub`5fnh zbb~qF_KPq8msosTeIIm&5A=6*@7A$m|5K0ae~AhF_YU~Iy!B^L*(jN|b`t0vHS0Tu zREqsm{o7sP`^(xhlq=NDeg?Tc2&kd~qGmOxn=J%TSI1KScY!3OWU3=i#j$ZU|8@(% zEcv&ez{LXSc69(+T#tT=0sN*5%iUjgf2q-Xo$LUGC+pF?_C9616@A>!N@)zW-%I|< zz-`NNf5Q+0w6kVJ&25iQ`1O>IW;zs#15BbrU-lkh(M{sG&pg5hOKN{G z0P8Ab0Ii*K;jSy#w&NDk(7*kik%hHZ{+&8ELJZ?QfDTHt2A34(W@F-7{J}Nie8B!y zW5PT&z(7Whn&Sd_vU;Me+t;tfF{&*8V80%)({X`?^gxBAT5`E|bEi(r`a}f$-Ig4{ zWXk>#4yNZ0_`(XruHZj3AMo+op=E1JROz4ybZFtiM>WRhu#Sy8**oytfw3A~eT}~z zHV2ees-PaT%bJD5PAd0jYQmB^n&vGV3?nP}NK3b_g&n%*KNG*dBVcZ6Owh+Wz7Lr5 zkRyzuUuljjP`9p(f;H0JK>qwP-L%Rp1|Cd*w0AKo4*g)RB6#q?@oO%y*I zvIy(ahu0rQJjPSWbR?`J^`XW4hls?GKYd5GM&j0n2a0A^<>jRZ=lA&XL<-zViYe@(26&j)J8K8D1J{iE%k^7`tBWw zhBN0g#s9K|weK#vaft?sQ<9qF0*?xL>biPa2PFr#wUE`W5Yd$&E;pX?`H9V6|Ivk= z#ztyxJ{2`;ViobwCH~uC;xVUuu#71Q0yvJ&Pypv%!q}--=rsH0l&*R1l=E!QkK8`N zlyvx3BzhpaRwE`NztTM;P|SpL(dtHEi+90-J#R5mD~2|5^}!RtEcE z^GJfK#dZG+VO~Rp-(pvU@g5^GuoSO~Hmkd6;Q+OW~_LhNGJZ zvDbnNrzdY*YRP8`eeU;3vDWUF#f3%y+VkFzo2(d&r;|Ge0QxQ}>YgqDnlJtF;VYer z*gZ3u4yrpETG zT8oZKH(yZjAmSzOLkT@y#w20=fh=s6Z~&A#2yn({xVFq3EmiNm1uVi*@9Wx(VTrg}RrVkfYDd6rYsebLn@U`#Q}r zJ}={Bw)wGC7t?!cFNAmv{Ck2t>YOb7r-bA&$G?-M@rUbyTK&VStC`=OIN>Wa!Yy85`U;YCbSMM0KJmbU+G*rf1 zh`=GRNR{E<@8rJ6(z<=XzYIuoo}f>%<8LxY3oq9N2mLy!I0Y0V->wX$J_j3_c=~Yh zM{@Q+9CiQp@!DkdCv}BGvwVLPy=FXZlRzH5TrrS=?~;v|*BXp`FFAuCh)&2kg!DW> zR=WS>*B>7G`CJ%ll(Ma8>Ug+2k%Yx}nG-58ojvU2;#t|9EV3TKK;O3ppV@^bAo(PLN z=2K3UW^#?A=O)iY7XEBvk)o%(?ZJtnyOrih&;y)+BCe+yPBND9_<30l=55w2j8){Y zOz!tnMiSrUi;wrUg~;v(&zEO9qu_>kd&Cw0s-y4SpC+;WNYvua4lBNQ?r?KF_wX7o zz~x{h60V~I3#1Dmg()72kQrI&UV=FP;pc5}F6WM4?yh-=21AXMx3svuYwrc2PHw*B z*9OBD^dl0!{ru#w#o2q8SU&4>zn~Wq{JvMNXy4JpXM$F4dJv_`9U(zf6uIq^S(geD zK#p#4GWX1W0GHPwKj4XKP>2WFV(<7_-20~ECz#h`k27g>&or$RS_ktC!3<*!6pdvJ`&{p- z@iz*uC&*ns7(`^0e@{w=l6a>F%u97onB>PCqup*Kqv|2;^5GYWH6E)smgjTcgAWE< zQgw}c){x_~2mImy>!Ic2RK%O!1I;~-F#U`cx;U2@&Xo$a6UG8tE^S<@P!-?4F&p)w zMC0i4<)F@W__eK{LE^iRvt}eJilC?R`1-SmsE6@brE8s+MxB3GC-V(s7`MCXRxx7FudT;!0@T0jcMqZ75g6WmPLyQ# ze^K}5;ZVO}->{MvMg1xfq2(u`P_~e@MG>-$Z7hlGW674i2+39nS;oH3*qIsIm{5uA z%vgp&ip&`M7>wn)#(h8ceLR1>?|VGQ`^Wpd{;J_S-|f22>%7kMvsh-JI__+=B+q;Y z{g!0X*6U3I$}J6aWNcU$sUcPl1-R<|>n3Kdiglv_@o_VGt^R4w-Mr8C3*1Rs4@UbZ zuv_OC&iCv6W`EklSoA9(7f**@2_+Q=C+-@mm&dVEHwN%h z)LId9nl{KA7k)0 z&9?g}js?do$2Y{)L3QrdYSty4#EjtU6R;TM;p=AKvEswu>uYZJwNHUs+Bmi4-h1&k zMZc};xo9XG4PqYcDv^^uau1^6Lw!$cJUnz*-pO?6fqB~{vge6NPt2Y|SHf1GUh`@wIf_|z1=gP+QO3>b=_|*$N(nQec@U5 zsxo1F_}>NrKMNMUx#Ew)@>OExv}NyNex0t&~k_I zt?SadFx@nDZqBhCpYecN8wAXo!N2R`6aWuJ5f>)udT{IF8x^(Fhcraov&!Sl!7M^Lfgds2Go!MaB(y78{LrMvWqU3ua>LbNI z^_Pfc@3T0g4gE_Hb#dx%4#pxN+-Na05dr%r(MfYd=Jc2{6=|6;B#YakEuflVkI@`) zis!WCu1l5&@A`%;lsSEGrvAhIqh_1ivcm9K9l9;?d;QM)&cRoP5ln|!pSc!_;!YC5-VVb#xX%0>p= z9(0a<8mvadKWhr~^P28_UX|HJ7Hwp2GLIuR$GZ`=VdXPgDz{-KbMmhsdtUQ2{~gQb zMaZCiE?qy-xsBf!U-j2$xZ!(nV9l?tAN?BCoq97tfo_4@k3h%#mkG4Zk5xI zgm-$^*)0V9V_V)f*t?ooG6JEYBb5fe)wziN>nl2wTqd$^W^Ntp%3mID%+RS>f6~KO zNm7}C8>#Z!!;Ah0e%&H!thx#%H`42`w!DyU$PmS<&ra4~3k}>(1kh7`TJZo!L|}lt(6M!D@RwI(}G{9+o)3dru$-^G+8I5OE=_>R0LR?#3VZRUe=&C^$6NU?qBjb$+7yrc0?%Cmp?0@{VU*2B& z;Wn~p4^r@4&Mp7gp4)g%UedQX@8Y_S<3&VK-_NB)-(i+|Kdr8ycj2y-XQJOOw!9NZvb5e^ z-g>e6_or4DcF)pfA#m&X>W`45)R4x?G37l#Jw%gN+DH_YRRYR|c8;M8sNyixp9A^f zcEhipc}|kxU3ohJpAF*vI(IAjypjj4oEgfIPpPbt-Z#S-T+A+Hnk?+-mmBwmJBn*u zrf2O`3$3r>EVAdut+O1H)9iIStSoC8Z1+s7mZb*DZ7Y=8CvsOwmXnl8wuo#ylX%wo z)y&e(FU0Q_m@gEATR+yw6&jcS{{H=WhoK{#8^G_nJ||B)0qW9Cz6^A@-E+ULQ8o5U z{G+~=&6iC=OS&;N%0u5h3028me9;-QnI--czdy?s1bA5SU6XaO0IXY6Sn`LvhXV%| za$8o9TRz=>`3GtT+99dkY)X%lTfUB;`S)dD0xT?7HJ3j67dw=ridyzyZL(J2^1?Jz z@I9FoKxIh6oKaaS+Jx$0Q2iH6{eZ@TW60U3Tsktkt?P(V9}EAJsMaJ2LU-kB%^v3v zjNi|-YntiM6?&)I*dq?y|6GfVB(&l+? zZ^b;XqbotH<@@BK<}LiQxDVkZ_m-RNx31S%Se2wXi1n{SC^|Ek696 zgpMvCwv%a%$x$x2UH&zoEH|Zk1=i3Q7fz68HV~A(7uE^C?B^SW_Bf}NI?qUs z=|#UbE5<)}l(nbgD_tEY8R7bdm2ozr>hU&aS`xw>fue8|wZIDK8TOQCkJayR#0z{FGNq$GSBNp~3X zOS*QgX?=6TS4_{aIQ5|JliJ_9EkaS=iHF1*i+?LL!pq*=tZkayy`Sih*h2c3CK-M@ zgmgl#iq0CQo_;uaOE&<$n40$ zteIrOJbQin6TeEBD>i%&NIv_oQ|lZitAby^Q2hq~U7SSUx~vAv12R*@(c!^>+Xc6b zP3rB{K{w7@y5_!Ferw4`QgPb+&H}YQC3W=Fpf0Q6n({gdIjD=@I7T+4(Iu~r)~`1c z5&htEW{8{HYVP6*+2nL{yKWWNSDW2KTd9C9D6Jvb+j>E^3ON$;&|WGqJ`TO(3cx05 zDMKz}c8XR0ZX|`(u?DfOR*4%oL;RhxH>s?dse*fRHlEF~Qk_vljn8$O)?uqway`;2 zIQz1?)_~nPxump{=UgfvRUa(8K9-ZRzp+KVw)~-e?jiobd$AG7BZ7hFu8UTiA#cIO zl3BylS(kOi0euI^)8j`zKb^A^lsGDP1hVYvSj_ocOIBL7o==eL`rO4iOk8?7!9tHSWBGR_9&ZI`f3hVxGs!Q;te zE3f5e$a0&lClsB2Z{7TnLT##6S5x{G;J?b5i5B(97 z_Of0@c&B*vD$MilVP7#m`sidS?3S9rdY$;t*jA`TlL#i`;61o)@z>3zzVXgE4H^Ip z$^O{h{rCY7L4~Y_VPP*bYC}#M%!E79vNxBx`MQ|9;j#v|3~XkF`m)lVnXT5R4;fz3 z0(uSRLrq5c4desHW2WN7mbnA`qUyMu(7vec9_kqW&Aw4(H`9B_> z6pb~g&{KuD^{PoFH9l?gR~~M{Y~|r9d~LrN1Xw3-IKi7G|7xD(LYvDt_N}P}mv@=@ zdWK(O3ousapJ#)f^)c`1wRRP((Ip=n?(?B2e*w2CVqQO;w_e)3!^Ph-S)XouV0^q= zL#VL1M}6&l(GL1Lv7$GQPd&3*S&SOFF5Fm}1LOa71cqX8`p*2;Tm;H@JpL$HeTXYe zft;lAI*8z$iqv+miax66b-|n?D&?modiokjc-)7USDP49J-TgoS99aU`UNo$)d}qc z=)VVkTwgf%XvBfI&Tsd9n>It0ckUDuEV=0{cU56fpE&ClxA=TfE`2DDG9Q~Qnya0k zK(jfVY10Z{DCC9o|8ZT+NeJ|X@?dqW+Y@6%OTr*Urj_N}!=|!NsPb)V$HIQA|@~MN%?%XG3U7_}^ zRvXx}{+xU-^=ZrQ3<^nCsDfD#=}urmSi?RMis`*CBqRb$wsKyDej+MksAlns8Kb*P z)P%wIPw470;?lJuk&=2CM{vQysz=Q@d-qf{*8^o7{-)0#vNe9SHagP~=Em4OVPzlnsc*Nt>lA zwCexuJr-Pr$XE=pR==eBv(BlHy_n2l)@&-g45GwRMYcOo$*^&nIKOb;xZL0e$PC92)pt*_NaRItUsX)q$5A8asNW&tdh@EAAj=G-of>`7WJji&Wbe>u8agA+ECGw zf|emm;e8weXab`?Y{B8iY;8~H3Y1l-$hb_)!X5WaiqWF?=`l`f+QM=P)t6CHNzy8h z+Pc82HE|ynO(?c0*nw7heAt^8pXzxsh7OO+QH?4ml)vZOEujUc7A~3_%Fr|y>7i>6 z9!XYlf{sSNnr!kR;nlGKv92}Wrdv|V7WT-;d+Ry>pbMP`twce~EJEM0$^V3J{;$*5 zr|xZKfxKwH^(?ZC+VTH9VR?Hj-Tk3NuGsuZ?C>3!>Vdo@O)^BZ>|?Mo`M9kV6ZHnD z_Q#TZXUb4O_(Khn0@>N>8^P$a$CT_t8blzy%1njOi#yEN0Dk^kYx@6#KK%c@hJ8vf ztEu%q$02JfBiOSJ2%0}Lj7?;^v(|)k`ki0Tpu9t%wS_xN&b2DMgm06)^LJx`XdjCE z_m$+nDtjtiotTI1mSj~pDfNx2m;yQtoduaXsUUrsg@hAGLri90qu1y>`3GtUd?#Rsc%w(=E z$m#}qrKMD^2`O4Nm*zkBmoed&7f7GH zbAy$oD6ORoA*;(9n4}_XUBdl}+0J@R-DC@qvPzU#&V01tg+HMf@~z|fu;?A5U z#j2rfv{`itsj>5p%AT5H>8v5%RM~PlDFyq>ThW3WVqk=y^M)Fw08r9w&-i_7=py~w zsQgUnx4o6wtE2paV|yk8*+KTul@f9b1s(XDWU{7P?MFNe<*4@sFyOh=vlk#A$qp8b zb>$W>@oO@mMM1rTY0_<+Yu)H+Oqr zf@oK2Ft&SrC28pX36D0}?Tk@l{Gs0P_e;L2#QwtREb=t6=jk*7H52x{eBP(>D>Wfl zcy#Zz#e>1ypeL1KioIbyMen0KegDWk&z5T2Cc;0*2rmHKan*UaOmQZl#egl)WfoS| z(`CPnFB!i!JYd(veb1aFNvUDprj1>j#_{DBLIl$lwab@YVmy)C(FDT?4S`DQ2WQW~ z9|FR|vvO_|4zmp(^(OI5J0r+A z(fMYUlISbVi?0Ys*f9aCAXl?RYgS!NvqUHi?tS25;I_2x)a~mV%B~C)&Og2py}7IZ z_6)##4J-P)6fc=d;6fN@b4Ow5Fe!zjOJPG})W|xoAFtLGczn zz}01ZaljuXj?qyvdSV#mZ!lNgS6|b;6n>=3xOkS?!m`Qyh{b8XqMa9WKv2=nKTD^P zL822)%~mgbr&uoBX%dwo_|r_;F!w3=D*rBJQhH^EH5I(t$5$AiR%7Bv}Rr-!D;WKaP#I=p)XeDNoy&UBKZ!^nFON&Vwm!YY1tr zMZdzV3~@`W({X}p(EO3_=Qyp3RCG&GXn{kCY_D0eL!)hLpdwYXmZJ3rVyB*-G)ltxKFT3)0Nm@m5#EHcmf0gf6U66zT-({Y;25u0>+|Z|Za+oA{QCOQUzUmt!-pBkJ!8w#z?lhsGtf0dbbYgkgvJZ9pvS7XeDTZ21 z$(1;@`Gg=g1E-E$g5_%Ra((5sd1b|6%sB6wOs@TCYuv$e&fDBDRn%F`X#U)`?3PeU zd#Ked_7wkEKJ)g1@1k8gM)|ipu_2$2CyL45Ku>Qmxi3ZlQSdk5wDP&pt1yfMU)UuQ zhmIl5j{9ZUM+nL|(J{(DRyhoo_-67ph=8voO#b6}ikmuU7T7S#b`4kc<}%cii*78+ zMP%4(mF0#S2G!h~^EMDELgp4tW;u0;!i=sNIl62-d!@PN{XrG`vY}VLhOxU_-S-hd z@bdvc9sLHhDy^K_)D{HVPM=t6R`cq6YVl1|aoj}Qh1tqxS*q)Qs`>iB#njR72(x1K zcM6yEUn0bwx(?Tm#z7IRmlGA76!8lXhy_F+34`#&9Z!5$kLVSR|s0Zzgcw;@J zSlMC{gq+d5i#!ESh6rA;s8eKE!X>2e0v;vpNL>YGDcSM`t)&f2;FV@LD=?KfELr+0{<{5+7WsWo_nLd-wVcAiwQeJm%7-20yO3Q*2=f zAExn^(Amp$2S!z}xTF*YK44dB?n-U!%1G)->abqg_@wgOhOqOQdH+}=w{A1>hwc2qXej8g4__>-oLnRX4wIu9P|^{QI@r^<>c zoWWt+D5pa5P6%LefF0pj{H6P?Curad#I#4b9><>>tZcSipps}R3!?YAvJ{w+tC@?! zjBh0*oV#TAO9p^o!v1(T>uL1Qv2d( zIkBpv(_|&oLdOtdOw6X^UzqJJWgH2}UbkQb)?sw9iYIDdDPa^ofO|Ih4NK50ny^?& zH{~a9IKG@~_PIbaCDY8T+|&3I3+UvENO)Nf{mYH==tHcECovSCFyez_0tQsi5Q%L^ zzK;5snsU!J+>18Zfa_=HRcMye*NGc-O0xd3q@Ey71*0IX!zrQY`*z(O7;2T7?-7Zo zN9b-{Gk*#^)lMqnw5)Zj9f|(c)<2#2-ZK(Ulv$Rep%lW)dp#aLBM7*nQ(*jBdXm)w zEP=px&U-$LI8kmZ!H{>#q>TtO1oEXY@eaK26hX51{&bn`ZE{81$EYtTo4rd;xO-6Lo z$TkBRAy}wE*#A3I9qF`pr4T^KW|`aP-|8>C%7#-n%0hyz+khn9@%^1yox#<$n*cR+ zXOn!(lSZxXIxc0j*#na4dO)aWx%lV7Ngy)BSdbj3)@n*wp7YWYXmqNP2g02NFxNQ0f#Y9PXrW8R1`{_i%IGhLe^C<^ck)3U`uu)s z`hK}UsqstZ!B+&ieTi{E_X2Pjs2ucJzS}QyWeA&7?zSa9uldqhES;XL*(nKpZ~NL~aGF z_djtsysi_Y==)7e@AG?K-twqtT2cscGAX2-gp{EF0gqMmxJ12e{oVDAg1PaV%b18oT_8?f z2)OR1z%b|{9Ozq@IkoAmw*S(eseYc;)sp{*R6BNktpEa7NZD?bMYEQb@vlCf<)l6? zYFh}z|MFMq?x}zb7tLq@tVVzd{)}-|5e_=v3a8wVqt8gt zl~7ZFg9%C(IlItSl;`_XRB9)I0iM)6*czP&+nVv&!kO)n%xIVUM|HrCuL@`zGKB{V zyHh!coyk_Po6*LuvQ}jX$)7+nSB&6tQ;Yx@wjCIT7QXcB$yTN!DRH`G9rjqQ>3iM! za@nT;*!Qzr@t)YFbmI5VXi|vu`1drzuyVKd9pjy0K4ht<4Y=opBDKPlvT&@+#?FhK z!u?G}=A$QXF95!pL2ucx_-dyx-PSOw;&`)Y-|lQF?DyRQ=uP#$|Fnvu}ihj zKe-jxSikIwD7B-0r`nAdXu36am5kSn#|`LHAk>W4#sZLl|XiZ{AI;{aB{#xdvkQAQKc>Wh$o6lYOqwUF^D|}=IIpi zv{abK=>CC?qAIt=bpiIeoVFf`4QTAedG(5e*tpnHs2g(ZS*9m3EBeTeB1fsRz?&&B zSqmTh*HUds(LJC_QFM689l~Ce8E@1$r<*O#9<`&nUNx|4gargCQ&?fbNMlg$c}xEa zILXM=I*G}F#b?wSM=H&G4KC}h-5&pT7^{n|pq`j(!ih|vNsXw#Dff!jGtwCNhJr#& z8EC;iexxSH81`6XV&jxEQ1j1MzvmQo36s#X|Jj7jkpoR=`u#9M-2elW>`ozG75nj` zjbSK?2Zh6JA`~xZD1o#dPb%*Ost=FT(q%~jIQ^i;tDIYLA}gUwK;xiOYIRp1yWN5k z3~2MM;As2iCxrGQ>Maq_+ks9I_&to0!TV!SQ}1Y+UyfB@mqI#-4_&N z3LK~7K|$nI%v4L>Aoz`O$A~y)jt)1%+awl5XvsoGMam9IKh7_83z5e@06J|v4KTjR z@P!#k^y0rz_L6?6Z==^GHKBU{#l67wBKYhzuZB|-l|jaeoCb5A7E~lez1$M%ovr#) ztOj|{l0UWBOTdAHxel;=Rj-n*DmFCswqolp>MdKxbls==aQQ&Yw#l<0^jY44hE`QY zqo>APvrAc)sBe&1&f`hjg!|c6LL*a>Cu$Dwmt}6HYD!aD6T8chCePcn^Ve(GKa1i| zM6`M3^c#7b3r&$pW%tavkDT>t4MW_F7R%QO=Op<<5_ui8_piY={i*k_kN@Jm!$Ki$ zLvE$m9rvK3=|X1BoH&9ol;RN>X#cFNG5Q`Ui)(kil>gDxn?1?piT6xF|FQzux}ghU z7BNqt&uH!(|MP-Aie+D?KA}^KBD^Mg;vhn|P59+aB}-xtqFU0OZ5ZEkLI_7==bnL* zLjdwJbyd+ynbHF@Z3ndT0H>^cu%MVQCH|NsgH9%9R$S5hJwH=H&hESO-cqt^feH3m zCDRORewKUty4A}S-2$e*n+VH~Kry)+Xq&{Z1S27{Y*fw5S8z*+*3vsO4JP#$rBxUz z;8Sv{OcqU}entuvBhbfpne1z-C2l1gH!h0R@fL=>vC3msVfjA0V%MctN7rzvzz-`e zJC_sH`io64l_^YkxafAIu*JnE?;Cwb^LNt;xfhWzznzgI+wwA9Da#)>@!!sL0;N4W z(e0ipK)qp@Dl=W$6w7qW(M~MAXWee>8hbk~9S9O0fzsbNk~Wl@9JOE;4ji5o9*}=3 zo(tF;Q!5<#IB0uzA+*_4O`kps<77vS7Hl!$06v-CR=*~9-i4hdjW-GYz9Goj;$On} z7kRJllGll&xQF&*!eC`&9Fj{vSaC3c>5rAVXc>$g*{e81b^klmBMUSFQ+#(~L-2%$U2C>u0w`dt`T#7j7(j9E*#G zM-unTWd*)%N-&;J9sI=kAm1SuOIJZfcISHCa@U5uAu1q_^1Y2-(R6t`5EK$dpTbj6 zF(d=Js6emd=ODPIB@xwLy_ics5uDM!FiWXr{NpxOkN&RYN3%4;5R3`W^Ty?<$(l9Z z<>N&((d+RG-k;$lfXv{2s4?PQYyEU6Xg=soeoocc`svc3zMw#qm07Sue5H>mEP>QE zZfq)T^;2SEW7D-wKf;UH1C(J1q`Q9W@8oKQ-Pt`dCkFxh@a@e+_j;DV_56X(+NeUJ zqee6R$C==#Oohh+PYh|-JWnbpg4^5^!_=5m5mw90)qL}ZX+(FE#g-Ma3Te|`#j+IX z>w54TLV!larF}?Acjh2*`7=pl_%!M1u`*3fRE*|*uD3KGp+XoTZU+9soG)7=aII~J z;zqnFO*k`43c+Ta^<}kHeX8i<4ZmigRN zbaoAD4*8T$G=_XIdcrwizedI_NJ&|Sj4BgmZ`lrBnt5NI2?+}9liXK><@H|$ch#6WopR0GaHCmz--YMMe!C-@Z6 z{Ep$&RRYGZoUSoP(pt}8QiBcxxM=nts*mI_2QF~D=-wRSjTY}$NMgk zF{N=(EA}37>BC`Z*Div4dR*+~3nUu1dggN^X|2Wdp`+vMwq-Tm3!Pugoo$~#|T zeM;!Z4?or8a*&@37+NrjwIJr$WEGHdbb?}z%II+@Ai7|;%Cwzo z$6@cGSBE3eWj?6((fWGXV)pW~BNXTE@t(ugPZvA&7EJ|3K|uytYT@OHKBhxJ4`B<7 zW}OPlhS)?{_N>V%>n3<~BSS#C26mhLAPwX`{Xl0I%rX7G?=^}$E7c=_#xA<8KN6Y~ z=QDOTOhe#Vn`pjaqJwu%{*ywOt`WZ?Ux^iVERGNQg8#%9C#$jdIyK-WRGkMR?UlzS zZui9DEo;!Cee8zyk!7`_m7+@8J`ICVG38arxURvkF?8N!uj+%(*r_QNbK^*>g^i|tA6HOF=>in57(gf!b8Mu9NDNag zE7pinpv-dG#7C-h-FG;ikUaSyx?~7P89O`pA5(3 z`Eze*CXAeqzrQ1HY{(^P)_PwEBO-$a1EMH93XiAcHSe+l6gd(Kt&8SYQ?M(uxCwZz z1-XFN$_@+V)5%`gg^D9NE$YKR_`9E@vCks}D6-0B+|y5cPoomKFpUiy`W0sizN3%% zm@7WjYyY#4qz`K$=2OK*7i>EKgnogC@@a9ep9!gfCvDXnJLlW;%r~&6yqR%P?i0}p zmg`?bkfjVC!pI=TQ7u&s6XoL4UKIH+OykD783OlwBcKYKji0GunCD#ZnU|MwP0V|b zJat`f;y_me5c-?HGbqSpi$7GUGXV02@0Us2FH2Rzw>q3N-PQr-*I+b2ci^hr&~ek~ z&Z=*j!(WWi^aC@)1LAVR@PJZ_?z!ILhVwy7l?4byzoCnn+d02%?VrhBA~MrggWA%) z=cETv8eQ%{j!S=)35D)0pdW|cUQ}9bk1S95Z=v3*{AUKR=Y!SPXXEEb0ah6PL|4$J0Pu{i}fS_K_ zgxv<6a{KfW3?$=9!xr7&x zS*l_eYW{JSWA4xSlxSu>tN$ePiGrX1Zu8@_KR}8&W0QU>Fr1DK6?5aWr(ctML5)ff zF#fZ+{&4#&6@mkLPlFh01RSbW6__^Fuh#A|MVvtF_&>^8`+uvc-h`)MQsh)DZIpMP z<(h8e?Lx%C4OJc6?**5SS+_2V36|Y*>o&D1WXiw`ULSrs`%eUTd{tWisX16#oYUWF zRzVsA3jC(P=cb#f&v18{KvjGk-%ai6KtZY7Rm%76REe7~(Yv$5wl|0jQ@ z(l{xV%lY&rOpR86XMZ-fE3Z=Q|NshV*fn$&(T(JP)a;_CC}`e`(N<6 zXe1_n24W8muDHPcf*<%u7Gbpua)gd9`~fy94vIT$|GJV?s9W}$KGmyQyh%pHGIajfX|{a zFY}JAA%Ja54!VISPY^YnviYL)$_|ge*bH5P{kCtxWC!cgzeon^^KwUiH}UK;kWuC4J5@LccP z*qCovBX_RgOThKz4^5bcO@<-OA6hgz^A zd?nLCk{9^LWSJI8%kTt4z8MQ4pJSnzdQaKj@I&M3>SO#ZV-_Xou-A)BSdd zVymC_w1qhuTR|$4_vzs8!G_sX-~?uCxZ!>F7wTlawn2eQ!0i(gFZCa*#PvTS!g-E%vX!< zC~f%!gfN0@l)6rBbq=Q0Ac_X;z-a#p5&+dwue97D6*B|=V93WJ=6G{46?FCeBx0y- z21D_OThoTMnSmiJ|W8S3Ec}s`1)_~b<=>(%k!kjz}u+BZTx66>UiW7 zp9HV<1CAU`;Q9CC0tgMyRIL|)(})_&gM_)@mL8Hc$?{1bwR%GbEMt$sLi3(&5ZK*p zi_YsP`4*{g$7bv60D{7AFj>KS4LaJade~qCOI?hz%($OT1wJKL8nYpDms?^_cFP@! z|LYPb044VTbBBi~f!KEIwWrtCB3c*0fE+;}4n3`r9gonGIdc!6PaFRijAyWJ9zA^P z?ZSWpO~TXp4b!hoD&hP3#qGg|1}p^oz_ar+ofHykI)Rg_y>Wk( zIU|^ssE2RGXohhG9>e|kw^_0z;P}96KG3eY0nXo5)Kz}4bwqD|g=!OkT>DTqO z?ceI2eT#M@`nT8)B!1!`Tg*$mPAce8Wi?9*KEl20TCho16x)_(MhGB<+^gvoKfxT> z7Awfp6~H-hCg;pyeMk*g@xLZiBgpA75dBrSyR>l!*q$o%f*=Q6!%x(FEcWe@Detm` zwJrjML>IT7HxmTl7;3S06%5Z*zDqm~LWskqr%H$y{z3O2xpZ6UUicKU3Cqp>4h+nM zzN z9)Q0^a!jxJ8XIA(2Vrr8Q;DP7`k{&PUTFfLo{~{<1cH}(b+AJGb}?cN*vIs}29d4^ z0KgnB!{1D7Om36=EyMTirEFRxdK(RizwAR&41$&);p8p8Tx7mpTQK0xh-~Wsn7{1W zi+C%b`t|IS*f(xw{CWr?X5On%$^0f&3K-~2`K97nPSgN+1?oo+lbH)8I!tW5|ETJg zr?m1@5O92yW-(*>$H9r|v9Rf#OWTN+vvg0s2LD^I<4MV&iNF-cz zd)9bP*eZ{Almu!e!rDM+`vqE8hyFc|6XqwR$>MHi2i!r%(t-5fj8u>QR$x#40+uW= zPGejJ7k^3+3_;Zc4-c43Y0xkH5x^~b&Ui15>Le~S*r{mlRDTwKskGa1>pz#SQ_KmY zD*TmQ6+D^ChuV%?gUoLO6u3XA8x-ou$ji%f@b0c8<2K66EM~jmSFAE!dRJ&SVWTxa zh5&j+hIZ3^@(c5ygOPEjgU-%;Ya>PMdC0@(3v|l84hH_Mt^8144l}{H!f{EzGhY6R z`Y}Y1HTM~LzKe90vu=6AJT!23b9NV*5HOdYeAx<0bYG~O-6}5||J|N&#mb5(={fUw z58v*XQUs24fRgVr&e66;5A(Oy|E+%O4}*m%uWt4FUL(^`99w0C`xBeYzl%99^*zL! zn!38^uq2a9C))<^Vy$Pek>aviDqDp?yV>?CH>oR%h7!3rYg%7=h(;R4EyQLoj-cVW zLv`!M!dEm~AV}SvsL4{DZF^0mM(>Hb zncNU7dr^$*N+uZWBN3-!N(z@0H{NP>UYb>TmgGut(|WA}tse&H)MEfeBMniyc2h(NoNCtHk#D%d!s;N7uy}!Y3>dLHS6^Oxmh1N@*oZ!K z1IjkJ!%GcMv2|`+&!DHTC>+bLmwH9WLpK{n0*gP8bbj~;q>JZc{R4d91D|{x-G~gt z*$90rmN`wVW-?9u^1rMs@2?(he(0<>WvrxxY!z#fc;!`yvnJ>1+AyQnMJ4bZ>tjJs zu3+Ifsw-%bYi_W;{|7AMU2=c5dBF|_bgoAYj+oqXdJ@IY+BMgq(k=UQ$n~b2rF7`~ z&`(k9t%$A{guLKD(2;krUgM0dnN+BIm$RhB@<3~(CcmF>=UREu$}Ed^VHT2E{;kBD-z{J%Xl6u-I0+^BCi(~C&IU_tk3Z%M_QqKyzADllgC)n`;)YB1TCubI(hd&9i7l6NkJn{{3q? z#|)e@4&X`rs~P&?R()3_;*K-REqqbcX}(+U*kA0(t5Ou0p&baY8`IOWK28nUmsf(Y zmo)|UjR_*#B(O!o&gM5v*J5d8d z!&~roj;h?(@Xw>di(o@Uf_&ks1V!uk3&U<%Krzl-~BUe(gvMNrxqsk3VZYsFZ$% z?;6wQQ(vmcw=%_l;R$o=xcn_B)LLX4@op8Em7&89O6Ie@&DDmAcD<8XzOpF&E{pd_v7P3zv}kpMuS?sPeFlkjPAd2M+|(esbB8!b zI(8IQkC*|!y&rJ$a^Z*jUg|)DxIZAeE9tgp--EYdb~l+6F~_RP@hBxhb`$hZgvrkydPJ{D^Yrs&f0B8zEz zEji*Gb06Er;#0^{a`8f5!2$D!>SJ$OD|V?*X3g|!O2ljX?F3lfKPsZ9-ai`%IXHi? z;%}vnzM(?Ns}mPfHL;;WO`>c@M?6%=YQfA4rls|M%euCOnQ5Af?i$2(0@qcR_veo` zo8A+?LY|Ux*?wC#GdscmJ<&gl(>N%=xVT&sQ*=1?wgFGm&Pv}SMp7@E5t}tW=d-_@ z8V}2Rd2nQlk=1I%wfk&!>7A^!tcfUbUOn~`gE~|6W#gpH_|rjpU`AX#b1Jm@Rap!2 z?CtY3Ad~8l6%1ew@%m4*ILpr7e%R&qaop(yND9^4OiWdUuZR(@Do7s7{PfA#aAX;{ zI>L@|SGDdtcQ|mSVT@8+*2lyeIDgB=LW$>lQ~BH?RZZ`y)>N**~Ec-alG$ zPD^x!MfHS+7$;_o+z@2ZqheZ2+U10%&bfnh3ap!)I_;{6xM+NA=&IK>w{ zvNK0Ris#iJHY{&3OlP+w0y(EhJg)wDUTPs2&Tp@b0{=^Ig`37)&-$V^7lZHAZN0I; z-Wei9+-_{tS#XAu!e195I{e`oT1L_8<;RY^d-~ZOpEcX<|4q3Gc=;fb2Z9eJx|!Qs zfX4HQ(oG}gY>cSoskF*AuH zw_XbPmi`lZ*PLMuH;K?au{I3fKCG+Yt&TV|G#OQ?qwLhsVL8-Z<8Z!f=oXm40q|+U zpV~HFAm5UZz9jqi%{5bo13b|AJXWWb*^3u+zF?RcbUaOJI{z@6pl#-gc7x8Sm~3ge zvT4YFiU2Yzdl`YKazTJr{BS}^V|dfa|B`E(^qPR=VD}lmCuJ%_Of$VvnD5WG>qAYtA1AcF&BXfb zovon;_z~4n&WGxlx4so;xgC9^3rJ5su~?m}Zw^HZ|)}`$XRJA`DKhN?$_Gt5^ z>8ly8y5`=ew}BKSAYj~C#iUWUnc|X)+5C31B1>~`428=UzM7N|wMXj2wC1bmXC^p$ z2&3B(GUXPi!MHGCk!J8gyv!#r>RgYuOPFSHC^3slIM&x1vYVNg&$_~F&N$04)b869 zHZjiS*{l^>Wm+vxhRx?KW(3O>2JR)}h38eUwEF=$WfNaYfN|o;Y@Ff_yP}o0#BXNzDqiNVTeK%oWrQ5@ zrRG9sgfVyAzUwNqiYLsg_cE~uZT?o+2gHim-W}?luhF~8c629cWiN1ThzyO zyOlo{+dxZKlIe*ECsehI*V~@4_>5ovHNnNATB5#`vne?Iv0&-cZkv^GSx$mB+#yA z6Wag;?>rl*{wi3vPpd%bj(ptXtsr-BIqvs02L1N@rtm_=^mLPF{*I1N?&AA~zborn z@uQWyLACRPYU}ESe;W{Ot`CjwwuO7#FvzSZ4)F~`g{+O*atE(1)+GIdH++GL8NZSdbe!ns(^D3X1`kIW5PAyl2 za`6+@0R@yxt<4R%`2}XT+)SuYM^v|fPmD*Pg_spI^r*H`d1r6__TvMPOIHG4;D>{p zA1n+z3FS9^>~ z;uq4h-C*V&_g9B7MuXB$D}dJ10d`;;noQLnyel(4YglY%5H&>82R1-|Oy4n5xKALljh zNKYzjP!CYXzB(y)=O31R7ZpLI&j`~^Y90V{Hf5;wbggw$c*H6Va`e>^-g`sixFZ!E zHF`z#hZ8vWECqe1X1JbcegPnmj#{{oaGp7?u_t?QPR0b4V!mE}lXJ6;j;_>ew0D;9 zZ_Th{B+YpYyGj>RHIdJnxYVd4vh3pzb}lXq5g_nA*}JNB-+*@d0kl-t_25!~I0;;O zvrv}6UOfw>-tbpdua06244A$@2nF#Q@0rom%`|B8w}LF{gN}#0y{?~-=PHJ(#D|v^ z&R9fnKmR9O=>FOgYVI#gi$PIZVi-nd3va!_k)K5U_wCLJ0L?yT1LRZU)Q zT%(QBz%f>q(wKMmp%`e;$$&EfeBI_O$y-4BG#IlnH}~7GP_e7P6aj1wmlz`{$DKcg zZl)Y5U3DYUkBB8*4*fPcY2o-lBcT4X05hQtirV3tXefO)F|*F=Me)dPVZ5=xnr=bT zb&fi;A*$U86QG2(4BfK}gpsyV^7I<7N3XsklfFPugj@@t{!jm}#Zacb7c15ctJ;7< zVulK9yS#nBje6c-Br~zJoa><~8Dz6r^VF-HRvh?LXa3ZsI@*{SEZm?-IxzX$@9FX9 zN-=h!kejpu>trfMVi<1yqTMAk1p(;s6Ts3QjNGVe`=QSN6Sn>K6a0Z%*abRla_)ME z1(JX#&Q>%)`_YiBv{t1#kscXl*Ic7Cd%o<@hkIVZR6+~>Mg<8cGt5zJZ1ULd)!Q*4 zOAg5xi#H=nUw+}Qx4!j$i7FP&pdOLPr$yxCeAD&^?Acuo(z{6MiE4T|a~g-gtV}K* z;5NxZq7xku_O6Tmg^_FfDif>8mBw-@xA5T=aygM--d+$ndjZbDYIO$|9TOj>4?Nl@ zbZl{yJ%QjkS8{9pBiGWuga)amYKTY(yaX_x!%ORSn zc2udoFb)aQE|>K&&ysCdPJP_*X`!qP+Hb>hTw0b9zLmQVCMRN^)`YtIC==!@L!Gny zY61&=W{Wlp2E`Y!q^&(B^tUyVd#J_a1RSO+5Z+uIQU)$z^r#mk9Q+O z^-ZFpPVzIUQ8uZnRh}7*AE|p-j5y z?gvpyfzNEgC^VRaEbDP1Gxs*>wU^Tc^2#Aq7|jSP_P&P-BS9Mx-o1jBsf1DBw`Q*I z@WgRScup34M{GTddQ?!t%Dsi++)o^ySVTjS@*63xlWZ4^AnBjf={KBH^mA@Y`tu?R zo!s18$4y$G>jiO_%asfT6BP~CcY1GQ)Lxa_{zCfRi+}mHScMoJ^1yFZ-3kJUc2=|3 zV6pyoaau7;z!=@deYJLFlE>^1TH)3w()hZX#+=aXaLrTb9Gi<5Q|)tB^kdZpqSw@t z#8nHuH|V7a(M!36PIO%6g}_fs57F9F^5^2W&?<>uqX`VAp{%+mVMG-#L$Hir!$BQ_ z!I}9NPd{$Tkh%&sj!f1Z6gp-4IVBcYxAh8QTlSnSjFy^aBd!hzSiJE3vuP2;MHky{ zIz-3UqZ}`4rX563GsJf$$E7=mwU4EFT>RkpdyhgUX!^Wu8i-B( zI6MoB87@%vIHal`tl1KIgW7(*Ig43NZSP9=pG4@QtultSd%15-9b0coGLsglM&RTg z|4?V9qCK2|EAMg@NnQsYs&i&`jK`CDi#NIQRrdxozFT|?4fr&lHOP!Yy>T?T6*9p0 zT2V89EqOM^$Eo(5-5VCN239YW`yVBgP-|o35K&8luzhkUf3S8GXrmFqoKI=P9e}XP6Ds1+~5LKVJ zA_)*+EAH()F{%mQn&D#z_#)r$I%smn+BLDi^;m7XRp!B9V}nf$^NPDm>qcuUB6Zvr z213$5R_O7z|0LB4X{H5|T<&li4h9d`vX%1kcp)JRJ~yyCbJsGa|I7$iXjRHHyA7ZS zF=G08onNz+vw*=#o!>+yY5P+GI=%*9LMsF8B19v)LE~a-U!;iO?_4X}osRpNDTbYy zP2T=`+;B3x&z@n#P!g;z%{ELiIMbHsl%qY1miv#MWE8gyR5h4_0Yvf}Ye;>ZP6sHX z8Zi88z<_WR0#0$v-i=tm{C}B$4@j^7BozJ!AcD<8fRE9eSw=J4O)hvdDe*d;@1p>NFXr^Wi31xU3*a_C?*FoL@<`qv>!7{bq|n@Jsp^gA@b*Bt JGhHub{s~{{LF)hj From f1b911e1a38dc449d2c5318f01752e7ba01ea691 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Wed, 31 May 2023 10:22:41 +0800 Subject: [PATCH 08/42] add explain for QuantizationSpec --- .../quantization_2_0_tutotial.rst | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/prototype_source/quantization_2_0_tutotial.rst b/prototype_source/quantization_2_0_tutotial.rst index 1db08d61074..772e9cbd830 100644 --- a/prototype_source/quantization_2_0_tutotial.rst +++ b/prototype_source/quantization_2_0_tutotial.rst @@ -66,8 +66,15 @@ tutorial for how to use the `QuantizationAnnotation API` to create a quantizer. 1. Define QuantizationConfig -------------------------------------------------------- -QuantizationConfig defines the data type and qscheme for activation, weight and bias. Suppose -we want to define: +QuantizationConfig defines the data type and qscheme for activation, weight and bias. +`QuantizationConfig `__ is defined here. +It consists of `QuantizationSpec `__ defined for activation, weight and bias. +When annotating the model, methods of `get_act_qspec `__, +`get_weight_qspec `__, +`get_bias_qspec `__ +are used to get the `QuantizationSpec` from `QuantizationConfig`. Then corresponding observer will been created +based on the `QuantizationSpec`. +Suppose we want to define: - Activation: `int8` data type, `per_tensor_affine` quantization, `HistogramObserver` - Weight : `int8` data type, `per_channel_symmetric` quantization, `PerChannelMinMaxObserver` @@ -105,7 +112,7 @@ we want to define: observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr ) quantization_config = QuantizationConfig( - act_quantization_spec, weight_quantization_spec, bias_quantization_spec, is_qat + act_quantization_spec, weight_quantization_spec, bias_quantization_spec ) return quantization_config @@ -125,17 +132,16 @@ defined later. self.operator_type_config: Dict[str, Optional[QuantizationConfig]] = {} def set_global(self, quantization_config: QuantizationConfig): + """set global QuantizationConfig used for the backend. + QuantizationConfig is defined in torch/ao/quantization/_pt2e/quantizer/quantizer.py. + """ self.global_config = quantization_config return self - def set_config_for_operator_type( - self, operator_type: str, quantization_config: QuantizationConfig - ): - self.operator_type_config[operator_type] = quantization_config - return self - def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: - """just handling global spec for now""" + """annotate nodes in the graph with observer or fake quant constructors + to convey the desired way of quantization. + """ global_config = self.global_config self.annotate_symmetric_config(model, global_config) @@ -150,10 +156,12 @@ defined later. return model def validate(self, model: torch.fx.GraphModule) -> None: + """validate the annotated graph is supported by the backend""" pass @classmethod def get_supported_operators(cls) -> List[OperatorConfig]: + """return the operator list which is supported by the backend""" return [] 3. Annotate common operator patterns @@ -372,12 +380,6 @@ to run a example with Torchvision Resnet18. self.global_config = quantization_config return self - def set_config_for_operator_type( - self, operator_type: str, quantization_config: QuantizationConfig - ): - self.operator_type_config[operator_type] = quantization_config - return self - def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: """just handling global spec for now""" global_config = self.global_config From e78bcefe25611b770378808bbb82c82a9ec5677c Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Wed, 31 May 2023 10:31:25 +0800 Subject: [PATCH 09/42] add explanations for QuantizationSpec and QuantizationConfig --- prototype_source/quantization_2_0_tutotial.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/prototype_source/quantization_2_0_tutotial.rst b/prototype_source/quantization_2_0_tutotial.rst index 772e9cbd830..04ff8e7fafd 100644 --- a/prototype_source/quantization_2_0_tutotial.rst +++ b/prototype_source/quantization_2_0_tutotial.rst @@ -72,8 +72,8 @@ It consists of `QuantizationSpec `__, `get_weight_qspec `__, `get_bias_qspec `__ -are used to get the `QuantizationSpec` from `QuantizationConfig`. Then corresponding observer will been created -based on the `QuantizationSpec`. +are used to get the `QuantizationSpec` from `QuantizationConfig` for the specific node. Then corresponding observer will been created +based on the node's `QuantizationSpec`. Suppose we want to define: - Activation: `int8` data type, `per_tensor_affine` quantization, `HistogramObserver` From 238e2f0cbe96edc8e252e754ce3c59e1faa2e38f Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Wed, 31 May 2023 10:51:25 +0800 Subject: [PATCH 10/42] unify to use module partitional API --- .../quantization_2_0_tutotial.rst | 200 ++++++++++-------- 1 file changed, 110 insertions(+), 90 deletions(-) diff --git a/prototype_source/quantization_2_0_tutotial.rst b/prototype_source/quantization_2_0_tutotial.rst index 04ff8e7fafd..87c87daf11c 100644 --- a/prototype_source/quantization_2_0_tutotial.rst +++ b/prototype_source/quantization_2_0_tutotial.rst @@ -16,7 +16,6 @@ with PyTorch's quantization 2.0 flow. We only need to define the backend specific quantizer. The high level arch of quantization 2.0 with quantizer could be: .. image:: /_static/img/quantization/pytorch_quantization_2_0_diagram.png - :width: 300 px An existing quantizer object defined for QNNPack/XNNPack is here `QNNPackQuantizer `__. @@ -173,12 +172,25 @@ annotate the input, weight, bias and output. :: def _annotate_conv2d( - self, node: Node, quantization_config: QuantizationConfig + self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig ) -> None: - if ( - conv_node.op == "call_function" - and conv_node.target == torch.ops.aten.convolution.default - ): + conv_partitions = get_source_partitions( + gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d] + ) + conv_partitions = list(itertools.chain(*conv_partitions.values())) + for conv_partition in conv_partitions: + if len(conv_partition.output_nodes) > 1: + raise ValueError("conv partition has more than one output node") + conv_node = conv_partition.output_nodes[0] + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.convolution.default + ): + raise ValueError(f"{conv_node} is not an aten conv2d operator") + # skip annotation if it is already annotated + if _is_annotated([conv_node]): + continue + input_qspec_map = {} input_act = conv_node.args[0] assert isinstance(input_act, Node) @@ -195,7 +207,7 @@ annotate the input, weight, bias and output. conv_node.meta["quantization_annotation"] = QuantizationAnnotation( input_qspec_map=input_qspec_map, output_qspec=get_act_qspec(quantization_config), - _annotated=True + _annotated=True, ) 4. Annotate sharing qparams operators @@ -208,12 +220,14 @@ sharing the same quantization parameters. :: def _annotate_add( - self, node: Node, quantization_config: QuantizationConfig + self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig ) -> None: - if add_node.op == "call_function" and add_node.target in [ - torch.ops.aten.add.Tensor, - torch.ops.aten.add_.Tensor, - ]: + add_partitions = get_source_partitions(gm.graph, [operator.add, torch.add]) + add_partitions = list(itertools.chain(*add_partitions.values())) + for add_partition in add_partitions: + add_node = add_partition.output_nodes[0] + if _is_annotated([add_node]): + continue act_qspec = get_act_qspec(quantization_config) input_qspec_map = {} @@ -239,12 +253,14 @@ we can use fixed parameters for it with `FixedQParamsQuantizationSpec`. :: def _annotate_sigmoid( - self, node: Node, quantization_config: QuantizationConfig + self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig ) -> None: - if node.op == "call_function" and node.target in [ - torch.ops.aten.sigmoid.default, - ]: - input_act = node.args[0] + sigmoid_partitions = get_source_partitions(gm.graph, [torch.nn.Sigmoid]) + sigmoid_partitions = list(itertools.chain(*sigmoid_partitions.values())) + for sigmoid_partition in sigmoid_partitions: + sigmoid_node = sigmoid_partition.output_nodes[0] + + input_act = sigmoid_node.args[0] assert isinstance(input_act, Node) act_qspec = FixedQParamsQuantizationSpec( dtype=torch.uint8, @@ -254,7 +270,7 @@ we can use fixed parameters for it with `FixedQParamsQuantizationSpec`. scale=2.0 / 256.0, zero_point=128, ) - node.meta["quantization_annotation"] = QuantizationAnnotation( + sigmoid_node.meta["quantization_annotation"] = QuantizationAnnotation( input_qspec_map={ input_act: act_qspec, }, @@ -271,12 +287,14 @@ For example, we want to define the scale, zp for bias derived from activation an :: def _annotate_conv2d_derived_bias( - self, node: Node, quantization_config: QuantizationConfig + self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig ) -> None: - if ( - node.op == "call_function" - and node.target == torch.ops.aten.convolution.default - ): + conv_partitions = get_source_partitions( + gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d] + ) + conv_partitions = list(itertools.chain(*conv_partitions.values())) + for conv_partition in conv_partitions: + node = conv_partition.output_nodes[0] input_act = node.args[0] weight = node.args[1] bias = node.args[2] @@ -316,16 +334,18 @@ to run a example with Torchvision Resnet18. .. code:: ipython3 import copy - import functools + import itertools import operator from typing import Callable, Dict, List, Optional, Set, Any import torch import torch._dynamo as torchdynamo from torch.ao.quantization._pt2e.quantizer.utils import ( + _annotate_input_qspec_map, + _annotate_output_qspec, get_act_qspec, - get_weight_qspec, get_bias_qspec, + get_weight_qspec, ) from torch.fx import Node @@ -338,8 +358,6 @@ to run a example with Torchvision Resnet18. QuantizationSpec, Quantizer, QuantizationAnnotation, - _annotate_input_qspec_map, - _annotate_output_qspec, ) from torch.ao.quantization.observer import ( HistogramObserver, @@ -391,42 +409,48 @@ to run a example with Torchvision Resnet18. self, model: torch.fx.GraphModule, config: QuantizationConfig ) -> torch.fx.GraphModule: self._annotate_linear(model, config) - for node in reversed(model.graph.nodes): - self._annotate_conv2d(node, config) - self._annotate_maxpool2d(node, config) + self._annotate_conv2d(model, config) + self._annotate_maxpool2d(model, config) return model def _annotate_conv2d( - self, node: Node, quantization_config: QuantizationConfig + self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig ) -> None: - conv_node = node - if ( - conv_node.op != "call_function" - or conv_node.target != torch.ops.aten.convolution.default - ): - return - # skip annotation if it is already annotated - if _is_annotated([conv_node]): - return - - input_qspec_map = {} - input_act = conv_node.args[0] - assert isinstance(input_act, Node) - input_qspec_map[input_act] = get_act_qspec(quantization_config) - - weight = conv_node.args[1] - assert isinstance(weight, Node) - input_qspec_map[weight] = get_weight_qspec(quantization_config) - - bias = conv_node.args[2] - if isinstance(bias, Node): - input_qspec_map[bias] = get_bias_qspec(quantization_config) - - conv_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=get_act_qspec(quantization_config), - _annotated=True + conv_partitions = get_source_partitions( + gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d] ) + conv_partitions = list(itertools.chain(*conv_partitions.values())) + for conv_partition in conv_partitions: + if len(conv_partition.output_nodes) > 1: + raise ValueError("conv partition has more than one output node") + conv_node = conv_partition.output_nodes[0] + if ( + conv_node.op != "call_function" + or conv_node.target != torch.ops.aten.convolution.default + ): + raise ValueError(f"{conv_node} is not an aten conv2d operator") + # skip annotation if it is already annotated + if _is_annotated([conv_node]): + continue + + input_qspec_map = {} + input_act = conv_node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = get_act_qspec(quantization_config) + + weight = conv_node.args[1] + assert isinstance(weight, Node) + input_qspec_map[weight] = get_weight_qspec(quantization_config) + + bias = conv_node.args[2] + if isinstance(bias, Node): + input_qspec_map[bias] = get_bias_qspec(quantization_config) + + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=get_act_qspec(quantization_config), + _annotated=True, + ) def _annotate_linear( self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig @@ -478,40 +502,36 @@ to run a example with Torchvision Resnet18. _mark_nodes_as_annotated(nodes_to_mark_annotated) def _annotate_maxpool2d( - self, node: Node, quantization_config: QuantizationConfig + self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig ) -> None: - if ( - node.op != "call_function" - or node.target != operator.getitem - or node.args[1] != 0 - ): - return - getitem_node = node - maxpool_node = getitem_node.args[0] - assert isinstance(maxpool_node, Node) - if ( - maxpool_node.op != "call_function" - or maxpool_node.target != torch.ops.aten.max_pool2d_with_indices.default - ): - return - if _is_annotated([getitem_node, maxpool_node]): - return - - input_act = maxpool_node.args[0] - assert isinstance(input_act, Node) - - act_qspec = get_act_qspec(quantization_config) - maxpool_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map={ - input_act: act_qspec, - }, - _annotated=True, - ) - getitem_node.meta["quantization_annotation"] = QuantizationAnnotation( - output_qspec=act_qspec, - _input_output_share_observers=True, - _annotated=True, + module_partitions = get_source_partitions( + gm.graph, [torch.nn.MaxPool2d, torch.nn.functional.max_pool2d] ) + maxpool_partitions = list(itertools.chain(*module_partitions.values())) + for maxpool_partition in maxpool_partitions: + output_node = maxpool_partition.output_nodes[0] + maxpool_node = None + for n in maxpool_partition.nodes: + if n.target == torch.ops.aten.max_pool2d_with_indices.default: + maxpool_node = n + if _is_annotated([output_node, maxpool_node]): # type: ignore[list-item] + continue + + input_act = maxpool_node.args[0] # type: ignore[union-attr] + assert isinstance(input_act, Node) + + act_qspec = get_act_qspec(quantization_config) + maxpool_node.meta["quantization_annotation"] = QuantizationAnnotation( # type: ignore[union-attr] + input_qspec_map={ + input_act: act_qspec, + }, + _annotated=True, + ) + output_node.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=act_qspec, + _input_output_share_observers=True, + _annotated=True, + ) def validate(self, model: torch.fx.GraphModule) -> None: pass From f7c774701632acc5ada238b2764ecd8a02be3cce Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Wed, 31 May 2023 12:12:04 +0800 Subject: [PATCH 11/42] Modify the descriptation --- prototype_source/quantization_2_0_tutotial.rst | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/prototype_source/quantization_2_0_tutotial.rst b/prototype_source/quantization_2_0_tutotial.rst index 87c87daf11c..dcb9e18b515 100644 --- a/prototype_source/quantization_2_0_tutotial.rst +++ b/prototype_source/quantization_2_0_tutotial.rst @@ -65,20 +65,22 @@ tutorial for how to use the `QuantizationAnnotation API` to create a quantizer. 1. Define QuantizationConfig -------------------------------------------------------- -QuantizationConfig defines the data type and qscheme for activation, weight and bias. -`QuantizationConfig `__ is defined here. -It consists of `QuantizationSpec `__ defined for activation, weight and bias. -When annotating the model, methods of `get_act_qspec `__, -`get_weight_qspec `__, +`QuantizationConfig `__ +consists of `QuantizationSpec `__ +for activation, weight and bias seperately. Each `QuantizationSpec` defines the data type, qscheme and other quantization parameters used to create the observer. +When annotating the model, methods of +`get_act_qspec `__, +`get_weight_qspec `__ and `get_bias_qspec `__ -are used to get the `QuantizationSpec` from `QuantizationConfig` for the specific node. Then corresponding observer will been created -based on the node's `QuantizationSpec`. -Suppose we want to define: +are used to get the `QuantizationSpec` from `QuantizationConfig` for a specific node. Then corresponding observer will be created +based on this node's `QuantizationSpec`. Suppose we want use these quantization parameters for activation, weight and bias: - Activation: `int8` data type, `per_tensor_affine` quantization, `HistogramObserver` - Weight : `int8` data type, `per_channel_symmetric` quantization, `PerChannelMinMaxObserver` - Bias : `float` data type, `PlaceholderObserver` +We can define the `QuantizationConfig` as below: + :: def get_symmetric_quantization_config(): From 75d74931b3a2ce927feb44a4f5694108aa4684ee Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Thu, 1 Jun 2023 09:36:02 +0800 Subject: [PATCH 12/42] fix comments --- .../quantization_2_0_tutotial.rst | 61 ++++++++++--------- 1 file changed, 31 insertions(+), 30 deletions(-) diff --git a/prototype_source/quantization_2_0_tutotial.rst b/prototype_source/quantization_2_0_tutotial.rst index dcb9e18b515..2b908943cca 100644 --- a/prototype_source/quantization_2_0_tutotial.rst +++ b/prototype_source/quantization_2_0_tutotial.rst @@ -1,25 +1,26 @@ (prototype) PyTorch Quantization 2.0 Tutorial -========================================== +===================================================== Today we have `FX Graph Mode Quantization `__ -which uses symbolic_trace to capture the model into a graph, and then +which uses ``symbolic_trace`` to capture the model into a graph, and then perform quantization transformations on top of the captured model. In a similar way, for Quantization 2.0 flow, we will now use the PT2 Export -workflow to capture the model into a graph, and perform quantizations -transformations on top of the ATen dialect graph. This is expected to +workflow to capture the model into a graph, and perform quantization +transformations on top of the ATen dialect graph. This approach is expected to have significantly higher model coverage, better programmability, and a simplified UX. -Suppose we are a backend developer and we wish to integrate our backend -with PyTorch's quantization 2.0 flow. We only need to define the backend -specific quantizer. The high level arch of quantization 2.0 with quantizer could be: +Imagine a backend developer who wishes to integrate a third-party backend +with PyTorch's quantization 2.0 flow. To accomplish this, they would only need +to define the backend specific quantizer. The high level architecture of +quantization 2.0 with quantizer could look like this: .. image:: /_static/img/quantization/pytorch_quantization_2_0_diagram.png -An existing quantizer object defined for QNNPack/XNNPack is here +An existing quantizer object defined for QNNPack/XNNPack is located in `QNNPackQuantizer `__. -Taking QNNPackQuantizer as example, the overall Quantization 2.0 flow could be: +Taking QNNPackQuantizer as an example, the overall Quantization 2.0 flow could be: :: @@ -59,27 +60,27 @@ Taking QNNPackQuantizer as example, the overall Quantization 2.0 flow could be: Inside the Quantizer, we will use the `QuantizationAnnotation API `__ to convey user's intent for what quantization spec to use and how to -observe certain tensor values in the prepare step. Now, we will have a step by step -tutorial for how to use the `QuantizationAnnotation API` to create a quantizer. +observe certain tensor values in the prepare step. Now, we will have a step-by-step +tutorial for how to use the ``QuantizationAnnotation API`` to create a quantizer. -1. Define QuantizationConfig +1. Define ``QuantizationConfig`` -------------------------------------------------------- `QuantizationConfig `__ consists of `QuantizationSpec `__ -for activation, weight and bias seperately. Each `QuantizationSpec` defines the data type, qscheme and other quantization parameters used to create the observer. +for activation, weight, and bias separately. Each ``QuantizationSpec`` defines the data type, ``qscheme``, and other quantization parameters used to create the observer. When annotating the model, methods of `get_act_qspec `__, `get_weight_qspec `__ and `get_bias_qspec `__ -are used to get the `QuantizationSpec` from `QuantizationConfig` for a specific node. Then corresponding observer will be created -based on this node's `QuantizationSpec`. Suppose we want use these quantization parameters for activation, weight and bias: +are used to get the ``QuantizationSpec`` from ``QuantizationConfig`` for a specific node. Then corresponding observer will be created +based on this node's ``QuantizationSpec``. For example, we want to use these quantization parameters for activation, weight, and bias: -- Activation: `int8` data type, `per_tensor_affine` quantization, `HistogramObserver` -- Weight : `int8` data type, `per_channel_symmetric` quantization, `PerChannelMinMaxObserver` -- Bias : `float` data type, `PlaceholderObserver` +- Activation: ``int8`` data type, ``per_tensor_affine`` quantization, ``HistogramObserver`` +- Weight : ``int8`` data type, ``per_channel_symmetric`` quantization, ``PerChannelMinMaxObserver`` +- Bias : ``float`` data type, ``PlaceholderObserver`` -We can define the `QuantizationConfig` as below: +We can define the ``QuantizationConfig`` as below: :: @@ -117,10 +118,10 @@ We can define the `QuantizationConfig` as below: ) return quantization_config -2. Define the BackendQuantizer +2. Define the ``BackendQuantizer`` -------------------------------------------------------- -Then we will define the skeleton of a BackendQuantizer. The annotatation methods for each operation will be +Then we will define the skeleton of a ``BackendQuantizer``. The annotatation methods for each operation will be defined later. :: @@ -168,8 +169,8 @@ defined later. 3. Annotate common operator patterns -------------------------------------------------------- -Now we will start to define the annotatation methods inside quantizer. For common operators like `conv2d`, we can use `QuantizationSpec` to -annotate the input, weight, bias and output. +Now we will start to define the annotatation methods inside quantizer. For common operators like ``conv2d``, we can use ``QuantizationSpec`` to +annotate the input, weight, bias, and output. :: @@ -215,8 +216,8 @@ annotate the input, weight, bias and output. 4. Annotate sharing qparams operators -------------------------------------------------------- -For operator such as `add` and `cat`, which we want the two inputs sharing -quantization parameters, we can use the `SharedQuantizationSpec` to make the two inputs +For operator such as ``add`` and ``cat``, which we want the two inputs sharing +quantization parameters, we can use the ``SharedQuantizationSpec`` to make the two inputs sharing the same quantization parameters. :: @@ -249,8 +250,8 @@ sharing the same quantization parameters. 5. Annotate fixed qparams operators -------------------------------------------------------- -For operator such as `sigmoid`, which has predefined and fixed scale/zero_point, -we can use fixed parameters for it with `FixedQParamsQuantizationSpec`. +For operator such as ``sigmoid``, which has predefined and fixed scale/zero_point, +we can use fixed parameters for it with ``FixedQParamsQuantizationSpec``. :: @@ -283,8 +284,8 @@ we can use fixed parameters for it with `FixedQParamsQuantizationSpec`. 6. Annotate tensor with derived quantization parameters -------------------------------------------------------- -`DerivedQuantizationSpec` is the quantization spec for the Tensors whose quantization parameters are derived from other Tensors. -For example, we want to define the scale, zp for bias derived from activation and weight of convolution node. +``DerivedQuantizationSpec`` is the quantization spec for the Tensors whose quantization parameters are derived from other Tensors. +For example, we want to define the ``scale``, ``zp`` for bias derived from activation and weight of convolution node. :: @@ -330,7 +331,7 @@ For example, we want to define the scale, zp for bias derived from activation an 7. A Toy Example with Resnet18 -------------------------------------------------------- -After above annotation methods defined with `QuantizationAnnotation API`, we can now put them together for the BackendQuantizer +After above annotation methods defined with ``QuantizationAnnotation API``, we can now put them together for the ``BackendQuantizer`` to run a example with Torchvision Resnet18. .. code:: ipython3 From 79e05489957dd24c6d07e065363d11884d7c5e3f Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Thu, 1 Jun 2023 09:48:09 +0800 Subject: [PATCH 13/42] add Prerequisites --- prototype_source/quantization_2_0_tutotial.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/prototype_source/quantization_2_0_tutotial.rst b/prototype_source/quantization_2_0_tutotial.rst index 2b908943cca..c2e5785ef74 100644 --- a/prototype_source/quantization_2_0_tutotial.rst +++ b/prototype_source/quantization_2_0_tutotial.rst @@ -11,6 +11,12 @@ transformations on top of the ATen dialect graph. This approach is expected to have significantly higher model coverage, better programmability, and a simplified UX. +Prerequisites: + +- `Understanding of the quantization concepts in PyTorch `__ +- `Understanding of FX graph mode post training static quantization `__ +- `Understanding of torchdynamo concepts in PyTorch `__ + Imagine a backend developer who wishes to integrate a third-party backend with PyTorch's quantization 2.0 flow. To accomplish this, they would only need to define the backend specific quantizer. The high level architecture of From a9015690fbc6340f2afa4de48b39504a285128d9 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Thu, 1 Jun 2023 10:05:26 +0800 Subject: [PATCH 14/42] add conclusion and further readings --- prototype_source/quantization_2_0_tutotial.rst | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/prototype_source/quantization_2_0_tutotial.rst b/prototype_source/quantization_2_0_tutotial.rst index c2e5785ef74..3d86266f759 100644 --- a/prototype_source/quantization_2_0_tutotial.rst +++ b/prototype_source/quantization_2_0_tutotial.rst @@ -12,6 +12,7 @@ have significantly higher model coverage, better programmability, and a simplified UX. Prerequisites: +----------------------- - `Understanding of the quantization concepts in PyTorch `__ - `Understanding of FX graph mode post training static quantization `__ @@ -600,3 +601,17 @@ to run a example with Torchvision Resnet18. after_prepare_result = m(*example_inputs) m = convert_pt2e(m) print("converted module is: {}".format(m), flush=True) + +8. Conclusion +------------ + +With this tutorial, we introduce the new quantization path in PyTorch 2.0. Users can learn about +how to define a ``BackendQuantizer`` with the ``QuantizationAnnotation API`` and integrate it into the quantization 2.0 flow. +Examples of ``QuantizationSpec``, ``SharedQuantizationSpec``, ``FixedQParamsQuantizationSpec``, and ``DerivedQuantizationSpec`` +are given for specific annotation use cases. Quantization 2.0 flow is still under active development. If user wants +to learn more about the design, here are some further reading materials. Please contact to @jerryzh168 if you want access to +below materials. + +- `Quantization in PyTorch 2.0 Export Detailed Design `__ +- `Quantization Annotation API Design `__ +- `Quantized Model Representation in PyTorch 2.0 Export `__ From 0073dc60b0b0d5c07170809109a500487249e89c Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Thu, 1 Jun 2023 10:11:08 +0800 Subject: [PATCH 15/42] update the title and filename --- prototype_source/prototype_index.rst | 4 ++-- ..._tutotial.rst => quantization_in_pytorch_2_0_tutorial.rst} | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) rename prototype_source/{quantization_2_0_tutotial.rst => quantization_in_pytorch_2_0_tutorial.rst} (99%) diff --git a/prototype_source/prototype_index.rst b/prototype_source/prototype_index.rst index b004cc2f365..a7e7941ab2a 100644 --- a/prototype_source/prototype_index.rst +++ b/prototype_source/prototype_index.rst @@ -72,7 +72,7 @@ Prototype features are not available as part of binary distributions like PyPI o :header: PyTorch Quantization 2.0 Tutorial :card_description: Learn how to use the PyTorch Quantization 2.0 stack. :image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png - :link: ../prototype/quantization_2_0_tutotial.html + :link: ../prototype/quantization_in_pytorch_2_0_tutorial.html :tags: Quantization .. Mobile @@ -200,7 +200,7 @@ Prototype features are not available as part of binary distributions like PyPI o prototype/fx_graph_mode_ptq_dynamic.html prototype/fx_graph_mode_ptq_static.html prototype/graph_mode_dynamic_bert_tutorial.html - prototype/quantization_2_0_tutotial.html + prototype/quantization_in_pytorch_2_0_tutorial.html prototype/ios_gpu_workflow.html prototype/nnapi_mobilenetv2.html prototype/tracing_based_selective_build.html diff --git a/prototype_source/quantization_2_0_tutotial.rst b/prototype_source/quantization_in_pytorch_2_0_tutorial.rst similarity index 99% rename from prototype_source/quantization_2_0_tutotial.rst rename to prototype_source/quantization_in_pytorch_2_0_tutorial.rst index 3d86266f759..fd0bb265194 100644 --- a/prototype_source/quantization_2_0_tutotial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_tutorial.rst @@ -1,4 +1,4 @@ -(prototype) PyTorch Quantization 2.0 Tutorial +(prototype) Quantization in PyTorch 2.0 Export Tutorial ===================================================== Today we have `FX Graph Mode From b50bfa2aed4c5204bfbf30fe1a016c340d89c7d8 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Thu, 1 Jun 2023 10:13:28 +0800 Subject: [PATCH 16/42] Modify the arch graph --- .../pytorch_quantization_2_0_diagram.png | Bin 39873 -> 0 bytes .../quantization_in_pytorch_2_0_tutorial.rst | 31 +++++++++++++++++- 2 files changed, 30 insertions(+), 1 deletion(-) delete mode 100644 _static/img/quantization/pytorch_quantization_2_0_diagram.png diff --git a/_static/img/quantization/pytorch_quantization_2_0_diagram.png b/_static/img/quantization/pytorch_quantization_2_0_diagram.png deleted file mode 100644 index e00ebf90276de8ca39a020606a8f21ba5ddd5d35..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 39873 zcmdqJc|4Ti`#(A=2`PzG3~f>hMaWJRjdd`_PT7~Sgps`xDp`_U_I1X-%vgtLkt|~$ zW-Li&7{(H2?9QXl_nh-Pf1KCr{BzFt`#OK{yyltbxn1{tU)ObC*ZX~4-spz#1y-pfA^d^qN)s-p@5RmHOJK4Atv zvwCWpdV@gcgbsh09vfU+0fEwuAFHbxJ-4FmvAp1!%{e$wuB>zmD-Z{JeYm0tt$M-1 zabD_CztcVJd7Yv3R!i@*JyQ8j&up)5<6&plq>{c)=6*lsx(L?i*5_oJ$KBQ%3!e(g z%y3Kvlikv5nu?mO<+?w~dXu$fo8+_B5}YxEwtyfY(3>2L@}bzE?EVwLBtW1SqyJN` zX0-LqK@e!DD?x*@WvDV^frOxxnL$6e6D*D#O8WHb|IIZ|Ez7L$33J@`9CyC`%FY?D z)f#qCYGl0D;ae(humAsZRT+7{HLsl7r3ZErXxic3=Ldn}ZUggE88O8L#X6P6mzKPGh{iAhUWhRM&-A?VqQdXvRkJyQlU@z{LUod?HA8=V+){ zT+2AG+_pl6Rm%@-r_qLCPt*SItF-0!H@06r?D`SUT+@N7@I;eo?82mh1Cq;<)XB=LKUy=jRZ z`TJ+~-0N4VGd-U!c#Mi$t;+?!+$fsLY$Yv>CqJ6C=UcqrjbY&kX~_3=@f^41_+}Zf z>CFUkG6m$A5tH_L@>`65J88*Q+#qv{wA{my%x~Wc)|f#VJrHu2`;OZ_y+K*}O_ZFk zZ}QHz-?(ULBQn>)u~lAsH3L7JeI|0WN({Xf;2d2#Q#>+8I6bq_>s`<6Sr4)I%H=<=Ht9S4@Bk`lMK*8l?5-fp*lC%4I0 z1hc+xu}Tu|8J1D9JsR^0FL`=*(wZ&$E@eh$@%hl`%;NnIXYmc#ZT`6js~XBpq+WPm zC}UP*(Z|C4XV)%jJr7-#hqEol|Do*PWgf_AF>3sk7O0&Z-?Gnd8?!m$k1WLpc_Kg0LKt&S%>K4_P9Z(`@1&9Zq_^gm#x+5? z?Kf1e_jT4mKVoVoK0jd@TWvO7Cf$u&dIa{-j(%~YJlYFESN!ul6kXW1zmRl01D)aaTq! zM;!NkmD6!i9<798u$LR+UCr~*#>W{t3nyE2mucukDrWlB{am>L&T#Mf$&#TMzNd3@ zUuY(X%C9u<+!)!fU=~CxNwAKi&e|_6xb$*R%ctVtk}_J+!uzkdiohQ4)06!E@53l% zm%*u;u`PjPtNfv%prN4#a}Rl4!j8)kP#@>(u6tZOCtq=M3GA&Mq`b;~x+5C6^7Ijt zs;UFg91l}oCO)G|L1`y=9u384b~*WT=&C*$mfJpgo{S%xEs+%*V-TC2ah8v4p`Orm zC~k^s&t5`6WfvlK{U+izvFC{A$_8<|%N4;}Gk+uzHwFT}fIW+^MtgofiS;^Z99U9O zHk-bS5>a~u1D{1`897%^>XCA3u%=?Az*7~IA*|Si#o^F58NBNoyLwj{qEO0}0A<33 zxAOZJ*VA~jn)61ao1*7`g?0jvsyvK5g1vPjAvjg2xAsFj( zw{Ua*Nf6v`-1f)JHgCdDk7eh9wx~6sRb6jh#kapH#ub@o6=Ir+0>PAB>(r}yn;$H% zYm##M{>WkL_cKq^cFWAA$7}-W1%`{wD{>f)w~)?S(Y9WR^34guVc@x&?($(^XNA>+ z5c`haZ7G;g3!6ps>U?CyXcb-0$Z||a1A|JVk0{fguh)Bwt$Gz%JJ&&{bzafrI&|+( z^s*Z`m3dIEu=%l(Ie0f2S2^41vO!)y*KP-KSIka9U?#3O|(YKW{`@XH4#pwB}v#pLrj?HV$;K-ZnD7E)1(o_uFZ80a=v}JOcCFX0)ms%PyKO^cK zR&ALv(|LMmw0E^5%cDaC%5K)U;vI^fHp+lAX~)R=!lg#TVibmBrryr|acty0B?JvD zZM?xbQs;dp<8+VVTibfuGr0XvUG6o5Y^%_MWDfs|OT0s03K~_U-!=-xwybzrr)*GV zMq|Rx_3n*k5g%5VV%%OHoB+L8e^9RtrQUf&Tf0|0kzVFNwfwd4VA66WQs@3w$G(w{ z$X>iUQwMXC&5P0MhE>(+;->(i z`~8)5Nu_a+S8w~s9TUaftw;|{D|~}FrO7w8K2114Y?N%-syw7{X}7>6K^^fE@V$=# z-&+vH?z^3xDQ{o-J zm;ZFRge_``zgVsUq6dpMftiQ-DEWk3GYQxlVTbJ)5k&Rk%N$!(_qkt?bH$%<-x)49 zZ5>=M$S#@pBn!z9HgRRn@Y!*UKfyJ`d-XY~e0)#o?QisMiB3ac7Ex~E2q;_jjF{oj zs^jk0@$DZL^2Ncw$CCG+)KT+wIuoDGr*H68D7DBPb1_@)Dmyh;u5`IRgXVgHAhW&zT8!I4x;n z`ChpCo&f8JkLS8jVhFy(A*|(x0gvgP6)fqNoDDVjYw(2=AQi2v z77icO!R5zFyk=*4iB%9@-|5^JDE4c-M~21pGLAgN&`P7wZyJ2(t}M!gZe`0g&tJKf zG5;G1oAOab*%pbovP}~SHart#ee;c>Y8@Cwj+%?98kf6y_o1A=u-$gX*?lmXtl?{L z9W9khaV_~>K6R9Oqi4@*F!>cmw1P3f3pX>izT%R;zViSvBPEXhD4rPLBFhj62=+dR zVFFz`({6uG|MLga!F}6~C*{0m!C^~~s}t4R(pJ2hOAj9L>Oy%HHm2R9b9+5~&Wwyw zab7-DVpkSKiO?1&NB6}2ly*cFEJQyGcoma0l&K|jP2zsv{w=Y(UgXw8V=PdtS5;-vAV&^sUvhJ63Nd!eDtY>*mGH8p?Y1TCib_O;^HVwSgWw#HWp(` zi+GW@y%1Ic)@0svD?2999^<{B5@ARH?v_E81#jJWPf#cLAexN}kRyKO#3F z?No{|kS`!}mgbK1(JG=K!7xba3`+H~7<7J-N#m1S2ivaP>5J_zD|5Kd56^#`JKyA8 zpb8GZ;p4KSt5h)!OFQcYD}zxdxP+pGiYu;Ft3WeuD!9I44JB1c5%Xw&V zv9ubQnd~mzn;S)apQ^VQSL4I7EP?Gvu(%Tcs^A*6OkZYP55BbjX_{Wt`D+O=W}sRk zJFj(|S&Mb@k^rz+V{bj*nmuH`<9pQ_eEsiN|==2b!ZHZbLeQl0E;z(;b0l z&ETWJYfWdw&dk>jR+PGE@KVn^tQDfI^fxT3Nw=hh!=ENvq2a@JUaG?r%Mu?^kW!`7 z)ITU}`q#r31N(CHMCU4xXa@VR4M=S6y{X?EV}8-C)(Iqvz}M34|AVhiaT)S{0S8XG z)T+B>yYtiwN-|Tr{vW3)OrX^KL z?GAW@Z4s(;+?j}VMi4O@5#z?!3F*Cv^qFY$v8{&5-R5!r9pXrtd8)BlVum|HjT{d% z@0K~eXoYzIQu%(#V&qkn@THA2V&=zN1tnkXZ|_Wt7hxEEfz5;I(3E7`ckHCJ+*kaA zIoZ)Sf4JKIC1IqQtxUQ->Rk0rj)5+{@ib)@%X~4UQSWwPH$+=m1Ug>zC4FUokvE_X z%2e&@RX+cl0~|#RP%FRr6OS@G4|?(F|N1u#jc&ng5eVX+T8Tufcq5cu5gnm}7xORl zs0$jz8Ojf~WJkB&0f9On)|WE#>ZzRE?#UJV!nGISJp2>M_SAJ`o)A}NTNHklfbrHk z_MdqezGkT97{4#WWl8Ow@D+7;4DjocRFtQjxyA(QiTNjTTt>4st32)(#j==~0tfGY zGA@jMkSF7m-+^{^L@uuH6MpqD&f!pCnt!q$AmmqGI0II zy`up5taq>GEkp~Iv+-V>y;QwnD3 z?MPXy@Q>VA(uEnyt!NodQTVWA{FBD=MnT!^x?_?`tUKsj9~V@=z6`O08jVCKzv z^^{|v{B!O0qV>-oY_Hnp$`~xiUR)3k3jB6ToD+L@tD#K$B`35 zr^q9ZB=*X)k^Em)0m>rQeK$k?4-sqpa0`L%P?1>Oxx2OQT`K-60NE)MxFLEkkbb+; z8U%!qVkeR@zkaj!^h9VVKg7DvGHWhBH9TDhU%K;<_Dnl67h724SA{ZX;pP?5#7cpi z@4}${v4PH)!S5FFiXhPZe{{0X$bvP0?zchgEfGIrFT&_& zalFRH`e+aP2a$%pvSTld+Lt@NJI{tX=VxsUScN<2PhF|eh&^&(3IYxAO;?|@P`K>e zEGW5=hPE=OgBuZ%(T&ax0e(hr7KbL;6Rh4CrE6RQuk2Lq54jsCHep>!#FKNTHoz){ ziJ}0*7~~XlZToiiVmX>+L>hr!ny&YbtQt0MdVaaO&hIX265JFp*vP>W*SomfVRJ^o zEwFQ|OFCVO51Z=sr70lNp=f??Z};>(+zztCohk-)Qf3AHaC_Zl2EY$*(Aaqm4Z?`a zay{-@Q70azA&1ajcHg2;dW$3)RUVV*H9f!3f$c{;`Qe#%CkCUF<;1Q4=51=I{k~t?zt%W$TgZ#^EW*HuHub|V=f)9!~ zRm9bKBbG6jGT!o}^2w*?%17OjO2=KKlRWB_PxB;xUmS(Z_Bo0QK$37Px_zY^dw$h z3`Rz5nIt;pLrZbh62cXMu4VV4j7uk6jLX;B3GYN5zBP(E>=n|mJvl#jEa$D{ponl^ zV8v!9TW*`y3@0SQH}U8Bu63MRcDo)kubX`gWGD4+TB(#8WoG_GlOZ0CW97O__v zX|#8KZbU$_@rEXO4Ew<{%K0kuHMIhUBsO4Ax@Lqu)m7R}=GKT?)ga$qT%ov+2u>e8 zi*A{L>VU%a_uQlh>tyM{!>-=0O^Vw%{VLl?(?*vlQ*yX{Vy|2Qj=g8}Nk&iowB;gT z=YYRM%UJ|yI2HNz_iHE1XkH0!HuvX!D;LT46Pr#@Hl-=sb_)oFJ7f)9@$P=9sLjyz zr?-3nrCE<9L2BG*km7FclD7yjj&1*mMjV>MHl2SXx3ssKTc+eemJ5#&CYT-~1K)cS zh&S3Z4-+6Nd#5?UV7GG65lsv|7-^o(!_uGj%#9+yZJ9mLVQ;k=OfD-xnoDmnnXJL& zWW5fSjzK`6D}chT_7;`&P&s%UGFmLBES_N9O70;@MzG}4N?w_zzOj;I^yAw58PQk- zE~-C6-p#z2zJ;I#Uup9Di9zH$h>zM??K{~Ix3>VBZLf9t(W%>X3P_V@*8CRkNu-QX z%3s7qcv{co42RYE339mJ9#2i{%MHvUXO$D_B&6KLD;7BkEWCmTj+irifxG{;yQ6^qeIb-iyH%hfkZ~kSv z#Eh3NGY{mw3+tk>WeO|@6earg+loKuw7%3;qLLTw*Eg(X|I+W}C`zpyHxY@O)c}qu z3oXKdUNiQqm3nVV-DH-n_G1bXo0cBnzu2F;)5nzR{(XFdT&71zk#Q`S?jpEh)-1B~ z!9ZZap;oVbCDENqa^JX8KbdZY3n}4!fb%yOapT>9s5icCDL(oT_6^OYsWtjd`L>*$ z%Td@o_#VvpR=2CIT7G~N>}~!LJ4cW4(MdDdp3dd9w}fOPYm0Ys^C@=<(`5s+-zu}H zT<^p04l&@P`ltbjhVt4~hwBYa$wz=d@$*8fctz#evXO0?C140j6#< zk=M*e<}<)OB-%B9y;1rj`h|kZ^=ul#zBZKTPIHZ6TwbSVV<3UJJ{k>fF0ONqOi78M z7GMnJ`2>dhV+88D5Ed^=#S^V=I8oarHF?w7(k>p=RNlswWGxF@Tf4NXvl$1%beN9- zi?9mhcHMy86-Vv#SaD9|Vjo)l8mbk0ZM0`>i;*QZ21kxrK~0Biav7lu{1eet7-yQ0 zREB;A%=cNp-pec1&Pq?6|1rb_L(@4;L51=Ula1jna8A_gBX}0CyJ)N^(yLIeN{4N9p15vM7-bb>RkIyEk8!i%NZ zzZ#m$LkJ%wUjoXL^sxR|nkHh{0hTNK%%%#^GH!;wj*&dXw>$5p831V=zTJ1y2w|)$ z`P}?0Y{Jl)y-tb=2}R|%#IeF93HT_Y#NeSC4jqp4;nv&!aPJlDb))o$(j{f!0*J7l zTrKb6Of`=JGY!9nmM%Hf2dFM zZ#`M_IVVn7{j-OmV(A63=HUyRaqdnU02+_HWO3vs1a{o_u0Xl8){b`mQ9{h$n9%Mz z>jZU7-ZlOI8vCNYRJ*Ws33c(BrckQL*7OxH>NE3Cb=nPK@W^1&fApbSz5dw1n-Z(@ zFa%QFnT1cf!BfNA@McrXX7-_Rk3|D=+$0PZ0-Kaeyqy--GFJ(Ph94<_^mm+E(3S@r zXEoXXEi$2ArqF)$sKyvYF^->J0d!4jGyz6kKCa8 zGS-U^nEvf!w$h1Yl?a1Nm0%`ZORXH|iD3t*U^%#;rT9&w=&8f?Q#lX3d<0n6;5n@m zhb#Bu`xT2Lrr#r2#GQ_NxaOvh-MX(WS1Ua1Q1kxCKiyK1;1qj#yCD8%_`70Kn)dm@ zXqTD`+Drx%Ynau>Lp4B$z_Sm9PgYvr*9VLObnnO}DOn|z9?lHFuy^ur66%Z*8Vd=Bj>!=(oGGY6{uhO@{&cLRE z(zpR3!iv=w!x?a*R~$rQ-n4eB0hz$IKON(Lxb2{{v%nL1s9=YBVrhH9(%a(7$HmSU zeKC1P+4+du|M(9rP}Rf4Ox9tCj~X4Y4*0ND*h>SZ%tKqa#KS4ZbPbkEN-Mr*D3mrn zlD`jFBcSV(f&c!z!3ALOk5J5*=i~GeX|3U&uP&|{vQM;S%ps3~NCJe4@MO*IylH_} zH&m4zRRR{Osos_wP+0Jq)1~)EK`)FR)gS9CEtYQ5%0(uispw3$OF{aEq z7-1joQf~CnW1Z_sL&GC61_d95!NjP4(utlCW2(JBVWxIjG6*m>Rv6zG?P6)ILj{YQ z2be_zNUlo_o3+^Mf~ZC!6m59*&`N%A16y-*VQN`iM!dPT`Vq!7{BC%)QJHj!!DI6Y z8SWeZB#;0D68!WAPXT*arpIbidC|qNZ=2@@prJq9tylJdW%xIk4=|YXH;kpxyiFYB z*x88@b(kT~?o#v91mz-k5#Lf#P7xKC)c-P{h^a!hS|GVTT$W3SVI&~4s_(StU&R_) zHCv#}9>G(-LYl3dEstT{H03_6KW$u)@K05W0n5I0^#*6eGB3Ie5(L(Z338?)BD zHeccP_z$dzqgE&rQPR>Bi`3gRU!=!FQKc76V<@NIJ_~Gic#Y2Gy zfPJ;Q0}IE^=QMn{YzT7{mwWQ!pcoxJ3-jxKZ;^DkydYnI^Za2V^3BLb`wyUSJ!2Y- z5-AgOg2Ckk6+pA}n@0A}gCt^zHjk==J^>x-TP zT`A75)+kaKtTSh711u<8;C~9k{NE+VKW1A~)leXVqQ~@hLT1)Ll2%69)oAh&op(GW z8tE4#+ujQ>L_@{mpLD)5C4Vo8We7gxQ>5SbrrU?E2r`W=R2xF}Zh};<=lub^_92C7 z^5of%A8I9$J?SH)KL05}F-d}TP(FvcW{l{p{eA+ckD0t_P z?#y~8Zy{>!jl=KC)VHlgQA8>Q{Cyfym)5-Z4bx3|L%LoVWy? zynFw*IVH874yH9X-pl_Z@qBT@>*O=aGQfEe-)DFRMEjf7SsncT=y!em;0}VschERR z>X2FVKZ7G@U)3wK3e{gR=?V+oeXA%fCsuT@$r#?bFnh@KHYTiZ2r~8GXgKy_QKeHT zPKbMw`_7-PF+t*L95j+x^&sj2Jmo<&yW6zNdOsgN#XKY+*Wa5urQQ1SsvBp2waXN$ zk_*{|zdlA;w<_1kom&+Y&bqL*|9i3npa%0qT0*yJTMi?&a9RM*`@d1|O0*m|Hi4CFZSFp6`NG?EY5h;;IO#zSp|3u!0<9XT1;;{oQ{ne0YH~_}uy^hd(oo z76d@jPjsyr2J*{|i$4AXcoc{E~?f}fa3ON>IGjNO}pY7U+Vr_sSp zsd%M;Z*C?{?Fs7_Ou&~u5i(oxVxgZmms4;Ium`0lYb|GM_oEELWWL`cuHg^-bh=us zEQuevC|g+lw8?URLEMoz@Z}Qol*KG#ngsV>c1thMKmwy+K$$PTw0tAls=4{i_B>*qF)F`zrUl&K7lA{bl1 z_k-o@O`!o9SbK7y_$b38L? z|58}B1tm%~X|eW!xtf+Z?C*I2?#i(B5JxKmSlTvFx!}Lk^h04k*?JUVfO4p}8--S$ z^r#SM3jVaV%7?s-X1f)jQ_pR*Kv?)wh6=5lxswQ?ny^f<0|!A15HTk4{UZ_ec}ml$ z5Eh=Nxv%gymVKu=b34!?W1NST7>511lSP&xACzG{*S|E<=R7*l-&M2=204O8uMX@cFSS|>bDa`HZV`QMKT(W zaxtX8pe{+Gksop^Gm4AGiN)U6=W8gQ@~6yN?r)FWXtR1c%nrnuDiUXa$a-~jf-qJ^ z@-k2#rFL#E%g$#NAz?HV3y;>C!o7}|FG>up^5I@ds^npG!vX~V zo+59Da2zmZm{4R#*4EjB)U-#jiX$;RW)_}a8GF~F(6h(2@U@7ZVe$a8eRxyH(tdab zlKb{@##=l(Ot(2UKH-g**)Dn&i`s?3XOHrqg%XkH9*V=vbbP*ma^@;~8Wmr4Y z3NZJH_Hv44B3G05+?+|%p1=W3kmzOWWFAHWf}meOT5<9(J?b>fFClN5=*_5+`3BE6 zK#jx-E5}uSg{)!TKZrUWUTU|&kB&oM%BR&&$LB3#Z?3qI(xQ&!q5r%)k{%N+uj|52 zjrnlITCHK@H48@4H`6~`#6%$UyB_*WNzAg>RMSsgWd@L2psEa7Z2o)N#Z=x>4tGaB+u?-_F@OQMrg{)Bs>15D|cq!DX=L)KtK0m6Ftz4t$9l&#lOZl%W*cj zgoc?j?((RW0_n_hIkdbkTVm1f3W0upSVo_Pbh{A{+pj_k^JLf-hku!a{>~O4Hq^Sk zXuaFeqzzOEkPl0TeCGS{-YE`k#uWaJ>+#kLqPWVVs;6E z5~+eBwtyQih(OS6vNV~r6q2mT(H0X1+%vcU-p;=-x6TC0mj{0TJ9UBE|HJ9C{|EA6 zjpY^h?-}&{D~AMNjiP@vDChc*jyfP1uWT~ZuvpS7I1Ld4f;yns6bQfV$|#ej4lKZ@ zuNoK=kg02Xq6cx6T_(217UZOB6veF1hyM1jma>{_8i>pRE9LudxoTFc(KoTOzF%j_ zpIfHZDWKF?pjr_~wm)EHfs!4nM6+Ef$1HOLcvn_arVJUd>kRrKa<)8ZFvhyx{}cGA zig&=D8S_2VP&WH^aC6j-7_*`dv?*ra!WIj2zspxwF2zphxocgTM#h&tF7V!y=fDh> z$cRgc_XPjYjohH!ALSokHmPx5kXysc0zqtvHos5OdUhan!9c$hg@DD_{piMHM;{IDP}E7Z7C!>_YdagN)eJ! z2;&R=nZlG2k?tOj+tZK|5@9`APC_&thdK?@1*KN*?c9fbQ#_4Y8f~KxKj-hg&jpDt z{DhC)9XNfsg_ok{9~`_;I*WJ?edBLGd6k@M8j2N)x|o@<$7>#cP3YmC+gpn(B0>o- zEj14PB3QPJlyXK9sARnNiG<{c#Wzpfo~rHZh2JNYEI7$$xWxMl7+B_tU&!2EKV_b) z?7d~a&!5lN4jiB==L~tY!)2K-hNmRjN9{SH-6pYZ#!!n&WKwV`I`b`r2ad3M`X%Rx zf$bOkUmfCoc%GOHmT1*AjuMBFElDgMGwcJD?vo))2I3__NMw5CPbx1l}Baj-? zFG+0j$k?vB#XI(s&D4&`N_)~Naemr8p2UkK++{8Sk%hPXvMN@G(l0XPukjEq(6!wd z!=S)v<(^z$??s<0!l#;k9M@p}eb{)p!LyU(bKt}cN&BUPQkLujt%3+s$r+t@5fz9(~qeYFlBG)`w*KdKWv5_Og zH(S5{sS~Y^g=Kgnw%FFj{`#TSp5)Segc8`|vqVD)NdrO)3iZg6>8+SR%S}ZKdHmAD zVAsWu#OE$!zht@n%6&fhjL-yU&w4t9OhV9!1S}C$Bc@(ot=W`Y+M@+amJ&Z%&qD}XJR3&Jej$2 zqo~-}&-47s9?sgFZ zrpw)$T8$q_p;C>FWRwWY{Ox!AD7H4^k{oD^G)OSiG#WuyT*_%bB;#4?Or& z>kd<#kaHyXL~oUurH(}JqLp|cVwR5yjiM#pL-=I|32t(rg#+Cs|#p-u9Goc+LiA2<10lKr%zvSn+^w?&G&jS2~QMSSJ zlZ>E}I|ZDmUw>ojtjoW4C+;GqsxXboq!naDn7t>iQ||QA`H0x)8A$v1-E#;+2rl*> z;4{nn)jT`{zu*|hwXR^q%Y6(xrr$6LlPI}#p`e&A)PPAZ%H>UgrQL* zoBDTDVQsWTtD8@-LfSgW(jm;=%H8#l$F7;mjv@2y_)1R6=GJXZ*U|4N%Ld17(PAS~ zZo4L0`!Ic zFiZpd|l3Yzy-U|)OvVC4BSIJL1XLr7~f5v+k-rZN5H+P4XgV>zS|<+6lHR{mdDO0|Juj89-+zS zKl8kM20g^2vVt1g+(arWncSE6X+8b6knYnlG>?#w5gYHFf~!4wtqPwN)<|Ipo81ZK zFWsS`KqmfcQ?VOmrTM8cRHo+OS4?3oJZJtla~n{Yv$g!r?k8ffHw#Uw{U%PTg5+{d zq~%M)RPC`DVmGN%*}bktN5~N#9I|-hWux<{tw^#vcvgLEj`_o#@Qb`;ms>D)!$hsb z*COL1$b;Spryu@rY+b)Ogo>oG$E)g$4VFCEu8a>`tnDpe$K;|LB(m2kFH)vB%0-k}SY$<7`GJ zBGj3`Yg~^7zS&k;i`06QQliV}q^pdMl?^h2o9B*JIWeTYHDg|lKMh4A%a7fHrq^8x zzYoJFe4k9aX7YO|&ivzzzD89bo35et+`OOhNT8h$HH%)@1Szt+XRyCqC6ZCP+|>Ud}9xMy6$+B+y&xMKV&<*jW!ZvNdCNw8~we@f9zcChkKS5a?xcXca1LTS7b%!WNJJth;?Bx z5p?mMKR$8G0C77Ha(fu*5j2m&?&YGr8ZZ5Bw~;6nwRSIweMh{5UKcGz;nfheifr*N zfyFbNC+*}sP!A}Ru*l$p1=Rld1PbS{vKz_rf6U+}*1hDOpMv7lr@v8*Zbjfou@<{o z^JJAY6Oc0Xsjkj{$SaXGqA@-_Xqygz(eNnW$9?~ zMw@~ixb&lyk!xV0?<(c{3_7Wt*!p$k%ePV$hLHi#2l23mOd**65(BgvA zg>|}z9zdUMP%Dd+VKQJx7$I?9%erRfhSmr?JQS7{BX?SEQm)*|O!gy?lz{9KcZ$TS zXKG^0w^(>ozKrj0COfR6Mai$#0@7UZ{BD4cUC&ZQC~=H)-!&_MD0dGbC~@lE+ybN9N~ zWXbmT^~P^+T(#zJfM&6sl3mgv_Q|IQ1DE=pm#HzP?aKxF{WW_6$VR0O;{qF^JI>LW za>hZW=m|1(2gz?If^?Hact7(a_3BIpHWG^A1N-1eBF)%Q&5}Lvg4FHIx^3Ggx1qcJ zA-+GC-6;#tJ&4#zw^0LFqSOV7>O!AuEiD)~Q%3XN$&K$@&yZyTe71uUnn`jR>Do8#M3zZFZZ1>M0#r>n(8wY(6OH3lnj52 zpzLxrX?yn3U0bErS9)9NdA1=klKG_bfx4LYy|L{z5NF4-=lF8?xZ`!|S_<2^m$vX> zAKVL*$#mRE3|=1V43X_yHxe~MoEk1cKS*BiGPK?3Zm^-D_IiLgL%|iGE9Unl=N{P!6pSL`~J!!%wYnuh!8eDwVfhF|29T1pmmmycN@RNwQ<3dazlapI3svT z4QXw`V_Nr~-m`I@zW|HBda-a=cjM?Okg3!ElD$TM!5hpJe@XDk6M1L5tzFBzm>ajU z*-fI|OFK=uEdNRF_p?xuVEt!668Rzk+Bu<28Mh5I^es?x3i-s6FLIc+fG(=nv7~Nq z?Ik2(?(`lt;Q0S8^HmucHL`lLzqkLF+Q!;DC;#%%NT+*3A;Zz z+@`Lf;A4h%7I>a%Vr@7vE z{+1=Av6tgIP7x;etGD{`OV~9Y$Ye0USVRv&t*~ubXlEL_Kw94_ct(=0(75KmdqKGh zzYt{bP;x65a!>$2$RTeBD%GkT_9-834t~A$JwIx3{gB}6&yi00?1ySfIWM5VVbr_Z^GBNwiE(@fScCdOp=cb1Wd z9S@$C*!&zz)UU^tQ=`?}DNG?gjV7mG^|r5r!lv0QLLtPn**g zU{~@9;N9p!Hx)L+#EA<4Oinrvu#YLrRxRc`?M>y%1mmsY{o-TG4XsmlQ?}z@CAh1U`^lPnGH<=eCj)&N|D7&qiA*LPQf1hjWs4Hr|bcWal$=x#X@~+i&imYDnX`O z?&N?j7BZgCSTg_)RMAn;RsJ|Yc=G@4sfiyk)~9F*O}7Pui0L0vMJLj2vaSm&HH#{1c-Yf^zxLnmb}Ys&+ac$*Us#ZyKJeyQf4<3 zYYs+_EXH`JD7A_pPPr`N;}teLvPSn}%{=&P8}fVXZTrYvDhR;T?JECuAUd5MM>cRI zHQEp=k;7aE7iypBJxH5NBF*;6vJs;DRn&Jp{((RkQfnKm4<$%#nELKi&o zuJRW0Ml|#41dmDJz;QYFyT>qy8whlrvIvA@CcH_P0oG|Tq8Vsl3qkyphLx2>@k+FF zTX9?QJYN?uE23M*(t?(^>TAX2go*Hst2m8hfxSgX%{5EVxPg^JM+`e-daF8w_Dmkx zb-xW;l*GKwG=<1?;K4FNj@`L{!Ran>l~m z5{{}xALs(KfRxqQ-P_?>fniw-h9R4-=HEuVf={Q(K$m?|4e*q5XY-u2bxd7p%GP9aC45r+dHUp~w zs+jyK>yuIJ5V2q34N=p-9=1-ZUDi_aKVX(W7k&x0#_kbDZt`y%7`kK`7a3Mi{I2Z5 zYhwEGqbW`}jg03nGmKIC5;to1p8q{6btG=G17R)!IXMfE;?=g`sg)3vNa?5bumzW6 zTK!%_RRecTTk7eaM!bxMzGSAhwGvwYUCrwR`aS+f80GTUeTe}P^XzYJUJPN;)be|< zKp67_xWt8^<@?H7+ip^@%5J2eG9wWzcqjiO_^2q#{E0`PYRKP}s6E+#giSNGtZ$BEFoutF(3b`LwiEZcHqN^7Qbu?M)C)v78kVzMzmSQu%=7JLLdtDXOW*09 z3+}5qjJHH2V?UVA@7|#!WKx}6{JKnuTi=2Yn)u^k_JTeOG zBxl={(G%P75$s*4-#PLXRkp*&PGm9)w@n{>!O6ImH7p(UNVEpq#B9y2`4(zW6xO1= zBT}G8qxOXQK1C$IFB3xt*7N3&vek`M}7#l3M-Tn@84EmP^es% z>8T3CmxW|RybbFm(tcaCs1K6r7|#O*#$^fF&OL+PsV;)cH4`OjC~)2^z$2aj z^5L_EsxTecte8bXWzP6^ZiL-91CQjK31`MM7yO8okI>JN=Wf!$1zMdjcLSF+-pKEgnsl1`v9O!4Y3< zzPonT+x_vh@9U{HS>jVIonp(2!@jk$Zl50)YN_gd97IQ3#zu^PQjpifKV;Sp@5qC+ zx^DYjba_YgEe)B#{uC{=wSo^St9C-0J_Ut}NX|cRd@;ery~#soL0+)OtIhVlqk3OW zE27lCPtuWDih8du^wIoQ7p>kca2;3CuK*m<7tLA&Qo##X#`|dy?sNC_GKO}^;(|uj zp0~6MYHVO86KMrG0IXZkhha`ODnh>DD2T~iXZ|u*)j>O###hr)nC)F^KCNlwYkx26 zZ&@Pw3!aeLird7L(F4ZyrX-XPx*W-Z9-@-2GVU@!KhB-00|(``AAgQ^hi2UWT-(*N zYW+6p?_0t1JnQX39~(KvZ=zky6@%wpk&7p9$^MAw9ZpY9>u{WO-L~;XTzqrPIUQZUkX!53nc6RD$BhJeZF0HXdo_F~B;{FmI3x*K$8s;9vAk z)CyG)*QR@;dzZ(D2X$s+*qXY}9Mz>;r-u9`a&H<+;qEamlOB(Rwgio5g(ax~^~P}! zd6wcLPUQH|%9K0mNi+&6Iq~o`{VD70%BXrVxy8RBn{)Xst+md1F1hkXL zi3pVKWhh0GW1_;d@;$Zwh9;A?Qozgwd3Uxp1J8|9eCK>mw)>=~RTlgIhLr0miHEQT z8Tv0SEiB3!G?9y`KLMNPM9{t?q&1cHvxLasBU|q9PAW_U?WFm2>*4gb2o&B?mkO$3 zh|>X`{GK(q=fF~FVBB}uNq(Z`g){>7w7*f-VXF|GwFVGKBR3MC!H#V-CwUuPL2*I+ z0rrJ9YHErdE&Y3(HCfsvm@P?!BOOS!9M-UD1?mB1$M01H1I!mbs(l8)uN_)UqI=?$$295|sp0MGX+QHE`?o%I_BCh2Xzsp|oKYzPn@UKWt zLhxzX$(S&uy%nBN^t&WKoSXE}FW?_;um^;nMTFj;6FKG_GOml{4`*j?`{gwdSWi>l z#jkkDkf?g>~H{lGQDCM&SFal#3OE z*v6P5uC3h}bniR$ieXA`a8kUe$UWiTN`u>fe*w`JP~v3yDXYK@>pqB-bnL&rMtlz= z=OnjE%P1-kECO~q$npf7Bfz<|aK!uV+qR@`@3TpIJWi#3a$HuFOF8T_DOYr@^k0Ot zq>#ieJkh$vpc?+E;6ai1Dy%Ep9=wFAs&NZ&U%6+EXGug%&FNh*9F<- zAkW)$SeItHcX-viY2k&*sB-stZm>A3T znm*s(`Q6t!_qqPK??0|{U7gd>A@7-atR;Nj?u7b$7^fYHdv5OGNs}Tc z#c;RJU2Rt6-rP61qwr@Nzcl8%SQmv_aXh>N(#qiv)CTQO*n7~EdGooZj02Qq6pJ(A zZc$UEsmmy!*|+INi^k*OhENMrF>~jRw zW(F=rC}+9G+ot;Z6k-@!-PgSCMy}!Q6Rr2cNJW4y_Js5c+B$)iu8)aAw|R9w{_n;( zbX>At@h2eWTmN8^PMmiQvRKF=oen_kPvbj>GcFrY@Gwp&K5l%cV@@@UR2b}8!A_2r zkIi-chDf|{40F-{4IA%ajkX72Stkr`vhwpQ`p(VOZR_DiuKI~`JdvM8)O0N%RxW6I z@^JiQA^9S$ieldmiYDs)D@8xF~t~gbL>d0WPd{!ByC}DasZo;vXU2ms~EQnUxa} zgYK#kg!9VzO2|V*x>Z-(AKtM68BHNT<8}Y>IOu0J>rGXjySuS{%_FTNZm-49Hl%@E zhnPJkmUha`bmXSz&ryQ^^Lg6-AheM(*0D(q`|`SA14g<3H|OXcm~0ri@MQBm+TLyL zm6zkk_`O}za~o4jdZZavYCAN@gO-al#r<Mkt>fG$N z>qYnV_|IA6q3RTiT-&}JyJauL)`VZPVv@YCF>&!r8zOGKJsQ4jJ=v#mDR2+@B2FrC z!s&4;CIlNf>tbCOLVcv*x$CE0RE}M)bxzI6J6)3%rEa|<$F-w9UsrLkVKpQzQ}Jca zt6#}Zr(DLP=e0UMD`v{B#kM7QHZ-Uc)4IK!XX+coPu-tXm|>G=eKM!j20!S~-JCO7 zXeBrYfqD!@2}<*^rTl)q6A(Zvh$K6i+nQFKIDL$7o0qozW}`-OJ=on8 z(h4^bzvWF)P2ZF+SU$MJDvsGgrKgT8W2(FHMii#sH@?#D^6S~o%=PP z6lh!GwB2+G!U&`^E1-hVNN18x7Ue!EDr+IO=6^g%wzg*-YpO}TUD&kUJB95D=p@_r zDi&sK&oTf+i{BXklejSHh5Y7bZWCC?Z$q5@MT{ON z%&zW#{&0QbPr-t@!dEMW)d3d|vHNvY(NB>nIpe1*`$8LL?L2t=#$8cxU^^8w^lX^+&Ea460E0|BqJS`QoA68WK^ zJXe8VLid#%=J(*?tXF!to10;pp6i>j?5nvpa#ikVY;V=|4J9WxX;aFcKv439d8SgP zS>=g2?KxFaoz1vAZFF!)?iWIOgV4R1Jb@5pCZc7`MWjvnO*1;m!?8LB@G7dsF$niFOIACkpiM+mUnUo%Vg^JcCI(t%^Xc^SyJ5G-nM@Pjo}N9%JxMv zzoO$j_Pd=QW$^G`-c#TVtQZJ*LR(R4H6!!_YU$U#e))Z3+>hT;9lRSps^0_$&8oFT zjnusp(fO~5n(ftzB2#Ynh2_j*MH)VM&ZPT}U68BuaFgSlA$n=(V5o)BlI7fTL2*rf zu1;Fl-lip3J`sil%hpd;?dPGk-nXHXY#Wg#9@XVT{00m`8zrI3CWi_Q1(>3;Js%{# z)M{CTG#CpRK5+Aww~Tt|8QiuP4Jj~^x=(p5@=d70m{WsnT3;*rnv|usn&Xd)pxy`(lFr*CiNxvRNpJgpTY=+{6z7fTz}qPeJnS& z;oiAh?px|$jF%oUDs7KfeK&u};$lwAW{Ux%KWux2DrIsaL8$WOi-OW9e;qqvh5C;l z-u%e65jH3Y`-u`=Qlw1SLLO*E5Q!sKhwiOjz~7sbId@$>4gpUOQ+ika%YUDkDmBq` zyA`#S6mpj1amNcyoF(9btJf4DKdmwb1Jmthx9?ZtQ{6Xnh|tLz5YQUKfC1guA;Sx0}B0sRz}5@N#757nWmAOBI0>R&!$&($fxIREjh;WlRBPKzbeg++3lns z$a(2`>K0OEgo0*b}hX> z^DE!NDE}@9Sjlj391e@Q2fOz+6Q@<$y*E?XYWxnir2nAua`alpZ1#W88j>RR5=ctYox8ta-o-M*v@MFG<`V~J}3IXibvc#>%z|=>Mzj| ziC!pA#?o6H?-PU#is5T1jdqs&ZG)W`$9)mRo}G8 zt4ZriqsmJ*%BNT1=4`V>mL_zfvjV+;`A8G$OXV5+W>U2GZ6OGtq0TvEw(JXt zd9q5(J*~+a_0ZQKfr|RfW~cFbhLZ%HNzL0aUAD3c5(9V{eO`GKxJ?}Yv|k{prwV*a9Pb1GyzhlPl3-?1}t zeR8J6(_}q=b9Iv3$mvQ~Q=XNVUlOHDT-|PJyOMkkI*C60W$K4$r36W6dY9Ui!t2-< zA_G}1n^N+z>oaM3zP!q-HT+7Ua(rq02nz#t%%wYKT{!c%4EM!WOs=FlleTyri9p~0 z`^c?(iNE5On~VwKNItq?l8-R-!OuPP9VuxQ#d8;WYOX)4KLlGx^L)bzx-0T)KJB;NxTCgvJLJ4Pa1`NgIVP1d? zHP<@g5%m*Ig4MKV6B9K5{6WF8n`U04jXFER*JDV-qxzg7V&PF8M?V+`D~#bV zv8K4daaljo*}_4TTr>^nL>gCl-mRtn0UT7Tj zkI`x+bAIFP;mC5J{qQKP>^imRr@=Y> zHGP#Qq%zn*&nl;6Z3TBb?tC=bpbN6;+#dqjEfcGNydAfCRKw*EF2?la2G^Li&eFO1 z1lzg0t5qxxEA>YXe?aR!SM=|$<0W!o_M&ijR-+sZ5;sB@e8UPGiMG6BmZSgt`a|xb z4Un?_9$Kp9UulL@tgf|@`&$c)>HvQ0sQ&DN}u_E`FZ|dcnr=LLXx=W-o|-QiWqGCiel@>vNZYs*K*g+ z@pS40jNRX6(qqTlaMIwgo>mh~SB8#*1Z!ka`WhiiGno8joqRh3EjteQGGN2|04jz1 zi2=~;s4GzC*5wPSx}mseGD5OurVL=Fe=ORa z<;`C`6sQZenyHytO-t>XtR6xDR3bFT6626ahZpUSS#mt4gs=?vim!_9Br9W3mUC@; zu~HxvuNEvoNjJi~G&i=|L(_Ul(j$wNSh@E*T7-O+!USavkbe>W0r_?mz5g)RKnY%9 z`l-G&I}&gQ3{l4^Z7$ywcp`6{d%Kp#Nbl3o7(XTL-1OIpZ;v|Gl{EvFnzbIXM@e(oKs z#lg^TbuBi>Lyy+pz{;O21*)dU=u8@ZvEX2g*u^oWg*i2va`wg=cjsy}ey9W^czsuC z$3n+^w|ejAv zt=e@mU#6T&`iUw*yKwHw-!>4-(nt&;z_ntja!jm*!i{)E-O4_XiI#z0KfxgGrK+Rh z+!}3$ONI{(1-!am4Hq$_BV&-N@5AETtL71DMFiekKkR98JkPx4iE}>g9G)B>eYD@@ z$f+e(RNKwbYv1@d&hKHg{LZKe>pB0`gu`U%Ms#)V-e^fIs=LjmSd2F;i?Pg<#HP?9 z=V7*LdstWVVSy>^0cUAPr``sKTu0dIG}@Y0Lj4MRHzyWes9Jv{9vaJM^jK34v%yc! z)eji7#qMe_FkCZ08W&G$L|gw(Mj3vX?1@DfOyU67ML*!_EEx|wQwNJh>us2B7d!@^ zD#Av{1WnhW&)6gTbt{2ugiv3Ds$5xT0ptTeV_tJs}(6!0O!UpTD&LA^Ix z!Z$F(qlZfzC~3jX)Yx2fzS&j(-C8~a7svSa!U>HqQU#lRFdT|{QwS({@rYbaG2yx+ z&(gTrAiSpQ56zE1ND4!f$^LE*4ZQM1G!j)8CBz|oq=;B_@N~QW0C_cWcZEQ^AM~cZVARkMHyLZ4ao`^V`v+4hf6P`Cc{DBy zHTSyY>Qa4WG2sZt8Xt;y>hqzW4y3Nut^bQ@8_8fvgb1QUj3b~cBP)7R8B8G``=wF# z+KffFc4+upby-usp|K*{J_V0$4La2eLaWhFx8cuS^Q_kaJ6te#t&mXPmiWn~bdaD1 z)#@0dMK)BTpAKD;FOLK104(t_?1j>`DZ3olpl>w#*16d6A8Jr<#2s-VgG4uIMk10h zf5E*(z(QJz4)f!<{ufKw zW^>_!n(b4|ZD#rLhoYZyL7SSi`YIrFmtNa}g;VhJwuUO(j;&IeV)42gAS%j}qdPLIxqW4d^ z6NY$Gdog;aNa|kzqW^yb5P_)ZCU&!$k=o#X!SYuTJGA2D#H68MIAF0%l*MGor5e;5 zj~M7>L?MmIf4ge;ZSS=~@7x!Bcn5MKy`FI1H-plkoYP(rR(Z(Sl4{REiAR}xyb2UQ zg^A<9=uJt7uPD7Vm>Yy?;DQ?W5_g`mj;+`68ew8qtg_)*;A*Ki*HDNWE;3I>8rLXw z%0AVmZ1eGjYhs}f@tIJUJ{m~ob5S2p=bI|bw* zUp2a#AX(-RX_A+I(Jou4D;HD*fqk`z4O<$j-O^jFaB!^H?GL?*ID#=~vDk>>m8WCE zukjv$>9Z88UL1#ZF)dWJoVl--1vpt)+ad&`$jgPd)fu~)q8eSkwjVx0KF^=Y7-B^ufZ zZp?Iyha@P#N{UF|5a(T&BSG>uOhT8H-MU{_s~1|PKQ#gsUk?8}9x8LTTW9$WH@%U( z+Zgv-bn{p$;D4eyy0CBn1Ruh)M5s4>F0>25EY2MBBlXq7@kBs#_~LuA#%qbVn5jaY zuf?3Y^`dbF;1@2q-ipX5T)iLBU)q@kTi%Iy)BX#SD8uIAuHe8@AH z2)D%~fG?-0dZZiY?zH{n>RHr}lBEk){+Tvh>HPNGrb@+WpV0H?Gxc5UZEfjA~^;bp}#2<`f9%CAT)W+tF^ z{t$Psxl`1_FC<6jI{jO@pIN) zO__#L&So3!qgq0x!w@! z*Obs$KZDpELo;D+EvC3;P0|8Ur-&;EZnsbzqkb+m#Bn&QB(6l>BbVwa`0mcso)N<6 zMhVq$jT0(090+wVg?f$4_U}+0ZcP4JbZK(8Aj(i>2ZQFZwTg0)23eP{aKS|1q5;}a zAhCZoK|1+VA{;OtCB*mI;GpM3u_Q3w^W%M7tA-u{*$(bF|#vfT$>5GY}RAB(%%RBeBzMyBy)Y;=RnFMU$! zX6J67wUWeY2zuBWXt$mE`)ITmC}1{K@Hyzpl_tBq8~WZn&D50@#JGV2R7_GV?|<{}bE@G>U1(Cg zY|zv_i6s$_hxOmoEF*k`ob7gex@3ZzxKbv@*R?Bv?s2|%=E6Wj#==9~ur*}wA5k8U zit;?2*4P(1G~Wzuz zED>_iLfT>`He+@}J-wUGa^2SY-t3i{x~fxR2D=o{PzLL~V=4FRFYz*#jIH9j4J8(L z`dZoI>k&&P`z40#Q`ug0p0$x7u20u=1$lFZBs01&CLMsk<83~3b_0+X|MG>WeRBF0 z%m9ALB(*qeaa8B?2?z8(GTU`->NRUZgxcf1Lwv% znX;9O!)fWepeLU?r`#%2 zrg_t{-uT2aAUI6-siQ%D)$)JP1hJrV&)MixR1CF!^hymFx1)?Qa4eDME1{+ z(c1v90>`u^=DB}W#7q!+2|h@m$0ZLrJA-F*oV z7I(~mq)`gMyLvjctYt*H&eFW!o&*wL=YiZnK3luPVli{pf0)}ZxU{AIGPf&j>vpPS z+(*$JCWJC1t{&hh3;m2&YU4^~2_KZa4^36e{)cZbWt%u4U>a;(kW43_XSr!l#=htQ z$C`fG~I)WhiEZJQ?D=VM#1P^0BHtgV0!ZFpy{FE4qH z0>lR-gI`mtS7(NodZV5BE=mLd^qK@1)%$`kwIn&f2#tkr0cckbi0G?^qI@>OhW-p1 zk$o!POPu?e8{4{c^P0tzbMAA+0)@}cxleBcO1RsRZkyxRoz3l&ywc$as>JBD-eDJYnfas7jJ-fhe4j)A-`iK! zlA~qk?5ac2T+*oO!D``hj#{hLTg>Ryf-a`GnZArH4#Ohh4AENnx86>)Upttb-2nEbrtUX`iz*XGwz(HJ+Q0P}8~n z=bmQN_)MGn9+2w3VV36O37=2SD|GW^=}agsKFyvQbGn^TyVVJ<2NlgH7j+g%hf75g z-O||)T8dF&sKHjx2Ol)&J|&KX7yGQyY4pzgC^lUTY!7>r!~uYk&RRF1#QE6dLyejO zl!=4wbualogA~M6{oN)}V2&_2*hs8EsXfAuK%Th;gxe+xlw7`=c=hrI*aV=u;>A{X zN4u~Q6$}AjR@*uccL->8o+%#ek!jh91gD-A`PfKRSS6}4;uxh`r2Mt}!^-?>QRX2? zz^O=Wd!@hCyuk>~AB!D{SIb4L9Xh>=vAJRCnp0sp*AtOG`n)Um>xmU2ZP0m>XD3gi zdLiY;Qs;?5OVC6!`_q4a+t94i%*89rHT_N-BxMORJfT#hgtP6T7re>mpv+t+9g68h zMwq{G`gb*qE1-MHNoYzU&D<>HtmuLXTg@f$73j3Z-)H|7;&}FcP7|j|Bm|R^+s3zQ z?4YMknP};7)Vt5TO>_0n zj+j%Z4V@P-yAvHJNmrj^#q7E#q;AdD?6ckHDoGX-?^E=zQ4DSK-d-OH6?fnvY-=&9 z?3h=%yyY$VRf$GbF~0z^7358n*&!XA9P^Xty<$tbSwDlZi`^`2?px1?5TG-K%uWQwjLnkVpthA2vvDcbLGpQxQUc|HKiOQ+L}Q6szDZ+@HA z&;3#M%y;s=)}HJQ4Mqdf`Y_As9dFk(DKc~;J-Ra;7$$YRr_EKnJx~r3&7%8&S<2*_ zmpmT>&pJZ%(j4D68QnONWzIoCyUbTuuebGGA!DtkR$kI!ECjBy&F^7$n+Mq`K zsg!)G$fj7}5?#J$mS#Ez-%?&ylfSUjQ=?-`i=Ev&P;luy9AO-MI1M&jZuWR0Xy(>X zmG|i&3$wvpS^l&><$$60ti0?2N-Zhd=uv}5wsiFWEHJ64CIyXD&83WuDbN}LU@3(a zXtT8x(Ee1Y7=agc7FzKh^|S^_`)Vgr_pwha=-V|<#$`l|>dWTN z2VW+*TRHN}Gv8_K%ZGd963C+g{5+B@-qEnzgTW#%sBdvgb#Eq&sKjLebqNMdU=q5w zZZ(g3GP6&On&&gNCn*EBG@s5)0InqFOCn^2_LcN4Ey2o0Q+r$>r>X`;p<8!#tf7M? z@G@C$>}GbC(w0(+p_PGSPz=6hUgnUvGPrCq1mgj>zXEl`!T}-r6I@rddJNU~8H&!(j5f)I@3l3U< z9FMv>{^VtX=Iy`5WP!534jbV+`aD|DaH#MwXKV0ambo(ODs2T2$mkMYB^H`@Y^%;I zu9#bWl!akD^8`vzn9i)v?=kE4>?3?9*D3E&;R5>& z5DWueIHy9GjRSW-!kR*CSmkOI1t9F1f*nl41wiEtY$1F zvwYLQ4#WEbo1+ByN#r1|P~JOo*35%)6U3usrk$#QpCspdD4vny3RJU*&k7pDp)v%d zK3$>{x^|3i8y$ad^UXTAp6zCMytimmS>vW|n~tC-vnKOn)7_Qg#W{;_1@jFX`?T^! zs_`{RxOe_0!6X!uQacM*5@8FIxrrpp}d}#H4cl2RchCb zCMh$S)ZtGi$u+?SWSeHu$-W}n>SUn?wYvuYVf);FKsO*1wkvTMPF5{uv1?Pg!pz6) z#4HSnq3Z?iOEd}P@6vfY?WJL{BD^0Iv99B}sGRkQk~orAzfbF8A*%c}Nwn-50Pc_w zfKfR+FXl0xKrNoF6{wwfib6v?oJz`3Foo;`8MO2H3*_;!WME&Y9 zh&PTZR4rvGwhp$!-<<$neJ~e$fB(gaCGv*ox&Jt-Z48EJ@P<;3ia=m7S_N_CyfZFb z6E>Z9yd8~$Pps`;4t=HI%Q9y|5&DDw5Vw5pzcdgH>b^!S0+l{c2o!2=Wb3gJ29=C` zm(IK|pzl!a;0=|TgV-8@%U(3Ufl|=;*R~3wiF?TWbFuz`@n>ns$WlsFtO2gdyN|jo z^EDPkA&b${$@k4`8{AX6^vY1IJov&}{}5)y#>#Jfy}|CD_t3QIvl{G#Wx8^>Dj|j$^DwS!7s%_o9cxrtk6 z5qUj=3<4xDxArGYUlUS%gRMhlhq8#44c4zytVh~nr_V|LJR@&XoQC>>+Taftgazfg zHl*6@Euzv5WzhLj=dR;JYA$9R&>KuFKEBBPf;kU#FGp!7EL=6S2=EdeMeL`9<9G8c` z@~kW^$-CtDRGChU04BHh&jO`KL1WJ0#5a30p<*Lss|aRAR7A0P()C9-NxepTc`H`Y zQ?z@oW@MU$lYBYuR{5<+y=uE~6HH8b4ZjBL#CXe?{7l05p+Y`4Yvc9NhjmPgUksy4 zW9>YH>f!*S5<3+k{QjPcb0vr0Qv}O-s-S9oY^oauP@ntH)Pid~x5uGF_`0GeX$Eh5 z55sH_SDtI7YRPFlF4{qH0hO;J+Pc@tSK^{9%SblLvdE9kiVX*IP*G~;ojPsRDh|B6$1A~r*8H|Cx?D;@mO zk?{X?Tf*jR2HAh|=4+pQns~Wm<$i}P+DT7mgrw(%^(T(wwtXi8L6KSAapI@Ae1%YJ%&By9 z%<9c?`8o-aVR+q!768FLh}@FYh!d`g11F)rGU#BRCr1&IQm$E9KTfJBW{v}z390W` z$qv9@2j1Jbf0sugYq#e)9mL^!UE?d5A?I^I8S$$#sJD%Nyj@*8V&T)y;62Bg*Zm-p zQZ1N#Cyn~Cj+RbW3_6U|ybT>JUZYD*G=jMn$h!e38YWS^GR`}pBL|M$w-W2c~wEOjG|<6Ja`{Jz`6Xa zHb=i$*B2LIXpBcJul1s-rj!+;J~NUzYH_|O2k$ufp@G;k5pIUvZ)e}ijseg93YFW2jVISb^xuhthj)T@~=G1qs%Aej`c<9a#ggNb5KohGro9!O-1+RRm_pdTF zGqaoa6Fvi;^yKra9C*t8OVDvaAUFTaT`f|H!7xnoDqgk^dLT2D^)AzuCCO(OJm#<5 zr`da!L?h*g2IFwA^@h9MbgLns!hY;+L>C9S0!2ZwUj)Od4ce>ympH0%8zn2}i(Sq? zGIPHQtS@4X%5)qon3MgK_7Y*~3unZE#l%%0mf`HX_~+^cNMG6rW|RcU1bzL&(&oXj zP;yAmSx5X;Bhc~>jiis0I3rkjHBD9wPA<;`fRBEM}wgB zIUw7Y^)Ap5daE_}*A(ibWSmXO|0gxgN9qghyr;Y7G zp~PYIS@H%$f-EFOHB%PS7WIo0#Ku80kxeVe+x)Fh&SJGI%aV{i>X1~1 z&>5^Bdr-iu#Pu3dQQd{QHWt|?L4(?pt2SSa7X-20V3FGY)mw&IL&=D-HPQVG+B@4# zTlEUt>4$5s8oPxGrnjhpo!66Br+hR%hvRCdKP;+&zWFQ*3g2QmAnZ(9i1~arUGPf4 zh!7CkbJklz#U!~V#K*_L9yTW<>|Lr3u2;%K)A2X7HwBMeWCuqY2mbN%lX?0J#Nh@= zJaZ!5Sgg}Ha!in?o0^-i4b_32u7QK-dy%}I#Vd;v+-qUxjjj1P5H@_)D*e_e$ zi)Kz`|KgK}1q%-eB;UT?WY$6r-$w$ZaJBDLe87-;N6**i#$`tt7zXHK4d=r1!ED`b zW2;&SNule3gc#(IaBHbxb=Cny+kJ_}0V5cM^U}2}`A=EVU-XMVXo_-X{>7}lq9bol zgnyUcH^48I2%U83=s9>x=j}_PSL$C z?nElsHP=AM*52E^e^_2=q2KDIX1)S_QT(;DxtVAM(Px9a)eXC7ME#OKEeu*Shk@15 z9ap1bmftmTD|3B};*m|g0jacZDWNXs-RfW3{ZJd zk&1l6_}T0}#m);1&$EL0lRG{-hzA9l?wUcI#B{2WL*)$emm&EO^vP(RccwCz zj*jwS=D!D^iX+$sJvQL5?Mas~CF#rh`a6-p;;xaZHLq{2&NdiR^fd{T`n4)NV%5kk{jJ|#VcG7*Au;Nh?mFxp|3Qdv( zlfM9V@TYth+eA$(M5i`=5dn95$@eU&t-JgBfeFBY>=Y0Kzh{VSUY%rtbzTAEhpLd6 zEvu+I!mcvT28~2m{{m|pr`+gMLmn+%P;@#_B zh-RcuZSix@6OzwVXXyHn9apHxBW4jTzq|^1Rl%(+AO&{8QNxe<@7bIKrs8Gc{#+8< zV;6n7!FB>JQfI{NRsu=4;^4f81c2mdagbf?g@?<1K!B(RL@j?fAatjgqgnw_KB1AN^QuGr|;68kB$7NtV%RJQFM;!$j&($SOW&VYYL?DXG?4urv8^v0mt z8x}on7fUs`k|JYy+37z-EDdkFJyoBzro@te< z0$-%o>*`L;Wd@ZWEE09nB!$;0b*F}k^r+SuDM?ZEr=ts$YCQ}pDq#J$N~3$#ac*ur zKD|qKvvQCGg7kci8CURZo|QT2X5gkXc*kzvC-d3}Z-^y%+cT#l(oXX?j9lN|%-Y)4 zu=+Dr+v8@9eG?Li*D%`WgP|t(V)hix_;n~vt-_&kiyP-D^bfVWa_pJ=y?V8Ku0!Qt zyC>yF&~<@7>jXj6c+-(g80vO9XlZK>I%t?33bpotpyBITAl&l;vt_^mImYW1CHrf) zvpWH{=UzMmjBsduhdzq(n17r5zlqj*uF8 zTksiu1M0G>N;`KP#8PYqP2Jz9-Dv0>K3X#CG)S)R0VdZ(>!dZjywT1PK`{l3!|Z3_ z0k>}Dk$$0_!8@Bl>iDHq|3D={aERfaCurp>CK@mP^Z+;QZ2sM#e}oJ?m`^hQ!_YRx+<@TH_>FoG=2X;^ifnLv`R15iH_QynTha6T zyMui-z-t*B^ZZ;K+N#Er5!M#M-F;Pd3=EYlg30$ENgMZ?d#r#3pLF4s=n8U+Ia{!p zEQo*1$S?wYl!P^M3^SaOR8vCnJ-R`lVu9u$Y~_8dAQ{V>kLSpU>epc_;k)3A;wR)#n)qV?Gl!87dVLn9RgrFVZe( z_*!n|?c2lAbP5S+9fgZZW77dS5gGr*!HY�?;`TUj^G(Meedi?t^-oA?oE|xF(UC@#9Ll94ve9knZtRhG#xs z%I-7R-++=#7IrcK9|<(w_Zn2X44QlpPiN5^Nhw3-c1?}TEsOdrewQ`8X`Aaf5c9OpCPnttA%ey!5bpc!)n!+oVh z`$k7_;^)h?3ofd|LU9XfAfUu)OqLsEoMA}ILiudaxESI!k0&EJbJ!RR<3jjLdOYKF z1035^;Is3$1oweeL+yojLDv~koA2L6bii6r6=cX~2>xAU<-ogVl?i-S!6B8c{*S1^ zHH1)(N9Tz|hIq4fY!@4d>wpvR2{+&`>GPMC4O~gUX_|=tJP?xmkpkIwwRvCCf!75F)*BW`gZ!IIX+&nUagHjKWg_rO7NlkK-Izo@w zXG-y4QUkQUFHeSn=eu4QIZHJBTB1kl%3{fu z{;TVEDZM^6)PntlCK&N=D)riSuJZIW(cN6rps7kM*&RJZ*9wefT-I$fWR$oM{j?o+ zboA;4aSG;c#v%2A>YYpPq`xsdQ$w<+7IFwdA1{Oz%vbIIDUmd4V^@((96X{W1a&Qr zEqmrZ9heM7;`&~;*hl5IN>TMJ?XMmn(8|Xco{j$_I`n@3Ye{EoCD{Ba)*+!nlh+R; z>UWx&G=-tucM%WmCBnPGC7D=BZfU7E{%Bl$@GIze3d1w!mtB(fCAR|ZzHi^XAe!(z zu)i;=jZ8Z|^v!tj{ZiGV9oG3^N<`#|B}g z3;sGL#KSlJW=ofp$H=QO??I46ay!KE{D*F^^TBUr zL5#6~nO>POyACQZ`Dx-Hq%1iW{_Tf=aU2uFZ2-`}2>yOcWL7rXy>KzZVCS@ePF#mreswl@CmIJrG@xIU_R|tzUEM)(#*%-P{ z$Fo@ta6FG%peM3arTe*n4t#by%(Pr+zIjftAFu^WMmj-M#p3VL)%wUh9+lFazQ5^DGe0;1G zTr)EdGh`PJ03o}YL8b~+!fR1~kL}U?MHXAij+r@6&)RyUb2%))?=3IX0a`KS9i_QP zNm~mKT)Q+5jm%J-M$rCjE$Nn6>C*l&MtfxhY%eI|M|{gnyVPOnB@svh5XuQ!kTPyn zPf@boc|&h9ch_PvC79HkTvuMrF31M&*Sep>TY)W{fS($y>%%db|2bXG*A)Hz$L9~C zn?85WHUGfzQ;02`aZ;?PzLxNwv@O>chnU3mMhh*uI%d*`O|6f#%2GULld>|;1u&NOMI~8FuCfnd zeE4bHmXdDU(5|z8wD7+)H}g=Xjp%HHg8iLcx?hWwD6T1Kv!vyIvQiK$W&_G(Jtplk zj}Lv7ac!&pAi}#xB9w+5_r<(%4kgpB{0R`AY#vz2I#e1g=3$Nui58%44?3x34Mc}h z;v2TMl1tncEr9>a?6OwTNMEs1sG0OXDC{AIpWhcfWqsQhw!-wek}H>%qv9IsO)W z?*Y(tY-ahKA7m-+ zjeS_}ur{?mez3b`y3AA=a@4VTqq)_dGq^3MFKYhh_UlO{7i(`yOkawq;!Y}bzwCZO zTk!KMZRDYct1^MhURw>o6l$>7l=G?V?+(R|mK<<*ars8K2E2R!u#b8jFK{yYGMezyQybGk-~!tMaF)J*9JpF8?^U za2f}U;#|a@hHln0m5AL&ye(l-zEMJ`N8AqGyx5PieJN7j!OE<;o(%}m7a|eU) zrU|`Oo*BA6;SP(aJwF%3qyCo=jdpvvplGFfdb-6gaZg(?w=bS1K8-cy$Jp%!*YDz< zZ?UKiI6>paVHe5?~H)bT=u(Z?qTHTWQa-oN9|K6dbjo%W|&3tql&SN0yc`%jUGc6bti2%Ljs-&YVFe+&VQu2`KRk>rM>K->DUHYk zA^R4j{$It@f_Jg)IU%RJ=t5RrFwMx=D!aWK@0)j3Y3GNyUgb_w|D$Z#lDD|b{Dhdf z3$!uiu&sP0FO^=a=Cu64C8xl#D8^~m(s`B?A)?>k@i3v)v2$F4($s0$RD!n|W|oOi zAf|?BehKO^J=g3y3if}D?R*^@kGQP&=WXjJQAmdI!t;FRu-%|aQ`L5_}8}AcLw}Ssca$hMws#l*( zSz8U+i?P7E+amBjr%>N6GY3umkM7R?E2%VrdLfc%ao$@;+Z5R*D_O6@ZH2T zRL}`E9ne-CGh5NdH_Df`(n_1su}LKa1r)SK#Gz4+HKVSb&A~t&l+LoJ@=RU%RTo#&$-X@efiK>`cqdW(~9ZM-^!G09P_Ham9GAGm*rSeu#T6S z;0m)|{mH1i3x~7~wwYHM(jb%jGAAiNCJe=fIgvlcym+PrlA2*#1L%HS>MCU#PC4dl z@#_y2GH>PAahAL76Ta(URSuDp7)J>w)QNE8{utzvof+ftmGbaGz)9olxY!jnhU9%^ zbUK7D|4VJlA~rs5bs<~{k4x2nZ%3Y`OXTj-0~8NNS!;2Dnpx>{cjDoJ__pqdSZ9${ zy(PPwv=96}A8~%d%mz)jhFHAwwZ#+e5!9vT@E|!Y($9=xo=7`N=)7K-@Igs!C@8PA$Hh31ZG1@Gt_&aNfylfmN#&D-IX(NYX?3v?a9B< z>#Z3VR2u-CcH;G7M)SSHv(@!1??Yy^oNaXu6y(a(boQecyr*EN$bA=>M`aQfr}j|e zVwBW&bBv(M-i@q>&0zMT_CLS3?&t#@J$p>d4YQ9cb~}*K*dv4VBbRSTzn@0<5Edfn zHa2*Eok^P|Aw;y|vH+%cXWOlrx|Wmp?|!zZtuoKt_ymG8NurNXpEqf+#ApsusJt51 zu!XH?($-=!yw-ZRFKyTZC}=r$gp_~gtSAM~e*O39-l=X;tke!)pR?|jOoxNq``$9DhAf9blE@bAKI8PD12`{4H;XzZBN-_w|2M%S;hmszGqh`_r`=>s?%ww ze7(v3Z=$6*y>;|F4sy)$Bt`E2lLz*f=?>Nic|t3FwiJIERqsTnT;u$vFz1WfqY81M zlPXyO$2JCy)_y_o7-xwZZr+TF<&)s~W|&oSagXHZX4GCv&Kw8PU{G+-$F*Bt8{*lY z`|17193QbNv97Hf{XD5%L5;F>rZVbf`rs=RCUNjM|C;i8@sHGlJLMTpMGxr6;M0OC z=PAYk;9`5*j|%@hkqu%dKbhRq4B|m*>Jwp0GT-*wNLuZ`V6`HY}A``0jpf z`t!D#SQ*=t>kJF)B|uliOf5U+Q+ThJ`i=hoqZa;?ZAI84qCW!TSYl z_`xt?s(kA|=vFB?BkY5D0^f6vl)uh4=EIQZ>|&W=)atdO_j!9M>GOSi3OyLZ#H?Jb zeWI
d~!EsjX>45cQ&8#GnXb~CMo^H!Sh)=^J+85YJ64&@4*a25bPd7S*_B9%cAJ8Q?f zfRRB5@!_Z#29{`jq_EA{x%lbwOwGS0juK|GtNBZy?X^FNaUErEazu6eGCpqa?Xi{z zMpJ{DsP74O^Fgpqm|-(X7dkF>%rDMVb^CXjJPn+e0)!qn=18Wv;D{^j>b}I~Xv>b~wdyxw6+CpL--|2GI5U;3S=0b6dGOgz7r4|<3~R9@DR&b1 z)Yw-%nrz}cJBvUhX;RZAc@-izu4@7^tSzmohr&L5NE57!U5!D z8KCYMERsH;`~I)4f1n^gi~?tbF9_ +(prepare_pt2e_quantizer) | / +—-------------------------------------------------------- +| prepare_pt2e_quantizer | +—-------------------------------------------------------- + | + Calibrate/Train +(convert_pt2e) | +—-------------------------------------------------------- +| convert_pt2e | +—-------------------------------------------------------- + | + Reference Quantized Model + | +—-------------------------------------------------------- +| Lowering | +—-------------------------------------------------------- + | + Executorch, Inductor, +``` An existing quantizer object defined for QNNPack/XNNPack is located in `QNNPackQuantizer `__. From c4c0761e61a750911ebecd24a17c0457c94fa339 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Thu, 1 Jun 2023 10:42:46 +0800 Subject: [PATCH 17/42] modify the ascii graph format --- .../quantization_in_pytorch_2_0_tutorial.rst | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/prototype_source/quantization_in_pytorch_2_0_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_tutorial.rst index 9b188b3d807..f899227d78d 100644 --- a/prototype_source/quantization_in_pytorch_2_0_tutorial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_tutorial.rst @@ -1,5 +1,5 @@ (prototype) Quantization in PyTorch 2.0 Export Tutorial -===================================================== +============================================================== Today we have `FX Graph Mode Quantization `__ @@ -23,7 +23,8 @@ with PyTorch's quantization 2.0 flow. To accomplish this, they would only need to define the backend specific quantizer. The high level architecture of quantization 2.0 with quantizer could look like this: -``` +:: + float_model(Python) Input \ / \ / @@ -34,13 +35,13 @@ float_model(Python) Input FX Graph in CATen QNNPackQuantizer, | or X86InductorQuantizer, | or -(prepare_pt2e_quantizer) | / + | / —-------------------------------------------------------- | prepare_pt2e_quantizer | —-------------------------------------------------------- | Calibrate/Train -(convert_pt2e) | + | —-------------------------------------------------------- | convert_pt2e | —-------------------------------------------------------- @@ -52,7 +53,6 @@ float_model(Python) Input —-------------------------------------------------------- | Executorch, Inductor, -``` An existing quantizer object defined for QNNPack/XNNPack is located in `QNNPackQuantizer `__. From 6572cd6a3cc424c180496a18764607243e8f3060 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Thu, 1 Jun 2023 12:14:05 +0800 Subject: [PATCH 18/42] Modify the comment --- prototype_source/prototype_index.rst | 4 +- .../quantization_in_pytorch_2_0_tutorial.rst | 56 +++++++++---------- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/prototype_source/prototype_index.rst b/prototype_source/prototype_index.rst index a7e7941ab2a..20ca55eff34 100644 --- a/prototype_source/prototype_index.rst +++ b/prototype_source/prototype_index.rst @@ -69,8 +69,8 @@ Prototype features are not available as part of binary distributions like PyPI o :tags: Debugging,Quantization .. customcarditem:: - :header: PyTorch Quantization 2.0 Tutorial - :card_description: Learn how to use the PyTorch Quantization 2.0 stack. + :header: Quantization in PyTorch 2.0 Tutorial + :card_description: Learn how to use the Quantization in PyTorch 2.0 stack. :image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png :link: ../prototype/quantization_in_pytorch_2_0_tutorial.html :tags: Quantization diff --git a/prototype_source/quantization_in_pytorch_2_0_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_tutorial.rst index f899227d78d..db5358ca71e 100644 --- a/prototype_source/quantization_in_pytorch_2_0_tutorial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_tutorial.rst @@ -25,34 +25,34 @@ quantization 2.0 with quantizer could look like this: :: -float_model(Python) Input - \ / - \ / -—------------------------------------------------------- -| Dynamo Export | -—------------------------------------------------------- - | - FX Graph in CATen QNNPackQuantizer, - | or X86InductorQuantizer, - | or - | / -—-------------------------------------------------------- -| prepare_pt2e_quantizer | -—-------------------------------------------------------- - | - Calibrate/Train - | -—-------------------------------------------------------- -| convert_pt2e | -—-------------------------------------------------------- - | - Reference Quantized Model - | -—-------------------------------------------------------- -| Lowering | -—-------------------------------------------------------- - | - Executorch, Inductor, + float_model(Python) Input + \ / + \ / + —------------------------------------------------------- + | Dynamo Export | + —------------------------------------------------------- + | + FX Graph in CATen QNNPackQuantizer, + | or X86InductorQuantizer, + | or + | / + —-------------------------------------------------------- + | prepare_pt2e_quantizer | + —-------------------------------------------------------- + | + Calibrate/Train + | + —-------------------------------------------------------- + | convert_pt2e | + —-------------------------------------------------------- + | + Reference Quantized Model + | + —-------------------------------------------------------- + | Lowering | + —-------------------------------------------------------- + | + Executorch, Inductor, An existing quantizer object defined for QNNPack/XNNPack is located in `QNNPackQuantizer `__. From 1eca8320b2802d90410f7530e660b62b7f41fd40 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Thu, 1 Jun 2023 12:19:24 +0800 Subject: [PATCH 19/42] modify title name --- prototype_source/prototype_index.rst | 4 ++-- ...al.rst => quantization_in_pytorch_2_0_export_tutorial.rst} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename prototype_source/{quantization_in_pytorch_2_0_tutorial.rst => quantization_in_pytorch_2_0_export_tutorial.rst} (100%) diff --git a/prototype_source/prototype_index.rst b/prototype_source/prototype_index.rst index 20ca55eff34..05870d6ddbb 100644 --- a/prototype_source/prototype_index.rst +++ b/prototype_source/prototype_index.rst @@ -72,7 +72,7 @@ Prototype features are not available as part of binary distributions like PyPI o :header: Quantization in PyTorch 2.0 Tutorial :card_description: Learn how to use the Quantization in PyTorch 2.0 stack. :image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png - :link: ../prototype/quantization_in_pytorch_2_0_tutorial.html + :link: ../prototype/quantization_in_pytorch_2_0_export_tutorial.html :tags: Quantization .. Mobile @@ -200,7 +200,7 @@ Prototype features are not available as part of binary distributions like PyPI o prototype/fx_graph_mode_ptq_dynamic.html prototype/fx_graph_mode_ptq_static.html prototype/graph_mode_dynamic_bert_tutorial.html - prototype/quantization_in_pytorch_2_0_tutorial.html + prototype/quantization_in_pytorch_2_0_export_tutorial.html prototype/ios_gpu_workflow.html prototype/nnapi_mobilenetv2.html prototype/tracing_based_selective_build.html diff --git a/prototype_source/quantization_in_pytorch_2_0_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst similarity index 100% rename from prototype_source/quantization_in_pytorch_2_0_tutorial.rst rename to prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst From fa8437a9ad5790eb8f734e326ffcd1c452cfdcd9 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Thu, 1 Jun 2023 12:30:50 +0800 Subject: [PATCH 20/42] format document --- ...quantization_in_pytorch_2_0_export_tutorial.rst | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst index db5358ca71e..4f6a0b0fb36 100644 --- a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst @@ -27,7 +27,7 @@ quantization 2.0 with quantizer could look like this: float_model(Python) Input \ / - \ / + \ / —------------------------------------------------------- | Dynamo Export | —------------------------------------------------------- @@ -40,10 +40,10 @@ quantization 2.0 with quantizer could look like this: | prepare_pt2e_quantizer | —-------------------------------------------------------- | - Calibrate/Train + Calibrate/Train | —-------------------------------------------------------- - | convert_pt2e | + | convert_pt2e | —-------------------------------------------------------- | Reference Quantized Model @@ -52,7 +52,7 @@ quantization 2.0 with quantizer could look like this: | Lowering | —-------------------------------------------------------- | - Executorch, Inductor, + Executorch, or Inductor, or An existing quantizer object defined for QNNPack/XNNPack is located in `QNNPackQuantizer `__. @@ -637,9 +637,9 @@ to run a example with Torchvision Resnet18. With this tutorial, we introduce the new quantization path in PyTorch 2.0. Users can learn about how to define a ``BackendQuantizer`` with the ``QuantizationAnnotation API`` and integrate it into the quantization 2.0 flow. Examples of ``QuantizationSpec``, ``SharedQuantizationSpec``, ``FixedQParamsQuantizationSpec``, and ``DerivedQuantizationSpec`` -are given for specific annotation use cases. Quantization 2.0 flow is still under active development. If user wants -to learn more about the design, here are some further reading materials. Please contact to @jerryzh168 if you want access to -below materials. +are given for specific annotation use case. Quantization 2.0 flow is still under active development. If the user wants +to learn more about the design, here are some further reading links. Please contact to `Jerry `__ if you want access to +below links. - `Quantization in PyTorch 2.0 Export Detailed Design `__ - `Quantization Annotation API Design `__ From a2569f80c0206e0478865ea08d435297b5c10116 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Thu, 1 Jun 2023 12:33:07 +0800 Subject: [PATCH 21/42] format document --- prototype_source/prototype_index.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/prototype_source/prototype_index.rst b/prototype_source/prototype_index.rst index 05870d6ddbb..0f190d51190 100644 --- a/prototype_source/prototype_index.rst +++ b/prototype_source/prototype_index.rst @@ -69,8 +69,8 @@ Prototype features are not available as part of binary distributions like PyPI o :tags: Debugging,Quantization .. customcarditem:: - :header: Quantization in PyTorch 2.0 Tutorial - :card_description: Learn how to use the Quantization in PyTorch 2.0 stack. + :header: Quantization in PyTorch 2.0 Export Tutorial + :card_description: Learn how to use the Quantization in PyTorch 2.0 Export. :image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png :link: ../prototype/quantization_in_pytorch_2_0_export_tutorial.html :tags: Quantization From 6923015a6540b775b76e5e562f2dfec314c96e6b Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Fri, 2 Jun 2023 09:14:53 +0800 Subject: [PATCH 22/42] change descriptation --- ...ization_in_pytorch_2_0_export_tutorial.rst | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst index 4f6a0b0fb36..21f5d42b8a3 100644 --- a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst @@ -32,7 +32,7 @@ quantization 2.0 with quantizer could look like this: | Dynamo Export | —------------------------------------------------------- | - FX Graph in CATen QNNPackQuantizer, + FX Graph in ATen QNNPackQuantizer, | or X86InductorQuantizer, | or | / @@ -194,12 +194,17 @@ defined later. return model def validate(self, model: torch.fx.GraphModule) -> None: - """validate the annotated graph is supported by the backend""" + """validate if the annotated graph is supported by the backend""" pass @classmethod def get_supported_operators(cls) -> List[OperatorConfig]: - """return the operator list which is supported by the backend""" + """return the OperatorConfig list supported by the backend. + An OperatorConfig is a mapping from QuantizationConfig to a list of operators patterns. + The return value can be used to check: + 1. If a QuantizationConfig is supported by the BackendQuantizer. + 2. For a specific QuantizationConfig, if an operators' pattern is supported by the BackendQuantizer. + """ return [] 3. Annotate common operator patterns @@ -318,7 +323,7 @@ we can use fixed parameters for it with ``FixedQParamsQuantizationSpec``. ) 6. Annotate tensor with derived quantization parameters --------------------------------------------------------- +--------------------------------------------------------------- ``DerivedQuantizationSpec`` is the quantization spec for the Tensors whose quantization parameters are derived from other Tensors. For example, we want to define the ``scale``, ``zp`` for bias derived from activation and weight of convolution node. @@ -632,15 +637,9 @@ to run a example with Torchvision Resnet18. print("converted module is: {}".format(m), flush=True) 8. Conclusion ------------- +--------------------- With this tutorial, we introduce the new quantization path in PyTorch 2.0. Users can learn about how to define a ``BackendQuantizer`` with the ``QuantizationAnnotation API`` and integrate it into the quantization 2.0 flow. Examples of ``QuantizationSpec``, ``SharedQuantizationSpec``, ``FixedQParamsQuantizationSpec``, and ``DerivedQuantizationSpec`` -are given for specific annotation use case. Quantization 2.0 flow is still under active development. If the user wants -to learn more about the design, here are some further reading links. Please contact to `Jerry `__ if you want access to -below links. - -- `Quantization in PyTorch 2.0 Export Detailed Design `__ -- `Quantization Annotation API Design `__ -- `Quantized Model Representation in PyTorch 2.0 Export `__ +are given for specific annotation use case. From 376abc49004599e506237360263cc08380f7cb85 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Fri, 2 Jun 2023 12:43:50 +0800 Subject: [PATCH 23/42] Modify --- ...ization_in_pytorch_2_0_export_tutorial.rst | 250 ++++++++---------- 1 file changed, 116 insertions(+), 134 deletions(-) diff --git a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst index 21f5d42b8a3..8039beab7b0 100644 --- a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst @@ -94,124 +94,19 @@ Taking QNNPackQuantizer as an example, the overall Quantization 2.0 flow could b # Step 4: Lower Reference Quantized Model into the backend -Inside the Quantizer, we will use the `QuantizationAnnotation API `__ +Inside the Quantizer, we will use the ``QuantizationAnnotation API`` to convey user's intent for what quantization spec to use and how to observe certain tensor values in the prepare step. Now, we will have a step-by-step -tutorial for how to use the ``QuantizationAnnotation API`` to create a quantizer. +tutorial for how to use the ``QuantizationAnnotation API`` with different types of +``QuantizationSpec``. -1. Define ``QuantizationConfig`` +1. Annotate common operator patterns -------------------------------------------------------- -`QuantizationConfig `__ -consists of `QuantizationSpec `__ -for activation, weight, and bias separately. Each ``QuantizationSpec`` defines the data type, ``qscheme``, and other quantization parameters used to create the observer. -When annotating the model, methods of -`get_act_qspec `__, -`get_weight_qspec `__ and -`get_bias_qspec `__ -are used to get the ``QuantizationSpec`` from ``QuantizationConfig`` for a specific node. Then corresponding observer will be created -based on this node's ``QuantizationSpec``. For example, we want to use these quantization parameters for activation, weight, and bias: - -- Activation: ``int8`` data type, ``per_tensor_affine`` quantization, ``HistogramObserver`` -- Weight : ``int8`` data type, ``per_channel_symmetric`` quantization, ``PerChannelMinMaxObserver`` -- Bias : ``float`` data type, ``PlaceholderObserver`` - -We can define the ``QuantizationConfig`` as below: - -:: - - def get_symmetric_quantization_config(): - act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = \ - HistogramObserver - act_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=-128, - quant_max=127, - qscheme=torch.per_tensor_affine, - is_dynamic=False, - observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12), - ) - - weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PerChannelMinMaxObserver - extra_args: Dict[str, Any] = {"eps": 2**-12} - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=-127, - quant_max=127, - qscheme=torch.per_channel_symmetric, - ch_axis=0, - is_dynamic=False, - observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(**extra_args), - ) - - bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PlaceholderObserver - bias_quantization_spec = QuantizationSpec( - dtype=torch.float, - observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr - ) - quantization_config = QuantizationConfig( - act_quantization_spec, weight_quantization_spec, bias_quantization_spec - ) - return quantization_config - -2. Define the ``BackendQuantizer`` --------------------------------------------------------- - -Then we will define the skeleton of a ``BackendQuantizer``. The annotatation methods for each operation will be -defined later. - -:: - - class BackendQuantizer(Quantizer): - - def __init__(self): - super().__init__() - self.global_config: QuantizationConfig = None # type: ignore[assignment] - self.operator_type_config: Dict[str, Optional[QuantizationConfig]] = {} - - def set_global(self, quantization_config: QuantizationConfig): - """set global QuantizationConfig used for the backend. - QuantizationConfig is defined in torch/ao/quantization/_pt2e/quantizer/quantizer.py. - """ - self.global_config = quantization_config - return self - - def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: - """annotate nodes in the graph with observer or fake quant constructors - to convey the desired way of quantization. - """ - global_config = self.global_config - self.annotate_symmetric_config(model, global_config) - - return model - - def annotate_symmetric_config( - self, model: torch.fx.GraphModule, config: QuantizationConfig - ) -> torch.fx.GraphModule: - for node in reversed(model.graph.nodes): - # The annotation methods for each op will defined later - pass - return model - - def validate(self, model: torch.fx.GraphModule) -> None: - """validate if the annotated graph is supported by the backend""" - pass - - @classmethod - def get_supported_operators(cls) -> List[OperatorConfig]: - """return the OperatorConfig list supported by the backend. - An OperatorConfig is a mapping from QuantizationConfig to a list of operators patterns. - The return value can be used to check: - 1. If a QuantizationConfig is supported by the BackendQuantizer. - 2. For a specific QuantizationConfig, if an operators' pattern is supported by the BackendQuantizer. - """ - return [] - -3. Annotate common operator patterns --------------------------------------------------------- - -Now we will start to define the annotatation methods inside quantizer. For common operators like ``conv2d``, we can use ``QuantizationSpec`` to -annotate the input, weight, bias, and output. +`QuantizationSpec `__ +is used to annotate common operators like ``conv2d``. It allows user to specify how to quantize +input tensors and output tensor of this operator which includes parameters of ``observer type``, ``dtype``, +``quant_min``, and ``quant_max`` etc. :: @@ -235,18 +130,47 @@ annotate the input, weight, bias, and output. if _is_annotated([conv_node]): continue + act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = \ + HistogramObserver + act_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12), + ) + + weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PerChannelMinMaxObserver + extra_args: Dict[str, Any] = {"eps": 2**-12} + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-127, + quant_max=127, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + is_dynamic=False, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(**extra_args), + ) + + bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PlaceholderObserver + bias_quantization_spec = QuantizationSpec( + dtype=torch.float, + observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr + ) + input_qspec_map = {} input_act = conv_node.args[0] assert isinstance(input_act, Node) - input_qspec_map[input_act] = get_act_qspec(quantization_config) + input_qspec_map[input_act] = act_quantization_spec weight = conv_node.args[1] assert isinstance(weight, Node) - input_qspec_map[weight] = get_weight_qspec(quantization_config) + input_qspec_map[weight] = weight_quantization_spec bias = conv_node.args[2] if isinstance(bias, Node): - input_qspec_map[bias] = get_bias_qspec(quantization_config) + input_qspec_map[bias] = bias_quantization_spec conv_node.meta["quantization_annotation"] = QuantizationAnnotation( input_qspec_map=input_qspec_map, @@ -254,12 +178,16 @@ annotate the input, weight, bias, and output. _annotated=True, ) -4. Annotate sharing qparams operators +2. Annotate sharing qparams operators -------------------------------------------------------- -For operator such as ``add`` and ``cat``, which we want the two inputs sharing -quantization parameters, we can use the ``SharedQuantizationSpec`` to make the two inputs -sharing the same quantization parameters. +`SharedQuantizationSpec `__ +is used to annotate tensors whose quantization parameters are shared with other tensors. +As example, for operators like ``add`` and ``cat``, which user may want two input tensors +sharing quantization parameters. Then user can use the ``SharedQuantizationSpec`` to annotate +this operator. Input of ``SharedQuantizationSpec`` can be an input edge or an output value. +Input edge is the connection between input node and the node consuming the input, so it's a +Tuple[Node, Node]. Output value is an fx Node. :: @@ -272,7 +200,18 @@ sharing the same quantization parameters. add_node = add_partition.output_nodes[0] if _is_annotated([add_node]): continue - act_qspec = get_act_qspec(quantization_config) + + act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = \ + HistogramObserver + act_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12), + ) + act_qspec = act_quantization_spec input_qspec_map = {} input_act0 = add_node.args[0] @@ -288,11 +227,13 @@ sharing the same quantization parameters. _annotated=True, ) -5. Annotate fixed qparams operators +3. Annotate fixed qparams operators -------------------------------------------------------- -For operator such as ``sigmoid``, which has predefined and fixed scale/zero_point, -we can use fixed parameters for it with ``FixedQParamsQuantizationSpec``. +`FixedQParamsQuantizationSpec `__ +is a quantization spec for tensors whose quantization parmaters are known beforehand. +For example, operator like ``sigmoid``, which has predefined and fixed scale/zero_point +at input and output tensors. We can annotate it with ``FixedQParamsQuantizationSpec``. :: @@ -322,11 +263,14 @@ we can use fixed parameters for it with ``FixedQParamsQuantizationSpec``. _annotated=True, ) -6. Annotate tensor with derived quantization parameters +4. Annotate tensor with derived quantization parameters --------------------------------------------------------------- -``DerivedQuantizationSpec`` is the quantization spec for the Tensors whose quantization parameters are derived from other Tensors. -For example, we want to define the ``scale``, ``zp`` for bias derived from activation and weight of convolution node. +`DerivedQuantizationSpec `__ +is for the tensors whose quantization parameters are derived from other tensors. For example, +if we want to annotate a convolution node, and define the ``scale``, ``zp`` of its bias input tensor +as derived from the activation and weight tensors. We can use ``DerivedQuantizationSpec`` to annotate +this bias tensor. :: @@ -342,8 +286,30 @@ For example, we want to define the ``scale``, ``zp`` for bias derived from activ input_act = node.args[0] weight = node.args[1] bias = node.args[2] - act_qspec = get_act_qspec(quantization_config) - weight_qspec = get_weight_qspec(quantization_config) + + act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = \ + HistogramObserver + act_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12), + ) + weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PerChannelMinMaxObserver + extra_args: Dict[str, Any] = {"eps": 2**-12} + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-127, + quant_max=127, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + is_dynamic=False, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(**extra_args), + ) + act_qspec = act_quantization_spec + weight_qspec = weight_quantization_spec def derive_qparams_fn(obs_or_fqs: List[ObserverOrFakeQuantize]) -> Tuple[Tensor, Tensor]: assert len(obs_or_fqs) == 2, \ @@ -369,11 +335,21 @@ For example, we want to define the ``scale``, ``zp`` for bias derived from activ _annotated=True, ) -7. A Toy Example with Resnet18 +5. A Toy Example with Resnet18 -------------------------------------------------------- -After above annotation methods defined with ``QuantizationAnnotation API``, we can now put them together for the ``BackendQuantizer`` -to run a example with Torchvision Resnet18. +After above annotation methods defined with ``QuantizationAnnotation API``, we can now put them together to construct a ``BackendQuantizer`` +to run a example with Torchvision Resnet18. Here are some basic concepts before we move on to this example: + +- `QuantizationSpec `__ +defines the ``data type``, ``qscheme``, and other quantization parameters used to quantize a tensor. +- `QuantizationConfig `__ +consists of ``QuantizationSpec`` for activation, weight, and bias separately. +- When annotating the model, methods of +`get_act_qspec `__, +`get_weight_qspec `__ and +`get_bias_qspec `__ +can be used to get the ``QuantizationSpec`` from ``QuantizationConfig`` for a specific node. .. code:: ipython3 @@ -439,11 +415,16 @@ to run a example with Torchvision Resnet18. self.operator_type_config: Dict[str, Optional[QuantizationConfig]] = {} def set_global(self, quantization_config: QuantizationConfig): + """set global QuantizationConfig used for the backend. + QuantizationConfig is defined in torch/ao/quantization/_pt2e/quantizer/quantizer.py. + """ self.global_config = quantization_config return self def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: - """just handling global spec for now""" + """annotate nodes in the graph with observer or fake quant constructors + to convey the desired way of quantization. + """ global_config = self.global_config self.annotate_symmetric_config(model, global_config) @@ -578,6 +559,7 @@ to run a example with Torchvision Resnet18. ) def validate(self, model: torch.fx.GraphModule) -> None: + """validate if the annotated graph is supported by the backend""" pass @classmethod @@ -636,7 +618,7 @@ to run a example with Torchvision Resnet18. m = convert_pt2e(m) print("converted module is: {}".format(m), flush=True) -8. Conclusion +6. Conclusion --------------------- With this tutorial, we introduce the new quantization path in PyTorch 2.0. Users can learn about From b809ba4d85ac382b027e868db72bee57af071467 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Fri, 2 Jun 2023 13:31:12 +0800 Subject: [PATCH 24/42] format document --- ...tization_in_pytorch_2_0_export_tutorial.rst | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst index 8039beab7b0..c8c773ebb5e 100644 --- a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst @@ -341,15 +341,15 @@ this bias tensor. After above annotation methods defined with ``QuantizationAnnotation API``, we can now put them together to construct a ``BackendQuantizer`` to run a example with Torchvision Resnet18. Here are some basic concepts before we move on to this example: -- `QuantizationSpec `__ -defines the ``data type``, ``qscheme``, and other quantization parameters used to quantize a tensor. -- `QuantizationConfig `__ -consists of ``QuantizationSpec`` for activation, weight, and bias separately. -- When annotating the model, methods of -`get_act_qspec `__, -`get_weight_qspec `__ and -`get_bias_qspec `__ -can be used to get the ``QuantizationSpec`` from ``QuantizationConfig`` for a specific node. +- `QuantizationSpec `__ + defines the ``data type``, ``qscheme``, and other quantization parameters used to quantize a tensor. +- `QuantizationConfig `__ + consists of ``QuantizationSpec`` for activation, weight, and bias separately. +- When annotating the model, methods of + `get_act_qspec `__, + `get_weight_qspec `__ and + `get_bias_qspec `__ + can be used to get the ``QuantizationSpec`` from ``QuantizationConfig`` for a specific node. .. code:: ipython3 From b4782b2c8fef38fc53ff704465b3398e02d0e7e9 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Fri, 2 Jun 2023 14:05:23 +0800 Subject: [PATCH 25/42] format --- .../quantization_in_pytorch_2_0_export_tutorial.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst index c8c773ebb5e..4140761be28 100644 --- a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst @@ -347,7 +347,7 @@ to run a example with Torchvision Resnet18. Here are some basic concepts before consists of ``QuantizationSpec`` for activation, weight, and bias separately. - When annotating the model, methods of `get_act_qspec `__, - `get_weight_qspec `__ and + `get_weight_qspec `__, and `get_bias_qspec `__ can be used to get the ``QuantizationSpec`` from ``QuantizationConfig`` for a specific node. From 3402a8ea1ea4ae200cde3bd821757a6c39616246 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Sat, 3 Jun 2023 14:32:15 +0800 Subject: [PATCH 26/42] update document --- ...ization_in_pytorch_2_0_export_tutorial.rst | 133 +++++++++--------- 1 file changed, 68 insertions(+), 65 deletions(-) diff --git a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst index 4140761be28..2f47a7d0945 100644 --- a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst @@ -103,31 +103,33 @@ tutorial for how to use the ``QuantizationAnnotation API`` with different types 1. Annotate common operator patterns -------------------------------------------------------- +In order to use the quantized operators, e.g. ``quantized add``, +backend developers will have intent to quantize (as expressed by `QuantizationSpec `__ -is used to annotate common operators like ``conv2d``. It allows user to specify how to quantize -input tensors and output tensor of this operator which includes parameters of ``observer type``, ``dtype``, -``quant_min``, and ``quant_max`` etc. +) input, output of the operator. Following is an example flow (with ``add``) +of how this intent is conveyed in the quantization workflow with node annotation API. + +- Step1: Identify the original floating point ``add`` node in the FX graph. There are + several ways to identify this node: 1. User may use a pattern matcher (e.g. SubgraphMatcher) + to match the operator pattern. 2. User may go through the nodes from start to the end and compare + the node's target type. +- Step2: Define the ``QuantizationSpec`` for two inputs and one output of the ``add`` node to specify + how to quantize input tensors and output tensor which includes parameters of ``observer type``, + ``dtype``, ``quant_min``, and ``quant_max`` etc. +- Step3: Annotate the inputs and output of the ``add`` node. User will create the ``QuantizationAnnotation`` + object and add it into ``add`` node's ``meta`` property. :: - def _annotate_conv2d( + def _annotate_add( self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig ) -> None: - conv_partitions = get_source_partitions( - gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d] - ) - conv_partitions = list(itertools.chain(*conv_partitions.values())) - for conv_partition in conv_partitions: - if len(conv_partition.output_nodes) > 1: - raise ValueError("conv partition has more than one output node") - conv_node = conv_partition.output_nodes[0] - if ( - conv_node.op != "call_function" - or conv_node.target != torch.ops.aten.convolution.default - ): - raise ValueError(f"{conv_node} is not an aten conv2d operator") - # skip annotation if it is already annotated - if _is_annotated([conv_node]): + # Step1: Identify the ``add`` node in the original floating point FX graph. + add_partitions = get_source_partitions(gm.graph, [operator.add, torch.add]) + add_partitions = list(itertools.chain(*add_partitions.values())) + for add_partition in add_partitions: + add_node = add_partition.output_nodes[0] + if _is_annotated([add_node]): continue act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = \ @@ -141,53 +143,47 @@ input tensors and output tensor of this operator which includes parameters of `` observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12), ) - weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PerChannelMinMaxObserver - extra_args: Dict[str, Any] = {"eps": 2**-12} - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=-127, - quant_max=127, - qscheme=torch.per_channel_symmetric, - ch_axis=0, - is_dynamic=False, - observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(**extra_args), - ) - - bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PlaceholderObserver - bias_quantization_spec = QuantizationSpec( - dtype=torch.float, - observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr - ) + # Step2: The ``add`` node has two inputs and one output. We define the ``QuantizationSpec`` + # for each input and output. + input_act_qspec = act_quantization_spec + output_act_qspec = act_quantization_spec input_qspec_map = {} - input_act = conv_node.args[0] - assert isinstance(input_act, Node) - input_qspec_map[input_act] = act_quantization_spec - - weight = conv_node.args[1] - assert isinstance(weight, Node) - input_qspec_map[weight] = weight_quantization_spec + input_act0 = add_node.args[0] + if isinstance(input_act0, Node): + input_qspec_map[input_act0] = input_act_qspec - bias = conv_node.args[2] - if isinstance(bias, Node): - input_qspec_map[bias] = bias_quantization_spec + input_act1 = add_node.args[1] + if isinstance(input_act1, Node): + input_qspec_map[input_act1] = input_act_qspec - conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + # Step3: Annotate the inputs and outputs of the ``add`` node. + add_node.meta["quantization_annotation"] = QuantizationAnnotation( input_qspec_map=input_qspec_map, - output_qspec=get_act_qspec(quantization_config), + output_qspec=output_act_qspec, _annotated=True, ) 2. Annotate sharing qparams operators -------------------------------------------------------- -`SharedQuantizationSpec `__ -is used to annotate tensors whose quantization parameters are shared with other tensors. -As example, for operators like ``add`` and ``cat``, which user may want two input tensors -sharing quantization parameters. Then user can use the ``SharedQuantizationSpec`` to annotate -this operator. Input of ``SharedQuantizationSpec`` can be an input edge or an output value. -Input edge is the connection between input node and the node consuming the input, so it's a -Tuple[Node, Node]. Output value is an fx Node. +It is natural that users want to annotate a quantized model where quantization +parameters can be shared among some tensors explicitly. Two typical use cases are: + +- Example 1: One example is for ``add`` where having both inputs sharing quantization + parameters makes operator implementation much easier. Without using of + `SharedQuantizationSpec `__, + we have to annotate ``add`` as example in above section 1, in which two inputs of ``add`` + has different quantization parameters. +- Example 2: Another example is that of sharing quantization parameters between inputs and output. + This typically results from operators such as ``maxpool``, ``average_pool``, ``concat`` etc. + +``SharedQuantizationSpec`` is designed for this use case to annotate tensors whose quantization +parameters are shared with other tensors. Input of ``SharedQuantizationSpec`` can be an input edge +or an output value. Input edge is the connection between input node and the node consuming the input, +so it's a Tuple[Node, Node]. Output value is an fx Node. + +Now, we have a example to rewrite ``add`` annotation example with ``SharedQuantizationSpec``. :: @@ -230,10 +226,12 @@ Tuple[Node, Node]. Output value is an fx Node. 3. Annotate fixed qparams operators -------------------------------------------------------- +Another typical use case to annotate a quantized model is for tensors whose +quantization parmaters are known beforehand. For example, operator like ``sigmoid``, which has +predefined and fixed scale/zero_point at input and output tensors. `FixedQParamsQuantizationSpec `__ -is a quantization spec for tensors whose quantization parmaters are known beforehand. -For example, operator like ``sigmoid``, which has predefined and fixed scale/zero_point -at input and output tensors. We can annotate it with ``FixedQParamsQuantizationSpec``. +is designed for this use case. To use ``FixedQParamsQuantizationSpec``, users need to pass in parameters +of ``scale`` and ``zero_point`` explicitly. :: @@ -266,8 +264,9 @@ at input and output tensors. We can annotate it with ``FixedQParamsQuantizationS 4. Annotate tensor with derived quantization parameters --------------------------------------------------------------- +We also need to define the constraint that the scale of bias is a product of input scale and weight scale in the annotation API. `DerivedQuantizationSpec `__ -is for the tensors whose quantization parameters are derived from other tensors. For example, +is designed for this use case where a tensor's quantization parameters is derived from other tensors. For example, if we want to annotate a convolution node, and define the ``scale``, ``zp`` of its bias input tensor as derived from the activation and weight tensors. We can use ``DerivedQuantizationSpec`` to annotate this bias tensor. @@ -363,7 +362,8 @@ to run a example with Torchvision Resnet18. Here are some basic concepts before from torch.ao.quantization._pt2e.quantizer.utils import ( _annotate_input_qspec_map, _annotate_output_qspec, - get_act_qspec, + get_input_act_qspec, + get_output_act_qspec, get_bias_qspec, get_weight_qspec, ) @@ -461,7 +461,7 @@ to run a example with Torchvision Resnet18. Here are some basic concepts before input_qspec_map = {} input_act = conv_node.args[0] assert isinstance(input_act, Node) - input_qspec_map[input_act] = get_act_qspec(quantization_config) + input_qspec_map[input_act] = get_input_act_qspec(quantization_config) weight = conv_node.args[1] assert isinstance(weight, Node) @@ -473,7 +473,7 @@ to run a example with Torchvision Resnet18. Here are some basic concepts before conv_node.meta["quantization_annotation"] = QuantizationAnnotation( input_qspec_map=input_qspec_map, - output_qspec=get_act_qspec(quantization_config), + output_qspec=get_output_act_qspec(quantization_config), _annotated=True, ) @@ -483,7 +483,7 @@ to run a example with Torchvision Resnet18. Here are some basic concepts before module_partitions = get_source_partitions( gm.graph, [torch.nn.Linear, torch.nn.functional.linear] ) - act_qspec = get_act_qspec(quantization_config) + act_qspec = get_input_act_qspec(quantization_config) weight_qspec = get_weight_qspec(quantization_config) bias_qspec = get_bias_qspec(quantization_config) for module_or_fn_type, partitions in module_partitions.items(): @@ -545,7 +545,7 @@ to run a example with Torchvision Resnet18. Here are some basic concepts before input_act = maxpool_node.args[0] # type: ignore[union-attr] assert isinstance(input_act, Node) - act_qspec = get_act_qspec(quantization_config) + act_qspec = get_input_act_qspec(quantization_config) maxpool_node.meta["quantization_annotation"] = QuantizationAnnotation( # type: ignore[union-attr] input_qspec_map={ input_act: act_qspec, @@ -596,7 +596,10 @@ to run a example with Torchvision Resnet18. Here are some basic concepts before observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr ) quantization_config = QuantizationConfig( - act_quantization_spec, weight_quantization_spec, bias_quantization_spec + act_quantization_spec, + act_quantization_spec, + weight_quantization_spec, + bias_quantization_spec, ) return quantization_config From 7a997a26a1b2a4f2dea5c863b8b2d910d466861a Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Wed, 7 Jun 2023 13:52:04 +0800 Subject: [PATCH 27/42] add introduction of Quantizer --- ...ization_in_pytorch_2_0_export_tutorial.rst | 37 ++++++++++++++++--- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst index 2f47a7d0945..97da2817ba0 100644 --- a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst @@ -18,6 +18,34 @@ Prerequisites: - `Understanding of FX graph mode post training static quantization `__ - `Understanding of torchdynamo concepts in PyTorch `__ +Previously in ``FX Graph Mode Quantization`` we were using ``QConfigMapping`` for users to specify how the model to be quantized +and ``BackendConfig`` to specify the supported ways of quantization in their backend. +This API covers most use cases relatively well, but the main problem is that this API is not fully extensible +with two main limitations: + +- Limitation around expressing quantization intentions for complicated operator patterns such as in the + `discussion `__ to support `conv add` fusion with oneDNN library. + It also requires some changes to current already complicated pattern matching code such as in the + `PR `__ to support `conv add` fusion. +- Limitation around supporting user's advanced intention to quantize their model. For example, ``FX Graph Mode Quantization`` + doesn't support this quantization intention: only quantize inputs and outputs when the ``linear`` has a third input. + +To address these scalability issues, +`Quantizer `__ +is introduced for quantization in PyTorch 2.0 export. ``Quantizer`` is a class that users can use to +programmably set the observer or fake quant objects for each node in the model graph. It adds flexibility +to the quantization API and allows modeling users and backend developers to configure quantization programmatically. +This will allow users to express how they want an operator pattern to be observed in a more explicit +way by annotating the appropriate nodes. To define a backend specific quantizer, user mainly need to override +several APIs: + +- `annotate method `__ + is used to annotate nodes in the graph with observer or fake quant constructors to convey the desired way of quantization. +- `validate method `__ + is used to validate if the annotated graph is supported by the backend. +- `set_global method `__ + is used to set the global ``QuantizationConfig`` object for this quantizer to specify how the model will be quantized. + Imagine a backend developer who wishes to integrate a third-party backend with PyTorch's quantization 2.0 flow. To accomplish this, they would only need to define the backend specific quantizer. The high level architecture of @@ -264,12 +292,11 @@ of ``scale`` and ``zero_point`` explicitly. 4. Annotate tensor with derived quantization parameters --------------------------------------------------------------- -We also need to define the constraint that the scale of bias is a product of input scale and weight scale in the annotation API. +We also may need to define the constraint for tensors whose quantization parameters are derived from other tensors. +For example, if we want to annotate a convolution node, and define the ``scale`` of its bias input tensor +as product of the activation tensor's ``scale`` and weight tensor's ``scale``. We can use `DerivedQuantizationSpec `__ -is designed for this use case where a tensor's quantization parameters is derived from other tensors. For example, -if we want to annotate a convolution node, and define the ``scale``, ``zp`` of its bias input tensor -as derived from the activation and weight tensors. We can use ``DerivedQuantizationSpec`` to annotate -this bias tensor. +to annotate this bias tensor. :: From 41b27d9fb3afb17bf6816663dc7b47e1714c6df7 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Wed, 7 Jun 2023 14:00:05 +0800 Subject: [PATCH 28/42] remove final example --- ...ization_in_pytorch_2_0_export_tutorial.rst | 282 +----------------- 1 file changed, 7 insertions(+), 275 deletions(-) diff --git a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst index 97da2817ba0..5b90a7159cd 100644 --- a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst @@ -137,14 +137,14 @@ backend developers will have intent to quantize (as expressed by ) input, output of the operator. Following is an example flow (with ``add``) of how this intent is conveyed in the quantization workflow with node annotation API. -- Step1: Identify the original floating point ``add`` node in the FX graph. There are +- Step 1: Identify the original floating point ``add`` node in the FX graph. There are several ways to identify this node: 1. User may use a pattern matcher (e.g. SubgraphMatcher) to match the operator pattern. 2. User may go through the nodes from start to the end and compare the node's target type. -- Step2: Define the ``QuantizationSpec`` for two inputs and one output of the ``add`` node to specify +- Step 2: Define the ``QuantizationSpec`` for two inputs and one output of the ``add`` node to specify how to quantize input tensors and output tensor which includes parameters of ``observer type``, ``dtype``, ``quant_min``, and ``quant_max`` etc. -- Step3: Annotate the inputs and output of the ``add`` node. User will create the ``QuantizationAnnotation`` +- Step 3: Annotate the inputs and output of the ``add`` node. User will create the ``QuantizationAnnotation`` object and add it into ``add`` node's ``meta`` property. :: @@ -364,8 +364,7 @@ to annotate this bias tensor. 5. A Toy Example with Resnet18 -------------------------------------------------------- -After above annotation methods defined with ``QuantizationAnnotation API``, we can now put them together to construct a ``BackendQuantizer`` -to run a example with Torchvision Resnet18. Here are some basic concepts before we move on to this example: +To better understand the final example, here are some basic concepts before we move on to this part: - `QuantizationSpec `__ defines the ``data type``, ``qscheme``, and other quantization parameters used to quantize a tensor. @@ -377,276 +376,9 @@ to run a example with Torchvision Resnet18. Here are some basic concepts before `get_bias_qspec `__ can be used to get the ``QuantizationSpec`` from ``QuantizationConfig`` for a specific node. -.. code:: ipython3 - - import copy - import itertools - import operator - from typing import Callable, Dict, List, Optional, Set, Any - - import torch - import torch._dynamo as torchdynamo - from torch.ao.quantization._pt2e.quantizer.utils import ( - _annotate_input_qspec_map, - _annotate_output_qspec, - get_input_act_qspec, - get_output_act_qspec, - get_bias_qspec, - get_weight_qspec, - ) - - from torch.fx import Node - - from torch.fx.passes.utils.source_matcher_utils import get_source_partitions - - from torch.ao.quantization._pt2e.quantizer.quantizer import ( - OperatorConfig, - QuantizationConfig, - QuantizationSpec, - Quantizer, - QuantizationAnnotation, - ) - from torch.ao.quantization.observer import ( - HistogramObserver, - PerChannelMinMaxObserver, - PlaceholderObserver, - ) - from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor - import torchvision - from torch.ao.quantization._quantize_pt2e import ( - convert_pt2e, - prepare_pt2e_quantizer, - ) - - def _mark_nodes_as_annotated(nodes: List[Node]): - for node in nodes: - if node is not None: - if "quantization_annotation" not in node.meta: - node.meta["quantization_annotation"] = QuantizationAnnotation() - node.meta["quantization_annotation"]._annotated = True - - def _is_annotated(nodes: List[Node]): - annotated = False - for node in nodes: - annotated = annotated or ( - "quantization_annotation" in node.meta - and node.meta["quantization_annotation"]._annotated - ) - return annotated - - class BackendQuantizer(Quantizer): - - def __init__(self): - super().__init__() - self.global_config: QuantizationConfig = None # type: ignore[assignment] - self.operator_type_config: Dict[str, Optional[QuantizationConfig]] = {} - - def set_global(self, quantization_config: QuantizationConfig): - """set global QuantizationConfig used for the backend. - QuantizationConfig is defined in torch/ao/quantization/_pt2e/quantizer/quantizer.py. - """ - self.global_config = quantization_config - return self - - def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: - """annotate nodes in the graph with observer or fake quant constructors - to convey the desired way of quantization. - """ - global_config = self.global_config - self.annotate_symmetric_config(model, global_config) - - return model - - def annotate_symmetric_config( - self, model: torch.fx.GraphModule, config: QuantizationConfig - ) -> torch.fx.GraphModule: - self._annotate_linear(model, config) - self._annotate_conv2d(model, config) - self._annotate_maxpool2d(model, config) - return model - - def _annotate_conv2d( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig - ) -> None: - conv_partitions = get_source_partitions( - gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d] - ) - conv_partitions = list(itertools.chain(*conv_partitions.values())) - for conv_partition in conv_partitions: - if len(conv_partition.output_nodes) > 1: - raise ValueError("conv partition has more than one output node") - conv_node = conv_partition.output_nodes[0] - if ( - conv_node.op != "call_function" - or conv_node.target != torch.ops.aten.convolution.default - ): - raise ValueError(f"{conv_node} is not an aten conv2d operator") - # skip annotation if it is already annotated - if _is_annotated([conv_node]): - continue - - input_qspec_map = {} - input_act = conv_node.args[0] - assert isinstance(input_act, Node) - input_qspec_map[input_act] = get_input_act_qspec(quantization_config) - - weight = conv_node.args[1] - assert isinstance(weight, Node) - input_qspec_map[weight] = get_weight_qspec(quantization_config) - - bias = conv_node.args[2] - if isinstance(bias, Node): - input_qspec_map[bias] = get_bias_qspec(quantization_config) - - conv_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=get_output_act_qspec(quantization_config), - _annotated=True, - ) - - def _annotate_linear( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig - ) -> None: - module_partitions = get_source_partitions( - gm.graph, [torch.nn.Linear, torch.nn.functional.linear] - ) - act_qspec = get_input_act_qspec(quantization_config) - weight_qspec = get_weight_qspec(quantization_config) - bias_qspec = get_bias_qspec(quantization_config) - for module_or_fn_type, partitions in module_partitions.items(): - if module_or_fn_type == torch.nn.Linear: - for p in partitions: - act_node = p.input_nodes[0] - output_node = p.output_nodes[0] - weight_node = None - bias_node = None - for node in p.params: - weight_or_bias = getattr(gm, node.target) # type: ignore[arg-type] - if weight_or_bias.ndim == 2: # type: ignore[attr-defined] - weight_node = node - if weight_or_bias.ndim == 1: # type: ignore[attr-defined] - bias_node = node - if weight_node is None: - raise ValueError("No weight found in Linear pattern") - # find use of act node within the matched pattern - act_use_node = None - for node in p.nodes: - if node in act_node.users: # type: ignore[union-attr] - act_use_node = node - break - if act_use_node is None: - raise ValueError( - "Could not find an user of act node within matched pattern." - ) - if _is_annotated([act_use_node]) is False: # type: ignore[list-item] - _annotate_input_qspec_map( - act_use_node, - act_node, - act_qspec, - ) - if bias_node and _is_annotated([bias_node]) is False: - _annotate_output_qspec(bias_node, bias_qspec) - if _is_annotated([weight_node]) is False: # type: ignore[list-item] - _annotate_output_qspec(weight_node, weight_qspec) - if _is_annotated([output_node]) is False: - _annotate_output_qspec(output_node, act_qspec) - nodes_to_mark_annotated = list(p.nodes) - _mark_nodes_as_annotated(nodes_to_mark_annotated) - - def _annotate_maxpool2d( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig - ) -> None: - module_partitions = get_source_partitions( - gm.graph, [torch.nn.MaxPool2d, torch.nn.functional.max_pool2d] - ) - maxpool_partitions = list(itertools.chain(*module_partitions.values())) - for maxpool_partition in maxpool_partitions: - output_node = maxpool_partition.output_nodes[0] - maxpool_node = None - for n in maxpool_partition.nodes: - if n.target == torch.ops.aten.max_pool2d_with_indices.default: - maxpool_node = n - if _is_annotated([output_node, maxpool_node]): # type: ignore[list-item] - continue - - input_act = maxpool_node.args[0] # type: ignore[union-attr] - assert isinstance(input_act, Node) - - act_qspec = get_input_act_qspec(quantization_config) - maxpool_node.meta["quantization_annotation"] = QuantizationAnnotation( # type: ignore[union-attr] - input_qspec_map={ - input_act: act_qspec, - }, - _annotated=True, - ) - output_node.meta["quantization_annotation"] = QuantizationAnnotation( - output_qspec=act_qspec, - _input_output_share_observers=True, - _annotated=True, - ) - - def validate(self, model: torch.fx.GraphModule) -> None: - """validate if the annotated graph is supported by the backend""" - pass - - @classmethod - def get_supported_operators(cls) -> List[OperatorConfig]: - return [] - - def get_symmetric_quantization_config(): - act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = \ - HistogramObserver - act_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=-128, - quant_max=127, - qscheme=torch.per_tensor_affine, - is_dynamic=False, - observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12), - ) - - weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PerChannelMinMaxObserver - extra_args: Dict[str, Any] = {"eps": 2**-12} - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=-127, - quant_max=127, - qscheme=torch.per_channel_symmetric, - ch_axis=0, - is_dynamic=False, - observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(**extra_args), - ) - - bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PlaceholderObserver - bias_quantization_spec = QuantizationSpec( - dtype=torch.float, - observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr - ) - quantization_config = QuantizationConfig( - act_quantization_spec, - act_quantization_spec, - weight_quantization_spec, - bias_quantization_spec, - ) - return quantization_config - - if __name__ == "__main__": - example_inputs = (torch.randn(1, 3, 224, 224),) - m = torchvision.models.resnet18().eval() - m_copy = copy.deepcopy(m) - # program capture - m, guards = torchdynamo.export( - m, - *copy.deepcopy(example_inputs), - aten_graph=True, - ) - quantizer = BackendQuantizer() - operator_config = get_symmetric_quantization_config() - quantizer.set_global(operator_config) - m = prepare_pt2e_quantizer(m, quantizer) - after_prepare_result = m(*example_inputs) - m = convert_pt2e(m) - print("converted module is: {}".format(m), flush=True) +After above annotation methods defined with ``QuantizationAnnotation API``, we can now put them together to construct a ``BackendQuantizer`` +to run a `toy example `__ +with Torchvision Resnet18. 6. Conclusion --------------------- From 09f6a54916a67c19f8fc5e4a78430a56dd765615 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Wed, 7 Jun 2023 14:03:05 +0800 Subject: [PATCH 29/42] add note of prepare_pt2e_quantizer --- ...ization_in_pytorch_2_0_export_tutorial.rst | 291 +++++++----------- 1 file changed, 113 insertions(+), 178 deletions(-) diff --git a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst index 5b90a7159cd..083c0b50816 100644 --- a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst @@ -82,6 +82,8 @@ quantization 2.0 with quantizer could look like this: | Executorch, or Inductor, or +Note: ``prepare_pt2e_quantizer`` will be updated to ``prepare_pt2e`` soon. + An existing quantizer object defined for QNNPack/XNNPack is located in `QNNPackQuantizer `__. Taking QNNPackQuantizer as an example, the overall Quantization 2.0 flow could be: @@ -131,66 +133,71 @@ tutorial for how to use the ``QuantizationAnnotation API`` with different types 1. Annotate common operator patterns -------------------------------------------------------- -In order to use the quantized operators, e.g. ``quantized add``, +In order to use the quantized pattern/operators, e.g. ``quantized add``, backend developers will have intent to quantize (as expressed by `QuantizationSpec `__ -) input, output of the operator. Following is an example flow (with ``add``) -of how this intent is conveyed in the quantization workflow with node annotation API. +) inputs, output of the pattern. Following is an example flow (take ``add`` operator as example) +of how this intent is conveyed in the quantization workflow with annotation API. -- Step 1: Identify the original floating point ``add`` node in the FX graph. There are - several ways to identify this node: 1. User may use a pattern matcher (e.g. SubgraphMatcher) +- Step 1: Identify the original floating point pattern in the FX graph. There are + several ways to identify this pattern: 1. User may use a pattern matcher (e.g. SubgraphMatcher) to match the operator pattern. 2. User may go through the nodes from start to the end and compare - the node's target type. -- Step 2: Define the ``QuantizationSpec`` for two inputs and one output of the ``add`` node to specify - how to quantize input tensors and output tensor which includes parameters of ``observer type``, - ``dtype``, ``quant_min``, and ``quant_max`` etc. -- Step 3: Annotate the inputs and output of the ``add`` node. User will create the ``QuantizationAnnotation`` - object and add it into ``add`` node's ``meta`` property. + the node's target type to match the operator pattern. In this example, we use the + `get_source_partitions `__ + to match this pattern. The original floating point ``add`` pattern only contain a single ``add`` node. + +:: + + add_partitions = get_source_partitions(gm.graph, [operator.add, torch.add]) + add_partitions = list(itertools.chain(*add_partitions.values())) + for add_partition in add_partitions: + add_node = add_partition.output_nodes[0] + +- Step 2: Define the ``QuantizationSpec`` for inputs and output of the pattern. ``QuantizationSpec`` + defines the ``data type``, ``qscheme``, and other quantization parameters used to quantize a tensor. + In this example, the ``add`` pattern has two input tensors and one output tensor. We will define the ``QuantizationSpec`` + for each of the input or output tensor to specify how to quantize it. + +:: + + act_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12), + ) + + input_act_qspec = act_quantization_spec + output_act_qspec = act_quantization_spec + +- Step 3: Annotate the inputs and output of the pattern with + `QuantizationAnnotation `__ + . ``QuantizationAnnotation`` is a ``dataclass`` with several fields as: 1. ``input_qspec_map`` field is ``Dict`` + to map each input ``Node`` to a ``QuantizationSpec``. 2. ``output_qspec`` field expresses the ``QuantizationSpec`` used for + output node. 3. ``_annotated`` field indicates if this node has already been annotated by quantizer. + In this example, we will create the ``QuantizationAnnotation``object with the ``QuantizationSpec`` objects + created in above step 2. :: - def _annotate_add( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig - ) -> None: - # Step1: Identify the ``add`` node in the original floating point FX graph. - add_partitions = get_source_partitions(gm.graph, [operator.add, torch.add]) - add_partitions = list(itertools.chain(*add_partitions.values())) - for add_partition in add_partitions: - add_node = add_partition.output_nodes[0] - if _is_annotated([add_node]): - continue - - act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = \ - HistogramObserver - act_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=-128, - quant_max=127, - qscheme=torch.per_tensor_affine, - is_dynamic=False, - observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12), - ) - - # Step2: The ``add`` node has two inputs and one output. We define the ``QuantizationSpec`` - # for each input and output. - input_act_qspec = act_quantization_spec - output_act_qspec = act_quantization_spec - - input_qspec_map = {} - input_act0 = add_node.args[0] - if isinstance(input_act0, Node): - input_qspec_map[input_act0] = input_act_qspec - - input_act1 = add_node.args[1] - if isinstance(input_act1, Node): - input_qspec_map[input_act1] = input_act_qspec - - # Step3: Annotate the inputs and outputs of the ``add`` node. - add_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=output_act_qspec, - _annotated=True, - ) + input_qspec_map = {} + input_act0 = add_node.args[0] + input_qspec_map[input_act0] = input_act_qspec + + input_act1 = add_node.args[1] + input_qspec_map[input_act1] = input_act_qspec + + add_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + +After we annotate the ``add`` node like this, in the following up quantization flow, ``HistogramObserver`` will +be inserted at its two input nodes and one output node in prepare phase. And ``HistogramObserver`` will be substituted with +``quantize`` node and ``dequantize`` node in the convert phase. 2. Annotate sharing qparams operators -------------------------------------------------------- @@ -208,48 +215,27 @@ parameters can be shared among some tensors explicitly. Two typical use cases ar ``SharedQuantizationSpec`` is designed for this use case to annotate tensors whose quantization parameters are shared with other tensors. Input of ``SharedQuantizationSpec`` can be an input edge -or an output value. Input edge is the connection between input node and the node consuming the input, -so it's a Tuple[Node, Node]. Output value is an fx Node. +or an output value. -Now, we have a example to rewrite ``add`` annotation example with ``SharedQuantizationSpec``. +- Input edge is the connection between input node and the node consuming the input, + so it's a Tuple[Node, Node]. +- Output value is an fx Node. + +Now, If we want to rewrite ``add`` annotation example with ``SharedQuantizationSpec`` to indicate +two input tensors as sharing quantization parameters. We can define its ``QuantizationAnnotation`` +as this: :: - def _annotate_add( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig - ) -> None: - add_partitions = get_source_partitions(gm.graph, [operator.add, torch.add]) - add_partitions = list(itertools.chain(*add_partitions.values())) - for add_partition in add_partitions: - add_node = add_partition.output_nodes[0] - if _is_annotated([add_node]): - continue - - act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = \ - HistogramObserver - act_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=-128, - quant_max=127, - qscheme=torch.per_tensor_affine, - is_dynamic=False, - observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12), - ) - act_qspec = act_quantization_spec - - input_qspec_map = {} - input_act0 = add_node.args[0] - input_act1 = add_node.args[1] - - share_qparams_with_input_act0_qspec = SharedQuantizationSpec((input_act0, add_node)) - - input_qspec_map = {input_act0: act_qspec, input_act1: share_qparams_with_input_act0_qspec} - - add_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=act_qspec, - _annotated=True, - ) + input_qspec_map = {} + share_qparams_with_input_act0_qspec = SharedQuantizationSpec((input_act0, add_node)) + input_qspec_map = {input_act0: act_quantization_spec, input_act1: share_qparams_with_input_act0_qspec} + + add_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=act_quantization_spec, + _annotated=True, + ) 3. Annotate fixed qparams operators -------------------------------------------------------- @@ -263,31 +249,19 @@ of ``scale`` and ``zero_point`` explicitly. :: - def _annotate_sigmoid( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig - ) -> None: - sigmoid_partitions = get_source_partitions(gm.graph, [torch.nn.Sigmoid]) - sigmoid_partitions = list(itertools.chain(*sigmoid_partitions.values())) - for sigmoid_partition in sigmoid_partitions: - sigmoid_node = sigmoid_partition.output_nodes[0] - - input_act = sigmoid_node.args[0] - assert isinstance(input_act, Node) - act_qspec = FixedQParamsQuantizationSpec( - dtype=torch.uint8, - quant_min=0, - quant_max=255, - qscheme=torch.per_tensor_affine, - scale=2.0 / 256.0, - zero_point=128, - ) - sigmoid_node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map={ - input_act: act_qspec, - }, - output_qspec=act_qspec, - _annotated=True, - ) + act_qspec = FixedQParamsQuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + scale=2.0 / 256.0, + zero_point=128, + ) + sigmoid_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={input_act: act_qspec}, + output_qspec=act_qspec, + _annotated=True, + ) 4. Annotate tensor with derived quantization parameters --------------------------------------------------------------- @@ -300,74 +274,35 @@ to annotate this bias tensor. :: - def _annotate_conv2d_derived_bias( - self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig - ) -> None: - conv_partitions = get_source_partitions( - gm.graph, [torch.nn.Conv2d, torch.nn.functional.conv2d] - ) - conv_partitions = list(itertools.chain(*conv_partitions.values())) - for conv_partition in conv_partitions: - node = conv_partition.output_nodes[0] - input_act = node.args[0] - weight = node.args[1] - bias = node.args[2] - - act_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = \ - HistogramObserver - act_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=-128, - quant_max=127, - qscheme=torch.per_tensor_affine, - is_dynamic=False, - observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12), - ) - weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = PerChannelMinMaxObserver - extra_args: Dict[str, Any] = {"eps": 2**-12} - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=-127, - quant_max=127, - qscheme=torch.per_channel_symmetric, - ch_axis=0, - is_dynamic=False, - observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(**extra_args), - ) - act_qspec = act_quantization_spec - weight_qspec = weight_quantization_spec - - def derive_qparams_fn(obs_or_fqs: List[ObserverOrFakeQuantize]) -> Tuple[Tensor, Tensor]: - assert len(obs_or_fqs) == 2, \ - "Expecting two obs/fqs, one for activation and one for weight, got: {}".format(len(obs_or_fq)) - act_obs_or_fq = obs_or_fqs[0] - weight_obs_or_fq = obs_or_fqs[1] - act_scale, act_zp = act_obs_or_fq.calculate_qparams() - weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams() - return torch.tensor([act_scale * weight_scale]).to(torch.float32), torch.tensor([0]).to(torch.int32) - - bias_qspec = DerivedQuantizationSpec( - derived_from=[(input_act, node), (weight, node)], - derive_qparams_fn=derive_qparams_fn, - dtype=torch.int32, - quant_min=-2**31, - quant_max=2**31 - 1, - qscheme=torch.per_tensor_symmetric, - ) - input_qspec_map = {input_act: act_qspec, weight: weight_qspec, bias: bias_qspec} - node.meta["quantization_annotation"] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=act_qspec, - _annotated=True, - ) + def derive_qparams_fn(obs_or_fqs: List[ObserverOrFakeQuantize]) -> Tuple[Tensor, Tensor]: + assert len(obs_or_fqs) == 2, \ + "Expecting two obs/fqs, one for activation and one for weight, got: {}".format(len(obs_or_fq)) + act_obs_or_fq = obs_or_fqs[0] + weight_obs_or_fq = obs_or_fqs[1] + act_scale, act_zp = act_obs_or_fq.calculate_qparams() + weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams() + return torch.tensor([act_scale * weight_scale]).to(torch.float32), torch.tensor([0]).to(torch.int32) + + bias_qspec = DerivedQuantizationSpec( + derived_from=[(input_act, node), (weight, node)], + derive_qparams_fn=derive_qparams_fn, + dtype=torch.int32, + quant_min=-2**31, + quant_max=2**31 - 1, + qscheme=torch.per_tensor_symmetric, + ) + input_qspec_map = {input_act: act_quantization_spec, weight: weight_quantization_spec, bias: bias_qspec} + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=act_quantization_spec, + _annotated=True, + ) 5. A Toy Example with Resnet18 -------------------------------------------------------- To better understand the final example, here are some basic concepts before we move on to this part: -- `QuantizationSpec `__ - defines the ``data type``, ``qscheme``, and other quantization parameters used to quantize a tensor. - `QuantizationConfig `__ consists of ``QuantizationSpec`` for activation, weight, and bias separately. - When annotating the model, methods of From d9060a5e6a34e810cb99f502a63076edafb9c148 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Wed, 7 Jun 2023 16:10:46 +0800 Subject: [PATCH 30/42] Modify the documents --- ...ization_in_pytorch_2_0_export_tutorial.rst | 60 ++++++++++++------- 1 file changed, 39 insertions(+), 21 deletions(-) diff --git a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst index 083c0b50816..424788d23dc 100644 --- a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst @@ -27,8 +27,9 @@ with two main limitations: `discussion `__ to support `conv add` fusion with oneDNN library. It also requires some changes to current already complicated pattern matching code such as in the `PR `__ to support `conv add` fusion. -- Limitation around supporting user's advanced intention to quantize their model. For example, ``FX Graph Mode Quantization`` - doesn't support this quantization intention: only quantize inputs and outputs when the ``linear`` has a third input. +- Limitation around supporting user's advanced quantization intention to quantize their model. For example, if backend + developer only want to quantize inputs and outputs when the ``linear`` has a third input, it requires co-work from quantization + team and backend developer. To address these scalability issues, `Quantizer `__ @@ -36,15 +37,13 @@ is introduced for quantization in PyTorch 2.0 export. ``Quantizer`` is a class t programmably set the observer or fake quant objects for each node in the model graph. It adds flexibility to the quantization API and allows modeling users and backend developers to configure quantization programmatically. This will allow users to express how they want an operator pattern to be observed in a more explicit -way by annotating the appropriate nodes. To define a backend specific quantizer, user mainly need to override -several APIs: +way by annotating the appropriate nodes. A backend specific quantizer inherited from base quantizer, +some APIs need to be overrided: - `annotate method `__ is used to annotate nodes in the graph with observer or fake quant constructors to convey the desired way of quantization. -- `validate method `__ +- `validate method `__ is used to validate if the annotated graph is supported by the backend. -- `set_global method `__ - is used to set the global ``QuantizationConfig`` object for this quantizer to specify how the model will be quantized. Imagine a backend developer who wishes to integrate a third-party backend with PyTorch's quantization 2.0 flow. To accomplish this, they would only need @@ -140,9 +139,9 @@ backend developers will have intent to quantize (as expressed by of how this intent is conveyed in the quantization workflow with annotation API. - Step 1: Identify the original floating point pattern in the FX graph. There are - several ways to identify this pattern: 1. User may use a pattern matcher (e.g. SubgraphMatcher) - to match the operator pattern. 2. User may go through the nodes from start to the end and compare - the node's target type to match the operator pattern. In this example, we use the + several ways to identify this pattern: User may use a pattern matcher (e.g. SubgraphMatcher) + to match the operator pattern; User may go through the nodes from start to the end and compare + the node's target type to match the operator pattern. In this example, we can use the `get_source_partitions `__ to match this pattern. The original floating point ``add`` pattern only contain a single ``add`` node. @@ -154,9 +153,8 @@ of how this intent is conveyed in the quantization workflow with annotation API. add_node = add_partition.output_nodes[0] - Step 2: Define the ``QuantizationSpec`` for inputs and output of the pattern. ``QuantizationSpec`` - defines the ``data type``, ``qscheme``, and other quantization parameters used to quantize a tensor. - In this example, the ``add`` pattern has two input tensors and one output tensor. We will define the ``QuantizationSpec`` - for each of the input or output tensor to specify how to quantize it. + defines the ``data type``, ``qscheme``, and other quantization parameters about users' intent of + how to observer/quantize a tensor. :: @@ -174,11 +172,11 @@ of how this intent is conveyed in the quantization workflow with annotation API. - Step 3: Annotate the inputs and output of the pattern with `QuantizationAnnotation `__ - . ``QuantizationAnnotation`` is a ``dataclass`` with several fields as: 1. ``input_qspec_map`` field is ``Dict`` - to map each input ``Node`` to a ``QuantizationSpec``. 2. ``output_qspec`` field expresses the ``QuantizationSpec`` used for - output node. 3. ``_annotated`` field indicates if this node has already been annotated by quantizer. - In this example, we will create the ``QuantizationAnnotation``object with the ``QuantizationSpec`` objects - created in above step 2. + . ``QuantizationAnnotation`` is a ``dataclass`` with several fields as: ``input_qspec_map`` field is ``Dict`` + to map each input ``Node`` to a ``QuantizationSpec``; ``output_qspec`` field expresses the ``QuantizationSpec`` used for + output node; ``_annotated`` field indicates if this node has already been annotated by quantizer. + In this example, we will create the ``QuantizationAnnotation`` object with the ``QuantizationSpec`` objects + created in above step 2 for two inputs and one output of ``add`` node. :: @@ -221,10 +219,15 @@ or an output value. so it's a Tuple[Node, Node]. - Output value is an fx Node. -Now, If we want to rewrite ``add`` annotation example with ``SharedQuantizationSpec`` to indicate +Now, if we want to rewrite ``add`` annotation example with ``SharedQuantizationSpec`` to indicate two input tensors as sharing quantization parameters. We can define its ``QuantizationAnnotation`` as this: +- Step 1: Annotate input_act0 of ``add`` with ``QuantizationSpec``. +- Step 2: Create a ``SharedQuantizationSpec`` object with input edge defined as ``(input_act0, add_node)`` which means to + share the observer used for this edge. Then, user can annotate input_act1 with this ``SharedQuantizationSpec`` + object. + :: input_qspec_map = {} @@ -247,6 +250,10 @@ predefined and fixed scale/zero_point at input and output tensors. is designed for this use case. To use ``FixedQParamsQuantizationSpec``, users need to pass in parameters of ``scale`` and ``zero_point`` explicitly. +- Step 1: Create ``FixedQParamsQuantizationSpec`` object with inputs of fixed ``scale``, ``zero_point`` value. + These values will be used to create the ``quantize`` node and ``dequantize`` node in the convert phase. +- Step 2: Annotate inputs and output to use this ``FixedQParamsQuantizationSpec`` object. + :: act_qspec = FixedQParamsQuantizationSpec( @@ -266,11 +273,22 @@ of ``scale`` and ``zero_point`` explicitly. 4. Annotate tensor with derived quantization parameters --------------------------------------------------------------- -We also may need to define the constraint for tensors whose quantization parameters are derived from other tensors. +Another use case is to define the constraint for tensors whose quantization parameters are derived from other tensors. For example, if we want to annotate a convolution node, and define the ``scale`` of its bias input tensor as product of the activation tensor's ``scale`` and weight tensor's ``scale``. We can use `DerivedQuantizationSpec `__ -to annotate this bias tensor. +to annotate this conv node. + +- Step 1: Define ``derive_qparams_fn`` function, it accepts list of ``ObserverOrFakeQuantize`` ( + `ObserverBase `__ + or `FakeQuantizeBase `__) + as input. From each ``ObserverOrFakeQuantize`` object, user can get the ``scale``, ``zero point`` value. Combine these values together, + user can define its heuristic about how to derive new ``scale``, ``zero point`` value. +- Step 2: Define ``DerivedQuantizationSpec`` obejct, it accepts inputs of: list of ``EdgeOrNode`` (Edge is for the connection + between input node and the node consuming the input as Tuple[Node, Node]; Node is for the output node.). The observer at the input edge or output node + will be passed in to the ``derive_qparams_fn`` function; ``derive_qparams_fn`` function; + several other quantization parameters such as ``dtype``, ``qscheme``. +- Step 3: Annotate the inputs and output of this conv node with ``QuantizationAnnotation``. :: From b2784079a88359e114d67c6482d5555e9d4241c0 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Wed, 7 Jun 2023 17:04:08 +0800 Subject: [PATCH 31/42] format the document --- ...ization_in_pytorch_2_0_export_tutorial.rst | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst index 424788d23dc..cf9e9792048 100644 --- a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst @@ -23,10 +23,10 @@ and ``BackendConfig`` to specify the supported ways of quantization in their bac This API covers most use cases relatively well, but the main problem is that this API is not fully extensible with two main limitations: -- Limitation around expressing quantization intentions for complicated operator patterns such as in the - `discussion `__ to support `conv add` fusion with oneDNN library. +- Limitation around expressing quantization intentions for complicated operator patterns such as in the discussion of + `issue-96288 `__ to support ``conv add`` fusion with oneDNN library. It also requires some changes to current already complicated pattern matching code such as in the - `PR `__ to support `conv add` fusion. + `PR-97122 `__ to support ``conv add`` fusion. - Limitation around supporting user's advanced quantization intention to quantize their model. For example, if backend developer only want to quantize inputs and outputs when the ``linear`` has a third input, it requires co-work from quantization team and backend developer. @@ -212,8 +212,8 @@ parameters can be shared among some tensors explicitly. Two typical use cases ar This typically results from operators such as ``maxpool``, ``average_pool``, ``concat`` etc. ``SharedQuantizationSpec`` is designed for this use case to annotate tensors whose quantization -parameters are shared with other tensors. Input of ``SharedQuantizationSpec`` can be an input edge -or an output value. +parameters are shared with other tensors. Input of ``SharedQuantizationSpec`` is an ``EdgeOrNode`` object which +can be an input edge or an output value. - Input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node]. @@ -284,10 +284,9 @@ to annotate this conv node. or `FakeQuantizeBase `__) as input. From each ``ObserverOrFakeQuantize`` object, user can get the ``scale``, ``zero point`` value. Combine these values together, user can define its heuristic about how to derive new ``scale``, ``zero point`` value. -- Step 2: Define ``DerivedQuantizationSpec`` obejct, it accepts inputs of: list of ``EdgeOrNode`` (Edge is for the connection - between input node and the node consuming the input as Tuple[Node, Node]; Node is for the output node.). The observer at the input edge or output node - will be passed in to the ``derive_qparams_fn`` function; ``derive_qparams_fn`` function; - several other quantization parameters such as ``dtype``, ``qscheme``. +- Step 2: Define ``DerivedQuantizationSpec`` obejct, it accepts inputs of: list of ``EdgeOrNode`` objects. + The observer corresponding to each ``EdgeOrNode`` object will be passed into the ``derive_qparams_fn`` function; + ``derive_qparams_fn`` function; several other quantization parameters such as ``dtype``, ``qscheme``. - Step 3: Annotate the inputs and output of this conv node with ``QuantizationAnnotation``. :: @@ -319,7 +318,9 @@ to annotate this conv node. 5. A Toy Example with Resnet18 -------------------------------------------------------- -To better understand the final example, here are some basic concepts before we move on to this part: +After above annotation methods defined with ``QuantizationAnnotation API``, we can now put them together to construct a ``BackendQuantizer`` +and run a `toy example `__ +with ``Torchvision Resnet18``. To better understand the final example, here are some basic concepts which are used: - `QuantizationConfig `__ consists of ``QuantizationSpec`` for activation, weight, and bias separately. @@ -329,10 +330,6 @@ To better understand the final example, here are some basic concepts before we m `get_bias_qspec `__ can be used to get the ``QuantizationSpec`` from ``QuantizationConfig`` for a specific node. -After above annotation methods defined with ``QuantizationAnnotation API``, we can now put them together to construct a ``BackendQuantizer`` -to run a `toy example `__ -with Torchvision Resnet18. - 6. Conclusion --------------------- From 13cefd340d5c87ee9d59cb77b9f228d02aee8c89 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Wed, 7 Jun 2023 17:17:06 +0800 Subject: [PATCH 32/42] add links to Prerequisites --- .../quantization_in_pytorch_2_0_export_tutorial.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst index cf9e9792048..a6ee2cfdc06 100644 --- a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst @@ -14,9 +14,11 @@ a simplified UX. Prerequisites: ----------------------- +- `Understanding of torchdynamo concepts in PyTorch `__ - `Understanding of the quantization concepts in PyTorch `__ - `Understanding of FX graph mode post training static quantization `__ -- `Understanding of torchdynamo concepts in PyTorch `__ +- `Understanding of BackendConfig in PyTorch Quantization FX Graph Mode `__ +- `Understanding of QConfigMapping in PyTorch Quantization FX Graph Mode `__ Previously in ``FX Graph Mode Quantization`` we were using ``QConfigMapping`` for users to specify how the model to be quantized and ``BackendConfig`` to specify the supported ways of quantization in their backend. From ef4fa73cfbeca134924779f8b76176c2a065f7ac Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Thu, 8 Jun 2023 09:01:31 +0800 Subject: [PATCH 33/42] update descriptation --- ...ization_in_pytorch_2_0_export_tutorial.rst | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst index a6ee2cfdc06..e9ba5bf169a 100644 --- a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst @@ -36,16 +36,16 @@ with two main limitations: To address these scalability issues, `Quantizer `__ is introduced for quantization in PyTorch 2.0 export. ``Quantizer`` is a class that users can use to -programmably set the observer or fake quant objects for each node in the model graph. It adds flexibility +programmably set the quantization specifications for input and output of each node in the model graph. It adds flexibility to the quantization API and allows modeling users and backend developers to configure quantization programmatically. This will allow users to express how they want an operator pattern to be observed in a more explicit way by annotating the appropriate nodes. A backend specific quantizer inherited from base quantizer, -some APIs need to be overrided: +some methods that need to be implemented: - `annotate method `__ - is used to annotate nodes in the graph with observer or fake quant constructors to convey the desired way of quantization. -- `validate method `__ - is used to validate if the annotated graph is supported by the backend. + is used to annotate nodes in the graph with + `QuantizationAnnotation `__ + objects to convey the desired way of quantization. Imagine a backend developer who wishes to integrate a third-party backend with PyTorch's quantization 2.0 flow. To accomplish this, they would only need @@ -156,7 +156,7 @@ of how this intent is conveyed in the quantization workflow with annotation API. - Step 2: Define the ``QuantizationSpec`` for inputs and output of the pattern. ``QuantizationSpec`` defines the ``data type``, ``qscheme``, and other quantization parameters about users' intent of - how to observer/quantize a tensor. + how to observe or fake quantize a tensor. :: @@ -172,9 +172,8 @@ of how this intent is conveyed in the quantization workflow with annotation API. input_act_qspec = act_quantization_spec output_act_qspec = act_quantization_spec -- Step 3: Annotate the inputs and output of the pattern with - `QuantizationAnnotation `__ - . ``QuantizationAnnotation`` is a ``dataclass`` with several fields as: ``input_qspec_map`` field is ``Dict`` +- Step 3: Annotate the inputs and output of the pattern with ``QuantizationAnnotation``. + ``QuantizationAnnotation`` is a ``dataclass`` with several fields as: ``input_qspec_map`` field is ``Dict`` to map each input ``Node`` to a ``QuantizationSpec``; ``output_qspec`` field expresses the ``QuantizationSpec`` used for output node; ``_annotated`` field indicates if this node has already been annotated by quantizer. In this example, we will create the ``QuantizationAnnotation`` object with the ``QuantizationSpec`` objects @@ -218,8 +217,8 @@ parameters are shared with other tensors. Input of ``SharedQuantizationSpec`` is can be an input edge or an output value. - Input edge is the connection between input node and the node consuming the input, - so it's a Tuple[Node, Node]. -- Output value is an fx Node. + so it's a ``Tuple[Node, Node]``. +- Output value is an fx ``Node``. Now, if we want to rewrite ``add`` annotation example with ``SharedQuantizationSpec`` to indicate two input tensors as sharing quantization parameters. We can define its ``QuantizationAnnotation`` From 40caeefe4d67d3c6c3ad9fdf3d8232d7eebc1fd5 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Thu, 8 Jun 2023 09:11:25 +0800 Subject: [PATCH 34/42] add pattern match as step 1 --- ...ization_in_pytorch_2_0_export_tutorial.rst | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst index e9ba5bf169a..94cd3e60a05 100644 --- a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst @@ -145,7 +145,7 @@ of how this intent is conveyed in the quantization workflow with annotation API. to match the operator pattern; User may go through the nodes from start to the end and compare the node's target type to match the operator pattern. In this example, we can use the `get_source_partitions `__ - to match this pattern. The original floating point ``add`` pattern only contain a single ``add`` node. + to match this pattern. The original floating point ``add`` pattern only contain a single ``add`` node. :: @@ -224,8 +224,10 @@ Now, if we want to rewrite ``add`` annotation example with ``SharedQuantizationS two input tensors as sharing quantization parameters. We can define its ``QuantizationAnnotation`` as this: -- Step 1: Annotate input_act0 of ``add`` with ``QuantizationSpec``. -- Step 2: Create a ``SharedQuantizationSpec`` object with input edge defined as ``(input_act0, add_node)`` which means to +- Step 1: Identify the original floating point pattern in the FX graph. We can use the same + methods introduced in ``QuantizationSpec`` example to identify the ``add`` pattern. +- Step 2: Annotate input_act0 of ``add`` with ``QuantizationSpec``. +- Step 3: Create a ``SharedQuantizationSpec`` object with input edge defined as ``(input_act0, add_node)`` which means to share the observer used for this edge. Then, user can annotate input_act1 with this ``SharedQuantizationSpec`` object. @@ -251,9 +253,11 @@ predefined and fixed scale/zero_point at input and output tensors. is designed for this use case. To use ``FixedQParamsQuantizationSpec``, users need to pass in parameters of ``scale`` and ``zero_point`` explicitly. -- Step 1: Create ``FixedQParamsQuantizationSpec`` object with inputs of fixed ``scale``, ``zero_point`` value. +- Step 1: Identify the original floating point pattern in the FX graph. We can use the same + methods introduced in ``QuantizationSpec`` example to identify the ``sigmoid`` pattern. +- Step 2: Create ``FixedQParamsQuantizationSpec`` object with inputs of fixed ``scale``, ``zero_point`` value. These values will be used to create the ``quantize`` node and ``dequantize`` node in the convert phase. -- Step 2: Annotate inputs and output to use this ``FixedQParamsQuantizationSpec`` object. +- Step 3: Annotate inputs and output to use this ``FixedQParamsQuantizationSpec`` object. :: @@ -280,15 +284,17 @@ as product of the activation tensor's ``scale`` and weight tensor's ``scale``. W `DerivedQuantizationSpec `__ to annotate this conv node. -- Step 1: Define ``derive_qparams_fn`` function, it accepts list of ``ObserverOrFakeQuantize`` ( +- Step 1: Identify the original floating point pattern in the FX graph. We can use the same + methods introduced in ``QuantizationSpec`` example to identify the ``convolution`` pattern. +- Step 2: Define ``derive_qparams_fn`` function, it accepts list of ``ObserverOrFakeQuantize`` ( `ObserverBase `__ or `FakeQuantizeBase `__) as input. From each ``ObserverOrFakeQuantize`` object, user can get the ``scale``, ``zero point`` value. Combine these values together, user can define its heuristic about how to derive new ``scale``, ``zero point`` value. -- Step 2: Define ``DerivedQuantizationSpec`` obejct, it accepts inputs of: list of ``EdgeOrNode`` objects. +- Step 3: Define ``DerivedQuantizationSpec`` obejct, it accepts inputs of: list of ``EdgeOrNode`` objects. The observer corresponding to each ``EdgeOrNode`` object will be passed into the ``derive_qparams_fn`` function; ``derive_qparams_fn`` function; several other quantization parameters such as ``dtype``, ``qscheme``. -- Step 3: Annotate the inputs and output of this conv node with ``QuantizationAnnotation``. +- Step 4: Annotate the inputs and output of this conv node with ``QuantizationAnnotation``. :: From 1c59f15213af3f4ac3bcb9ef31739e6339324044 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Thu, 8 Jun 2023 09:21:05 +0800 Subject: [PATCH 35/42] thanks and the link updated --- ...ization_in_pytorch_2_0_export_tutorial.rst | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst index 94cd3e60a05..5a482dc30c7 100644 --- a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst @@ -266,8 +266,8 @@ of ``scale`` and ``zero_point`` explicitly. quant_min=0, quant_max=255, qscheme=torch.per_tensor_affine, - scale=2.0 / 256.0, - zero_point=128, + scale=1.0 / 256.0, + zero_point=0, ) sigmoid_node.meta["quantization_annotation"] = QuantizationAnnotation( input_qspec_map={input_act: act_qspec}, @@ -289,8 +289,9 @@ to annotate this conv node. - Step 2: Define ``derive_qparams_fn`` function, it accepts list of ``ObserverOrFakeQuantize`` ( `ObserverBase `__ or `FakeQuantizeBase `__) - as input. From each ``ObserverOrFakeQuantize`` object, user can get the ``scale``, ``zero point`` value. Combine these values together, - user can define its heuristic about how to derive new ``scale``, ``zero point`` value. + as input. From each ``ObserverOrFakeQuantize`` object, user can get the ``scale``, ``zero point`` value. + User can define its heuristic about how to derive new ``scale``, ``zero point`` value based on the + quantization parameters calculated from the observer or fake quant instances. - Step 3: Define ``DerivedQuantizationSpec`` obejct, it accepts inputs of: list of ``EdgeOrNode`` objects. The observer corresponding to each ``EdgeOrNode`` object will be passed into the ``derive_qparams_fn`` function; ``derive_qparams_fn`` function; several other quantization parameters such as ``dtype``, ``qscheme``. @@ -327,15 +328,17 @@ to annotate this conv node. After above annotation methods defined with ``QuantizationAnnotation API``, we can now put them together to construct a ``BackendQuantizer`` and run a `toy example `__ -with ``Torchvision Resnet18``. To better understand the final example, here are some basic concepts which are used: - -- `QuantizationConfig `__ - consists of ``QuantizationSpec`` for activation, weight, and bias separately. -- When annotating the model, methods of - `get_act_qspec `__, - `get_weight_qspec `__, and - `get_bias_qspec `__ - can be used to get the ``QuantizationSpec`` from ``QuantizationConfig`` for a specific node. +with ``Torchvision Resnet18``. To better understand the final example, here are the classes and utility +functions that are used in the example: + +- `QuantizationConfig `__ + consists of ``QuantizationSpec`` for activation, weight, and bias separately. +- When annotating the model, + `get_input_act_qspec `__, + `get_output_act_qspec `__, + `get_weight_qspec `__, and + `get_bias_qspec `__ + can be used to get the ``QuantizationSpec`` from ``QuantizationConfig`` for a specific node. 6. Conclusion --------------------- From 5edc1f2a01ce5d595673d3f08f2841b07deaa000 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Thu, 8 Jun 2023 10:22:46 +0800 Subject: [PATCH 36/42] fix typo --- ...ization_in_pytorch_2_0_export_tutorial.rst | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst index 5a482dc30c7..1f9ca12497a 100644 --- a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst @@ -16,7 +16,7 @@ Prerequisites: - `Understanding of torchdynamo concepts in PyTorch `__ - `Understanding of the quantization concepts in PyTorch `__ -- `Understanding of FX graph mode post training static quantization `__ +- `Understanding of FX Graph Mode post training static quantization `__ - `Understanding of BackendConfig in PyTorch Quantization FX Graph Mode `__ - `Understanding of QConfigMapping in PyTorch Quantization FX Graph Mode `__ @@ -30,13 +30,13 @@ with two main limitations: It also requires some changes to current already complicated pattern matching code such as in the `PR-97122 `__ to support ``conv add`` fusion. - Limitation around supporting user's advanced quantization intention to quantize their model. For example, if backend - developer only want to quantize inputs and outputs when the ``linear`` has a third input, it requires co-work from quantization + developer only wants to quantize inputs and outputs when the ``linear`` has a third input, it requires co-work from quantization team and backend developer. To address these scalability issues, `Quantizer `__ is introduced for quantization in PyTorch 2.0 export. ``Quantizer`` is a class that users can use to -programmably set the quantization specifications for input and output of each node in the model graph. It adds flexibility +programmatically set the quantization specifications for input and output of each node in the model graph. It adds flexibility to the quantization API and allows modeling users and backend developers to configure quantization programmatically. This will allow users to express how they want an operator pattern to be observed in a more explicit way by annotating the appropriate nodes. A backend specific quantizer inherited from base quantizer, @@ -85,7 +85,7 @@ quantization 2.0 with quantizer could look like this: Note: ``prepare_pt2e_quantizer`` will be updated to ``prepare_pt2e`` soon. -An existing quantizer object defined for QNNPack/XNNPack is located in +An existing quantizer object defined for QNNPack/XNNPack is in `QNNPackQuantizer `__. Taking QNNPackQuantizer as an example, the overall Quantization 2.0 flow could be: @@ -172,8 +172,9 @@ of how this intent is conveyed in the quantization workflow with annotation API. input_act_qspec = act_quantization_spec output_act_qspec = act_quantization_spec -- Step 3: Annotate the inputs and output of the pattern with ``QuantizationAnnotation``. - ``QuantizationAnnotation`` is a ``dataclass`` with several fields as: ``input_qspec_map`` field is ``Dict`` +- Step 3: Annotate the inputs and output of the pattern with + `QuantizationAnnotation `__. + ``QuantizationAnnotation`` is a ``dataclass`` with several fields as: ``input_qspec_map`` field is of class ``Dict`` to map each input ``Node`` to a ``QuantizationSpec``; ``output_qspec`` field expresses the ``QuantizationSpec`` used for output node; ``_annotated`` field indicates if this node has already been annotated by quantizer. In this example, we will create the ``QuantizationAnnotation`` object with the ``QuantizationSpec`` objects @@ -207,7 +208,7 @@ parameters can be shared among some tensors explicitly. Two typical use cases ar - Example 1: One example is for ``add`` where having both inputs sharing quantization parameters makes operator implementation much easier. Without using of `SharedQuantizationSpec `__, - we have to annotate ``add`` as example in above section 1, in which two inputs of ``add`` + we must annotate ``add`` as example in above section 1, in which two inputs of ``add`` has different quantization parameters. - Example 2: Another example is that of sharing quantization parameters between inputs and output. This typically results from operators such as ``maxpool``, ``average_pool``, ``concat`` etc. @@ -218,7 +219,7 @@ can be an input edge or an output value. - Input edge is the connection between input node and the node consuming the input, so it's a ``Tuple[Node, Node]``. -- Output value is an fx ``Node``. +- Output value is an FX ``Node``. Now, if we want to rewrite ``add`` annotation example with ``SharedQuantizationSpec`` to indicate two input tensors as sharing quantization parameters. We can define its ``QuantizationAnnotation`` @@ -247,7 +248,7 @@ as this: -------------------------------------------------------- Another typical use case to annotate a quantized model is for tensors whose -quantization parmaters are known beforehand. For example, operator like ``sigmoid``, which has +quantization parameters are known beforehand. For example, operator like ``sigmoid``, which has predefined and fixed scale/zero_point at input and output tensors. `FixedQParamsQuantizationSpec `__ is designed for this use case. To use ``FixedQParamsQuantizationSpec``, users need to pass in parameters From 4b8606005a853a052cadc594d100cbabf360dfd7 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Fri, 9 Jun 2023 08:52:06 +0800 Subject: [PATCH 37/42] add author --- .../quantization_in_pytorch_2_0_export_tutorial.rst | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst index 1f9ca12497a..157f1aa35e1 100644 --- a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst @@ -1,6 +1,8 @@ -(prototype) Quantization in PyTorch 2.0 Export Tutorial +(prototype) Quantization in PyTorch 2.0 Export Tutorial (Work in Progress) ============================================================== +**Author**: `Leslie Fang `_, `Weiwen Xia `__, `Jiong Gong `__ + Today we have `FX Graph Mode Quantization `__ which uses ``symbolic_trace`` to capture the model into a graph, and then @@ -339,7 +341,7 @@ functions that are used in the example: `get_output_act_qspec `__, `get_weight_qspec `__, and `get_bias_qspec `__ - can be used to get the ``QuantizationSpec`` from ``QuantizationConfig`` for a specific node. + can be used to get the ``QuantizationSpec`` from ``QuantizationConfig`` for a specific pattern. 6. Conclusion --------------------- From 5a43584415b88d1bf580de219db866d64f0154c2 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Fri, 9 Jun 2023 09:10:36 +0800 Subject: [PATCH 38/42] add more explain of limitations --- ...ization_in_pytorch_2_0_export_tutorial.rst | 50 +++++++++++++------ 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst index 157f1aa35e1..9c3000557f0 100644 --- a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst @@ -1,7 +1,7 @@ -(prototype) Quantization in PyTorch 2.0 Export Tutorial (Work in Progress) +(Work in Progress) Quantization in PyTorch 2.0 Export Tutorial ============================================================== -**Author**: `Leslie Fang `_, `Weiwen Xia `__, `Jiong Gong `__ +**Author**: `Leslie Fang `_, `Weiwen Xia `__, `Jiong Gong `__, `Kimish Patel `__, `Jerry Zhang `__ Today we have `FX Graph Mode Quantization `__ @@ -20,12 +20,12 @@ Prerequisites: - `Understanding of the quantization concepts in PyTorch `__ - `Understanding of FX Graph Mode post training static quantization `__ - `Understanding of BackendConfig in PyTorch Quantization FX Graph Mode `__ -- `Understanding of QConfigMapping in PyTorch Quantization FX Graph Mode `__ +- `Understanding of QConfig and QConfigMapping in PyTorch Quantization FX Graph Mode `__ Previously in ``FX Graph Mode Quantization`` we were using ``QConfigMapping`` for users to specify how the model to be quantized and ``BackendConfig`` to specify the supported ways of quantization in their backend. This API covers most use cases relatively well, but the main problem is that this API is not fully extensible -with two main limitations: +without involvement of the quantization team: - Limitation around expressing quantization intentions for complicated operator patterns such as in the discussion of `issue-96288 `__ to support ``conv add`` fusion with oneDNN library. @@ -34,6 +34,15 @@ with two main limitations: - Limitation around supporting user's advanced quantization intention to quantize their model. For example, if backend developer only wants to quantize inputs and outputs when the ``linear`` has a third input, it requires co-work from quantization team and backend developer. +- Currently we use ``QConfigMapping`` and ``BackendConfig`` as separate object. ``QConfigMapping`` describes user's + intention of how they want their model to be quantized. ``BackendConfig`` describes what kind of quantization a backend support. + Currently ``BackendConfig`` is backend specific, but ``QConfigMapping`` is not. And user can provide a ``QConfigMapping`` + that is incompatible with a specific BackendConfig. This is not a great UX. Ideally we can structure this better + by making both configuration (``QConfigMapping``) and quantization capability (``BackendConfig``) backend + specific, so there will be less confusion about incompatibilities. +- Currently in ``QConfig`` we are exposing observer/fake_quant classes as an object for user to configure quantization. + This increases the things that user may need to care about, e.g. not only the ``dtype`` but also how the observation should + happen. These could potentially be hidden from user so that the user interface is simpler. To address these scalability issues, `Quantizer `__ @@ -127,24 +136,34 @@ Taking QNNPackQuantizer as an example, the overall Quantization 2.0 flow could b # Step 4: Lower Reference Quantized Model into the backend -Inside the Quantizer, we will use the ``QuantizationAnnotation API`` -to convey user's intent for what quantization spec to use and how to -observe certain tensor values in the prepare step. Now, we will have a step-by-step -tutorial for how to use the ``QuantizationAnnotation API`` with different types of +Quantizer uses annotation API to convey quantization intent for different operators/patterns. +Annotation API uses ``QuantizationSpec`` ( +`definition is here `__ +) to convey intent of how a tensor will be quantized, +e.g. dtype, bitwidth, min, max values, symmetric vs. asymmetric etc. +Furthermore, annotation API also allows quantizer to specify how a +tensor value should be observed, e.g. ``MinMaxObserver``, or ``HistogramObserver`` +, or some customized observer. + +``QuantizationSpec`` is used to annotate nodes' output tensor or input tensors. Annotating +input tensors is equivalent of annotating edge of the graph, while annotating output tensor is +equivalent of annotating node. Thus annotation API requires quantizer to annotate nodes (output tensor) +or edges (input tensors) of the graph. + +Now, we will have a step-by-step tutorial for how to use the annotation API with different types of ``QuantizationSpec``. 1. Annotate common operator patterns -------------------------------------------------------- In order to use the quantized pattern/operators, e.g. ``quantized add``, -backend developers will have intent to quantize (as expressed by -`QuantizationSpec `__ -) inputs, output of the pattern. Following is an example flow (take ``add`` operator as example) +backend developers will have intent to quantize (as expressed by ``QuantizationSpec``) +inputs, output of the pattern. Following is an example flow (take ``add`` operator as example) of how this intent is conveyed in the quantization workflow with annotation API. - Step 1: Identify the original floating point pattern in the FX graph. There are - several ways to identify this pattern: User may use a pattern matcher (e.g. SubgraphMatcher) - to match the operator pattern; User may go through the nodes from start to the end and compare + several ways to identify this pattern: Quantizer may use a pattern matcher (e.g. SubgraphMatcher) + to match the operator pattern; Quantizer may go through the nodes from start to the end and compare the node's target type to match the operator pattern. In this example, we can use the `get_source_partitions `__ to match this pattern. The original floating point ``add`` pattern only contain a single ``add`` node. @@ -177,8 +196,9 @@ of how this intent is conveyed in the quantization workflow with annotation API. - Step 3: Annotate the inputs and output of the pattern with `QuantizationAnnotation `__. ``QuantizationAnnotation`` is a ``dataclass`` with several fields as: ``input_qspec_map`` field is of class ``Dict`` - to map each input ``Node`` to a ``QuantizationSpec``; ``output_qspec`` field expresses the ``QuantizationSpec`` used for - output node; ``_annotated`` field indicates if this node has already been annotated by quantizer. + to map each input ``Node`` to a ``QuantizationSpec``. It means to annotate each input edge with this ``QuantizationSpec``; + ``output_qspec`` field expresses the ``QuantizationSpec`` used to + annotate the output node; ``_annotated`` field indicates if this node has already been annotated by quantizer. In this example, we will create the ``QuantizationAnnotation`` object with the ``QuantizationSpec`` objects created in above step 2 for two inputs and one output of ``add`` node. From d9d32458e80af3989c501d3779a95e35312978cb Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Fri, 9 Jun 2023 10:37:05 +0800 Subject: [PATCH 39/42] adjust descriptation --- ...ization_in_pytorch_2_0_export_tutorial.rst | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst index 9c3000557f0..ef4dc9d29d6 100644 --- a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst @@ -27,22 +27,22 @@ and ``BackendConfig`` to specify the supported ways of quantization in their bac This API covers most use cases relatively well, but the main problem is that this API is not fully extensible without involvement of the quantization team: -- Limitation around expressing quantization intentions for complicated operator patterns such as in the discussion of - `issue-96288 `__ to support ``conv add`` fusion with oneDNN library. - It also requires some changes to current already complicated pattern matching code such as in the - `PR-97122 `__ to support ``conv add`` fusion. -- Limitation around supporting user's advanced quantization intention to quantize their model. For example, if backend +- Current API has limitation around expressing quantization intentions for complicated operator patterns such as in the discussion of + `Issue-96288 `__ to support ``conv add`` fusion. + Supporting ``conv add`` fusion also requires some changes to current already complicated pattern matching code such as in the + `PR-97122 `__. +- Current API also has limitation around supporting user's advanced quantization intention to quantize their model. For example, if backend developer only wants to quantize inputs and outputs when the ``linear`` has a third input, it requires co-work from quantization team and backend developer. -- Currently we use ``QConfigMapping`` and ``BackendConfig`` as separate object. ``QConfigMapping`` describes user's +- Current API uses ``QConfigMapping`` and ``BackendConfig`` as separate object. ``QConfigMapping`` describes user's intention of how they want their model to be quantized. ``BackendConfig`` describes what kind of quantization a backend support. - Currently ``BackendConfig`` is backend specific, but ``QConfigMapping`` is not. And user can provide a ``QConfigMapping`` - that is incompatible with a specific BackendConfig. This is not a great UX. Ideally we can structure this better + Currently, ``BackendConfig`` is backend specific, but ``QConfigMapping`` is not. And user can provide a ``QConfigMapping`` + that is incompatible with a specific ``BackendConfig``. This is not a great UX. Ideally, we can structure this better by making both configuration (``QConfigMapping``) and quantization capability (``BackendConfig``) backend - specific, so there will be less confusion about incompatibilities. -- Currently in ``QConfig`` we are exposing observer/fake_quant classes as an object for user to configure quantization. - This increases the things that user may need to care about, e.g. not only the ``dtype`` but also how the observation should - happen. These could potentially be hidden from user so that the user interface is simpler. + specific. So there will be less confusion about incompatibilities. +- Currently, in ``QConfig`` we are exposing observer/fake_quant classes as an object for user to configure quantization. + This increases the things that user needs to care about, e.g. not only the ``dtype`` but also how the observation should + happen. These could potentially be hidden from user to make user interface simpler. To address these scalability issues, `Quantizer `__ @@ -137,18 +137,18 @@ Taking QNNPackQuantizer as an example, the overall Quantization 2.0 flow could b # Step 4: Lower Reference Quantized Model into the backend Quantizer uses annotation API to convey quantization intent for different operators/patterns. -Annotation API uses ``QuantizationSpec`` ( -`definition is here `__ -) to convey intent of how a tensor will be quantized, +Annotation API uses +`QuantizationSpec `__ +to convey intent of how a tensor will be quantized, e.g. dtype, bitwidth, min, max values, symmetric vs. asymmetric etc. Furthermore, annotation API also allows quantizer to specify how a tensor value should be observed, e.g. ``MinMaxObserver``, or ``HistogramObserver`` , or some customized observer. -``QuantizationSpec`` is used to annotate nodes' output tensor or input tensors. Annotating +``QuantizationSpec`` is used to annotate nodes' input tensors or output tensor. Annotating input tensors is equivalent of annotating edge of the graph, while annotating output tensor is -equivalent of annotating node. Thus annotation API requires quantizer to annotate nodes (output tensor) -or edges (input tensors) of the graph. +equivalent of annotating node. Thus annotation API requires quantizer to annotate +edges (input tensors) or nodes (output tensor) of the graph. Now, we will have a step-by-step tutorial for how to use the annotation API with different types of ``QuantizationSpec``. @@ -162,7 +162,7 @@ inputs, output of the pattern. Following is an example flow (take ``add`` operat of how this intent is conveyed in the quantization workflow with annotation API. - Step 1: Identify the original floating point pattern in the FX graph. There are - several ways to identify this pattern: Quantizer may use a pattern matcher (e.g. SubgraphMatcher) + several ways to identify this pattern: Quantizer may use a pattern matcher to match the operator pattern; Quantizer may go through the nodes from start to the end and compare the node's target type to match the operator pattern. In this example, we can use the `get_source_partitions `__ @@ -200,7 +200,7 @@ of how this intent is conveyed in the quantization workflow with annotation API. ``output_qspec`` field expresses the ``QuantizationSpec`` used to annotate the output node; ``_annotated`` field indicates if this node has already been annotated by quantizer. In this example, we will create the ``QuantizationAnnotation`` object with the ``QuantizationSpec`` objects - created in above step 2 for two inputs and one output of ``add`` node. + created in above step 2 for two inputs and one output of the ``add`` node. :: From 1c0eeadf42cd04a16b374f152a825ffadfbcfd77 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Fri, 9 Jun 2023 16:34:40 +0800 Subject: [PATCH 40/42] Move QuantizationAnnotation to preface --- ...ization_in_pytorch_2_0_export_tutorial.rst | 41 ++++++++++--------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst index ef4dc9d29d6..2fe1a6f2c2f 100644 --- a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst @@ -41,8 +41,8 @@ without involvement of the quantization team: by making both configuration (``QConfigMapping``) and quantization capability (``BackendConfig``) backend specific. So there will be less confusion about incompatibilities. - Currently, in ``QConfig`` we are exposing observer/fake_quant classes as an object for user to configure quantization. - This increases the things that user needs to care about, e.g. not only the ``dtype`` but also how the observation should - happen. These could potentially be hidden from user to make user interface simpler. + This increases the things that user needs to care about, e.g. not only the ``dtype`` but also how the + observation should happen. These could potentially be hidden from user to make user interface simpler. To address these scalability issues, `Quantizer `__ @@ -136,22 +136,30 @@ Taking QNNPackQuantizer as an example, the overall Quantization 2.0 flow could b # Step 4: Lower Reference Quantized Model into the backend -Quantizer uses annotation API to convey quantization intent for different operators/patterns. -Annotation API uses +``Quantizer`` uses annotation API to convey quantization intent for different operators/patterns. +Annotation API mainly consists of `QuantizationSpec `__ -to convey intent of how a tensor will be quantized, +and +`QuantizationAnnotation `__. + +``QuantizationSpec`` is used to convey intent of how a tensor will be quantized, e.g. dtype, bitwidth, min, max values, symmetric vs. asymmetric etc. -Furthermore, annotation API also allows quantizer to specify how a +Furthermore, ``QuantizationSpec`` also allows quantizer to specify how a tensor value should be observed, e.g. ``MinMaxObserver``, or ``HistogramObserver`` , or some customized observer. -``QuantizationSpec`` is used to annotate nodes' input tensors or output tensor. Annotating -input tensors is equivalent of annotating edge of the graph, while annotating output tensor is -equivalent of annotating node. Thus annotation API requires quantizer to annotate -edges (input tensors) or nodes (output tensor) of the graph. +``QuantizationAnnotation`` composed of ``QuantizationSpec`` objects is used to annotate input tensors +and output tensor of a ``FX Node``. Annotating input tensors is equivalent of annotating input edges, +while annotating output tensor is equivalent of annotating node. ``QuantizationAnnotation`` is a ``dataclass`` +with several fields: + +- ``input_qspec_map`` field is of class ``Dict`` to map each input tensor (as input edge) to a ``QuantizationSpec``. +- ``output_qspec`` field expresses the ``QuantizationSpec`` used to annotate the output tensor; +- ``_annotated`` field indicates if this node has already been annotated by quantizer. -Now, we will have a step-by-step tutorial for how to use the annotation API with different types of -``QuantizationSpec``. +Thus annotation API requires quantizer to annotate edges (input tensors) or +nodes (output tensor) of the graph. Now, we will have a step-by-step tutorial for +how to use the annotation API with different types of ``QuantizationSpec``. 1. Annotate common operator patterns -------------------------------------------------------- @@ -193,13 +201,8 @@ of how this intent is conveyed in the quantization workflow with annotation API. input_act_qspec = act_quantization_spec output_act_qspec = act_quantization_spec -- Step 3: Annotate the inputs and output of the pattern with - `QuantizationAnnotation `__. - ``QuantizationAnnotation`` is a ``dataclass`` with several fields as: ``input_qspec_map`` field is of class ``Dict`` - to map each input ``Node`` to a ``QuantizationSpec``. It means to annotate each input edge with this ``QuantizationSpec``; - ``output_qspec`` field expresses the ``QuantizationSpec`` used to - annotate the output node; ``_annotated`` field indicates if this node has already been annotated by quantizer. - In this example, we will create the ``QuantizationAnnotation`` object with the ``QuantizationSpec`` objects +- Step 3: Annotate the inputs and output of the pattern with ``QuantizationAnnotation``. + In this example, we will create the ``QuantizationAnnotation`` object with the ``QuantizationSpec`` created in above step 2 for two inputs and one output of the ``add`` node. :: From 1fb6c57e556a7ad3aed470b71569100d746898fa Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Fri, 9 Jun 2023 16:47:38 +0800 Subject: [PATCH 41/42] fix descriptation --- .../quantization_in_pytorch_2_0_export_tutorial.rst | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst index 2fe1a6f2c2f..bb1f33fb6ad 100644 --- a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst @@ -27,20 +27,20 @@ and ``BackendConfig`` to specify the supported ways of quantization in their bac This API covers most use cases relatively well, but the main problem is that this API is not fully extensible without involvement of the quantization team: -- Current API has limitation around expressing quantization intentions for complicated operator patterns such as in the discussion of +- This API has limitation around expressing quantization intentions for complicated operator patterns such as in the discussion of `Issue-96288 `__ to support ``conv add`` fusion. Supporting ``conv add`` fusion also requires some changes to current already complicated pattern matching code such as in the `PR-97122 `__. -- Current API also has limitation around supporting user's advanced quantization intention to quantize their model. For example, if backend +- This API also has limitation around supporting user's advanced quantization intention to quantize their model. For example, if backend developer only wants to quantize inputs and outputs when the ``linear`` has a third input, it requires co-work from quantization team and backend developer. -- Current API uses ``QConfigMapping`` and ``BackendConfig`` as separate object. ``QConfigMapping`` describes user's +- This API uses ``QConfigMapping`` and ``BackendConfig`` as separate object. ``QConfigMapping`` describes user's intention of how they want their model to be quantized. ``BackendConfig`` describes what kind of quantization a backend support. - Currently, ``BackendConfig`` is backend specific, but ``QConfigMapping`` is not. And user can provide a ``QConfigMapping`` + ``BackendConfig`` is backend specific, but ``QConfigMapping`` is not. And user can provide a ``QConfigMapping`` that is incompatible with a specific ``BackendConfig``. This is not a great UX. Ideally, we can structure this better by making both configuration (``QConfigMapping``) and quantization capability (``BackendConfig``) backend specific. So there will be less confusion about incompatibilities. -- Currently, in ``QConfig`` we are exposing observer/fake_quant classes as an object for user to configure quantization. +- In ``QConfig``, we are exposing observer/fake_quant classes as an object for user to configure quantization. This increases the things that user needs to care about, e.g. not only the ``dtype`` but also how the observation should happen. These could potentially be hidden from user to make user interface simpler. From cfd095acf8d44bf7d1469ab8cae645607f331e8b Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Fri, 9 Jun 2023 17:04:13 +0800 Subject: [PATCH 42/42] fix typo --- .../quantization_in_pytorch_2_0_export_tutorial.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst index bb1f33fb6ad..c1c22d94e04 100644 --- a/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst +++ b/prototype_source/quantization_in_pytorch_2_0_export_tutorial.rst @@ -149,7 +149,7 @@ tensor value should be observed, e.g. ``MinMaxObserver``, or ``HistogramObserver , or some customized observer. ``QuantizationAnnotation`` composed of ``QuantizationSpec`` objects is used to annotate input tensors -and output tensor of a ``FX Node``. Annotating input tensors is equivalent of annotating input edges, +and output tensor of a pattern. Annotating input tensors is equivalent of annotating input edges, while annotating output tensor is equivalent of annotating node. ``QuantizationAnnotation`` is a ``dataclass`` with several fields: @@ -157,7 +157,7 @@ with several fields: - ``output_qspec`` field expresses the ``QuantizationSpec`` used to annotate the output tensor; - ``_annotated`` field indicates if this node has already been annotated by quantizer. -Thus annotation API requires quantizer to annotate edges (input tensors) or +To conclude, annotation API requires quantizer to annotate edges (input tensors) or nodes (output tensor) of the graph. Now, we will have a step-by-step tutorial for how to use the annotation API with different types of ``QuantizationSpec``.