Skip to content

Communication and compute on separate Streams do not overlap #599

Open
@garrett361

Description

@garrett361

Describe the bug

Communication and computation do not appear to overlap when launching kernels in different xpu.Streams (on Intel GPU Max 1550s). Being able to overlap communication and communication is crucial for efficiency. DeepSpeed and FSDP both use Stream objects for this purpose, for instance.

To test this, I am launching communication and compute in various permutations of using Streams or not. Driver code which operates on both xpu and cuda:

"""
Tests the ability of Stream objects to overlap computation and compute.

Compute: bfloat16 matmuls
Comms: bfloat16 all_reduce

The script first times the comms and compute operations separately. Then, comms and compute
operations are launched together in various ways:
    * All kernels sent to the default stream
    * Comms and compute kernels sent to separate streams

Expectation:
    * No overlap when all kernels are in the default stream (since they run sequentially).
      Total time is approximately equal to the sum of the individually measured comms and compute
      times.
    * Comms and compute overlap when processed by different streams. Total time is less than the sum
      of the individually comms and compute times.

The ratio of the various times are printed out to test overlap.

Example of running with two gpus on one node:

torchrun --nnodes=1 --nproc-per-node=2 streams_overlap_test.py
"""

import io
import os
from contextlib import contextmanager
from dataclasses import dataclass
from time import perf_counter
from typing import Optional

import torch
import torch.distributed as dist

if torch.cuda.is_available():
    assert torch.cuda.is_available()
    from torch import cuda as accel  # noqa

    DEVICE_TYPE = "cuda"
    BACKEND = "nccl"

else:
    import intel_extension_for_pytorch as ipex  # noqa
    from torch import xpu as accel  # noqa
    import oneccl_bindings_for_pytorch  # noqa

    DEVICE_TYPE = "xpu"
    BACKEND = "ccl"

# Matrix sizes, iterations, and warmups. Dimensions chosen to make the compute and comms times
# similar.
COMPUTE_DIM = 2**14
COMMS_DIM = 4 * COMPUTE_DIM
ITERS = 20
WARMUPS = 3


RANK = int(os.environ["RANK"])
LOCAL_RANK = int(os.environ["LOCAL_RANK"])
WORLD_SIZE = int(os.environ["WORLD_SIZE"])
DEVICE = torch.device(f"{DEVICE_TYPE}:{LOCAL_RANK}")
DTYPE = torch.bfloat16
accel.set_device(DEVICE)

compute_stream = accel.Stream(device=DEVICE)
comms_stream = accel.Stream(device=DEVICE)

compute_matrix = torch.randn(COMPUTE_DIM, COMPUTE_DIM, device=DEVICE, dtype=DTYPE)
comms_matrix = torch.randn(COMMS_DIM, COMMS_DIM, device=DEVICE, dtype=DTYPE)


# Simple timer class via a context manager. Time w/ perf_counter rather than Events, due to
# https://github.com/intel/intel-extension-for-pytorch/issues/568


@dataclass
class Time:
    s: int = 0.0


@contextmanager
def timer():
    t = Time()
    accel.synchronize()
    start = perf_counter()
    yield t
    # Barrier to ensure all comms are finished on all ranks
    dist.barrier()
    # An sync CPU to all kernels in all streams.
    accel.synchronize()
    stop = perf_counter()
    # Update the elapsed time in the yielded Time object.
    t.s = stop - start


def compute(stream: Optional[accel.Stream] = None) -> None:
    with accel.stream(stream):
        for _ in range(ITERS):
            compute_matrix @ compute_matrix


def comms(stream: Optional[accel.Stream] = None) -> None:
    with accel.stream(stream):
        for _ in range(ITERS):
            dist.all_reduce(comms_matrix)


def main() -> None:
    for _ in range(WARMUPS):
        compute()
        comms()

    # Perform computation and comms in different permutations, sometimes using Streams.

    with timer() as t_compute_only:
        compute()

    with timer() as t_comms_only:
        comms()

    with timer() as t_total_default_stream:
        compute()
        comms()

    with timer() as t_total_compute_stream:
        compute(compute_stream)
        comms()

    with timer() as t_total_comms_stream:
        compute()
        comms(comms_stream)

    with timer() as t_total_compute_and_comms_stream:
        compute(compute_stream)
        comms(comms_stream)

    # Print out results
    str_buffer = io.StringIO()
    str_buffer.write(f"{RANK=}\n")
    str_buffer.write(f"\t Compute matrix shape: {compute_matrix.shape}\n")
    str_buffer.write(f"\t Comms matrix shape: {comms_matrix.shape}\n")

    # Compare the case of submitting all work to the default stream to performing the operations
    # independently. Expect they should take approximately the same amount of time, since all
    # kernels run sequentially (ratio ~= 1).
    str_buffer.write("\n")
    str_buffer.write(f"\t {t_total_default_stream.s / (t_compute_only.s + t_comms_only.s)=}\n")

    # Performing the compute in a non-default stream should allow for overlap (ratio < 1).
    str_buffer.write("\n")
    str_buffer.write(f"\t {t_total_compute_stream.s / t_total_default_stream.s =}\n")

    # Performing the communication in a non-default stream should allow for overlap (ratio < 1).
    str_buffer.write("\n")
    str_buffer.write(f"\t {t_total_comms_stream.s / t_total_default_stream.s=}\n")

    # Performing the compute and computation in separate, non-default streams should allow for
    # overlap (ratio < 1).
    str_buffer.write("\n")
    str_buffer.write(f"\t {t_total_compute_and_comms_stream.s / t_total_default_stream.s =}\n")

    print(str_buffer.getvalue(), flush=True)


