@@ -622,14 +622,15 @@ sycl::event inclusive_scan_iter(sycl::queue &exec_q,
622
622
size_t src_size = acc_groups - 1 ;
623
623
using LocalScanIndexerT =
624
624
dpctl::tensor::offset_utils::Strided1DIndexer;
625
- LocalScanIndexerT scan_iter_indexer{
625
+ const LocalScanIndexerT scan_iter_indexer{
626
626
0 , static_cast <ssize_t >(iter_nelems),
627
627
static_cast <ssize_t >(src_size)};
628
628
629
629
using IterIndexerT =
630
630
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
631
631
OutIterIndexerT, LocalScanIndexerT>;
632
- IterIndexerT iter_indexer_{out_iter_indexer, scan_iter_indexer};
632
+ const IterIndexerT iter_indexer_{out_iter_indexer,
633
+ scan_iter_indexer};
633
634
634
635
dependent_event =
635
636
inclusive_scan_base_step<outputT, outputT, n_wi, IterIndexerT,
@@ -651,17 +652,18 @@ sycl::event inclusive_scan_iter(sycl::queue &exec_q,
651
652
652
653
using LocalScanIndexerT =
653
654
dpctl::tensor::offset_utils::Strided1DIndexer;
654
- LocalScanIndexerT scan1_iter_indexer{
655
+ const LocalScanIndexerT scan1_iter_indexer{
655
656
0 , static_cast <ssize_t >(iter_nelems),
656
657
static_cast <ssize_t >(size_to_update)};
657
- LocalScanIndexerT scan2_iter_indexer{
658
+ const LocalScanIndexerT scan2_iter_indexer{
658
659
0 , static_cast <ssize_t >(iter_nelems),
659
660
static_cast <ssize_t >(src_size)};
660
661
661
662
using IterIndexerT =
662
663
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
663
664
LocalScanIndexerT, LocalScanIndexerT>;
664
- IterIndexerT iter_indexer_{scan1_iter_indexer, scan2_iter_indexer};
665
+ const IterIndexerT iter_indexer_{scan1_iter_indexer,
666
+ scan2_iter_indexer};
665
667
666
668
dependent_event =
667
669
inclusive_scan_base_step<outputT, outputT, n_wi, IterIndexerT,
@@ -705,21 +707,23 @@ sycl::event inclusive_scan_iter(sycl::queue &exec_q,
705
707
{iter_nelems * update_nelems},
706
708
[chunk_size, update_nelems, src_size, local_stride, src,
707
709
local_scans, scan_op, identity](auto wiid) {
708
- size_t gid = wiid[0 ];
710
+ const size_t gid = wiid[0 ];
709
711
710
- size_t iter_gid = gid / update_nelems;
711
- size_t axis_gid = gid - (iter_gid * update_nelems);
712
+ const size_t iter_gid = gid / update_nelems;
713
+ const size_t axis_gid =
714
+ gid - (iter_gid * update_nelems);
712
715
713
- size_t src_axis_id0 = axis_gid * updates_per_wi;
714
- size_t src_iter_id = iter_gid * src_size;
716
+ const size_t src_axis_id0 = axis_gid * updates_per_wi;
717
+ const size_t src_iter_id = iter_gid * src_size;
715
718
#pragma unroll
716
719
for (nwiT i = 0 ; i < updates_per_wi; ++i) {
717
- size_t src_axis_id = src_axis_id0 + i;
718
- size_t src_id = src_axis_id + src_iter_id;
720
+ const size_t src_axis_id = src_axis_id0 + i;
721
+ const size_t src_id = src_axis_id + src_iter_id;
719
722
720
723
if (src_axis_id < src_size) {
721
- size_t scan_axis_id = src_axis_id / chunk_size;
722
- size_t scan_id =
724
+ const size_t scan_axis_id =
725
+ src_axis_id / chunk_size;
726
+ const size_t scan_id =
723
727
scan_axis_id + iter_gid * local_stride;
724
728
725
729
src[src_id] =
@@ -759,22 +763,24 @@ sycl::event inclusive_scan_iter(sycl::queue &exec_q,
759
763
[chunk_size, update_nelems, src_size, local_stride, src,
760
764
local_scans, scan_op, identity, out_iter_indexer,
761
765
out_indexer](auto wiid) {
762
- size_t gid = wiid[0 ];
766
+ const size_t gid = wiid[0 ];
763
767
764
- size_t iter_gid = gid / update_nelems;
765
- size_t axis_gid = gid - (iter_gid * update_nelems);
768
+ const size_t iter_gid = gid / update_nelems;
769
+ const size_t axis_gid =
770
+ gid - (iter_gid * update_nelems);
766
771
767
- size_t src_axis_id0 = axis_gid * updates_per_wi;
768
- size_t src_iter_id = out_iter_indexer (iter_gid);
772
+ const size_t src_axis_id0 = axis_gid * updates_per_wi;
773
+ const size_t src_iter_id = out_iter_indexer (iter_gid);
769
774
#pragma unroll
770
775
for (nwiT i = 0 ; i < updates_per_wi; ++i) {
771
- size_t src_axis_id = src_axis_id0 + i;
772
- size_t src_id =
776
+ const size_t src_axis_id = src_axis_id0 + i;
777
+ const size_t src_id =
773
778
out_indexer (src_axis_id) + src_iter_id;
774
779
775
780
if (src_axis_id < src_size) {
776
- size_t scan_axis_id = src_axis_id / chunk_size;
777
- size_t scan_id =
781
+ const size_t scan_axis_id =
782
+ src_axis_id / chunk_size;
783
+ const size_t scan_id =
778
784
scan_axis_id + iter_gid * local_stride;
779
785
780
786
src[src_id] =
0 commit comments