diff --git a/generate/generate.py b/generate/generate.py index 992eccf..8473542 100644 --- a/generate/generate.py +++ b/generate/generate.py @@ -8,7 +8,7 @@ # tpl imports import torch -from transformers import pipeline +from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer # local imports from utils import BalancedBracketsCriteria, PromptDataset, clean_output, get_inference_config @@ -33,6 +33,10 @@ parser.add_argument('--do_sample', action='store_true', help='Enable sampling (default: False)') parser.add_argument('--batch_size', type=int, default=16, help='Batch size for generation (default: 8)') parser.add_argument('--prompted', action='store_true', help='Use prompted generation. See StarCoder paper (default: False)') +device_group = parser.add_mutually_exclusive_group() +device_group.add_argument('--device_map', help='Path to the device map JSON file or the string "auto"') +device_group.add_argument('--device', type=int, help='Device to use for inference') +device_group.add_argument('--axonn', action='store_true', help='Use AxoNN for inference') args = parser.parse_args() """ Load prompts """ @@ -96,12 +100,48 @@ # and repeat them for however many samples we want to generate per prompt prompts_repeated = [p for p in prompts for _ in range(args.num_samples_per_prompt)] +""" Set device kwarg for inference """ +device_kwarg = {} +USE_AXONN = False +if args.device_map: + if args.device_map == "auto": + device_kwarg["device_map"] = "auto" + else: + with open(args.device_map, 'r') as json_file: + device_map = json.load(json_file) + device_kwarg["device_map"] = device_map +elif args.device: + device_kwarg["device"] = args.device +elif args.axonn: + from mpi4py import MPI + from axonn import axonn as ax + from modify_llama import monkey_patch_llama_with_axonn + world_size = MPI.COMM_WORLD.Get_size() + rank = MPI.COMM_WORLD.Get_rank() + if rank == 0: + print(f"Using AxoNN with {world_size} GPUs.") + ax.init(G_data=1, G_inter=1, G_intra_r=1, G_intra_c=1, G_intra_d=world_size) + if "llama" in args.model: + monkey_patch_llama_with_axonn() + USE_AXONN = True + device_kwarg["device"] = "cuda" +else: + device_kwarg["device"] = 0 + +""" Load model and tokenizer """ +model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=inference_config.get_dtype()) +if USE_AXONN: + model = model.to("cuda") +tokenizer = AutoTokenizer.from_pretrained(args.model) + """ Initialize HuggingFace pipeline for generation """ -generator = pipeline(model=args.model, torch_dtype=inference_config.get_dtype(), device=0) +generator = pipeline(task='text-generation', model=model, tokenizer=tokenizer, **device_kwarg) inference_config.init_padding(generator.tokenizer) """ Create a prompt data set to pass to generate method """ prompt_dataset = PromptDataset([inference_config.format_prompt(p["prompt"]) for p in prompts_repeated]) +if USE_AXONN: + prompt_dataset = prompt_dataset#.to("cuda") generated_outputs = generator( prompt_dataset, max_new_tokens=args.max_new_tokens, @@ -114,7 +154,7 @@ ) """ Iterate over prompts and generate code """ -if not args.restart and args.cache is not None: +if not args.restart and args.cache is not None and os.path.exists(args.cache): with open(args.cache, 'r') as jsonl_file: responses = [json.loads(line) for line in jsonl_file] responses = [r for r in responses if r["temperature"] == args.temperature and r["prompted"] == args.prompted @@ -140,8 +180,9 @@ responses.append(cur_prompt) if not args.restart and args.cache is not None: - with open(args.cache, 'a') as jsonl_file: - jsonl_file.write(json.dumps(cur_prompt) + "\n") + if not USE_AXONN or rank == 0: + with open(args.cache, 'a') as jsonl_file: + jsonl_file.write(json.dumps(cur_prompt) + "\n") if idx != 0 and idx % args.num_samples_per_prompt == 0: print(f"Tokens per second: {total_tokens / (time.time() - start_time):.2f}") @@ -151,5 +192,6 @@ print(f"Generated {len(responses)} code samples in {end_time - start_time:.2f} seconds ({tokens_per_second:.2f} tokens per second)") """ Save responses to JSON file """ -with open(args.output, 'w') as output_file: - json.dump(responses, output_file, indent=4) +if not USE_AXONN or rank == 0: + with open(args.output, 'w') as output_file: + json.dump(responses, output_file, indent=4) diff --git a/generate/select_gpu_device b/generate/select_gpu_device new file mode 100755 index 0000000..2914f9e --- /dev/null +++ b/generate/select_gpu_device @@ -0,0 +1,6 @@ +#!/bin/bash +# select_gpu_device wrapper script +export RANK=${SLURM_PROCID} +export WORLD_SIZE=${SLURM_NTASKS} +export LOCAL_RANK=${SLURM_LOCALID} +exec $* \ No newline at end of file