Skip to content

Commit 5686600

Browse files
committed
coll: make bcast ring unsigned-safe
In the conversion to support big count, there are several places where signed int's were replaced by unsigned types (size_t). Unfortunately there were a few places where signedness was being used and these need to be refactored. To find these places the -Wtype-limit gnu compile option was used. This compile option is added to the --enable-picky compile option list as part of this PR. Signed-off-by: Howard Pritchard <hppritcha@gmail.com>
1 parent e0177a9 commit 5686600

File tree

2 files changed

+47
-32
lines changed

2 files changed

+47
-32
lines changed

config/opal_setup_cc.m4

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ AC_DEFUN([OPAL_SETUP_CC],[
277277
_OPAL_CHECK_SPECIFIC_CFLAGS(-Wstrict-prototypes, Wstrict_prototypes)
278278
_OPAL_CHECK_SPECIFIC_CFLAGS(-Wcomment, Wcomment)
279279
_OPAL_CHECK_SPECIFIC_CFLAGS(-Wshadow, Wshadow)
280+
_OPAL_CHECK_SPECIFIC_CFLAGS(-Wtype-limits,Wtype_limits)
280281
_OPAL_CHECK_SPECIFIC_CFLAGS(-Werror-implicit-function-declaration, Werror_implicit_function_declaration)
281282
_OPAL_CHECK_SPECIFIC_CFLAGS(-Wno-long-double, Wno_long_double, int main() { long double x; })
282283
_OPAL_CHECK_SPECIFIC_CFLAGS(-fno-strict-aliasing, fno_strict_aliasing, int main() { long double x; })

ompi/mca/coll/base/coll_base_bcast.c

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
* Copyright (c) 2016 Research Organization for Information Science
1515
* and Technology (RIST). All rights reserved.
1616
* Copyright (c) 2017 IBM Corporation. All rights reserved.
17+
* Copyright (c) 2025 Triad National Security, LLC. All rights reserved.
1718
* $COPYRIGHT$
1819
*
1920
* Additional copyrights may follow
@@ -34,6 +35,15 @@
3435
#include "coll_base_topo.h"
3536
#include "coll_base_util.h"
3637

38+
/*
39+
* if a > b return a- b otherwise 0
40+
*/
41+
static inline size_t
42+
rectify_diff(size_t a, size_t b)
43+
{
44+
return a > b ? a - b : 0;
45+
}
46+
3747
int
3848
ompi_coll_base_bcast_intra_generic( void* buffer,
3949
size_t original_count,
@@ -811,8 +821,8 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
811821
if (vrank & mask) {
812822
int parent = (rank - mask + comm_size) % comm_size;
813823
/* Compute an upper bound on recv block size */
814-
recv_count = count - vrank * scatter_count;
815-
if (recv_count <= 0) {
824+
recv_count = rectify_diff(count, (size_t)(vrank * scatter_count));
825+
if (recv_count == 0) {
816826
curr_count = 0;
817827
} else {
818828
/* Recv data from parent */
@@ -832,7 +842,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
832842
mask >>= 1;
833843
while (mask > 0) {
834844
if (vrank + mask < comm_size) {
835-
send_count = curr_count - scatter_count * mask;
845+
send_count = rectify_diff(curr_count, (size_t)(scatter_count * mask));
836846
if (send_count > 0) {
837847
int child = (rank + mask) % comm_size;
838848
err = MCA_PML_CALL(send((char *)buf + (ptrdiff_t)scatter_count * (vrank + mask) * extent,
@@ -850,10 +860,8 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
850860
* Allgather by recursive doubling
851861
* Each process has the curr_count elems in the buf[vrank * scatter_count, ...]
852862
*/
853-
size_t rem_count = count - vrank * scatter_count;
863+
size_t rem_count = (count > vrank * scatter_count) ? count - vrank * scatter_count : 0;
854864
curr_count = (scatter_count < rem_count) ? scatter_count : rem_count;
855-
if (curr_count < 0)
856-
curr_count = 0;
857865

858866
mask = 0x1;
859867
while (mask < comm_size) {
@@ -866,9 +874,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
866874
if (vremote < comm_size) {
867875
ptrdiff_t send_offset = vrank_tree_root * scatter_count * extent;
868876
ptrdiff_t recv_offset = vremote_tree_root * scatter_count * extent;
869-
recv_count = count - vremote_tree_root * scatter_count;
870-
if (recv_count < 0)
871-
recv_count = 0;
877+
recv_count = rectify_diff(count, (size_t)(vremote_tree_root * scatter_count));
872878
err = ompi_coll_base_sendrecv((char *)buf + send_offset,
873879
curr_count, datatype, remote,
874880
MCA_COLL_BASE_TAG_BCAST,
@@ -877,7 +883,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
877883
MCA_COLL_BASE_TAG_BCAST,
878884
comm, &status, rank);
879885
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
880-
recv_count = (int)(status._ucount / datatype_size);
886+
recv_count = (size_t)(status._ucount / datatype_size);
881887
curr_count += recv_count;
882888
}
883889

@@ -913,7 +919,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
913919
MCA_COLL_BASE_TAG_BCAST,
914920
comm, &status));
915921
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
916-
recv_count = (int)(status._ucount / datatype_size);
922+
recv_count = (size_t)(status._ucount / datatype_size);
917923
curr_count += recv_count;
918924
}
919925
}
@@ -988,8 +994,8 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
988994
if (vrank & mask) {
989995
int parent = (rank - mask + comm_size) % comm_size;
990996
/* Compute an upper bound on recv block size */
991-
recv_count = count - vrank * scatter_count;
992-
if (recv_count <= 0) {
997+
recv_count = rectify_diff(count, (size_t)(vrank * scatter_count));
998+
if (0 == recv_count) {
993999
curr_count = 0;
9941000
} else {
9951001
/* Recv data from parent */
@@ -1009,7 +1015,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
10091015
mask >>= 1;
10101016
while (mask > 0) {
10111017
if (vrank + mask < comm_size) {
1012-
send_count = curr_count - scatter_count * mask;
1018+
send_count = rectify_diff(curr_count, (size_t)(scatter_count * mask));
10131019
if (send_count > 0) {
10141020
int child = (rank + mask) % comm_size;
10151021
err = MCA_PML_CALL(send((char *)buf + (ptrdiff_t)scatter_count * (vrank + mask) * extent,
@@ -1023,33 +1029,41 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
10231029
mask >>= 1;
10241030
}
10251031

1026-
/* Allgather by a ring algorithm */
1032+
/* Allgather by a ring algorithm, using only unsigned types */
10271033
int left = (rank - 1 + comm_size) % comm_size;
10281034
int right = (rank + 1) % comm_size;
1035+
1036+
/* The block we will send/recv in each step */
10291037
int send_block = vrank;
10301038
int recv_block = (vrank - 1 + comm_size) % comm_size;
10311039

1032-
for (int i = 1; i < comm_size; i++) {
1033-
recv_count = (scatter_count < count - recv_block * scatter_count) ?
1034-
scatter_count : count - recv_block * scatter_count;
1035-
if (recv_count < 0)
1036-
recv_count = 0;
1037-
ptrdiff_t recv_offset = recv_block * scatter_count * extent;
1038-
1039-
send_count = (scatter_count < count - send_block * scatter_count) ?
1040-
scatter_count : count - send_block * scatter_count;
1041-
if (send_count < 0)
1042-
send_count = 0;
1043-
ptrdiff_t send_offset = send_block * scatter_count * extent;
1044-
1045-
err = ompi_coll_base_sendrecv((char *)buf + send_offset, send_count,
1040+
for (int i = 1; i < comm_size; ++i) {
1041+
/* how many elements remain in recv_block? */
1042+
size_t recv_offset_elems = recv_block * scatter_count;
1043+
size_t recv_remaining = rectify_diff(count, recv_offset_elems);
1044+
recv_count = (recv_remaining < scatter_count) ?
1045+
recv_remaining : scatter_count;
1046+
size_t recv_offset = recv_offset_elems * extent;
1047+
1048+
/* same logic for send */
1049+
size_t send_offset_elems = send_block * scatter_count;
1050+
size_t send_remaining = rectify_diff(count, send_offset_elems);
1051+
send_count = (send_remaining < scatter_count) ?
1052+
send_remaining : scatter_count;
1053+
size_t send_offset = send_offset_elems * extent;
1054+
1055+
err = ompi_coll_base_sendrecv((char*)buf + send_offset, send_count,
10461056
datatype, right, MCA_COLL_BASE_TAG_BCAST,
1047-
(char *)buf + recv_offset, recv_count,
1057+
(char*)buf + recv_offset, recv_count,
10481058
datatype, left, MCA_COLL_BASE_TAG_BCAST,
10491059
comm, MPI_STATUS_IGNORE, rank);
1050-
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
1060+
if (MPI_SUCCESS != err) {
1061+
goto cleanup_and_return;
1062+
}
1063+
1064+
/* rotate blocks */
10511065
send_block = recv_block;
1052-
recv_block = (recv_block - 1 + comm_size) % comm_size;
1066+
recv_block = (recv_block + comm_size - 1) % comm_size;
10531067
}
10541068

10551069
cleanup_and_return:

0 commit comments

Comments
 (0)