Skip to content

Commit 06b3d4a

Browse files
authored
Add magicoder (#20)
add magicoder outputs
1 parent c276d2b commit 06b3d4a

23 files changed

+595664
-34
lines changed

drivers/cpp/cpp_driver_wrapper.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ def run(self, executable: PathLike, **run_config) -> RunOutput:
106106
run_process = run_command(launch_cmd, timeout=self.run_timeout, dry=self.dry)
107107
except subprocess.TimeoutExpired as e:
108108
return RunOutput(-1, str(e.stdout), f"[Timeout] {str(e.stderr)}", config=run_config)
109+
except UnicodeDecodeError as e:
110+
logging.warning(f"UnicodeDecodeError: {str(e)}\nRunnning command: {launch_cmd}")
111+
return RunOutput(-1, "", f"UnicodeDecodeError: {str(e)}", config=run_config)
109112
return RunOutput(run_process.returncode, run_process.stdout, run_process.stderr, config=run_config)
110113

111114
def test_single_output(self, prompt: str, output: str, test_driver_file: PathLike, problem_size: str) -> GeneratedTextResult:

drivers/driver_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(self, exit_code: int, stdout: str, stderr: str, config: dict = {}):
4646
logging.warning(f"Runtime is 0 for run with config {self.config}. Try increasing the problem size.")
4747
if self.is_valid and self.best_sequential_runtime == 0:
4848
logging.warning(f"The best sequential runtime is 0 for run with config {self.config}. Try increasing the problem size.")
49-
if self.is_valid and self.best_sequential_runtime < 0.001:
49+
if self.is_valid and self.best_sequential_runtime and self.best_sequential_runtime < 0.001:
5050
logging.warning(f"The best sequential runtime is very small ({self.best_sequential_runtime}) for run with config {self.config}. Try increasing the problem size.")
5151

5252
def __repr__(self) -> str:
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#!/bin/bash
2+
#SBATCH -n 1
3+
#SBATCH -c 4
4+
#SBATCH --ntasks-per-node=1
5+
#SBATCH --gpus-per-task=1
6+
#SBATCH --mem=128000
7+
#SBATCH -t 05:00:00
8+
#SBATCH -A m2404
9+
#SBATCH -C gpu&hbm80g
10+
#SBATCH -q regular
11+
#SBATCH -J generate-magicoder-s-ds-prompted
12+
#SBATCH -o generation-job-logs/generate-magicoder-s-ds-prompted-%A.out
13+
14+
# settings
15+
MODEL="ise-uiuc/Magicoder-S-DS-6.7B"
16+
TEMP=0.2
17+
TOPP=0.95
18+
MAX_NEW_TKNS=2048
19+
SAMPLES_PER_PROMPT=20
20+
BATCH_SIZE=16
21+
hash=$(md5sum ../prompts/generation-prompts.json | cut -d' ' -f1)
22+
OUTPUT="../outputs/output_${hash:0:8}_${MODEL//\//--}_prompted_temp${TEMP}.json"
23+
CACHE="../outputs/cache/cache_${hash:0:8}_${MODEL//\//--}_prompted_temp${TEMP}.jsonl"
24+
echo "Writing to $OUTPUT"
25+
echo "model=$MODEL MAX_NEW_TKNS=$MAX_NEW_TKNS SAMPLES_PER_PROMPT=$SAMPLES_PER_PROMPT BATCH_SIZE=$BATCH_SIZE"
26+
27+
# setup
28+
#ml cuda/11.8.0
29+
source .env/bin/activate
30+
export HF_HOME=/pscratch/sd/d/dnicho/.cache/huggingface
31+
export OMP_NUM_THREADS=4
32+
33+
# generate
34+
srun python generate.py \
35+
--model $MODEL \
36+
--prompts ../prompts/generation-prompts.json \
37+
--output $OUTPUT \
38+
--cache $CACHE \
39+
--temperature $TEMP \
40+
--top_p $TOPP \
41+
--do_sample \
42+
--max_new_tokens $MAX_NEW_TKNS \
43+
--num_samples_per_prompt $SAMPLES_PER_PROMPT \
44+
--batch_size $BATCH_SIZE \
45+
--prompted

generate/generate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@
114114
)
115115

