Skip to content

Commit 17e4dc5

Browse files
committed
update scripts
1 parent de7080b commit 17e4dc5

File tree

2 files changed

+164
-33
lines changed

2 files changed

+164
-33
lines changed

generate/generate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@
114114
)
115115

116116
""" Iterate over prompts and generate code """
117-
if not args.restart and args.cache is not None:
117+
if not args.restart and args.cache is not None and os.path.exists(args.cache):
118118
with open(args.cache, 'r') as jsonl_file:
119119
responses = [json.loads(line) for line in jsonl_file]
120120
responses = [r for r in responses if r["temperature"] == args.temperature and r["prompted"] == args.prompted
@@ -133,7 +133,7 @@
133133
prompt_str = cur_prompt["prompt"]
134134

135135
total_tokens += len(generator.tokenizer.encode(output[0]["generated_text"]))
136-
cleaned_output = clean_output(output[0]["generated_text"], prompt_str)
136+
cleaned_output = inference_config.clean_output(output[0]["generated_text"], prompt_str)
137137
cur_prompt["outputs"].append(cleaned_output)
138138

139139
if idx % args.num_samples_per_prompt == args.num_samples_per_prompt - 1:

generate/utils.py

Lines changed: 162 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,71 @@
11
# std imports
22
from abc import ABC, abstractmethod
3+
import re
34

45
# tpl imports
56
import torch
67
from torch.utils.data import Dataset
78
from transformers import StoppingCriteria
89

910

11+
def clean_output(output : str, prompt : str) -> str:
12+
""" Remove `prompt` from the begging of `output`.
13+
Also truncate at the end of the function definition (i.e. matching closing brace).
14+
"""
15+
# replace up to the end of the first instance of prompt
16+
prompt_loc = output.find(prompt)
17+
if prompt_loc == -1:
18+
raise ValueError(f"Prompt not found in output: {prompt}")
19+
output = output[prompt_loc + len(prompt):].strip()
20+
21+
# temporarily add opening brace to the beginning
22+
output = '{' + output
23+
24+
# find the matching brace to output[0]
25+
stack = []
26+
index = 0
27+
while index < len(output):
28+
token = output[index]
29+
if token == '{':
30+
stack.append(token)
31+
elif token == '}':
32+
stack.pop()
33+
if len(stack) == 0:
34+
break
35+
36+
index += 1
37+
38+
# truncate at the matching brace
39+
output = output[1:index+1]
40+
return output
41+
42+
GPU_FUNCTION_NAME_PATTERN = re.compile(r"__global__ void ([a-zA-Z0-9_]+)\(")
43+
CPU_FUNCTION_NAME_PATTERN = re.compile(r"\s*[a-zA-Z_]+ ([a-zA-Z0-9_]+)\(")
44+
def get_function_name(prompt: str, execution_model: str) -> str:
45+
if execution_model in ['cuda', 'hip']:
46+
match = GPU_FUNCTION_NAME_PATTERN.match(prompt.splitlines()[-1])
47+
else:
48+
match = CPU_FUNCTION_NAME_PATTERN.match(prompt.splitlines()[-1])
49+
if match is None:
50+
raise ValueError(f"Could not find function name in prompt: {prompt}")
51+
return match.group(1)
52+
53+
54+
def find_matching_brace_index(code: str, open_brace_index: int) -> int:
55+
"""Finds the index of the closing brace that matches the opening brace at the given index."""
56+
57+
brace_count = 1
58+
for i in range(open_brace_index + 1, len(code)):
59+
if code[i] == "{":
60+
brace_count += 1
61+
elif code[i] == "}":
62+
brace_count -= 1
63+
if brace_count == 0:
64+
return i
65+
66+
raise ValueError("Unmatched opening brace")
67+
68+
1069
class InferenceConfig(ABC):
1170

1271
def __init__(self, prompted : bool = False):
@@ -36,6 +95,10 @@ def trust_remote_code(self) -> bool:
3695
def format_prompt(self, prompt : str) -> str:
3796
pass
3897

98+
@abstractmethod
99+
def clean_output(self, output: str, prompt: str) -> str:
100+
pass
101+
39102

40103
class StarCoderConfig(InferenceConfig):
41104

@@ -63,6 +126,9 @@ def format_prompt(self, prompt : str) -> str:
63126
return f"<filename>solutions/solution_1.cpp\n// here is the correct implementation of the coding exercise\n\n{prompt}"
64127
return prompt.strip()
65128

129+
def clean_output(self, output: str, prompt: str) -> str:
130+
return clean_output(output, prompt)
131+
66132
class CodeLlamaConfig(InferenceConfig):
67133

68134
def __init__(self, prompted : bool = False):
@@ -90,6 +156,8 @@ def format_prompt(self, prompt : str) -> str:
90156
return f"// filename: solutions/solution_1.cpp\n// here is the correct implementation of the coding exercise\n\n{prompt}"
91157
return prompt.strip()
92158

159+
def clean_output(self, output: str, prompt: str) -> str:
160+
return clean_output(output, prompt)
93161

94162
class PolyCoderConfig(InferenceConfig):
95163

@@ -116,6 +184,9 @@ def format_prompt(self, prompt : str) -> str:
116184
if self.prompted:
117185
return f"// filename: solutions/solution_1.cpp\n// here is the correct implementation of the coding exercise\n\n{prompt}"
118186
return prompt.strip()
187+
188+
def clean_output(self, output: str, prompt: str) -> str:
189+
return clean_output(output, prompt)
119190

120191

121192
class PhindConfig(InferenceConfig):
@@ -144,6 +215,9 @@ def format_prompt(self, prompt : str) -> str:
144215
return f"// filename: solutions/solution_1.cpp\n// here is the correct implementation of the coding exercise\n\n{prompt}"
145216
return prompt.strip()
146217

218+
def clean_output(self, output: str, prompt: str) -> str:
219+
return clean_output(output, prompt)
220+
147221

148222
class ReplitConfig(InferenceConfig):
149223

@@ -174,6 +248,92 @@ def format_prompt(self, prompt : str) -> str:
174248
return f"// filename: solutions/solution_1.cpp\n// here is the correct implementation of the coding exercise\n\n{prompt}"
175249
return prompt.strip()
176250

251+
def clean_output(self, output: str, prompt: str) -> str:
252+
return clean_output(output, prompt)
253+
254+
255+
class MagicoderConfig(InferenceConfig):
256+
257+
PROMPT_TEMPLATE = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.
258+
259+
@@ Instruction
260+
{instruction}
261+
262+
@@ Response
263+
"""
264+
265+
def __init__(self, prompted : bool = False):
266+
super().__init__(prompted=prompted)
267+
268+
def get_dtype(self):
269+
return torch.bfloat16
270+
271+
def init_padding(self, tokenizer):
272+
tokenizer.pad_token_id = tokenizer.eos_token_id # for batching
273+
tokenizer.padding_side = "left" # for decoder-only models
274+
pass
275+
276+
def get_pad_token_id(self, tokenizer) -> int:
277+
return tokenizer.pad_token_id
278+
279+
def get_eos_token_id(self, tokenizer) -> int:
280+
return tokenizer.eos_token_id
281+
282+
def trust_remote_code(self) -> bool:
283+
return False
284+
285+
def format_prompt(self, prompt : str) -> str:
286+
if self.prompted:
287+
function_name = get_function_name(prompt, "cuda" if "__global__" in prompt else "serial")
288+
prompt = f"Complete the following c++ function.\n```c++{prompt.strip()}```\nWrite only the function {function_name} and no other code. Enclose your solution in ```c++ and ```."
289+
return self.PROMPT_TEMPLATE.format(instruction=prompt)
290+
return prompt.strip()
291+
292+
def clean_output(self, output: str, prompt: str) -> str:
293+
""" Clean LLM output to find code solution. The output should be in a ```c++ ``` code block. If there are
294+
multiple, then it tries to find the block with the function definition (as contained in the prompt).
295+
The code block itself may include the function definition and body OR just the body. This will try
296+
to parse both.
297+
"""
298+
# 0. replace up to the end of the first instance of prompt
299+
prompt_loc = output.find("@@ Response")
300+
if prompt_loc == -1:
301+
raise ValueError(f"Prompt not found in output: {prompt}")
302+
output = output[prompt_loc + len("@@ Response"):].strip()
303+
304+
# 1. Find all code blocks enclosed in triple backticks with "c++" language tag
305+
code_blocks = re.findall(r"```c\+\+\n(.*?)\n```", output, flags=re.DOTALL)
306+
code_blocks = [block.lstrip('```c++').rstrip('```') for block in code_blocks]
307+
308+
# 2. Prioritize code blocks containing the function definition from the prompt
309+
sub_prompt = prompt.rstrip().removesuffix("@@ Response").rstrip().removesuffix("```").split("```")[-1]
310+
function_name = get_function_name(sub_prompt, "cuda" if "__global__" in sub_prompt else "serial")
311+
prioritized_blocks = [block for block in code_blocks if function_name in block]
312+
313+
# 3. Choose the first block if multiple match, or any block if none match
314+
if len(code_blocks) > 0:
315+
selected_block = prioritized_blocks[0] if prioritized_blocks else code_blocks[0]
316+
else:
317+
if '```c++' in output: # starts with ```c++ but it didn't finish
318+
code_idx = output.find('```c++')
319+
selected_block = output[code_idx:].removeprefix('```c++')
320+
else:
321+
selected_block = output
322+
323+
# 4. Handle cases where the block contains only the function body
324+
if function_name not in selected_block:
325+
return selected_block
326+
else:
327+
function_start_index = selected_block.index(function_name)
328+
open_brace_index = selected_block.find("{", function_start_index)
329+
try:
330+
close_brace_index = find_matching_brace_index(selected_block, open_brace_index)
331+
except ValueError:
332+
close_brace_index = len(selected_block)
333+
334+
function_body = selected_block[open_brace_index + 1 : close_brace_index]
335+
return function_body + "}"
336+
177337

178338
def get_inference_config(model_name : str, **kwargs) -> InferenceConfig:
179339
if model_name == "bigcode/starcoderbase":
@@ -186,41 +346,12 @@ def get_inference_config(model_name : str, **kwargs) -> InferenceConfig:
186346
return PhindConfig(**kwargs)
187347
elif model_name == 'replit/replit-code-v1_5-3b':
188348
return ReplitConfig(**kwargs)
349+
elif model_name.startswith('ise-uiuc/Magicoder'):
350+
return MagicoderConfig(**kwargs)
189351
else:
190352
raise ValueError(f"Unknown model name: {model_name}")
191353

192354

193-
def clean_output(output : str, prompt : str) -> str:
194-
""" Remove `prompt` from the begging of `output`.
195-
Also truncate at the end of the function definition (i.e. matching closing brace).
196-
"""
197-
# replace up to the end of the first instance of prompt
198-
prompt_loc = output.find(prompt)
199-
if prompt_loc == -1:
200-
raise ValueError(f"Prompt not found in output: {prompt}")
201-
output = output[prompt_loc + len(prompt):].strip()
202-
203-
# temporarily add opening brace to the beginning
204-
output = '{' + output
205-
206-
# find the matching brace to output[0]
207-
stack = []
208-
index = 0
209-
while index < len(output):
210-
token = output[index]
211-
if token == '{':
212-
stack.append(token)
213-
elif token == '}':
214-
stack.pop()
215-
if len(stack) == 0:
216-
break
217-
218-
index += 1
219-
220-
# truncate at the matching brace
221-
output = output[1:index+1]
222-
return output
223-
224355
class PromptDataset(Dataset):
225356
''' PyTorch dataset that simply wraps a list of strings. They do not have to have the same length.
226357
'''

0 commit comments

Comments
 (0)