Skip to content

Commit 5fd506c

Browse files
Use const qualifiers to make compiler's job easier
Indexers are made const, integral variables in kernels made const too
1 parent da03954 commit 5fd506c

File tree

1 file changed

+29
-23
lines changed

1 file changed

+29
-23
lines changed

dpctl/tensor/libtensor/include/kernels/accumulators.hpp

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -622,14 +622,15 @@ sycl::event inclusive_scan_iter(sycl::queue &exec_q,
622622
size_t src_size = acc_groups - 1;
623623
using LocalScanIndexerT =
624624
dpctl::tensor::offset_utils::Strided1DIndexer;
625-
LocalScanIndexerT scan_iter_indexer{
625+
const LocalScanIndexerT scan_iter_indexer{
626626
0, static_cast<ssize_t>(iter_nelems),
627627
static_cast<ssize_t>(src_size)};
628628

629629
using IterIndexerT =
630630
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
631631
OutIterIndexerT, LocalScanIndexerT>;
632-
IterIndexerT iter_indexer_{out_iter_indexer, scan_iter_indexer};
632+
const IterIndexerT iter_indexer_{out_iter_indexer,
633+
scan_iter_indexer};
633634

634635
dependent_event =
635636
inclusive_scan_base_step<outputT, outputT, n_wi, IterIndexerT,
@@ -651,17 +652,18 @@ sycl::event inclusive_scan_iter(sycl::queue &exec_q,
651652

652653
using LocalScanIndexerT =
653654
dpctl::tensor::offset_utils::Strided1DIndexer;
654-
LocalScanIndexerT scan1_iter_indexer{
655+
const LocalScanIndexerT scan1_iter_indexer{
655656
0, static_cast<ssize_t>(iter_nelems),
656657
static_cast<ssize_t>(size_to_update)};
657-
LocalScanIndexerT scan2_iter_indexer{
658+
const LocalScanIndexerT scan2_iter_indexer{
658659
0, static_cast<ssize_t>(iter_nelems),
659660
static_cast<ssize_t>(src_size)};
660661

661662
using IterIndexerT =
662663
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
663664
LocalScanIndexerT, LocalScanIndexerT>;
664-
IterIndexerT iter_indexer_{scan1_iter_indexer, scan2_iter_indexer};
665+
const IterIndexerT iter_indexer_{scan1_iter_indexer,
666+
scan2_iter_indexer};
665667

666668
dependent_event =
667669
inclusive_scan_base_step<outputT, outputT, n_wi, IterIndexerT,
@@ -705,21 +707,23 @@ sycl::event inclusive_scan_iter(sycl::queue &exec_q,
705707
{iter_nelems * update_nelems},
706708
[chunk_size, update_nelems, src_size, local_stride, src,
707709
local_scans, scan_op, identity](auto wiid) {
708-
size_t gid = wiid[0];
710+
const size_t gid = wiid[0];
709711

710-
size_t iter_gid = gid / update_nelems;
711-
size_t axis_gid = gid - (iter_gid * update_nelems);
712+
const size_t iter_gid = gid / update_nelems;
713+
const size_t axis_gid =
714+
gid - (iter_gid * update_nelems);
712715

713-
size_t src_axis_id0 = axis_gid * updates_per_wi;
714-
size_t src_iter_id = iter_gid * src_size;
716+
const size_t src_axis_id0 = axis_gid * updates_per_wi;
717+
const size_t src_iter_id = iter_gid * src_size;
715718
#pragma unroll
716719
for (nwiT i = 0; i < updates_per_wi; ++i) {
717-
size_t src_axis_id = src_axis_id0 + i;
718-
size_t src_id = src_axis_id + src_iter_id;
720+
const size_t src_axis_id = src_axis_id0 + i;
721+
const size_t src_id = src_axis_id + src_iter_id;
719722

720723
if (src_axis_id < src_size) {
721-
size_t scan_axis_id = src_axis_id / chunk_size;
722-
size_t scan_id =
724+
const size_t scan_axis_id =
725+
src_axis_id / chunk_size;
726+
const size_t scan_id =
723727
scan_axis_id + iter_gid * local_stride;
724728

725729
src[src_id] =
@@ -759,22 +763,24 @@ sycl::event inclusive_scan_iter(sycl::queue &exec_q,
759763
[chunk_size, update_nelems, src_size, local_stride, src,
760764
local_scans, scan_op, identity, out_iter_indexer,
761765
out_indexer](auto wiid) {
762-
size_t gid = wiid[0];
766+
const size_t gid = wiid[0];
763767

764-
size_t iter_gid = gid / update_nelems;
765-
size_t axis_gid = gid - (iter_gid * update_nelems);
768+
const size_t iter_gid = gid / update_nelems;
769+
const size_t axis_gid =
770+
gid - (iter_gid * update_nelems);
766771

767-
size_t src_axis_id0 = axis_gid * updates_per_wi;
768-
size_t src_iter_id = out_iter_indexer(iter_gid);
772+
const size_t src_axis_id0 = axis_gid * updates_per_wi;
773+
const size_t src_iter_id = out_iter_indexer(iter_gid);
769774
#pragma unroll
770775
for (nwiT i = 0; i < updates_per_wi; ++i) {
771-
size_t src_axis_id = src_axis_id0 + i;
772-
size_t src_id =
776+
const size_t src_axis_id = src_axis_id0 + i;
777+
const size_t src_id =
773778
out_indexer(src_axis_id) + src_iter_id;
774779

775780
if (src_axis_id < src_size) {
776-
size_t scan_axis_id = src_axis_id / chunk_size;
777-
size_t scan_id =
781+
const size_t scan_axis_id =
782+
src_axis_id / chunk_size;
783+
const size_t scan_id =
778784
scan_axis_id + iter_gid * local_stride;
779785

780786
src[src_id] =

0 commit comments

Comments
 (0)