Skip to content

Commit 6c0e723

Browse files
committed
update genereate script to be able to use axonn
1 parent eefa94e commit 6c0e723

File tree

2 files changed

+55
-7
lines changed

2 files changed

+55
-7
lines changed

generate/generate.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
# tpl imports
1010
import torch
11-
from transformers import pipeline
11+
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
1212

1313
# local imports
1414
from utils import BalancedBracketsCriteria, PromptDataset, clean_output, get_inference_config
@@ -33,6 +33,10 @@
3333
parser.add_argument('--do_sample', action='store_true', help='Enable sampling (default: False)')
3434
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for generation (default: 8)')
3535
parser.add_argument('--prompted', action='store_true', help='Use prompted generation. See StarCoder paper (default: False)')
36+
device_group = parser.add_mutually_exclusive_group()
37+
device_group.add_argument('--device_map', help='Path to the device map JSON file or the string "auto"')
38+
device_group.add_argument('--device', type=int, help='Device to use for inference')
39+
device_group.add_argument('--axonn', action='store_true', help='Use AxoNN for inference')
3640
args = parser.parse_args()
3741

3842
""" Load prompts """
@@ -96,12 +100,48 @@
96100
# and repeat them for however many samples we want to generate per prompt
97101
prompts_repeated = [p for p in prompts for _ in range(args.num_samples_per_prompt)]
98102

103+
""" Set device kwarg for inference """
104+
device_kwarg = {}
105+
USE_AXONN = False
106+
if args.device_map:
107+
if args.device_map == "auto":
108+
device_kwarg["device_map"] = "auto"
109+
else:
110+
with open(args.device_map, 'r') as json_file:
111+
device_map = json.load(json_file)
112+
device_kwarg["device_map"] = device_map
113+
elif args.device:
114+
device_kwarg["device"] = args.device
115+
elif args.axonn:
116+
from mpi4py import MPI
117+
from axonn import axonn as ax
118+
from modify_llama import monkey_patch_llama_with_axonn
119+
world_size = MPI.COMM_WORLD.Get_size()
120+
rank = MPI.COMM_WORLD.Get_rank()
121+
if rank == 0:
122+
print(f"Using AxoNN with {world_size} GPUs.")
123+
ax.init(G_data=1, G_inter=1, G_intra_r=1, G_intra_c=1, G_intra_d=world_size)
124+
if "llama" in args.model:
125+
monkey_patch_llama_with_axonn()
126+
USE_AXONN = True
127+
device_kwarg["device"] = "cuda"
128+
else:
129+
device_kwarg["device"] = 0
130+
131+
""" Load model and tokenizer """
132+
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=inference_config.get_dtype())
133+
if USE_AXONN:
134+
model = model.to("cuda")
135+
tokenizer = AutoTokenizer.from_pretrained(args.model)
136+
99137
""" Initialize HuggingFace pipeline for generation """
100-
generator = pipeline(model=args.model, torch_dtype=inference_config.get_dtype(), device=0)
138+
generator = pipeline(task='text-generation', model=model, tokenizer=tokenizer, **device_kwarg)
101139
inference_config.init_padding(generator.tokenizer)
102140

103141
""" Create a prompt data set to pass to generate method """
104142
prompt_dataset = PromptDataset([inference_config.format_prompt(p["prompt"]) for p in prompts_repeated])
143+
if USE_AXONN:
144+
prompt_dataset = prompt_dataset#.to("cuda")
105145
generated_outputs = generator(
106146
prompt_dataset,
107147
max_new_tokens=args.max_new_tokens,
@@ -114,7 +154,7 @@
114154
)
115155

116156
""" Iterate over prompts and generate code """
117-
if not args.restart and args.cache is not None:
157+
if not args.restart and args.cache is not None and os.path.exists(args.cache):
118158
with open(args.cache, 'r') as jsonl_file:
119159
responses = [json.loads(line) for line in jsonl_file]
120160
responses = [r for r in responses if r["temperature"] == args.temperature and r["prompted"] == args.prompted
@@ -140,8 +180,9 @@
140180
responses.append(cur_prompt)
141181

142182
if not args.restart and args.cache is not None:
143-
with open(args.cache, 'a') as jsonl_file:
144-
jsonl_file.write(json.dumps(cur_prompt) + "\n")
183+
if not USE_AXONN or rank == 0:
184+
with open(args.cache, 'a') as jsonl_file:
185+
jsonl_file.write(json.dumps(cur_prompt) + "\n")
145186

146187
if idx != 0 and idx % args.num_samples_per_prompt == 0:
147188
print(f"Tokens per second: {total_tokens / (time.time() - start_time):.2f}")
@@ -151,5 +192,6 @@
151192
print(f"Generated {len(responses)} code samples in {end_time - start_time:.2f} seconds ({tokens_per_second:.2f} tokens per second)")
152193

153194
""" Save responses to JSON file """
154-
with open(args.output, 'w') as output_file:
155-
json.dump(responses, output_file, indent=4)
195+
if not USE_AXONN or rank == 0:
196+
with open(args.output, 'w') as output_file:
197+
json.dump(responses, output_file, indent=4)

generate/select_gpu_device

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/bin/bash
2+
# select_gpu_device wrapper script
3+
export RANK=${SLURM_PROCID}
4+
export WORLD_SIZE=${SLURM_NTASKS}
5+
export LOCAL_RANK=${SLURM_LOCALID}
6+
exec $*

0 commit comments

Comments
 (0)