16
16
"""
17
17
import enum
18
18
19
+ from .logging import get_logger
20
+
21
+
22
+ logger = get_logger (__name__ )
23
+
19
24
20
25
class StateDictType (enum .Enum ):
21
26
"""
22
27
The mode to use when converting state dicts.
23
28
"""
24
29
25
30
DIFFUSERS_OLD = "diffusers_old"
26
- # KOHYA_SS = "kohya_ss" # TODO: implement this
31
+ KOHYA_SS = "kohya_ss"
27
32
PEFT = "peft"
28
33
DIFFUSERS = "diffusers"
29
34
@@ -100,6 +105,14 @@ class StateDictType(enum.Enum):
100
105
".to_out_lora.down" : ".out_proj.lora_linear_layer.down" ,
101
106
}
102
107
108
+ PEFT_TO_KOHYA_SS = {
109
+ "lora_A" : "lora_down" ,
110
+ "lora_B" : "lora_up" ,
111
+ # This is not a comprehensive dict as kohya format requires replacing `.` with `_` in keys,
112
+ # adding prefixes and adding alpha values
113
+ # Check `convert_state_dict_to_kohya` for more
114
+ }
115
+
103
116
PEFT_STATE_DICT_MAPPINGS = {
104
117
StateDictType .DIFFUSERS_OLD : DIFFUSERS_OLD_TO_PEFT ,
105
118
StateDictType .DIFFUSERS : DIFFUSERS_TO_PEFT ,
@@ -110,6 +123,8 @@ class StateDictType(enum.Enum):
110
123
StateDictType .PEFT : PEFT_TO_DIFFUSERS ,
111
124
}
112
125
126
+ KOHYA_STATE_DICT_MAPPINGS = {StateDictType .PEFT : PEFT_TO_KOHYA_SS }
127
+
113
128
KEYS_TO_ALWAYS_REPLACE = {
114
129
".processor." : "." ,
115
130
}
@@ -228,3 +243,82 @@ def convert_unet_state_dict_to_peft(state_dict):
228
243
"""
229
244
mapping = UNET_TO_DIFFUSERS
230
245
return convert_state_dict (state_dict , mapping )
246
+
247
+
248
+ def convert_all_state_dict_to_peft (state_dict ):
249
+ r"""
250
+ Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer`
251
+ for a valid `DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft`
252
+ """
253
+ try :
254
+ peft_dict = convert_state_dict_to_peft (state_dict )
255
+ except Exception as e :
256
+ if str (e ) == "Could not automatically infer state dict type" :
257
+ peft_dict = convert_unet_state_dict_to_peft (state_dict )
258
+ else :
259
+ raise
260
+
261
+ if not any ("lora_A" in key or "lora_B" in key for key in peft_dict .keys ()):
262
+ raise ValueError ("Your LoRA was not converted to PEFT" )
263
+
264
+ return peft_dict
265
+
266
+
267
+ def convert_state_dict_to_kohya (state_dict , original_type = None , ** kwargs ):
268
+ r"""
269
+ Converts a `PEFT` state dict to `Kohya` format that can be used in AUTOMATIC1111, ComfyUI, SD.Next, InvokeAI, etc.
270
+ The method only supports the conversion from PEFT to Kohya for now.
271
+
272
+ Args:
273
+ state_dict (`dict[str, torch.Tensor]`):
274
+ The state dict to convert.
275
+ original_type (`StateDictType`, *optional*):
276
+ The original type of the state dict, if not provided, the method will try to infer it automatically.
277
+ kwargs (`dict`, *args*):
278
+ Additional arguments to pass to the method.
279
+
280
+ - **adapter_name**: For example, in case of PEFT, some keys will be pre-pended
281
+ with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in
282
+ `get_peft_model_state_dict` method:
283
+ https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
284
+ but we add it here in case we don't want to rely on that method.
285
+ """
286
+ try :
287
+ import torch
288
+ except ImportError :
289
+ logger .error ("Converting PEFT state dicts to Kohya requires torch to be installed." )
290
+ raise
291
+
292
+ peft_adapter_name = kwargs .pop ("adapter_name" , None )
293
+ if peft_adapter_name is not None :
294
+ peft_adapter_name = "." + peft_adapter_name
295
+ else :
296
+ peft_adapter_name = ""
297
+
298
+ if original_type is None :
299
+ if any (f".lora_A{ peft_adapter_name } .weight" in k for k in state_dict .keys ()):
300
+ original_type = StateDictType .PEFT
301
+
302
+ if original_type not in KOHYA_STATE_DICT_MAPPINGS .keys ():
303
+ raise ValueError (f"Original type { original_type } is not supported" )
304
+
305
+ # Use the convert_state_dict function with the appropriate mapping
306
+ kohya_ss_partial_state_dict = convert_state_dict (state_dict , KOHYA_STATE_DICT_MAPPINGS [StateDictType .PEFT ])
307
+ kohya_ss_state_dict = {}
308
+
309
+ # Additional logic for replacing header, alpha parameters `.` with `_` in all keys
310
+ for kohya_key , weight in kohya_ss_partial_state_dict .items ():
311
+ if "text_encoder_2." in kohya_key :
312
+ kohya_key = kohya_key .replace ("text_encoder_2." , "lora_te2." )
313
+ elif "text_encoder." in kohya_key :
314
+ kohya_key = kohya_key .replace ("text_encoder." , "lora_te1." )
315
+ elif "unet" in kohya_key :
316
+ kohya_key = kohya_key .replace ("unet" , "lora_unet" )
317
+ kohya_key = kohya_key .replace ("." , "_" , kohya_key .count ("." ) - 2 )
318
+ kohya_key = kohya_key .replace (peft_adapter_name , "" ) # Kohya doesn't take names
319
+ kohya_ss_state_dict [kohya_key ] = weight
320
+ if "lora_down" in kohya_key :
321
+ alpha_key = f'{ kohya_key .split ("." )[0 ]} .alpha'
322
+ kohya_ss_state_dict [alpha_key ] = torch .tensor (len (weight ))
323
+
324
+ return kohya_ss_state_dict
0 commit comments