Skip to content

Commit fa73b96

Browse files
authored
Merge branch 'master' into LRfuse
2 parents 3e499cb + ce71457 commit fa73b96

30 files changed

+826
-189
lines changed

.gitignore

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
.coverage
1212
.hypothesis
1313
.mypy_cache
14+
*.so*
1415
*/*.pyc
1516
*/*.so*
1617
*/**/__pycache__
@@ -91,9 +92,9 @@ torch/version.py
9192
intel_pytorch_extension_py/version.py
9293
torch_ipex/csrc/version.cpp
9394
torch_ipex/csrc/aten_ipex_sparse_type_default.*
94-
torch_ipex/csrc/cpu/SparseOPs.*
95+
torch_ipex/csrc/cpu/SparseOPs*
9596
torch_ipex/csrc/cpu/OPs.*
96-
torch_ipex/csrc/cpu/DenseOPs.*
97+
torch_ipex/csrc/cpu/DenseOPs*
9798

9899
cscope.*
99100

cmake/CPU.cmake

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules)
1717

1818
FIND_PACKAGE(AVX)
1919

20-
IF (NOT C_AVX512_FOUND)
20+
IF (NOT C_AVX512_FOUND AND NOT CXX_AVX512_FOUND)
2121
message(FATAL_ERROR "Please build IPEX on Machines that support AVX512.")
2222
ENDIF()
2323

@@ -58,13 +58,14 @@ endif()
5858
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pedantic")
5959
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=redundant-decls")
6060
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=old-style-cast")
61-
IF (C_AVX512_FOUND)
61+
IF (C_AVX512_FOUND OR CXX_AVX512_FOUND)
62+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DAVX512")
6263
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f")
6364
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512bw")
6465
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512vl")
6566
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mf16c")
6667
ENDIF()
67-
IF (C_AVX512_BF16_FOUND)
68+
IF (C_AVX512_BF16_FOUND OR CXX_AVX512_BF16_FOUND)
6869
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512bf16 -DAVX512_BF16")
6970
ENDIF()
7071
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")

intel_pytorch_extension_py/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,4 @@
33
from .version import __version__
44
from .optim import *
55
from .ops import *
6-
import _torch_ipex as core
7-
8-
core._initialize_aten_bindings()
6+
import _torch_ipex as core

intel_pytorch_extension_py/ops/embeddingbag.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from torch.autograd import Function
44
import _torch_ipex as core
55

6+
'''
7+
# extension for BF16 fast path only
68
torch_embedding_bag = torch.embedding_bag
79
def embeddingbag(weights, inputs, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset):
810
if weights.dtype == torch.float:
@@ -12,21 +14,41 @@ def embeddingbag(weights, inputs, offsets, scale_grad_by_freq, mode, sparse, per
1214
ret = (ret, None, None, None)
1315
else:
1416
assert(0, "unimplement embeddingbag path in extension")
15-
17+
'''
18+
def embeddingbag(weights, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset):
19+
ret = EmbeddingBagFunction.apply(weights, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset)
1620
return ret
1721

1822

1923
class EmbeddingBagFunction(Function):
24+
'''
2025
@staticmethod
2126
def forward(ctx, weights, inputs, offsets):
2227
ctx.save_for_backward(weights, inputs, offsets)
2328
output = core.embedding_bag_forward(weights, inputs, offsets)
2429
return output
30+
'''
31+
@staticmethod
32+
def forward(ctx, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset):
33+
ctx.scale_grad_by_freq = scale_grad_by_freq
34+
ctx.mode = mode
35+
ctx.sparse = sparse
36+
ctx.num_weight = weight.size(0)
37+
ctx.save_for_backward(indices, offsets, per_sample_weights)
38+
ret = core.embedding_bag_forward(weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset)
39+
return ret
2540

41+
'''
2642
@staticmethod
2743
def backward(ctx, grad_out):
2844
weights, inputs, offsets = ctx.saved_tensors
2945
grad_weight = core.embedding_bag_backward(grad_out, weights, inputs, offsets)
3046
return (grad_weight, None, None)
47+
'''
48+
@staticmethod
49+
def backward(ctx, grad, offset2bag, bag_size, maximum_indices):
50+
indices, offsets, per_sample_weights = ctx.saved_tensors
51+
grad_weight = core.embedding_bag_backward(grad, indices, offsets, offset2bag, bag_size, maximum_indices, ctx.num_weight, ctx.scale_grad_by_freq, ctx.mode, ctx.sparse, per_sample_weights)
52+
return grad_weight, None, None, None, None, None, None, None
3153

3254
torch.embedding_bag = embeddingbag

