Skip to content

Commit b52d1b4

Browse files
authored
Add attention backend tests to more-tests.yml
Tests from #1477 only, without the generate.py refactor
1 parent 5684175 commit b52d1b4

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

.github/workflows/more-tests.yml

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ jobs:
1919
gpu-arch-version: "12.4"
2020
timeout: 60
2121
script: |
22+
set -xeou pipefail
2223
echo "::group::Print machine info"
2324
uname -a
2425
echo "::endgroup::"
@@ -83,3 +84,64 @@ jobs:
8384
echo "tests complete"
8485
echo "******************************************"
8586
echo "::endgroup::"
87+
88+
test-sdpa-backends:
89+
permissions:
90+
id-token: write
91+
contents: read
92+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
93+
with:
94+
runner: linux.g5.4xlarge.nvidia.gpu
95+
gpu-arch-type: cuda
96+
gpu-arch-version: "12.4"
97+
timeout: 60
98+
script: |
99+
set -xeou pipefail
100+
echo "::group::Print machine info"
101+
uname -a
102+
echo "::endgroup::"
103+
104+
echo "::group::Download checkpoints"
105+
# Install requirements
106+
./install/install_requirements.sh cuda
107+
pip3 list
108+
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
109+
echo "::endgroup::"
110+
111+
echo "::group::Download checkpoints"
112+
mkdir -p checkpoints/stories15M
113+
pushd checkpoints/stories15M
114+
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt
115+
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
116+
popd
117+
echo "::endgroup::"
118+
119+
echo "::group::Run inference"
120+
export MODEL_PATH=checkpoints/stories15M/stories15M.pt
121+
export MODEL_NAME=stories15M
122+
export MODEL_DIR=/tmp
123+
124+
for DEVICE in cpu cuda; do
125+
for DTYPE in bfloat16 float16 float32; do
126+
for SDPA in 'math' 'flash_attention' 'efficient_attention' 'cudnn_attention'; do
127+
echo "******************************************************************"
128+
echo "******* $DEVICE $DTYPE $SDPA "
129+
###################################################################
130+
# Python execution interpreted vanilla
131+
python torchchat.py generate --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0
132+
###################################################################
133+
# prefill, and compile and prefill compile
134+
python torchchat.py generate --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0 --compile --compile-prefill
135+
###################################################################
136+
# sequential prefill
137+
python torchchat.py generate --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0 --sequential-prefill
138+
###################################################################
139+
# prefill, and compile
140+
python torchchat.py generate --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0 --sequential-prefill --compile
141+
done
142+
done
143+
done
144+
145+
echo "tests complete"
146+
echo "******************************************"
147+
echo "::endgroup::"

0 commit comments

Comments
 (0)