Skip to content

Commit ee7b084

Browse files
committed
coll: make bcast ring unsigned-safe
In the conversion to support big count, there are several places where signed int's were replaced by unsigned types (size_t). Unfortunately there were a few places where signedness was being used and these need to be refactored. To find these places the -Wtype-limit gnu compile option was used. This compile option is added to the --enable-picky compile option list as part of this PR. Signed-off-by: Howard Pritchard <hppritcha@gmail.com>
1 parent e0177a9 commit ee7b084

File tree

2 files changed

+41
-32
lines changed

2 files changed

+41
-32
lines changed

config/opal_setup_cc.m4

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ AC_DEFUN([OPAL_SETUP_CC],[
277277
_OPAL_CHECK_SPECIFIC_CFLAGS(-Wstrict-prototypes, Wstrict_prototypes)
278278
_OPAL_CHECK_SPECIFIC_CFLAGS(-Wcomment, Wcomment)
279279
_OPAL_CHECK_SPECIFIC_CFLAGS(-Wshadow, Wshadow)
280+
_OPAL_CHECK_SPECIFIC_CFLAGS(-Wtype-limits,Wtype_limits)
280281
_OPAL_CHECK_SPECIFIC_CFLAGS(-Werror-implicit-function-declaration, Werror_implicit_function_declaration)
281282
_OPAL_CHECK_SPECIFIC_CFLAGS(-Wno-long-double, Wno_long_double, int main() { long double x; })
282283
_OPAL_CHECK_SPECIFIC_CFLAGS(-fno-strict-aliasing, fno_strict_aliasing, int main() { long double x; })

ompi/mca/coll/base/coll_base_bcast.c

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
* Copyright (c) 2016 Research Organization for Information Science
1515
* and Technology (RIST). All rights reserved.
1616
* Copyright (c) 2017 IBM Corporation. All rights reserved.
17+
* Copyright (c) 2025 Triad National Security, LLC. All rights reserved.
1718
* $COPYRIGHT$
1819
*
1920
* Additional copyrights may follow
@@ -811,7 +812,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
811812
if (vrank & mask) {
812813
int parent = (rank - mask + comm_size) % comm_size;
813814
/* Compute an upper bound on recv block size */
814-
recv_count = count - vrank * scatter_count;
815+
recv_count = (count > vrank * scatter_count) ? (count - vrank * scatter_count) : 0;
815816
if (recv_count <= 0) {
816817
curr_count = 0;
817818
} else {
@@ -832,7 +833,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
832833
mask >>= 1;
833834
while (mask > 0) {
834835
if (vrank + mask < comm_size) {
835-
send_count = curr_count - scatter_count * mask;
836+
send_count = (curr_count > scatter_count * mask) ? curr_count - scatter_count * mask : 0;
836837
if (send_count > 0) {
837838
int child = (rank + mask) % comm_size;
838839
err = MCA_PML_CALL(send((char *)buf + (ptrdiff_t)scatter_count * (vrank + mask) * extent,
@@ -850,10 +851,8 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
850851
* Allgather by recursive doubling
851852
* Each process has the curr_count elems in the buf[vrank * scatter_count, ...]
852853
*/
853-
size_t rem_count = count - vrank * scatter_count;
854+
size_t rem_count = (count > vrank * scatter_count) ? count - vrank * scatter_count : 0;
854855
curr_count = (scatter_count < rem_count) ? scatter_count : rem_count;
855-
if (curr_count < 0)
856-
curr_count = 0;
857856

858857
mask = 0x1;
859858
while (mask < comm_size) {
@@ -866,9 +865,8 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
866865
if (vremote < comm_size) {
867866
ptrdiff_t send_offset = vrank_tree_root * scatter_count * extent;
868867
ptrdiff_t recv_offset = vremote_tree_root * scatter_count * extent;
869-
recv_count = count - vremote_tree_root * scatter_count;
870-
if (recv_count < 0)
871-
recv_count = 0;
868+
recv_count = (count > vremote_tree_root * scatter_count) ?
869+
(count - vremote_tree_root * scatter_count) : 0;
872870
err = ompi_coll_base_sendrecv((char *)buf + send_offset,
873871
curr_count, datatype, remote,
874872
MCA_COLL_BASE_TAG_BCAST,
@@ -877,7 +875,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
877875
MCA_COLL_BASE_TAG_BCAST,
878876
comm, &status, rank);
879877
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
880-
recv_count = (int)(status._ucount / datatype_size);
878+
recv_count = (size_t)(status._ucount / datatype_size);
881879
curr_count += recv_count;
882880
}
883881

@@ -913,7 +911,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
913911
MCA_COLL_BASE_TAG_BCAST,
914912
comm, &status));
915913
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
916-
recv_count = (int)(status._ucount / datatype_size);
914+
recv_count = (size_t)(status._ucount / datatype_size);
917915
curr_count += recv_count;
918916
}
919917
}
@@ -988,8 +986,8 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
988986
if (vrank & mask) {
989987
int parent = (rank - mask + comm_size) % comm_size;
990988
/* Compute an upper bound on recv block size */
991-
recv_count = count - vrank * scatter_count;
992-
if (recv_count <= 0) {
989+
recv_count = (count > vrank * scatter_count) ? (count - vrank * scatter_count) : 0;
990+
if (0 == recv_count) {
993991
curr_count = 0;
994992
} else {
995993
/* Recv data from parent */
@@ -1009,7 +1007,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
10091007
mask >>= 1;
10101008
while (mask > 0) {
10111009
if (vrank + mask < comm_size) {
1012-
send_count = curr_count - scatter_count * mask;
1010+
send_count = (curr_count > scatter_count * mask) ? (curr_count - scatter_count * mask) : 0;
10131011
if (send_count > 0) {
10141012
int child = (rank + mask) % comm_size;
10151013
err = MCA_PML_CALL(send((char *)buf + (ptrdiff_t)scatter_count * (vrank + mask) * extent,
@@ -1023,33 +1021,43 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
10231021
mask >>= 1;
10241022
}
10251023

1026-
/* Allgather by a ring algorithm */
1024+
/* Allgather by a ring algorithm, using only unsigned types */
10271025
int left = (rank - 1 + comm_size) % comm_size;
10281026
int right = (rank + 1) % comm_size;
1027+
1028+
/* The block we will send/recv in each step */
10291029
int send_block = vrank;
10301030
int recv_block = (vrank - 1 + comm_size) % comm_size;
10311031

1032-
for (int i = 1; i < comm_size; i++) {
1033-
recv_count = (scatter_count < count - recv_block * scatter_count) ?
1034-
scatter_count : count - recv_block * scatter_count;
1035-
if (recv_count < 0)
1036-
recv_count = 0;
1037-
ptrdiff_t recv_offset = recv_block * scatter_count * extent;
1038-
1039-
send_count = (scatter_count < count - send_block * scatter_count) ?
1040-
scatter_count : count - send_block * scatter_count;
1041-
if (send_count < 0)
1042-
send_count = 0;
1043-
ptrdiff_t send_offset = send_block * scatter_count * extent;
1044-
1045-
err = ompi_coll_base_sendrecv((char *)buf + send_offset, send_count,
1032+
for (int i = 1; i < comm_size; ++i) {
1033+
/* how many elements remain in recv_block? */
1034+
size_t recv_offset_elems = recv_block * scatter_count;
1035+
size_t recv_remaining = (recv_offset_elems < count) ?
1036+
(count - recv_offset_elems) : 0;
1037+
recv_count = (recv_remaining < scatter_count) ?
1038+
recv_remaining : scatter_count;
1039+
size_t recv_offset = recv_offset_elems * extent;
1040+
1041+
/* same logic for send */
1042+
size_t send_offset_elems = send_block * scatter_count;
1043+
size_t send_remaining = (send_offset_elems < count) ?
1044+
(count - send_offset_elems) : 0;
1045+
send_count = (send_remaining < scatter_count) ?
1046+
send_remaining : scatter_count;
1047+
size_t send_offset = send_offset_elems * extent;
1048+
1049+
err = ompi_coll_base_sendrecv((char*)buf + send_offset, send_count,
10461050
datatype, right, MCA_COLL_BASE_TAG_BCAST,
1047-
(char *)buf + recv_offset, recv_count,
1051+
(char*)buf + recv_offset, recv_count,
10481052
datatype, left, MCA_COLL_BASE_TAG_BCAST,
10491053
comm, MPI_STATUS_IGNORE, rank);
1050-
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
1051-
send_block = recv_block;
1052-
recv_block = (recv_block - 1 + comm_size) % comm_size;
1054+
if (MPI_SUCCESS != err) {
1055+
goto cleanup_and_return;
1056+
}
1057+
1058+
/* rotate blocks */
1059+
send_block = recv_block;
1060+
recv_block = (recv_block + comm_size - 1) % comm_size;
10531061
}
10541062

10551063
cleanup_and_return:

0 commit comments

Comments
 (0)