1
+ # std imports
2
+ import argparse
3
+ import json
4
+ import os
5
+ import sys
6
+ import time
7
+ from tqdm import tqdm
8
+ import torch
9
+
10
+ # tpl imports
11
+ from vllm import LLM , SamplingParams
12
+
13
+ # local imports
14
+ from utils import BalancedBracketsCriteria , PromptDataset , clean_output , get_inference_config
15
+
16
+ """ Parse command line arguments """
17
+ parser = argparse .ArgumentParser (description = 'Generate code with vLLM' )
18
+ parser .add_argument ('--prompts' , required = True , help = 'Path to the prompt JSON file' )
19
+ parser .add_argument ('--model' , required = True , help = 'Path to the language model' )
20
+ parser .add_argument ('--output' , required = True , help = 'Path to the output JSON file' )
21
+ parser .add_argument ('--restart' , action = 'store_true' , help = 'Restart generation from scratch (default: False)' )
22
+ parser .add_argument ('--cache' , help = 'JSONL file to cache intermediate results in. Will be restored from if it ' +
23
+ 'already exists and --restart is not specified' )
24
+ parser .add_argument ('--restore_from' , help = 'JSON file to restore old results from. Will be restored from ' +
25
+ 'if it already exists and --restart is not specified. Is different from --cache in that it is a JSON file, not a ' +
26
+ 'JSONL file, and it is only used to restore old results where the prompt is equivalent. Cached results are ' +
27
+ 'prioritized over restored results.' )
28
+ parser .add_argument ('--max_new_tokens' , type = int , default = 1024 , help = 'Maximum number of new tokens to generate (default: 1024)' )
29
+ parser .add_argument ('--num_samples_per_prompt' , type = int , default = 50 , help = 'Number of code samples to generate (default: 50)' )
30
+ parser .add_argument ('--temperature' , type = float , default = 0.2 , help = 'Temperature for controlling randomness (default: 0.2)' )
31
+ parser .add_argument ('--top_p' , type = float , default = 0.95 , help = 'Top p value for nucleus sampling (default: 0.95)' )
32
+ parser .add_argument ('--do_sample' , action = 'store_true' , help = 'Enable sampling (default: False)' )
33
+ parser .add_argument ('--prompted' , action = 'store_true' , help = 'Use prompted generation. See StarCoder paper (default: False)' )
34
+ args = parser .parse_args ()
35
+
36
+ """ Load prompts """
37
+ with open (args .prompts , 'r' ) as json_file :
38
+ prompts = json .load (json_file )
39
+
40
+ """ Load existing responses if they exist """
41
+ if not args .restart and os .path .exists (args .cache ):
42
+ with open (args .cache , 'r' ) as jsonl_file :
43
+ responses = [json .loads (line ) for line in jsonl_file ]
44
+
45
+ # remove prompt from prompts if it is in responses and has an 'output' value with at least 1 entry
46
+ original_len = len (prompts )
47
+ prompts = [p for p in prompts if
48
+ not any (p ["name" ] == r ["name" ] and
49
+ p ["parallelism_model" ] == r ["parallelism_model" ] and
50
+ p ["prompt" ] == r ["prompt" ] and
51
+ args .temperature == r ["temperature" ] and
52
+ args .prompted == r ["prompted" ] and
53
+ args .num_samples_per_prompt == len (r ["outputs" ])
54
+ for r in responses )]
55
+ print (f"[cache] Skipping { original_len - len (prompts )} prompts that already have responses" )
56
+
57
+ """ Load existing responses if they exist """
58
+ if not args .restart and args .restore_from and os .path .exists (args .restore_from ):
59
+ with open (args .restore_from , 'r' ) as json_file :
60
+ restored_responses = json .load (json_file )
61
+
62
+ # remove prompt from prompts if it is in responses and has an 'output' value with at least 1 entry
63
+ original_len = len (prompts )
64
+ responses_to_keep = []
65
+ prompts_without_existing_responses = []
66
+ for p in prompts :
67
+ for r in restored_responses :
68
+ if p ["name" ] == r ["name" ] and \
69
+ p ["parallelism_model" ] == r ["parallelism_model" ] and \
70
+ p ["prompt" ] == r ["prompt" ] and \
71
+ args .temperature == r ["temperature" ] and \
72
+ args .prompted == r ["prompted" ] and \
73
+ args .num_samples_per_prompt == len (r ["outputs" ]):
74
+ responses_to_keep .append (r )
75
+ break
76
+ else :
77
+ prompts_without_existing_responses .append (p )
78
+ prompts = prompts_without_existing_responses
79
+ print (f"[restore_from] Skipping { original_len - len (prompts )} prompts that already have responses. " +
80
+ f"{ len (prompts )} prompts left." )
81
+
82
+ # write restored responses to cache
83
+ if args .cache is not None :
84
+ with open (args .cache , 'a' ) as jsonl_file :
85
+ for response in responses_to_keep :
86
+ jsonl_file .write (json .dumps (response ) + "\n " )
87
+ print (f"[restore_from] Wrote { len (responses_to_keep )} restored responses to cache" )
88
+
89
+ """ Initialize inference config """
90
+ inference_config = get_inference_config (args .model , prompted = args .prompted )
91
+
92
+ prompts_repeated = [p for p in prompts for _ in range (args .num_samples_per_prompt )]
93
+
94
+ """ Initialize vLLM engine """
95
+ llm = LLM (model = args .model , tensor_parallel_size = torch .cuda .device_count ())
96
+
97
+ # Configure sampling parameters
98
+ sampling_params = SamplingParams (
99
+ temperature = args .temperature if args .do_sample else 0 ,
100
+ top_p = args .top_p if args .do_sample else 1.0 ,
101
+ max_tokens = args .max_new_tokens ,
102
+ n = 1 , # We handle multiple samples manually
103
+ )
104
+
105
+ """ Generate code """
106
+ if not args .restart and args .cache is not None and os .path .exists (args .cache ):
107
+ with open (args .cache , 'r' ) as jsonl_file :
108
+ responses = [json .loads (line ) for line in jsonl_file ]
109
+ responses = [r for r in responses if r ["temperature" ] == args .temperature and r ["prompted" ] == args .prompted
110
+ and args .num_samples_per_prompt == len (r ["outputs" ])
111
+ and any (p ["name" ] == r ["name" ] and p ["prompt" ] == r ["prompt" ] and p ["parallelism_model" ] == r ["parallelism_model" ] for p in prompts )]
112
+ else :
113
+ responses = []
114
+
115
+ cur_prompt = None
116
+ start_time = time .time ()
117
+ total_tokens = 0
118
+
119
+ # Format all prompts
120
+ formatted_prompts = [inference_config .format_prompt (p ["prompt" ]) for p in prompts_repeated ]
121
+
122
+ # Generate all outputs at once
123
+ outputs = llm .generate (formatted_prompts , sampling_params )
124
+
125
+ # Process outputs
126
+ for idx , (prompt , output ) in enumerate (zip (prompts_repeated , outputs )):
127
+ if idx % args .num_samples_per_prompt == 0 :
128
+ cur_prompt = prompt .copy ()
129
+ cur_prompt .update ({
130
+ "temperature" : args .temperature ,
131
+ "top_p" : args .top_p ,
132
+ "do_sample" : args .do_sample ,
133
+ "max_new_tokens" : args .max_new_tokens ,
134
+ "prompted" : args .prompted
135
+ })
136
+ cur_prompt ["outputs" ] = []
137
+ cur_prompt ["raw_outputs" ] = []
138
+ prompt_str = cur_prompt ["prompt" ]
139
+
140
+ # Count tokens and clean output
141
+ # FIXME: This is to keep the same behavior as generate.py
142
+ huggingface_style_output = output .prompt + output .outputs [0 ].text
143
+ total_tokens += len (llm .get_tokenizer ().encode (huggingface_style_output ))
144
+ cleaned_output = inference_config .clean_output (huggingface_style_output , prompt_str )
145
+ cur_prompt ["outputs" ].append (cleaned_output )
146
+ cur_prompt ["raw_outputs" ].append (huggingface_style_output )
147
+
148
+ if idx % args .num_samples_per_prompt == args .num_samples_per_prompt - 1 :
149
+ responses .append (cur_prompt )
150
+
151
+ if not args .restart and args .cache is not None :
152
+ with open (args .cache , 'a' ) as jsonl_file :
153
+ jsonl_file .write (json .dumps (cur_prompt ) + "\n " )
154
+
155
+ end_time = time .time ()
156
+ tokens_per_second = total_tokens / (end_time - start_time )
157
+ print (f"Generated { len (responses )} code samples in { end_time - start_time :.2f} seconds ({ tokens_per_second :.2f} tokens per second)" )
158
+
159
+ """ Save responses to JSON file """
160
+ with open (args .output , 'w' ) as output_file :
161
+ json .dump (responses , output_file , indent = 4 )
0 commit comments