if __name__ == "__main__":
    try:
        dist.init_process_group(backend=BACKEND)
        main()
    finally:
        dist.destroy_process_group()

Running the above on two A100s, I get:

# On CUDA:
RANK=1
       Compute matrix shape: torch.Size([16384, 16384])
       Comms matrix shape: torch.Size([65536, 65536])

       t_total_default_stream.s / (t_compute_only.s + t_comms_only.s)=0.999107484420899

       t_total_compute_stream.s / t_total_default_stream.s =0.8255087794284478

       t_total_comms_stream.s / t_total_default_stream.s=0.8239232889706464

       t_total_compute_and_comms_stream.s / t_total_default_stream.s =0.820933508932193

RANK=0
       Compute matrix shape: torch.Size([16384, 16384])
       Comms matrix shape: torch.Size([65536, 65536])

       t_total_default_stream.s / (t_compute_only.s + t_comms_only.s)=0.999110786949743

       t_total_compute_stream.s / t_total_default_stream.s =0.8255076235561873

       t_total_comms_stream.s / t_total_default_stream.s=0.8239272185173557

       t_total_compute_and_comms_stream.s / t_total_default_stream.s =0.820936611580174

Running on two Intel GPU Max 1550s, I get:

# XPU
RANK=0
         Compute matrix shape: torch.Size([16384, 16384])
         Comms matrix shape: torch.Size([65536, 65536])

         t_total_default_stream.s / (t_compute_only.s + t_comms_only.s)=0.9993644113989368

         t_total_compute_stream.s / t_total_default_stream.s =1.0017862128444763

         t_total_comms_stream.s / t_total_default_stream.s=0.9987232523971512

         t_total_compute_and_comms_stream.s / t_total_default_stream.s =0.9996738417529752

RANK=1
         Compute matrix shape: torch.Size([16384, 16384])
         Comms matrix shape: torch.Size([65536, 65536])

         t_total_default_stream.s / (t_compute_only.s + t_comms_only.s)=0.99933541624957

         t_total_compute_stream.s / t_total_default_stream.s =1.001785787678655

         t_total_comms_stream.s / t_total_default_stream.s=0.9987192462460256

         t_total_compute_and_comms_stream.s / t_total_default_stream.s =0.9996800354416536

A clear speed-up can be seen when using Streams in their various permutations on A100s, while no speedup is visible on xpu. Absolute timings are not included above, but I have verified that the individual compute and comms times are comparable to each other in all cases.

Is this expected? Is there anything clearly wrong with the test code? The SYCL docs seem to imply that overlap should be possible.

Are there are any relevant environment variables that I might need to set?

Versions

PyTorch version: 2.1.0a0+cxx11.abi
PyTorch CXX11 ABI: Yes
IPEX version: 2.1.10+xpu
IPEX commit: a12f9f650
Build type: Release

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: N/A
IGC version: N/A
CMake version: N/A
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.14.21-150500.55.31_13.0.62-cray_shasta_c-x86_64-with-glibc2.35
Is XPU available: False
DPCPP runtime version: N/A
MKL version: N/A
GPU models and configuration:

Intel OpenCL ICD version: 23.30.26918.50-736~22.04
Level Zero version: 1.3.26918.50-736~22.04

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      48 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             256
On-line CPU(s) list:                0-255
Vendor ID:                          AuthenticAMD
Model name:                         AMD EPYC 7713 64-Core Processor
CPU family:                         25
Model:                              1
Thread(s) per core:                 2
Core(s) per socket:                 64
Socket(s):                          2
Stepping:                           1
Frequency boost:                    enabled
CPU max MHz:                        3720.7029
CPU min MHz:                        1500.0000
BogoMIPS:                           3992.49
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin brs arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm
Virtualization:                     AMD-V
L1d cache:                          4 MiB (128 instances)
L1i cache:                          4 MiB (128 instances)
L2 cache:                           64 MiB (128 instances)
L3 cache:                           512 MiB (16 instances)
NUMA node(s):                       8
NUMA node0 CPU(s):                  0-15,128-143
NUMA node1 CPU(s):                  16-31,144-159
NUMA node2 CPU(s):                  32-47,160-175
NUMA node3 CPU(s):                  48-63,176-191
NUMA node4 CPU(s):                  64-79,192-207
NUMA node5 CPU(s):                  80-95,208-223
NUMA node6 CPU(s):                  96-111,224-239
NUMA node7 CPU(s):                  112-127,240-255
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET, no microcode
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] intel-extension-for-pytorch==2.1.10+xpu
[pip3] mypy==1.5.1
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.4
[pip3] torch==2.1.0a0+cxx11.abi
[pip3] torch-tb-profiler==0.4.3
[pip3] torchaudio==2.1.0a0+cxx11.abi
[pip3] torchvision==0.16.0a0+cxx11.abi
[conda] N/A

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions