1
1
# std imports
2
2
from abc import ABC , abstractmethod
3
+ import re
3
4
4
5
# tpl imports
5
6
import torch
6
7
from torch .utils .data import Dataset
7
8
from transformers import StoppingCriteria
8
9
9
10
11
+ def clean_output (output : str , prompt : str ) -> str :
12
+ """ Remove `prompt` from the begging of `output`.
13
+ Also truncate at the end of the function definition (i.e. matching closing brace).
14
+ """
15
+ # replace up to the end of the first instance of prompt
16
+ prompt_loc = output .find (prompt )
17
+ if prompt_loc == - 1 :
18
+ raise ValueError (f"Prompt not found in output: { prompt } " )
19
+ output = output [prompt_loc + len (prompt ):].strip ()
20
+
21
+ # temporarily add opening brace to the beginning
22
+ output = '{' + output
23
+
24
+ # find the matching brace to output[0]
25
+ stack = []
26
+ index = 0
27
+ while index < len (output ):
28
+ token = output [index ]
29
+ if token == '{' :
30
+ stack .append (token )
31
+ elif token == '}' :
32
+ stack .pop ()
33
+ if len (stack ) == 0 :
34
+ break
35
+
36
+ index += 1
37
+
38
+ # truncate at the matching brace
39
+ output = output [1 :index + 1 ]
40
+ return output
41
+
42
+ GPU_FUNCTION_NAME_PATTERN = re .compile (r"__global__ void ([a-zA-Z0-9_]+)\(" )
43
+ CPU_FUNCTION_NAME_PATTERN = re .compile (r"\s*[a-zA-Z_]+ ([a-zA-Z0-9_]+)\(" )
44
+ def get_function_name (prompt : str , execution_model : str ) -> str :
45
+ if execution_model in ['cuda' , 'hip' ]:
46
+ match = GPU_FUNCTION_NAME_PATTERN .match (prompt .splitlines ()[- 1 ])
47
+ else :
48
+ match = CPU_FUNCTION_NAME_PATTERN .match (prompt .splitlines ()[- 1 ])
49
+ if match is None :
50
+ raise ValueError (f"Could not find function name in prompt: { prompt } " )
51
+ return match .group (1 )
52
+
53
+
54
+ def find_matching_brace_index (code : str , open_brace_index : int ) -> int :
55
+ """Finds the index of the closing brace that matches the opening brace at the given index."""
56
+
57
+ brace_count = 1
58
+ for i in range (open_brace_index + 1 , len (code )):
59
+ if code [i ] == "{" :
60
+ brace_count += 1
61
+ elif code [i ] == "}" :
62
+ brace_count -= 1
63
+ if brace_count == 0 :
64
+ return i
65
+
66
+ raise ValueError ("Unmatched opening brace" )
67
+
68
+
10
69
class InferenceConfig (ABC ):
11
70
12
71
def __init__ (self , prompted : bool = False ):
@@ -36,6 +95,10 @@ def trust_remote_code(self) -> bool:
36
95
def format_prompt (self , prompt : str ) -> str :
37
96
pass
38
97
98
+ @abstractmethod
99
+ def clean_output (self , output : str , prompt : str ) -> str :
100
+ pass
101
+
39
102
40
103
class StarCoderConfig (InferenceConfig ):
41
104
@@ -63,6 +126,9 @@ def format_prompt(self, prompt : str) -> str:
63
126
return f"<filename>solutions/solution_1.cpp\n // here is the correct implementation of the coding exercise\n \n { prompt } "
64
127
return prompt .strip ()
65
128
129
+ def clean_output (self , output : str , prompt : str ) -> str :
130
+ return clean_output (output , prompt )
131
+
66
132
class CodeLlamaConfig (InferenceConfig ):
67
133
68
134
def __init__ (self , prompted : bool = False ):
@@ -90,6 +156,8 @@ def format_prompt(self, prompt : str) -> str:
90
156
return f"// filename: solutions/solution_1.cpp\n // here is the correct implementation of the coding exercise\n \n { prompt } "
91
157
return prompt .strip ()
92
158
159
+ def clean_output (self , output : str , prompt : str ) -> str :
160
+ return clean_output (output , prompt )
93
161
94
162
class PolyCoderConfig (InferenceConfig ):
95
163
@@ -116,6 +184,9 @@ def format_prompt(self, prompt : str) -> str:
116
184
if self .prompted :
117
185
return f"// filename: solutions/solution_1.cpp\n // here is the correct implementation of the coding exercise\n \n { prompt } "
118
186
return prompt .strip ()
187
+
188
+ def clean_output (self , output : str , prompt : str ) -> str :
189
+ return clean_output (output , prompt )
119
190
120
191
121
192
class PhindConfig (InferenceConfig ):
@@ -144,6 +215,9 @@ def format_prompt(self, prompt : str) -> str:
144
215
return f"// filename: solutions/solution_1.cpp\n // here is the correct implementation of the coding exercise\n \n { prompt } "
145
216
return prompt .strip ()
146
217
218
+ def clean_output (self , output : str , prompt : str ) -> str :
219
+ return clean_output (output , prompt )
220
+
147
221
148
222
class ReplitConfig (InferenceConfig ):
149
223
@@ -174,6 +248,92 @@ def format_prompt(self, prompt : str) -> str:
174
248
return f"// filename: solutions/solution_1.cpp\n // here is the correct implementation of the coding exercise\n \n { prompt } "
175
249
return prompt .strip ()
176
250
251
+ def clean_output (self , output : str , prompt : str ) -> str :
252
+ return clean_output (output , prompt )
253
+
254
+
255
+ class MagicoderConfig (InferenceConfig ):
256
+
257
+ PROMPT_TEMPLATE = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.
258
+
259
+ @@ Instruction
260
+ {instruction}
261
+
262
+ @@ Response
263
+ """
264
+
265
+ def __init__ (self , prompted : bool = False ):
266
+ super ().__init__ (prompted = prompted )
267
+
268
+ def get_dtype (self ):
269
+ return torch .bfloat16
270
+
271
+ def init_padding (self , tokenizer ):
272
+ tokenizer .pad_token_id = tokenizer .eos_token_id # for batching
273
+ tokenizer .padding_side = "left" # for decoder-only models
274
+ pass
275
+
276
+ def get_pad_token_id (self , tokenizer ) -> int :
277
+ return tokenizer .pad_token_id
278
+
279
+ def get_eos_token_id (self , tokenizer ) -> int :
280
+ return tokenizer .eos_token_id
281
+
282
+ def trust_remote_code (self ) -> bool :
283
+ return False
284
+
285
+ def format_prompt (self , prompt : str ) -> str :
286
+ if self .prompted :
287
+ function_name = get_function_name (prompt , "cuda" if "__global__" in prompt else "serial" )
288
+ prompt = f"Complete the following c++ function.\n ```c++{ prompt .strip ()} ```\n Write only the function { function_name } and no other code. Enclose your solution in ```c++ and ```."
289
+ return self .PROMPT_TEMPLATE .format (instruction = prompt )
290
+ return prompt .strip ()
291
+
292
+ def clean_output (self , output : str , prompt : str ) -> str :
293
+ """ Clean LLM output to find code solution. The output should be in a ```c++ ``` code block. If there are
294
+ multiple, then it tries to find the block with the function definition (as contained in the prompt).
295
+ The code block itself may include the function definition and body OR just the body. This will try
296
+ to parse both.
297
+ """
298
+ # 0. replace up to the end of the first instance of prompt
299
+ prompt_loc = output .find ("@@ Response" )
300
+ if prompt_loc == - 1 :
301
+ raise ValueError (f"Prompt not found in output: { prompt } " )
302
+ output = output [prompt_loc + len ("@@ Response" ):].strip ()
303
+
304
+ # 1. Find all code blocks enclosed in triple backticks with "c++" language tag
305
+ code_blocks = re .findall (r"```c\+\+\n(.*?)\n```" , output , flags = re .DOTALL )
306
+ code_blocks = [block .lstrip ('```c++' ).rstrip ('```' ) for block in code_blocks ]
307
+
308
+ # 2. Prioritize code blocks containing the function definition from the prompt
309
+ sub_prompt = prompt .rstrip ().removesuffix ("@@ Response" ).rstrip ().removesuffix ("```" ).split ("```" )[- 1 ]
310
+ function_name = get_function_name (sub_prompt , "cuda" if "__global__" in sub_prompt else "serial" )
311
+ prioritized_blocks = [block for block in code_blocks if function_name in block ]
312
+
313
+ # 3. Choose the first block if multiple match, or any block if none match
314
+ if len (code_blocks ) > 0 :
315
+ selected_block = prioritized_blocks [0 ] if prioritized_blocks else code_blocks [0 ]
316
+ else :
317
+ if '```c++' in output : # starts with ```c++ but it didn't finish
318
+ code_idx = output .find ('```c++' )
319
+ selected_block = output [code_idx :].removeprefix ('```c++' )
320
+ else :
321
+ selected_block = output
322
+
323
+ # 4. Handle cases where the block contains only the function body
324
+ if function_name not in selected_block :
325
+ return selected_block
326
+ else :
327
+ function_start_index = selected_block .index (function_name )
328
+ open_brace_index = selected_block .find ("{" , function_start_index )
329
+ try :
330
+ close_brace_index = find_matching_brace_index (selected_block , open_brace_index )
331
+ except ValueError :
332
+ close_brace_index = len (selected_block )
333
+
334
+ function_body = selected_block [open_brace_index + 1 : close_brace_index ]
335
+ return function_body + "}"
336
+
177
337
178
338
def get_inference_config (model_name : str , ** kwargs ) -> InferenceConfig :
179
339
if model_name == "bigcode/starcoderbase" :
@@ -186,41 +346,12 @@ def get_inference_config(model_name : str, **kwargs) -> InferenceConfig:
186
346
return PhindConfig (** kwargs )
187
347
elif model_name == 'replit/replit-code-v1_5-3b' :
188
348
return ReplitConfig (** kwargs )
349
+ elif model_name .startswith ('ise-uiuc/Magicoder' ):
350
+ return MagicoderConfig (** kwargs )
189
351
else :
190
352
raise ValueError (f"Unknown model name: { model_name } " )
191
353
192
354
193
- def clean_output (output : str , prompt : str ) -> str :
194
- """ Remove `prompt` from the begging of `output`.
195
- Also truncate at the end of the function definition (i.e. matching closing brace).
196
- """
197
- # replace up to the end of the first instance of prompt
198
- prompt_loc = output .find (prompt )
199
- if prompt_loc == - 1 :
200
- raise ValueError (f"Prompt not found in output: { prompt } " )
201
- output = output [prompt_loc + len (prompt ):].strip ()
202
-
203
- # temporarily add opening brace to the beginning
204
- output = '{' + output
205
-
206
- # find the matching brace to output[0]
207
- stack = []
208
- index = 0
209
- while index < len (output ):
210
- token = output [index ]
211
- if token == '{' :
212
- stack .append (token )
213
- elif token == '}' :
214
- stack .pop ()
215
- if len (stack ) == 0 :
216
- break
217
-
218
- index += 1
219
-
220
- # truncate at the matching brace
221
- output = output [1 :index + 1 ]
222
- return output
223
-
224
355
class PromptDataset (Dataset ):
225
356
''' PyTorch dataset that simply wraps a list of strings. They do not have to have the same length.
226
357
'''
0 commit comments