116116
""" Iterate over prompts and generate code """
117-
if not args.restart and args.cache is not None:
117+
if not args.restart and args.cache is not None and os.path.exists(args.cache):
118118
with open(args.cache, 'r') as jsonl_file:
119119
responses = [json.loads(line) for line in jsonl_file]
120120
responses = [r for r in responses if r["temperature"] == args.temperature and r["prompted"] == args.prompted
@@ -133,7 +133,7 @@
133133
prompt_str = cur_prompt["prompt"]
134134

135135
total_tokens += len(generator.tokenizer.encode(output[0]["generated_text"]))
136-
cleaned_output = clean_output(output[0]["generated_text"], prompt_str)
136+
cleaned_output = inference_config.clean_output(output[0]["generated_text"], prompt_str)
137137
cur_prompt["outputs"].append(cleaned_output)
138138

139139
if idx % args.num_samples_per_prompt == args.num_samples_per_prompt - 1:

generate/utils.py

Lines changed: 162 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,71 @@
11
# std imports
22
from abc import ABC, abstractmethod
3+
import re
34

45
# tpl imports
56
import torch
67
from torch.utils.data import Dataset
78
from transformers import StoppingCriteria
89

910

11+
def clean_output(output : str, prompt : str) -> str:
12+
""" Remove `prompt` from the begging of `output`.
13+
Also truncate at the end of the function definition (i.e. matching closing brace).
14+
"""
15+
# replace up to the end of the first instance of prompt
16+
prompt_loc = output.find(prompt)
17+
if prompt_loc == -1:
18+
raise ValueError(f"Prompt not found in output: {prompt}")
19+
output = output[prompt_loc + len(prompt):].strip()
20+
21+
# temporarily add opening brace to the beginning
22+
output = '{' + output
23+
24+
# find the matching brace to output[0]
25+
stack = []
26+
index = 0
27+
while index < len(output):
28+
token = output[index]
29+
if token == '{':
30+
stack.append(token)
31+
elif token == '}':
32+
stack.pop()
33+
if len(stack) == 0:
34+
break
35+
36+
index += 1
37+
38+
# truncate at the matching brace
39+
output = output[1:index+1]
40+
return output
41+
42+
GPU_FUNCTION_NAME_PATTERN = re.compile(r"__global__ void ([a-zA-Z0-9_]+)\(")
43+
CPU_FUNCTION_NAME_PATTERN = re.compile(r"\s*[a-zA-Z_]+ ([a-zA-Z0-9_]+)\(")
44+
def get_function_name(prompt: str, execution_model: str) -> str:
45+
if execution_model in ['cuda', 'hip']:
46+
match = GPU_FUNCTION_NAME_PATTERN.match(prompt.splitlines()[-1])
47+
else:
48+
match = CPU_FUNCTION_NAME_PATTERN.match(prompt.splitlines()[-1])
49+
if match is None:
50+
raise ValueError(f"Could not find function name in prompt: {prompt}")
51+
return match.group(1)
52+
53+
54+
def find_matching_brace_index(code: str, open_brace_index: int) -> int:
55+
"""Finds the index of the closing brace that matches the opening brace at the given index."""
56+
57+
brace_count = 1
58+
for i in range(open_brace_index + 1, len(code)):
59+
if code[i] == "{":
60+
brace_count += 1
61+
elif code[i] == "}":
62+
brace_count -= 1
63+
if brace_count == 0:
64+
return i
65+
66+
raise ValueError("Unmatched opening brace")
67+
68+
1069
class InferenceConfig(ABC):
1170

1271
def __init__(self, prompted : bool = False):
@@ -36,6 +95,10 @@ def trust_remote_code(self) -> bool:
3695
def format_prompt(self, prompt : str) -> str:
3796
pass
3897

98+
@abstractmethod
99+
def clean_output(self, output: str, prompt: str) -> str:
100+
pass
101+
39102

40103
class StarCoderConfig(InferenceConfig):
41104

@@ -63,6 +126,9 @@ def format_prompt(self, prompt : str) -> str:
63126
return f"<filename>solutions/solution_1.cpp\n// here is the correct implementation of the coding exercise\n\n{prompt}"
64127
return prompt.strip()
65128

129+
def clean_output(self, output: str, prompt: str) -> str:
130+
return clean_output(output, prompt)
131+
66132
class CodeLlamaConfig(InferenceConfig):
67133

68134
def __init__(self, prompted : bool = False):
@@ -90,6 +156,8 @@ def format_prompt(self, prompt : str) -> str:
90156
return f"// filename: solutions/solution_1.cpp\n// here is the correct implementation of the coding exercise\n\n{prompt}"
91157
return prompt.strip()
92158

159+
def clean_output(self, output: str, prompt: str) -> str:
160+
return clean_output(output, prompt)
93161

94162
class PolyCoderConfig(InferenceConfig):
95163

@@ -116,6 +184,9 @@ def format_prompt(self, prompt : str) -> str:
116184
if self.prompted:
117185
return f"// filename: solutions/solution_1.cpp\n// here is the correct implementation of the coding exercise\n\n{prompt}"
118186
return prompt.strip()
187+
188+
def clean_output(self, output: str, prompt: str) -> str:
189+
return clean_output(output, prompt)
119190

120191

121192
class PhindConfig(InferenceConfig):
@@ -144,6 +215,9 @@ def format_prompt(self, prompt : str) -> str:
144215
return f"// filename: solutions/solution_1.cpp\n// here is the correct implementation of the coding exercise\n\n{prompt}"
145216
return prompt.strip()
146217

218+
def clean_output(self, output: str, prompt: str) -> str:
219+
return clean_output(output, prompt)
220+
147221

148222
class ReplitConfig(InferenceConfig):
149223

@@ -174,6 +248,92 @@ def format_prompt(self, prompt : str) -> str:
174248
return f"// filename: solutions/solution_1.cpp\n// here is the correct implementation of the coding exercise\n\n{prompt}"
175249
return prompt.strip()
176250

251+
def clean_output(self, output: str, prompt: str) -> str:
252+
return clean_output(output, prompt)
253+
254+
255+
class MagicoderConfig(InferenceConfig):
256+
257+
PROMPT_TEMPLATE = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.
258+
259+
@@ Instruction
260+
{instruction}
261+
262+
@@ Response
263+
"""
264+
265+
def __init__(self, prompted : bool = False):
266+
super().__init__(prompted=prompted)
267+
268+
def get_dtype(self):
269+
return torch.bfloat16
270+
271+
def init_padding(self, tokenizer):
272+
tokenizer.pad_token_id = tokenizer.eos_token_id # for batching
273+
tokenizer.padding_side = "left" # for decoder-only models
274+
pass
275+
276+
def get_pad_token_id(self, tokenizer) -> int:
277+
return tokenizer.pad_token_id
278+
279+
def get_eos_token_id(self, tokenizer) -> int:
280+
return tokenizer.eos_token_id
281+
282+
def trust_remote_code(self) -> bool:
283+
return False
284+
285+
def format_prompt(self, prompt : str) -> str:
286+
if self.prompted:
287+
function_name = get_function_name(prompt, "cuda" if "__global__" in prompt else "serial")
288+
prompt = f"Complete the following c++ function.\n```c++{prompt.strip()}```\nWrite only the function {function_name} and no other code. Enclose your solution in ```c++ and ```."
289+
return self.PROMPT_TEMPLATE.format(instruction=prompt)
290+
return prompt.strip()
291+
292+
def clean_output(self, output: str, prompt: str) -> str:
293+
""" Clean LLM output to find code solution. The output should be in a ```c++ ``` code block. If there are
294+
multiple, then it tries to find the block with the function definition (as contained in the prompt).
295+
The code block itself may include the function definition and body OR just the body. This will try
296+
to parse both.
297+
"""
298+
# 0. replace up to the end of the first instance of prompt
299+
prompt_loc = output.find("@@ Response")
300+
if prompt_loc == -1:
301+
raise ValueError(f"Prompt not found in output: {prompt}")
302+
output = output[prompt_loc + len("@@ Response"):].strip()
303+
304+
# 1. Find all code blocks enclosed in triple backticks with "c++" language tag
305+
code_blocks = re.findall(r"```c\+\+\n(.*?)\n```", output, flags=re.DOTALL)
306+
code_blocks = [block.lstrip('```c++').rstrip('```') for block in code_blocks]
307+
308+
# 2. Prioritize code blocks containing the function definition from the prompt
309+
sub_prompt = prompt.rstrip().removesuffix("@@ Response").rstrip().removesuffix("```").split("```")[-1]
310+
function_name = get_function_name(sub_prompt, "cuda" if "__global__" in sub_prompt else "serial")
311+
prioritized_blocks = [block for block in code_blocks if function_name in block]
312+
313+
# 3. Choose the first block if multiple match, or any block if none match
314+
if len(code_blocks) > 0:
315+
selected_block = prioritized_blocks[0] if prioritized_blocks else code_blocks[0]
316+
else:
317+
if '```c++' in output: # starts with ```c++ but it didn't finish
318+
code_idx = output.find('```c++')
319+
selected_block = output[code_idx:].removeprefix('```c++')
320+
else:
321+
selected_block = output
322+
323+
# 4. Handle cases where the block contains only the function body
324+
if function_name not in selected_block:
325+
return selected_block
326+
else:
327+
function_start_index = selected_block.index(function_name)
328+
open_brace_index = selected_block.find("{", function_start_index)
329+
try:
330+
close_brace_index = find_matching_brace_index(selected_block, open_brace_index)
331+
except ValueError:
332+
close_brace_index = len(selected_block)
333+
334+
function_body = selected_block[open_brace_index + 1 : close_brace_index]
335+
return function_body + "}"
336+
177337

178338
def get_inference_config(model_name : str, **kwargs) -> InferenceConfig:
179339
if model_name == "bigcode/starcoderbase":
@@ -186,41 +346,12 @@ def get_inference_config(model_name : str, **kwargs) -> InferenceConfig:
186346
return PhindConfig(**kwargs)
187347
elif model_name == 'replit/replit-code-v1_5-3b':
188348
return ReplitConfig(**kwargs)
349+
elif model_name.startswith('ise-uiuc/Magicoder'):
350+
return MagicoderConfig(**kwargs)
189351
else:
190352
raise ValueError(f"Unknown model name: {model_name}")
191353

192354

193-
def clean_output(output : str, prompt : str) -> str:
194-
""" Remove `prompt` from the begging of `output`.
195-
Also truncate at the end of the function definition (i.e. matching closing brace).
196-
"""
197-
# replace up to the end of the first instance of prompt
198-
prompt_loc = output.find(prompt)
199-
if prompt_loc == -1:
200-
raise ValueError(f"Prompt not found in output: {prompt}")
201-
output = output[prompt_loc + len(prompt):].strip()
202-
203-
# temporarily add opening brace to the beginning
204-
output = '{' + output
205-
206-
# find the matching brace to output[0]
207-
stack = []
208-
index = 0
209-
while index < len(output):
210-
token = output[index]
211-
if token == '{':
212-
stack.append(token)
213-
elif token == '}':
214-
stack.pop()
215-
if len(stack) == 0:
216-
break
217-
218-
index += 1
219-
220-
# truncate at the matching brace
221-
output = output[1:index+1]
222-
return output
223-
224355
class PromptDataset(Dataset):
225356
''' PyTorch dataset that simply wraps a list of strings. They do not have to have the same length.
226357
'''

0 commit comments

Comments
 (0)