scripts/cpu/common/codegen.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import os
2+
3+
def write_or_skip(filepath, content):
4+
try:
5+
with open(filepath, 'r') as f:
6+
old_content = f.read()
7+
except IOError:
8+
old_content = None
9+
10+
if old_content != content:
11+
with open(filepath, 'w') as f:
12+
print('writing', filepath)
13+
f.write(content)
14+
else:
15+
print('skipped writing', filepath)

scripts/cpu/gen-dense-cpu-ops.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import sys
1010
import json
1111

12+
from common.codegen import write_or_skip
1213
from common.cpp_sig_parser import CPPSig
1314
from common.aten_sig_parser import AtenSig
1415

@@ -92,6 +93,12 @@
9293
.op(torch::RegisterOperators::options().schema("{}")
9394
.impl_unboxedOnlyKernel<{}, &{}>(at::DispatchKey::DPCPPTensorId)
9495
.aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA))"""
96+
97+
_REG_BLOCK = """
98+
namespace {{
99+
static auto dispatch = torch::RegisterOperators(){reg_ops};
100+
}}"""
101+
95102
_H_HEADER = """// Autogenerated file by {gen}. Do not edit directly!
96103
#pragma once
97104
@@ -105,8 +112,6 @@ class AtenIpexCPUDefault {{
105112
{hfuncs}
106113
}};
107114
108-
void RegisterIpexDenseOPs();
109-
110115
}} // namespace cpu
111116
112117
}} // namespace torch_ipex
@@ -145,9 +150,7 @@ def __init__(self, reg_dec_file_path, func_file_path, op_h_file_path, op_cpp_fil
145150
self._reg_dec_file_path = reg_dec_file_path
146151
self._func_file_path = func_file_path
147152
self._op_h_file_path = op_h_file_path
148-
self._op_h_file = None
149153
self._op_cpp_file_path = op_cpp_file_path
150-
self._op_cpp_file = None
151154
self._sigs = []
152155
self._err_info = []
153156
self._func_data = ''
@@ -223,9 +226,6 @@ def prepare_functions(self):
223226
with open(self._func_file_path, 'r') as ff:
224227
self._func_data = ff.read()
225228

226-
self._op_h_file = open(self._op_h_file_path, 'w')
227-
self._op_cpp_file = open(self._op_cpp_file_path, 'w')
228-
229229
print('Extracted {} functions ({} errors) from {}'.format(
230230
len(self._sigs),
231231
len(self._err_info),
@@ -452,22 +452,37 @@ def gen_fallback_post_code(self, cpp_sig):
452452
def gen_head_dec_code(self, cpp_func_str_h):
453453
return ' static {};\n'.format(cpp_func_str_h)
454454

455+
def gen_cpu_ops_shard(self, func_defs, cpp_path, header_path, num_shards=1):
456+
head_file_content = _H_HEADER.format(gen=os.path.basename(sys.argv[0]), hfuncs=''.join([f['dec'] for f in func_defs]))
457+
write_or_skip(header_path, head_file_content)
458+
459+
shards = [[] for _ in range(num_shards)]
460+
for idx, func in enumerate(func_defs):
461+
shards[idx % num_shards].append(func)
462+
463+
for idx, shard in enumerate(shards):
464+
regs_code = _REG_BLOCK.format(reg_ops=''.join([f['reg'] for f in shard]))
465+
defs_code = ''.join([f['def'] for f in shard])
466+
467+
filename, ext = os.path.splitext(cpp_path)
468+
shard_filepath = '%s_%s%s' % (filename, idx, ext)
469+
shard_content = _CPP_HEADER.format(gen=os.path.basename(sys.argv[0]), funcs=defs_code, regs=regs_code)
470+
write_or_skip(shard_filepath, shard_content)
471+
455472
def gen_code(self):
456473
self.prepare_functions()
457474
assert len(self._err_info) == 0
458475

459476
def is_conv_overrideable_func(fname):
460477
return fname in ['convolution_overrideable', 'convolution_backward_overrideable']
461478

462-
func_decs = []
463-
func_regs = []
464479
func_defs = []
465-
for cpp_sig, aten_sig, cpp_func_sig_str, aten_func_sig_str in self._sigs:
480+
for cpp_sig, _, cpp_func_sig_str, aten_func_sig_str in self._sigs:
466481
cpp_func_str_h, cpp_func_str_cpp = self.gen_func_signature(cpp_func_sig_str)
467482
# Gen declaration code for head file
468-
func_decs.append(self.gen_head_dec_code(cpp_func_str_h))
483+
func_dec = self.gen_head_dec_code(cpp_func_str_h)
469484

470-
func_regs.append(_REG_PATTERN.format(aten_func_sig_str, self.get_func_dec(cpp_sig), "AtenIpexCPUDefault::" + cpp_sig.def_name))
485+
func_reg = _REG_PATTERN.format(aten_func_sig_str, self.get_func_dec(cpp_sig), "AtenIpexCPUDefault::" + cpp_sig.def_name)
471486

472487
# Gen definition code for cpp file
473488
code = '{} {{\n'.format(cpp_func_str_cpp)
@@ -480,23 +495,14 @@ def is_conv_overrideable_func(fname):
480495
code += self.gen_fallback_code(cpp_sig)
481496
code += self.gen_fallback_post_code(cpp_sig)
482497

483-
code += '}\n'
484-
485-
code += '\n'
486-
487-
func_defs.append(code)
488-
489-
head_file_content = _H_HEADER.format(gen=os.path.basename(sys.argv[0]), hfuncs=''.join(func_decs))
490-
491-
regs_code = 'void RegisterIpexDenseOPs() {\n'
492-
regs_code += ' static auto dispatch = torch::RegisterOperators()\n'
493-
regs_code += ''.join(func_regs)
494-
regs_code += ';\n}\n'
498+
code += '}\n\n'
495499

496-
source_file_content = _CPP_HEADER.format(gen=os.path.basename(sys.argv[0]), funcs=''.join(func_defs), regs=regs_code)
497-
print(head_file_content, file=self._op_h_file)
498-
print(source_file_content, file=self._op_cpp_file)
500+
func_defs.append({'dec': func_dec, 'reg': func_reg, 'def': code})
499501

502+
self.gen_cpu_ops_shard(func_defs,
503+
cpp_path=self._op_cpp_file_path,
504+
header_path=self._op_h_file_path,
505+
num_shards=8)
500506

501507
if __name__ == '__main__':
502508
arg_parser = argparse.ArgumentParser()

scripts/cpu/gen-sparse-cpu-ops.py

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import sys
1010
import json
1111

12+
from common.codegen import write_or_skip
1213
from common.cpp_sig_parser import CPPSig
1314
from common.aten_sig_parser import AtenSig
1415

@@ -47,6 +48,13 @@
4748
.op(torch::RegisterOperators::options().schema("{}")
4849
.impl_unboxedOnlyKernel<{}, &{}>(at::DispatchKey::SparseDPCPPTensorId)
4950
.aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA))"""
51+
52+
_REG_BLOCK = """
53+
namespace {{
54+
static auto dispatch = torch::RegisterOperators(){reg_ops};
55+
}}"""
56+
57+
5058
_H_HEADER = """// Autogenerated file by {gen}. Do not edit directly!
5159
#pragma once
5260
@@ -60,8 +68,6 @@ class AtenIpexCPUSparse {{
6068
{hfuncs}
6169
}};
6270
63-
void RegisterIpexSparseOPs();
64-
6571
}} // namespace cpu
6672
6773
}} // namespace torch_ipex
@@ -100,9 +106,7 @@ def __init__(self, reg_dec_file_path, func_file_path, sparse_dec_file_path, spar
100106
self._sparse_dec_file_path = sparse_dec_file_path
101107
self._sparse_attr_file_path = sparse_attr_file_path
102108
self._op_h_file_path = op_h_file_path
103-
self._op_h_file = None
104109
self._op_cpp_file_path = op_cpp_file_path
105-
self._op_cpp_file = None
106110
self._sigs = []
107111
self._sparse_attr_data = ''
108112
self._sparse_sigs = []
@@ -155,8 +159,8 @@ def prepare_functions(self):
155159
continue
156160
cpp_func_sig_str = m.group(1)
157161
_sparse_sig_strs.append(cpp_func_sig_str)
158-
print(cpp_func_sig_str)
159-
print("********************")
162+
# print(cpp_func_sig_str)
163+
# print("********************")
160164

161165
# Parse SparseAttrType.h
162166
with open(self._sparse_attr_file_path, 'r') as ff:
@@ -202,9 +206,6 @@ def prepare_functions(self):
202206
self._err_info.append((cpp_func_sig, str(e)))
203207
print('Error parsing "{}": {}'.format(cpp_func_sig, e), file=sys.stderr)
204208

205-
self._op_h_file = open(self._op_h_file_path, 'w')
206-
self._op_cpp_file = open(self._op_cpp_file_path, 'w')
207-
208209
print('Extracted {} functions ({} errors) from {}'.format(
209210
len(self._sigs),
210211
len(self._err_info),
@@ -369,44 +370,53 @@ def gen_fallback_post_code(self, cpp_sig):
369370
def gen_head_dec_code(self, cpp_func_str_h):
370371
return ' static {};\n'.format(cpp_func_str_h)
371372

373+
def gen_cpu_ops_shard(self, func_defs, cpp_path, header_path, num_shards=1):
374+
head_file_content = _H_HEADER.format(gen=os.path.basename(sys.argv[0]), hfuncs=''.join([f['dec'] for f in func_defs]))
375+
write_or_skip(header_path, head_file_content)
376+
377+
shards = [[] for _ in range(num_shards)]
378+
for idx, func in enumerate(func_defs):
379+
shards[idx % num_shards].append(func)
380+
381+
for idx, shard in enumerate(shards):
382+
regs_code = _REG_BLOCK.format(reg_ops=''.join([f['reg'] for f in shard]))
383+
defs_code = ''.join([f['def'] for f in shard])
384+
385+
filename, ext = os.path.splitext(cpp_path)
386+
shard_filepath = '%s_%s%s' % (filename, idx, ext)
387+
shard_content = _CPP_HEADER.format(gen=os.path.basename(sys.argv[0]), funcs=defs_code, regs=regs_code)
388+
write_or_skip(shard_filepath, shard_content)
389+
372390
def gen_code(self):
373391
self.prepare_functions()
374392
assert len(self._err_info) == 0
375393

376-
func_decs = []
377-
func_regs = []
378394
func_defs = []
379395
for cpp_sparse_sig, _, cpp_sparse_func_sig_str, aten_func_sig_str in self._sigs:
380-
func_regs.append(_REG_PATTERN.format(aten_func_sig_str, self.get_func_dec(cpp_sparse_sig), "AtenIpexCPUSparse::" + cpp_sparse_sig.def_name))
381396
# Gen declaration code for head file
382397
cpp_func_str_h, cpp_func_str_cpp = self.gen_func_signature(cpp_sparse_func_sig_str)
383-
func_decs.append(self.gen_head_dec_code(cpp_func_str_h))
398+
func_dec = self.gen_head_dec_code(cpp_func_str_h)
384399

385-
# Since we have pre-defined attr OPs, we don't need to regenerate it
386-
if self.is_sparse_attr_function(cpp_sparse_sig.def_name):
387-
continue
400+
func_reg = _REG_PATTERN.format(aten_func_sig_str, self.get_func_dec(cpp_sparse_sig), "AtenIpexCPUSparse::" + cpp_sparse_sig.def_name)
388401

389-
# Gen definition code for cpp file
390-
code = '{} {{\n'.format(cpp_func_str_cpp)
391-
code += self.gen_fallback_prepare_code(cpp_sparse_sig)
392-
code += self.gen_fallback_code(cpp_sparse_sig)
393-
code += self.gen_fallback_post_code(cpp_sparse_sig)
394-
395-
code += '}\n\n'
396-
397-
func_defs.append(code)
402+
code = ''
403+
# Since we have pre-defined attr OPs, we don't need to regenerate it
404+
if not self.is_sparse_attr_function(cpp_sparse_sig.def_name):
398405

399-
head_file_content = _H_HEADER.format(gen=os.path.basename(sys.argv[0]), hfuncs=''.join(func_decs))
406+
# Gen definition code for cpp file
407+
code += '{} {{\n'.format(cpp_func_str_cpp)
408+
code += self.gen_fallback_prepare_code(cpp_sparse_sig)
409+
code += self.gen_fallback_code(cpp_sparse_sig)
410+
code += self.gen_fallback_post_code(cpp_sparse_sig)
400411

401-
regs_code = 'void RegisterIpexSparseOPs() {\n'
402-
regs_code += ' static auto dispatch = torch::RegisterOperators()\n'
403-
regs_code += ''.join(func_regs)
404-
regs_code += ';\n}\n'
412+
code += '}\n\n'
405413

406-
source_file_content = _CPP_HEADER.format(gen=os.path.basename(sys.argv[0]), funcs=''.join(func_defs), regs=regs_code)
407-
print(head_file_content, file=self._op_h_file)
408-
print(source_file_content, file=self._op_cpp_file)
414+
func_defs.append({'dec': func_dec, 'reg': func_reg, 'def': code})
409415

416+
self.gen_cpu_ops_shard(func_defs,
417+
cpp_path=self._op_cpp_file_path,
418+
header_path=self._op_h_file_path,
419+
num_shards=1)
410420

411421
if __name__ == '__main__':
412422
arg_parser = argparse.ArgumentParser()

0 commit comments

Comments
 (0)