@@ -370,7 +370,6 @@ def clean_output(self, output: str, prompt: str) -> str:
370
370
371
371
372
372
class InstructConfig (InferenceConfig ):
373
-
374
373
def __init__ (self , prompted : bool = False , instruction_tag : str = "### Instruction" , response_tag : str = "### Response" ):
375
374
super ().__init__ (prompted = prompted )
376
375
self .instruction_tag = instruction_tag
@@ -401,6 +400,63 @@ def format_prompt(self, prompt : str) -> str:
401
400
def clean_output (self , output : str , prompt : str ) -> str :
402
401
return clean_instruct_output (output , prompt , self .response_tag )
403
402
403
+ class QwenConfig (InferenceConfig ):
404
+ def __init__ (self , prompted : bool = False ):
405
+ super ().__init__ (prompted = prompted )
406
+
407
+ def get_dtype (self ):
408
+ return torch .float16
409
+
410
+ def init_padding (self , tokenizer ):
411
+ tokenizer .pad_token_id = tokenizer .eos_token_id # for batching
412
+ tokenizer .padding_side = "left" # for decoder-only models
413
+
414
+ def get_pad_token_id (self , tokenizer ) -> int :
415
+ return tokenizer .eos_token_id
416
+
417
+ def get_eos_token_id (self , tokenizer ) -> int :
418
+ return None
419
+
420
+ def trust_remote_code (self ) -> bool :
421
+ return False
422
+
423
+ def format_prompt (self , prompt : str ) -> str :
424
+ if self .prompted :
425
+ return f"// filename: solutions/solution_1.cpp\n // here is the correct implementation of the coding exercise\n \n { prompt } "
426
+ return prompt .strip ()
427
+
428
+ def clean_output (self , output : str , prompt : str ) -> str :
429
+ return clean_output (output , prompt )
430
+
431
+ class ChatMLConfig (InferenceConfig ):
432
+ def __init__ (self , prompted : bool = False ):
433
+ super ().__init__ (prompted = prompted )
434
+
435
+ def get_dtype (self ):
436
+ return torch .bfloat16
437
+
438
+ def init_padding (self , tokenizer ):
439
+ tokenizer .pad_token_id = tokenizer .eos_token_id # for batching
440
+ tokenizer .padding_side = "left" # for decoder-only models
441
+
442
+ def get_pad_token_id (self , tokenizer ) -> int :
443
+ return tokenizer .pad_token_id
444
+
445
+ def get_eos_token_id (self , tokenizer ) -> int :
446
+ return tokenizer .eos_token_id
447
+
448
+ def trust_remote_code (self ) -> bool :
449
+ return False
450
+
451
+ def format_prompt (self , prompt : str ) -> str :
452
+ function_name = get_function_name (prompt , "cuda" if "__global__" in prompt else "serial" )
453
+ prompt = f"Complete the following c++ function.\n ```c++{ prompt .strip ()} ```\n Write only the function { function_name } and no other code. Enclose your solution in ```c++ and ```."
454
+ prompt = f"<|im_start|>system\n You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.<|im_end|>\n <|im_start|>user\n { prompt } <|im_end|>\n <|im_start|>assistant\n "
455
+ return prompt
456
+
457
+ def clean_output (self , output : str , prompt : str ) -> str :
458
+ return clean_instruct_output (output , prompt ,"<|im_start|>assistant\n " )
459
+
404
460
def get_inference_config (model_name : str , ** kwargs ) -> InferenceConfig :
405
461
if model_name == "bigcode/starcoderbase" :
406
462
return StarCoderConfig (** kwargs )
@@ -422,6 +478,12 @@ def get_inference_config(model_name : str, **kwargs) -> InferenceConfig:
422
478
return InstructConfig (instruction_tag = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n \n ### Instruction:' , response_tag = '### Response:' , ** kwargs )
423
479
elif model_name .startswith ('hpcgroup/rlpf' ):
424
480
return InstructConfig (instruction_tag = '### Instruction' , response_tag = '### Response' , ** kwargs )
481
+ elif model_name .startswith ('Qwen/Qwen2.5' ) and 'Instruct' in model_name :
482
+ return ChatMLConfig (** kwargs )
483
+ elif model_name .startswith ('Qwen/Qwen3' ):
484
+ return ChatMLConfig (** kwargs )
485
+ elif model_name .startswith ('Qwen/Qwen2.5' ):
486
+ return QwenConfig (** kwargs )
425
487
else :
426
488
raise ValueError (f"Unknown model name: { model_name } " )
427
489
0 commit comments