Skip to content

Commit 1d8c6ff

Browse files
authored
Adding Qwen 2.5 and Qwen 3 Series Support (#51)
1 parent f4b3083 commit 1d8c6ff

File tree

1 file changed

+63
-1
lines changed

1 file changed

+63
-1
lines changed

generate/utils.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,6 @@ def clean_output(self, output: str, prompt: str) -> str:
370370

371371

372372
class InstructConfig(InferenceConfig):
373-
374373
def __init__(self, prompted : bool = False, instruction_tag : str = "### Instruction", response_tag : str = "### Response"):
375374
super().__init__(prompted=prompted)
376375
self.instruction_tag = instruction_tag
@@ -401,6 +400,63 @@ def format_prompt(self, prompt : str) -> str:
401400
def clean_output(self, output: str, prompt: str) -> str:
402401
return clean_instruct_output(output, prompt, self.response_tag)
403402

403+
class QwenConfig(InferenceConfig):
404+
def __init__(self, prompted : bool = False):
405+
super().__init__(prompted=prompted)
406+
407+
def get_dtype(self):
408+
return torch.float16
409+
410+
def init_padding(self, tokenizer):
411+
tokenizer.pad_token_id = tokenizer.eos_token_id # for batching
412+
tokenizer.padding_side = "left" # for decoder-only models
413+
414+
def get_pad_token_id(self, tokenizer) -> int:
415+
return tokenizer.eos_token_id
416+
417+
def get_eos_token_id(self, tokenizer) -> int:
418+
return None
419+
420+
def trust_remote_code(self) -> bool:
421+
return False
422+
423+
def format_prompt(self, prompt : str) -> str:
424+
if self.prompted:
425+
return f"// filename: solutions/solution_1.cpp\n// here is the correct implementation of the coding exercise\n\n{prompt}"
426+
return prompt.strip()
427+
428+
def clean_output(self, output: str, prompt: str) -> str:
429+
return clean_output(output, prompt)
430+
431+
class ChatMLConfig(InferenceConfig):
432+
def __init__(self, prompted : bool = False):
433+
super().__init__(prompted=prompted)
434+
435+
def get_dtype(self):
436+
return torch.bfloat16
437+
438+
def init_padding(self, tokenizer):
439+
tokenizer.pad_token_id = tokenizer.eos_token_id # for batching
440+
tokenizer.padding_side = "left" # for decoder-only models
441+
442+
def get_pad_token_id(self, tokenizer) -> int:
443+
return tokenizer.pad_token_id
444+
445+
def get_eos_token_id(self, tokenizer) -> int:
446+
return tokenizer.eos_token_id
447+
448+
def trust_remote_code(self) -> bool:
449+
return False
450+
451+
def format_prompt(self, prompt : str) -> str:
452+
function_name = get_function_name(prompt, "cuda" if "__global__" in prompt else "serial")
453+
prompt = f"Complete the following c++ function.\n```c++{prompt.strip()}```\nWrite only the function {function_name} and no other code. Enclose your solution in ```c++ and ```."
454+
prompt = f"<|im_start|>system\nYou are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
455+
return prompt
456+
457+
def clean_output(self, output: str, prompt: str) -> str:
458+
return clean_instruct_output(output, prompt,"<|im_start|>assistant\n")
459+
404460
def get_inference_config(model_name : str, **kwargs) -> InferenceConfig:
405461
if model_name == "bigcode/starcoderbase":
406462
return StarCoderConfig(**kwargs)
@@ -422,6 +478,12 @@ def get_inference_config(model_name : str, **kwargs) -> InferenceConfig:
422478
return InstructConfig(instruction_tag='Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:', response_tag='### Response:', **kwargs)
423479
elif model_name.startswith('hpcgroup/rlpf'):
424480
return InstructConfig(instruction_tag='### Instruction', response_tag='### Response', **kwargs)
481+
elif model_name.startswith('Qwen/Qwen2.5') and 'Instruct' in model_name:
482+
return ChatMLConfig(**kwargs)
483+
elif model_name.startswith('Qwen/Qwen3'):
484+
return ChatMLConfig(**kwargs)
485+
elif model_name.startswith('Qwen/Qwen2.5'):
486+
return QwenConfig(**kwargs)
425487
else:
426488
raise ValueError(f"Unknown model name: {model_name}")
427489

0 commit comments

Comments
 (0)