Skip to content

Commit 8f82c53

Browse files
ytziDando18
authored andcommitted
Added vllm generation; improved one error message
1 parent 1d8c6ff commit 8f82c53

File tree

2 files changed

+162
-1
lines changed

2 files changed

+162
-1
lines changed

generate/generate-vllm.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# std imports
2+
import argparse
3+
import json
4+
import os
5+
import sys
6+
import time
7+
from tqdm import tqdm
8+
import torch
9+
10+
# tpl imports
11+
from vllm import LLM, SamplingParams
12+
13+
# local imports
14+
from utils import BalancedBracketsCriteria, PromptDataset, clean_output, get_inference_config
15+
16+
""" Parse command line arguments """
17+
parser = argparse.ArgumentParser(description='Generate code with vLLM')
18+
parser.add_argument('--prompts', required=True, help='Path to the prompt JSON file')
19+
parser.add_argument('--model', required=True, help='Path to the language model')
20+
parser.add_argument('--output', required=True, help='Path to the output JSON file')
21+
parser.add_argument('--restart', action='store_true', help='Restart generation from scratch (default: False)')
22+
parser.add_argument('--cache', help='JSONL file to cache intermediate results in. Will be restored from if it ' +
23+
'already exists and --restart is not specified')
24+
parser.add_argument('--restore_from', help='JSON file to restore old results from. Will be restored from ' +
25+
'if it already exists and --restart is not specified. Is different from --cache in that it is a JSON file, not a ' +
26+
'JSONL file, and it is only used to restore old results where the prompt is equivalent. Cached results are ' +
27+
'prioritized over restored results.')
28+
parser.add_argument('--max_new_tokens', type=int, default=1024, help='Maximum number of new tokens to generate (default: 1024)')
29+
parser.add_argument('--num_samples_per_prompt', type=int, default=50, help='Number of code samples to generate (default: 50)')
30+
parser.add_argument('--temperature', type=float, default=0.2, help='Temperature for controlling randomness (default: 0.2)')
31+
parser.add_argument('--top_p', type=float, default=0.95, help='Top p value for nucleus sampling (default: 0.95)')
32+
parser.add_argument('--do_sample', action='store_true', help='Enable sampling (default: False)')
33+
parser.add_argument('--prompted', action='store_true', help='Use prompted generation. See StarCoder paper (default: False)')
34+
args = parser.parse_args()
35+
36+
""" Load prompts """
37+
with open(args.prompts, 'r') as json_file:
38+
prompts = json.load(json_file)
39+
40+
""" Load existing responses if they exist """
41+
if not args.restart and os.path.exists(args.cache):
42+
with open(args.cache, 'r') as jsonl_file:
43+
responses = [json.loads(line) for line in jsonl_file]
44+
45+
# remove prompt from prompts if it is in responses and has an 'output' value with at least 1 entry
46+
original_len = len(prompts)
47+
prompts = [p for p in prompts if
48+
not any(p["name"] == r["name"] and
49+
p["parallelism_model"] == r["parallelism_model"] and
50+
p["prompt"] == r["prompt"] and
51+
args.temperature == r["temperature"] and
52+
args.prompted == r["prompted"] and
53+
args.num_samples_per_prompt == len(r["outputs"])
54+
for r in responses)]
55+
print(f"[cache] Skipping {original_len - len(prompts)} prompts that already have responses")
56+
57+
""" Load existing responses if they exist """
58+
if not args.restart and args.restore_from and os.path.exists(args.restore_from):
59+
with open(args.restore_from, 'r') as json_file:
60+
restored_responses = json.load(json_file)
61+
62+
# remove prompt from prompts if it is in responses and has an 'output' value with at least 1 entry
63+
original_len = len(prompts)
64+
responses_to_keep = []
65+
prompts_without_existing_responses = []
66+
for p in prompts:
67+
for r in restored_responses:
68+
if p["name"] == r["name"] and \
69+
p["parallelism_model"] == r["parallelism_model"] and \
70+
p["prompt"] == r["prompt"] and \
71+
args.temperature == r["temperature"] and \
72+
args.prompted == r["prompted"] and \
73+
args.num_samples_per_prompt == len(r["outputs"]):
74+
responses_to_keep.append(r)
75+
break
76+
else:
77+
prompts_without_existing_responses.append(p)
78+
prompts = prompts_without_existing_responses
79+
print(f"[restore_from] Skipping {original_len - len(prompts)} prompts that already have responses. " +
80+
f"{len(prompts)} prompts left.")
81+
82+
# write restored responses to cache
83+
if args.cache is not None:
84+
with open(args.cache, 'a') as jsonl_file:
85+
for response in responses_to_keep:
86+
jsonl_file.write(json.dumps(response) + "\n")
87+
print(f"[restore_from] Wrote {len(responses_to_keep)} restored responses to cache")
88+
89+
""" Initialize inference config """
90+
inference_config = get_inference_config(args.model, prompted=args.prompted)
91+
92+
prompts_repeated = [p for p in prompts for _ in range(args.num_samples_per_prompt)]
93+
94+
""" Initialize vLLM engine """
95+
llm = LLM(model=args.model, tensor_parallel_size=torch.cuda.device_count())
96+
97+
# Configure sampling parameters
98+
sampling_params = SamplingParams(
99+
temperature=args.temperature if args.do_sample else 0,
100+
top_p=args.top_p if args.do_sample else 1.0,
101+
max_tokens=args.max_new_tokens,
102+
n=1, # We handle multiple samples manually
103+
)
104+
105+
""" Generate code """
106+
if not args.restart and args.cache is not None and os.path.exists(args.cache):
107+
with open(args.cache, 'r') as jsonl_file:
108+
responses = [json.loads(line) for line in jsonl_file]
109+
responses = [r for r in responses if r["temperature"] == args.temperature and r["prompted"] == args.prompted
110+
and args.num_samples_per_prompt == len(r["outputs"])
111+
and any(p["name"] == r["name"] and p["prompt"] == r["prompt"] and p["parallelism_model"] == r["parallelism_model"] for p in prompts)]
112+
else:
113+
responses = []
114+
115+
cur_prompt = None
116+
start_time = time.time()
117+
total_tokens = 0
118+
119+
# Format all prompts
120+
formatted_prompts = [inference_config.format_prompt(p["prompt"]) for p in prompts_repeated]
121+
122+
# Generate all outputs at once
123+
outputs = llm.generate(formatted_prompts, sampling_params)
124+
125+
# Process outputs
126+
for idx, (prompt, output) in enumerate(zip(prompts_repeated, outputs)):
127+
if idx % args.num_samples_per_prompt == 0:
128+
cur_prompt = prompt.copy()
129+
cur_prompt.update({
130+
"temperature": args.temperature,
131+
"top_p": args.top_p,
132+
"do_sample": args.do_sample,
133+
"max_new_tokens": args.max_new_tokens,
134+
"prompted": args.prompted
135+
})
136+
cur_prompt["outputs"] = []
137+
cur_prompt["raw_outputs"] = []
138+
prompt_str = cur_prompt["prompt"]
139+
140+
# Count tokens and clean output
141+
# FIXME: This is to keep the same behavior as generate.py
142+
huggingface_style_output = output.prompt + output.outputs[0].text
143+
total_tokens += len(llm.get_tokenizer().encode(huggingface_style_output))
144+
cleaned_output = inference_config.clean_output(huggingface_style_output, prompt_str)
145+
cur_prompt["outputs"].append(cleaned_output)
146+
cur_prompt["raw_outputs"].append(huggingface_style_output)
147+
148+
if idx % args.num_samples_per_prompt == args.num_samples_per_prompt - 1:
149+
responses.append(cur_prompt)
150+
151+
if not args.restart and args.cache is not None:
152+
with open(args.cache, 'a') as jsonl_file:
153+
jsonl_file.write(json.dumps(cur_prompt) + "\n")
154+
155+
end_time = time.time()
156+
tokens_per_second = total_tokens / (end_time - start_time)
157+
print(f"Generated {len(responses)} code samples in {end_time - start_time:.2f} seconds ({tokens_per_second:.2f} tokens per second)")
158+
159+
""" Save responses to JSON file """
160+
with open(args.output, 'w') as output_file:
161+
json.dump(responses, output_file, indent=4)

generate/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def clean_instruct_output(output: str, prompt: str, response_tag: str) -> str:
7575
# 0. replace up to the end of the first instance of prompt
7676
prompt_loc = output.find(response_tag)
7777
if prompt_loc == -1:
78-
raise ValueError(f"Prompt not found in output: {prompt}")
78+
raise ValueError(f"Response tag {response_tag} not found in output: {prompt}")
7979
output = output[prompt_loc + len(response_tag):].strip()
8080

8181
# 1. Find all code blocks enclosed in triple backticks with "c++" language tag

0 commit comments

Comments
 (0)