|
31 | 31 | #
|
32 | 32 | # In this tutorial, we will cover three scenarios that require extending the ONNX registry with custom operators:
|
33 | 33 | #
|
34 |
| -# * Unsupported ATen operators |
35 | 34 | # * Custom operators with existing ONNX Runtime support
|
36 | 35 | # * Custom operators without ONNX Runtime support
|
37 | 36 | #
|
38 |
| -# Unsupported ATen operators |
39 |
| -# -------------------------- |
40 |
| -# |
41 |
| -# Although the ONNX exporter team does their best efforts to support all ATen operators, some of them |
42 |
| -# might not be supported yet. In this section, we will demonstrate how you can add |
43 |
| -# unsupported ATen operators to the ONNX Registry. |
44 |
| -# |
45 |
| -# .. note:: |
46 |
| -# The steps to implement unsupported ATen operators are the same to replace the implementation of an existing |
47 |
| -# ATen operator with a custom implementation. |
48 |
| -# Because we don't actually have an unsupported ATen operator to use in this tutorial, we are going to leverage |
49 |
| -# this and replace the implementation of ``aten::add.Tensor`` with a custom implementation the same way we would |
50 |
| -# if the operator was not present in the ONNX Registry. |
51 |
| -# |
52 |
| -# When a model cannot be exported to ONNX due to an unsupported operator, the ONNX exporter will show an error message |
53 |
| -# similar to: |
54 |
| -# |
55 |
| -# .. code-block:: python |
56 |
| -# |
57 |
| -# RuntimeErrorWithDiagnostic: Unsupported FX nodes: {'call_function': ['aten.add.Tensor']}. |
58 |
| -# |
59 |
| -# The error message indicates that the fully qualified name of unsupported ATen operator is ``aten::add.Tensor``. |
60 |
| -# The fully qualified name of an operator is composed of the namespace, operator name, and overload following |
61 |
| -# the format ``namespace::operator_name.overload``. |
62 |
| -# |
63 |
| -# To add support for an unsupported ATen operator or to replace the implementation for an existing one, we need: |
64 |
| -# |
65 |
| -# * The fully qualified name of the ATen operator (e.g. ``aten::add.Tensor``). |
66 |
| -# This information is always present in the error message as show above. |
67 |
| -# * The implementation of the operator using `ONNX Script <https://github.com/microsoft/onnxscript>`__. |
68 |
| -# ONNX Script is a prerequisite for this tutorial. Please make sure you have read the |
69 |
| -# `ONNX Script tutorial <https://github.com/microsoft/onnxscript/blob/main/docs/tutorial/index.md>`_ |
70 |
| -# before proceeding. |
71 |
| -# |
72 |
| -# Because ``aten::add.Tensor`` is already supported by the ONNX Registry, we will demonstrate how to replace it with a |
73 |
| -# custom implementation, but keep in mind that the same steps apply to support new unsupported ATen operators. |
74 |
| -# |
75 |
| -# This is possible because the :class:`OnnxRegistry` allows users to override an operator registration. |
76 |
| -# We will override the registration of ``aten::add.Tensor`` with our custom implementation and verify it exists. |
77 |
| -# |
78 | 37 |
|
79 | 38 | import torch
|
80 | 39 | import onnxruntime
|
81 | 40 | import onnxscript
|
82 | 41 | from onnxscript import opset18 # opset 18 is the latest (and only) supported version for now
|
83 | 42 |
|
84 |
| -class Model(torch.nn.Module): |
85 |
| - def forward(self, input_x, input_y): |
86 |
| - return torch.ops.aten.add(input_x, input_y) # generates a aten::add.Tensor node |
87 |
| - |
88 |
| -input_add_x = torch.randn(3, 4) |
89 |
| -input_add_y = torch.randn(3, 4) |
90 |
| -aten_add_model = Model() |
91 |
| - |
92 |
| - |
93 |
| -# Now we create a ONNX Script function that implements ``aten::add.Tensor``. |
94 |
| -# The function name (e.g. ``custom_aten_add``) is displayed in the ONNX graph, so we recommend to use intuitive names. |
95 |
| -custom_aten = onnxscript.values.Opset(domain="custom.aten", version=1) |
96 |
| - |
97 |
| -# NOTE: The function signature must match the signature of the unsupported ATen operator. |
98 |
| -# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml |
99 |
| -# NOTE: All attributes must be annotated with type hints. |
100 |
| -@onnxscript.script(custom_aten) |
101 |
| -def custom_aten_add(input_x, input_y, alpha: float = 1.0): |
102 |
| - input_y = opset18.Mul(input_y, alpha) |
103 |
| - return opset18.Add(input_x, input_y) |
104 |
| - |
105 |
| - |
106 |
| -# Now we have everything we need to support unsupported ATen operators. |
107 |
| -# Let's register the ``custom_aten_add`` function to ONNX registry, and export the model to ONNX again. |
108 |
| -onnx_registry = torch.onnx.OnnxRegistry() |
109 |
| -onnx_registry.register_op( |
110 |
| - namespace="aten", op_name="add", overload="Tensor", function=custom_aten_add |
111 |
| - ) |
112 |
| -print(f"aten::add.Tensor is supported by ONNX registry: \ |
113 |
| - {onnx_registry.is_registered_op(namespace='aten', op_name='add', overload='Tensor')}" |
114 |
| - ) |
115 |
| -export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry) |
116 |
| -onnx_program = torch.onnx.dynamo_export( |
117 |
| - aten_add_model, input_add_x, input_add_y, export_options=export_options |
118 |
| - ) |
119 |
| - |
120 |
| -###################################################################### |
121 |
| -# Now let's inspect the model and verify the model has a ``custom_aten_add`` instead of ``aten::add.Tensor``. |
122 |
| -# The graph has one graph node for ``custom_aten_add``, and inside of it there are four function nodes, one for each |
123 |
| -# operator, and one for constant attribute. |
124 |
| -# |
125 |
| - |
126 |
| -# graph node domain is the custom domain we registered |
127 |
| -assert onnx_program.model_proto.graph.node[0].domain == "custom.aten" |
128 |
| -assert len(onnx_program.model_proto.graph.node) == 1 |
129 |
| -# graph node name is the function name |
130 |
| -assert onnx_program.model_proto.graph.node[0].op_type == "custom_aten_add" |
131 |
| -# function node domain is empty because we use standard ONNX operators |
132 |
| -assert {node.domain for node in onnx_program.model_proto.functions[0].node} == {""} |
133 |
| -# function node name is the standard ONNX operator name |
134 |
| -assert {node.op_type for node in onnx_program.model_proto.functions[0].node} == {"Add", "Mul", "Constant"} |
135 |
| - |
136 |
| - |
137 |
| -###################################################################### |
138 |
| -# This is how ``custom_aten_add_model`` looks in the ONNX graph using Netron: |
139 |
| -# |
140 |
| -# .. image:: /_static/img/onnx/custom_aten_add_model.png |
141 |
| -# :width: 70% |
142 |
| -# :align: center |
143 |
| -# |
144 |
| -# Inside the ``custom_aten_add`` function, we can see the three ONNX nodes we |
145 |
| -# used in the function (``CastLike``, ``Add``, and ``Mul``), and one ``Constant`` attribute: |
146 |
| -# |
147 |
| -# .. image:: /_static/img/onnx/custom_aten_add_function.png |
148 |
| -# :width: 70% |
149 |
| -# :align: center |
150 |
| -# |
151 |
| -# This was all that we needed to register the new ATen operator into the ONNX Registry. |
152 |
| -# As an additional step, we can use ONNX Runtime to run the model, and compare the results with PyTorch. |
153 |
| -# |
154 |
| - |
155 |
| - |
156 |
| -# Use ONNX Runtime to run the model, and compare the results with PyTorch |
157 |
| -onnx_program.save("./custom_add_model.onnx") |
158 |
| -ort_session = onnxruntime.InferenceSession( |
159 |
| - "./custom_add_model.onnx", providers=['CPUExecutionProvider'] |
160 |
| - ) |
161 |
| - |
162 |
| -def to_numpy(tensor): |
163 |
| - return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() |
164 |
| - |
165 |
| -onnx_input = onnx_program.adapt_torch_inputs_to_onnx(input_add_x, input_add_y) |
166 |
| -onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)} |
167 |
| -onnxruntime_outputs = ort_session.run(None, onnxruntime_input) |
168 |
| - |
169 |
| -torch_outputs = aten_add_model(input_add_x, input_add_y) |
170 |
| -torch_outputs = onnx_program.adapt_torch_outputs_to_onnx(torch_outputs) |
171 |
| - |
172 |
| -assert len(torch_outputs) == len(onnxruntime_outputs) |
173 |
| -for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs): |
174 |
| - torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output)) |
175 |
| - |
176 | 43 |
|
177 | 44 | ######################################################################
|
178 | 45 | # Custom operators with existing ONNX Runtime support
|
@@ -262,12 +129,11 @@ def custom_aten_gelu(input_x, approximate: str = "none"):
|
262 | 129 | def to_numpy(tensor):
|
263 | 130 | return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
|
264 | 131 |
|
265 |
| -onnx_input = onnx_program.adapt_torch_inputs_to_onnx(input_gelu_x) |
| 132 | +onnx_input = [input_gelu_x] |
266 | 133 | onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}
|
267 |
| -onnxruntime_outputs = ort_session.run(None, onnxruntime_input) |
| 134 | +onnxruntime_outputs = ort_session.run(None, onnxruntime_input)[0] |
268 | 135 |
|
269 | 136 | torch_outputs = aten_gelu_model(input_gelu_x)
|
270 |
| -torch_outputs = onnx_program.adapt_torch_outputs_to_onnx(torch_outputs) |
271 | 137 |
|
272 | 138 | assert len(torch_outputs) == len(onnxruntime_outputs)
|
273 | 139 | for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):
|
@@ -369,25 +235,17 @@ def custom_addandround(input_x):
|
369 | 235 | #
|
370 | 236 |
|
371 | 237 | assert onnx_program.model_proto.graph.node[0].domain == "test.customop"
|
372 |
| -assert onnx_program.model_proto.graph.node[0].op_type == "custom_addandround" |
373 |
| -assert onnx_program.model_proto.functions[0].node[0].domain == "test.customop" |
374 |
| -assert onnx_program.model_proto.functions[0].node[0].op_type == "CustomOpOne" |
375 |
| -assert onnx_program.model_proto.functions[0].node[1].domain == "test.customop" |
376 |
| -assert onnx_program.model_proto.functions[0].node[1].op_type == "CustomOpTwo" |
| 238 | +assert onnx_program.model_proto.graph.node[0].op_type == "CustomOpOne" |
| 239 | +assert onnx_program.model_proto.graph.node[1].domain == "test.customop" |
| 240 | +assert onnx_program.model_proto.graph.node[1].op_type == "CustomOpTwo" |
377 | 241 |
|
378 | 242 |
|
379 | 243 | ######################################################################
|
380 |
| -# This is how ``custom_addandround_model`` ONNX graph looks using Netron: |
381 |
| -# |
382 |
| -# .. image:: /_static/img/onnx/custom_addandround_model.png |
383 |
| -# :width: 70% |
384 |
| -# :align: center |
385 |
| -# |
386 |
| -# Inside the ``custom_addandround`` function, we can see the two custom operators we |
387 |
| -# used in the function (``CustomOpOne``, and ``CustomOpTwo``), and they are from module |
388 |
| -# ``test.customop``: |
| 244 | +# This is how ``custom_addandround_model`` ONNX graph looks using Netron. |
| 245 | +# We can see the two custom operators we used in the function (``CustomOpOne``, and ``CustomOpTwo``), |
| 246 | +# and they are from module ``test.customop``: |
389 | 247 | #
|
390 |
| -# .. image:: /_static/img/onnx/custom_addandround_function.png |
| 248 | +# .. image:: /_static/img/onnx/custom_addandround.png |
391 | 249 | #
|
392 | 250 | # Custom Ops Registration in ONNX Runtime
|
393 | 251 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
0 commit comments