Skip to content

Commit 7b0dc38

Browse files
committed
fixed issue for inference sample codes
1 parent 1101ae5 commit 7b0dc38

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

recipes_source/recipes/intel_extension_for_pytorch.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,9 @@ def forward(self, input):
173173

174174
input = torch.randn(2, 4)
175175
model = Model()
176+
model.eval()
176177
# Invoke optimize function against the model object
177-
model = ipex.optimize(model)
178+
model = ipex.optimize(model, dtype=torch.float32)
178179
res = model(input)
179180

180181
###############################################################################
@@ -196,6 +197,7 @@ def forward(self, input):
196197

197198
input = torch.randn(2, 4)
198199
model = Model()
200+
model.eval()
199201
# Invoke optimize function against the model object with data type set to torch.bfloat16
200202
model = ipex.optimize(model, dtype=torch.bfloat16)
201203
with torch.cpu.amp.autocast():
@@ -233,9 +235,10 @@ def forward(self, input):
233235

234236
input = torch.randn(2, 4)
235237
model = Model()
238+
model.eval()
236239
# Invoke optimize function against the model object
237-
model = ipex.optimize(model)
238-
model = torch.jit.trace(model, torch.rand(args.batch_size, 3, 224, 224))
240+
model = ipex.optimize(model, dtype=torch.float32)
241+
model = torch.jit.trace(model, torch.randn(2, 4))
239242
model = torch.jit.freeze(model)
240243
res = model(input)
241244

@@ -261,10 +264,11 @@ def forward(self, input):
261264

262265
input = torch.randn(2, 4)
263266
model = Model()
267+
model.eval()
264268
# Invoke optimize function against the model with data type set to torch.bfloat16
265269
model = ipex.optimize(model, dtype=torch.bfloat16)
266270
with torch.cpu.amp.autocast():
267-
model = torch.jit.trace(model, torch.rand(args.batch_size, 3, 224, 224))
271+
model = torch.jit.trace(model, torch.randn(2, 4))
268272
model = torch.jit.freeze(model)
269273
res = model(input)
270274

@@ -327,6 +331,12 @@ def forward(self, input):
327331
set_property(TARGET example-app PROPERTY CXX_STANDARD 14)
328332
'''
329333

334+
###############################################################################
335+
# **Note:** Since Intel® Extension for PyTorch* is still under development, name of
336+
# the c++ dynamic library in the master branch may defer to
337+
# *libintel-ext-pt-cpu.so* shown above. Please check the name out in the
338+
# installation folder. The so file name starts with *libintel-*.
339+
330340
###############################################################################
331341
# **Command for compilation**
332342

0 commit comments

Comments
 (0)