Skip to content

Commit d3afa26

Browse files
committed
lora constants.
1 parent bc74fe8 commit d3afa26

File tree

2 files changed

+4
-10
lines changed

2 files changed

+4
-10
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@
5151

5252
logger = logging.get_logger(__name__)
5353

54+
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
55+
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
56+
5457

5558
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
5659
"""
@@ -195,8 +198,6 @@ def _fetch_state_dict(
195198
user_agent,
196199
allow_pickle,
197200
):
198-
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
199-
200201
model_file = None
201202
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
202203
# Let's first try to load .safetensors weights
@@ -260,8 +261,6 @@ def _fetch_state_dict(
260261
def _best_guess_weight_name(
261262
pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
262263
):
263-
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
264-
265264
if local_files_only or HF_HUB_OFFLINE:
266265
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
267266

@@ -722,8 +721,6 @@ def write_lora_layers(
722721
save_function: Callable,
723722
safe_serialization: bool,
724723
):
725-
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
726-
727724
if os.path.isfile(save_directory):
728725
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
729726
return

src/diffusers/loaders/lora_pipeline.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
logging,
3333
scale_lora_layers,
3434
)
35-
from .lora_base import LoraBaseMixin, _fetch_state_dict
35+
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa
3636
from .lora_conversion_utils import (
3737
_convert_kohya_flux_lora_to_diffusers,
3838
_convert_non_diffusers_lora_to_diffusers,
@@ -61,9 +61,6 @@
6161
UNET_NAME = "unet"
6262
TRANSFORMER_NAME = "transformer"
6363

64-
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
65-
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
66-
6764

6865
class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
6966
r"""

0 commit comments

Comments
 (0)