diff --git a/config/opal_setup_cc.m4 b/config/opal_setup_cc.m4 index 394826223fb..5b6f8993319 100644 --- a/config/opal_setup_cc.m4 +++ b/config/opal_setup_cc.m4 @@ -277,6 +277,7 @@ AC_DEFUN([OPAL_SETUP_CC],[ _OPAL_CHECK_SPECIFIC_CFLAGS(-Wstrict-prototypes, Wstrict_prototypes) _OPAL_CHECK_SPECIFIC_CFLAGS(-Wcomment, Wcomment) _OPAL_CHECK_SPECIFIC_CFLAGS(-Wshadow, Wshadow) + _OPAL_CHECK_SPECIFIC_CFLAGS(-Wtype-limits,Wtype_limits) _OPAL_CHECK_SPECIFIC_CFLAGS(-Werror-implicit-function-declaration, Werror_implicit_function_declaration) _OPAL_CHECK_SPECIFIC_CFLAGS(-Wno-long-double, Wno_long_double, int main() { long double x; }) _OPAL_CHECK_SPECIFIC_CFLAGS(-fno-strict-aliasing, fno_strict_aliasing, int main() { long double x; }) diff --git a/ompi/mca/coll/base/coll_base_bcast.c b/ompi/mca/coll/base/coll_base_bcast.c index dece487066f..f1941215b29 100644 --- a/ompi/mca/coll/base/coll_base_bcast.c +++ b/ompi/mca/coll/base/coll_base_bcast.c @@ -14,6 +14,7 @@ * Copyright (c) 2016 Research Organization for Information Science * and Technology (RIST). All rights reserved. * Copyright (c) 2017 IBM Corporation. All rights reserved. + * Copyright (c) 2025 Triad National Security, LLC. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -34,6 +35,15 @@ #include "coll_base_topo.h" #include "coll_base_util.h" +/* + * if a > b return a- b otherwise 0 + */ +static inline size_t +rectify_diff(size_t a, size_t b) +{ + return a > b ? a - b : 0; +} + int ompi_coll_base_bcast_intra_generic( void* buffer, size_t original_count, @@ -811,8 +821,8 @@ int ompi_coll_base_bcast_intra_scatter_allgather( if (vrank & mask) { int parent = (rank - mask + comm_size) % comm_size; /* Compute an upper bound on recv block size */ - recv_count = count - vrank * scatter_count; - if (recv_count <= 0) { + recv_count = rectify_diff(count, (size_t)(vrank * scatter_count)); + if (recv_count == 0) { curr_count = 0; } else { /* Recv data from parent */ @@ -832,7 +842,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather( mask >>= 1; while (mask > 0) { if (vrank + mask < comm_size) { - send_count = curr_count - scatter_count * mask; + send_count = rectify_diff(curr_count, (size_t)(scatter_count * mask)); if (send_count > 0) { int child = (rank + mask) % comm_size; err = MCA_PML_CALL(send((char *)buf + (ptrdiff_t)scatter_count * (vrank + mask) * extent, @@ -850,10 +860,8 @@ int ompi_coll_base_bcast_intra_scatter_allgather( * Allgather by recursive doubling * Each process has the curr_count elems in the buf[vrank * scatter_count, ...] */ - size_t rem_count = count - vrank * scatter_count; + size_t rem_count = rectify_diff(count, (size_t)(vrank * scatter_count)); curr_count = (scatter_count < rem_count) ? scatter_count : rem_count; - if (curr_count < 0) - curr_count = 0; mask = 0x1; while (mask < comm_size) { @@ -866,9 +874,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather( if (vremote < comm_size) { ptrdiff_t send_offset = vrank_tree_root * scatter_count * extent; ptrdiff_t recv_offset = vremote_tree_root * scatter_count * extent; - recv_count = count - vremote_tree_root * scatter_count; - if (recv_count < 0) - recv_count = 0; + recv_count = rectify_diff(count, (size_t)(vremote_tree_root * scatter_count)); err = ompi_coll_base_sendrecv((char *)buf + send_offset, curr_count, datatype, remote, MCA_COLL_BASE_TAG_BCAST, @@ -877,7 +883,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather( MCA_COLL_BASE_TAG_BCAST, comm, &status, rank); if (MPI_SUCCESS != err) { goto cleanup_and_return; } - recv_count = (int)(status._ucount / datatype_size); + recv_count = (size_t)(status._ucount / datatype_size); curr_count += recv_count; } @@ -913,7 +919,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather( MCA_COLL_BASE_TAG_BCAST, comm, &status)); if (MPI_SUCCESS != err) { goto cleanup_and_return; } - recv_count = (int)(status._ucount / datatype_size); + recv_count = (size_t)(status._ucount / datatype_size); curr_count += recv_count; } } @@ -988,8 +994,8 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring( if (vrank & mask) { int parent = (rank - mask + comm_size) % comm_size; /* Compute an upper bound on recv block size */ - recv_count = count - vrank * scatter_count; - if (recv_count <= 0) { + recv_count = rectify_diff(count, (size_t)(vrank * scatter_count)); + if (0 == recv_count) { curr_count = 0; } else { /* Recv data from parent */ @@ -1009,7 +1015,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring( mask >>= 1; while (mask > 0) { if (vrank + mask < comm_size) { - send_count = curr_count - scatter_count * mask; + send_count = rectify_diff(curr_count, (size_t)(scatter_count * mask)); if (send_count > 0) { int child = (rank + mask) % comm_size; err = MCA_PML_CALL(send((char *)buf + (ptrdiff_t)scatter_count * (vrank + mask) * extent, @@ -1023,33 +1029,41 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring( mask >>= 1; } - /* Allgather by a ring algorithm */ + /* Allgather by a ring algorithm, using only unsigned types */ int left = (rank - 1 + comm_size) % comm_size; int right = (rank + 1) % comm_size; + + /* The block we will send/recv in each step */ int send_block = vrank; int recv_block = (vrank - 1 + comm_size) % comm_size; - for (int i = 1; i < comm_size; i++) { - recv_count = (scatter_count < count - recv_block * scatter_count) ? - scatter_count : count - recv_block * scatter_count; - if (recv_count < 0) - recv_count = 0; - ptrdiff_t recv_offset = recv_block * scatter_count * extent; - - send_count = (scatter_count < count - send_block * scatter_count) ? - scatter_count : count - send_block * scatter_count; - if (send_count < 0) - send_count = 0; - ptrdiff_t send_offset = send_block * scatter_count * extent; - - err = ompi_coll_base_sendrecv((char *)buf + send_offset, send_count, + for (int i = 1; i < comm_size; ++i) { + /* how many elements remain in recv_block? */ + size_t recv_offset_elems = recv_block * scatter_count; + size_t recv_remaining = rectify_diff(count, recv_offset_elems); + recv_count = (recv_remaining < scatter_count) ? + recv_remaining : scatter_count; + size_t recv_offset = recv_offset_elems * extent; + + /* same logic for send */ + size_t send_offset_elems = send_block * scatter_count; + size_t send_remaining = rectify_diff(count, send_offset_elems); + send_count = (send_remaining < scatter_count) ? + send_remaining : scatter_count; + size_t send_offset = send_offset_elems * extent; + + err = ompi_coll_base_sendrecv((char*)buf + send_offset, send_count, datatype, right, MCA_COLL_BASE_TAG_BCAST, - (char *)buf + recv_offset, recv_count, + (char*)buf + recv_offset, recv_count, datatype, left, MCA_COLL_BASE_TAG_BCAST, comm, MPI_STATUS_IGNORE, rank); - if (MPI_SUCCESS != err) { goto cleanup_and_return; } + if (MPI_SUCCESS != err) { + goto cleanup_and_return; + } + + /* rotate blocks */ send_block = recv_block; - recv_block = (recv_block - 1 + comm_size) % comm_size; + recv_block = (recv_block + comm_size - 1) % comm_size; } cleanup_and_return: