14
14
* Copyright (c) 2016 Research Organization for Information Science
15
15
* and Technology (RIST). All rights reserved.
16
16
* Copyright (c) 2017 IBM Corporation. All rights reserved.
17
+ * Copyright (c) 2025 Triad National Security, LLC. All rights reserved.
17
18
* $COPYRIGHT$
18
19
*
19
20
* Additional copyrights may follow
34
35
#include "coll_base_topo.h"
35
36
#include "coll_base_util.h"
36
37
38
+ /*
39
+ * if a > b return a- b otherwise 0
40
+ */
41
+ static inline size_t
42
+ rectify_diff (size_t a , size_t b )
43
+ {
44
+ return a > b ? a - b : 0 ;
45
+ }
46
+
37
47
int
38
48
ompi_coll_base_bcast_intra_generic ( void * buffer ,
39
49
size_t original_count ,
@@ -811,8 +821,8 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
811
821
if (vrank & mask ) {
812
822
int parent = (rank - mask + comm_size ) % comm_size ;
813
823
/* Compute an upper bound on recv block size */
814
- recv_count = count - vrank * scatter_count ;
815
- if (recv_count < = 0 ) {
824
+ recv_count = rectify_diff ( count , ( size_t )( vrank * scatter_count )) ;
825
+ if (recv_count = = 0 ) {
816
826
curr_count = 0 ;
817
827
} else {
818
828
/* Recv data from parent */
@@ -832,7 +842,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
832
842
mask >>= 1 ;
833
843
while (mask > 0 ) {
834
844
if (vrank + mask < comm_size ) {
835
- send_count = curr_count - scatter_count * mask ;
845
+ send_count = rectify_diff ( curr_count , ( size_t )( scatter_count * mask )) ;
836
846
if (send_count > 0 ) {
837
847
int child = (rank + mask ) % comm_size ;
838
848
err = MCA_PML_CALL (send ((char * )buf + (ptrdiff_t )scatter_count * (vrank + mask ) * extent ,
@@ -850,10 +860,8 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
850
860
* Allgather by recursive doubling
851
861
* Each process has the curr_count elems in the buf[vrank * scatter_count, ...]
852
862
*/
853
- size_t rem_count = count - vrank * scatter_count ;
863
+ size_t rem_count = rectify_diff ( count , ( size_t )( vrank * scatter_count )) ;
854
864
curr_count = (scatter_count < rem_count ) ? scatter_count : rem_count ;
855
- if (curr_count < 0 )
856
- curr_count = 0 ;
857
865
858
866
mask = 0x1 ;
859
867
while (mask < comm_size ) {
@@ -866,9 +874,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
866
874
if (vremote < comm_size ) {
867
875
ptrdiff_t send_offset = vrank_tree_root * scatter_count * extent ;
868
876
ptrdiff_t recv_offset = vremote_tree_root * scatter_count * extent ;
869
- recv_count = count - vremote_tree_root * scatter_count ;
870
- if (recv_count < 0 )
871
- recv_count = 0 ;
877
+ recv_count = rectify_diff (count , (size_t )(vremote_tree_root * scatter_count ));
872
878
err = ompi_coll_base_sendrecv ((char * )buf + send_offset ,
873
879
curr_count , datatype , remote ,
874
880
MCA_COLL_BASE_TAG_BCAST ,
@@ -877,7 +883,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
877
883
MCA_COLL_BASE_TAG_BCAST ,
878
884
comm , & status , rank );
879
885
if (MPI_SUCCESS != err ) { goto cleanup_and_return ; }
880
- recv_count = (int )(status ._ucount / datatype_size );
886
+ recv_count = (size_t )(status ._ucount / datatype_size );
881
887
curr_count += recv_count ;
882
888
}
883
889
@@ -913,7 +919,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
913
919
MCA_COLL_BASE_TAG_BCAST ,
914
920
comm , & status ));
915
921
if (MPI_SUCCESS != err ) { goto cleanup_and_return ; }
916
- recv_count = (int )(status ._ucount / datatype_size );
922
+ recv_count = (size_t )(status ._ucount / datatype_size );
917
923
curr_count += recv_count ;
918
924
}
919
925
}
@@ -988,8 +994,8 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
988
994
if (vrank & mask ) {
989
995
int parent = (rank - mask + comm_size ) % comm_size ;
990
996
/* Compute an upper bound on recv block size */
991
- recv_count = count - vrank * scatter_count ;
992
- if (recv_count <= 0 ) {
997
+ recv_count = rectify_diff ( count , ( size_t )( vrank * scatter_count )) ;
998
+ if (0 == recv_count ) {
993
999
curr_count = 0 ;
994
1000
} else {
995
1001
/* Recv data from parent */
@@ -1009,7 +1015,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
1009
1015
mask >>= 1 ;
1010
1016
while (mask > 0 ) {
1011
1017
if (vrank + mask < comm_size ) {
1012
- send_count = curr_count - scatter_count * mask ;
1018
+ send_count = rectify_diff ( curr_count , ( size_t )( scatter_count * mask )) ;
1013
1019
if (send_count > 0 ) {
1014
1020
int child = (rank + mask ) % comm_size ;
1015
1021
err = MCA_PML_CALL (send ((char * )buf + (ptrdiff_t )scatter_count * (vrank + mask ) * extent ,
@@ -1023,33 +1029,41 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
1023
1029
mask >>= 1 ;
1024
1030
}
1025
1031
1026
- /* Allgather by a ring algorithm */
1032
+ /* Allgather by a ring algorithm, using only unsigned types */
1027
1033
int left = (rank - 1 + comm_size ) % comm_size ;
1028
1034
int right = (rank + 1 ) % comm_size ;
1035
+
1036
+ /* The block we will send/recv in each step */
1029
1037
int send_block = vrank ;
1030
1038
int recv_block = (vrank - 1 + comm_size ) % comm_size ;
1031
1039
1032
- for (int i = 1 ; i < comm_size ; i ++ ) {
1033
- recv_count = (scatter_count < count - recv_block * scatter_count ) ?
1034
- scatter_count : count - recv_block * scatter_count ;
1035
- if (recv_count < 0 )
1036
- recv_count = 0 ;
1037
- ptrdiff_t recv_offset = recv_block * scatter_count * extent ;
1038
-
1039
- send_count = (scatter_count < count - send_block * scatter_count ) ?
1040
- scatter_count : count - send_block * scatter_count ;
1041
- if (send_count < 0 )
1042
- send_count = 0 ;
1043
- ptrdiff_t send_offset = send_block * scatter_count * extent ;
1044
-
1045
- err = ompi_coll_base_sendrecv ((char * )buf + send_offset , send_count ,
1040
+ for (int i = 1 ; i < comm_size ; ++ i ) {
1041
+ /* how many elements remain in recv_block? */
1042
+ size_t recv_offset_elems = recv_block * scatter_count ;
1043
+ size_t recv_remaining = rectify_diff (count , recv_offset_elems );
1044
+ recv_count = (recv_remaining < scatter_count ) ?
1045
+ recv_remaining : scatter_count ;
1046
+ size_t recv_offset = recv_offset_elems * extent ;
1047
+
1048
+ /* same logic for send */
1049
+ size_t send_offset_elems = send_block * scatter_count ;
1050
+ size_t send_remaining = rectify_diff (count , send_offset_elems );
1051
+ send_count = (send_remaining < scatter_count ) ?
1052
+ send_remaining : scatter_count ;
1053
+ size_t send_offset = send_offset_elems * extent ;
1054
+
1055
+ err = ompi_coll_base_sendrecv ((char * )buf + send_offset , send_count ,
1046
1056
datatype , right , MCA_COLL_BASE_TAG_BCAST ,
1047
- (char * )buf + recv_offset , recv_count ,
1057
+ (char * )buf + recv_offset , recv_count ,
1048
1058
datatype , left , MCA_COLL_BASE_TAG_BCAST ,
1049
1059
comm , MPI_STATUS_IGNORE , rank );
1050
- if (MPI_SUCCESS != err ) { goto cleanup_and_return ; }
1060
+ if (MPI_SUCCESS != err ) {
1061
+ goto cleanup_and_return ;
1062
+ }
1063
+
1064
+ /* rotate blocks */
1051
1065
send_block = recv_block ;
1052
- recv_block = (recv_block - 1 + comm_size ) % comm_size ;
1066
+ recv_block = (recv_block + comm_size - 1 ) % comm_size ;
1053
1067
}
1054
1068
1055
1069
cleanup_and_return :
0 commit comments