14
14
* Copyright (c) 2016 Research Organization for Information Science
15
15
* and Technology (RIST). All rights reserved.
16
16
* Copyright (c) 2017 IBM Corporation. All rights reserved.
17
+ * Copyright (c) 2025 Triad National Security, LLC. All rights reserved.
17
18
* $COPYRIGHT$
18
19
*
19
20
* Additional copyrights may follow
@@ -811,7 +812,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
811
812
if (vrank & mask ) {
812
813
int parent = (rank - mask + comm_size ) % comm_size ;
813
814
/* Compute an upper bound on recv block size */
814
- recv_count = count - vrank * scatter_count ;
815
+ recv_count = ( count > vrank * scatter_count ) ? ( count - vrank * scatter_count ) : 0 ;
815
816
if (recv_count <= 0 ) {
816
817
curr_count = 0 ;
817
818
} else {
@@ -832,7 +833,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
832
833
mask >>= 1 ;
833
834
while (mask > 0 ) {
834
835
if (vrank + mask < comm_size ) {
835
- send_count = curr_count - scatter_count * mask ;
836
+ send_count = ( curr_count > scatter_count * mask ) ? curr_count - scatter_count * mask : 0 ;
836
837
if (send_count > 0 ) {
837
838
int child = (rank + mask ) % comm_size ;
838
839
err = MCA_PML_CALL (send ((char * )buf + (ptrdiff_t )scatter_count * (vrank + mask ) * extent ,
@@ -850,10 +851,8 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
850
851
* Allgather by recursive doubling
851
852
* Each process has the curr_count elems in the buf[vrank * scatter_count, ...]
852
853
*/
853
- size_t rem_count = count - vrank * scatter_count ;
854
+ size_t rem_count = ( count > vrank * scatter_count ) ? count - vrank * scatter_count : 0 ;
854
855
curr_count = (scatter_count < rem_count ) ? scatter_count : rem_count ;
855
- if (curr_count < 0 )
856
- curr_count = 0 ;
857
856
858
857
mask = 0x1 ;
859
858
while (mask < comm_size ) {
@@ -866,9 +865,8 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
866
865
if (vremote < comm_size ) {
867
866
ptrdiff_t send_offset = vrank_tree_root * scatter_count * extent ;
868
867
ptrdiff_t recv_offset = vremote_tree_root * scatter_count * extent ;
869
- recv_count = count - vremote_tree_root * scatter_count ;
870
- if (recv_count < 0 )
871
- recv_count = 0 ;
868
+ recv_count = (count > vremote_tree_root * scatter_count ) ?
869
+ (count - vremote_tree_root * scatter_count ) : 0 ;
872
870
err = ompi_coll_base_sendrecv ((char * )buf + send_offset ,
873
871
curr_count , datatype , remote ,
874
872
MCA_COLL_BASE_TAG_BCAST ,
@@ -877,7 +875,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
877
875
MCA_COLL_BASE_TAG_BCAST ,
878
876
comm , & status , rank );
879
877
if (MPI_SUCCESS != err ) { goto cleanup_and_return ; }
880
- recv_count = (int )(status ._ucount / datatype_size );
878
+ recv_count = (size_t )(status ._ucount / datatype_size );
881
879
curr_count += recv_count ;
882
880
}
883
881
@@ -913,7 +911,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
913
911
MCA_COLL_BASE_TAG_BCAST ,
914
912
comm , & status ));
915
913
if (MPI_SUCCESS != err ) { goto cleanup_and_return ; }
916
- recv_count = (int )(status ._ucount / datatype_size );
914
+ recv_count = (size_t )(status ._ucount / datatype_size );
917
915
curr_count += recv_count ;
918
916
}
919
917
}
@@ -988,8 +986,8 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
988
986
if (vrank & mask ) {
989
987
int parent = (rank - mask + comm_size ) % comm_size ;
990
988
/* Compute an upper bound on recv block size */
991
- recv_count = count - vrank * scatter_count ;
992
- if (recv_count <= 0 ) {
989
+ recv_count = ( count > vrank * scatter_count ) ? ( count - vrank * scatter_count ) : 0 ;
990
+ if (0 == recv_count ) {
993
991
curr_count = 0 ;
994
992
} else {
995
993
/* Recv data from parent */
@@ -1009,7 +1007,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
1009
1007
mask >>= 1 ;
1010
1008
while (mask > 0 ) {
1011
1009
if (vrank + mask < comm_size ) {
1012
- send_count = curr_count - scatter_count * mask ;
1010
+ send_count = ( curr_count > scatter_count * mask ) ? ( curr_count - scatter_count * mask ) : 0 ;
1013
1011
if (send_count > 0 ) {
1014
1012
int child = (rank + mask ) % comm_size ;
1015
1013
err = MCA_PML_CALL (send ((char * )buf + (ptrdiff_t )scatter_count * (vrank + mask ) * extent ,
@@ -1023,33 +1021,43 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
1023
1021
mask >>= 1 ;
1024
1022
}
1025
1023
1026
- /* Allgather by a ring algorithm */
1024
+ /* Allgather by a ring algorithm, using only unsigned types */
1027
1025
int left = (rank - 1 + comm_size ) % comm_size ;
1028
1026
int right = (rank + 1 ) % comm_size ;
1027
+
1028
+ /* The block we will send/recv in each step */
1029
1029
int send_block = vrank ;
1030
1030
int recv_block = (vrank - 1 + comm_size ) % comm_size ;
1031
1031
1032
- for (int i = 1 ; i < comm_size ; i ++ ) {
1033
- recv_count = (scatter_count < count - recv_block * scatter_count ) ?
1034
- scatter_count : count - recv_block * scatter_count ;
1035
- if (recv_count < 0 )
1036
- recv_count = 0 ;
1037
- ptrdiff_t recv_offset = recv_block * scatter_count * extent ;
1038
-
1039
- send_count = (scatter_count < count - send_block * scatter_count ) ?
1040
- scatter_count : count - send_block * scatter_count ;
1041
- if (send_count < 0 )
1042
- send_count = 0 ;
1043
- ptrdiff_t send_offset = send_block * scatter_count * extent ;
1044
-
1045
- err = ompi_coll_base_sendrecv ((char * )buf + send_offset , send_count ,
1032
+ for (int i = 1 ; i < comm_size ; ++ i ) {
1033
+ /* how many elements remain in recv_block? */
1034
+ size_t recv_offset_elems = recv_block * scatter_count ;
1035
+ size_t recv_remaining = (recv_offset_elems < count ) ?
1036
+ (count - recv_offset_elems ) : 0 ;
1037
+ recv_count = (recv_remaining < scatter_count ) ?
1038
+ recv_remaining : scatter_count ;
1039
+ size_t recv_offset = recv_offset_elems * extent ;
1040
+
1041
+ /* same logic for send */
1042
+ size_t send_offset_elems = send_block * scatter_count ;
1043
+ size_t send_remaining = (send_offset_elems < count ) ?
1044
+ (count - send_offset_elems ) : 0 ;
1045
+ send_count = (send_remaining < scatter_count ) ?
1046
+ send_remaining : scatter_count ;
1047
+ size_t send_offset = send_offset_elems * extent ;
1048
+
1049
+ err = ompi_coll_base_sendrecv ((char * )buf + send_offset , send_count ,
1046
1050
datatype , right , MCA_COLL_BASE_TAG_BCAST ,
1047
- (char * )buf + recv_offset , recv_count ,
1051
+ (char * )buf + recv_offset , recv_count ,
1048
1052
datatype , left , MCA_COLL_BASE_TAG_BCAST ,
1049
1053
comm , MPI_STATUS_IGNORE , rank );
1050
- if (MPI_SUCCESS != err ) { goto cleanup_and_return ; }
1051
- send_block = recv_block ;
1052
- recv_block = (recv_block - 1 + comm_size ) % comm_size ;
1054
+ if (MPI_SUCCESS != err ) {
1055
+ goto cleanup_and_return ;
1056
+ }
1057
+
1058
+ /* rotate blocks */
1059
+ send_block = recv_block ;
1060
+ recv_block = (recv_block + comm_size - 1 ) % comm_size ;
1053
1061
}
1054
1062
1055
1063
cleanup_and_return :
0 commit comments