-
Notifications
You must be signed in to change notification settings - Fork 6k
Fix for fetching variants only #10646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
403417e
9f0ae2f
974f67e
9f9db3b
2089700
a4bdc97
04d7dc3
c40f60c
ac4c23c
420c78c
abba8e0
6899f40
3db5a69
b79e720
02b0892
a29f742
30628b4
f568805
f35f83b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -104,7 +104,7 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No | |
extension is replaced with ".safetensors" | ||
""" | ||
passed_components = passed_components or [] | ||
if folder_names is not None: | ||
if folder_names: | ||
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names} | ||
|
||
# extract all components of the pipeline and their associated files | ||
|
@@ -141,7 +141,25 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No | |
return True | ||
|
||
|
||
def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]: | ||
def filter_model_files(filenames): | ||
"""Filter model repo files for just files/folders that contain model weights""" | ||
weight_names = [ | ||
WEIGHTS_NAME, | ||
SAFETENSORS_WEIGHTS_NAME, | ||
FLAX_WEIGHTS_NAME, | ||
ONNX_WEIGHTS_NAME, | ||
ONNX_EXTERNAL_WEIGHTS_NAME, | ||
] | ||
|
||
if is_transformers_available(): | ||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME] | ||
|
||
allowed_extensions = [wn.split(".")[-1] for wn in weight_names] | ||
|
||
return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)] | ||
|
||
|
||
def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]: | ||
weight_names = [ | ||
WEIGHTS_NAME, | ||
SAFETENSORS_WEIGHTS_NAME, | ||
|
@@ -169,6 +187,10 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi | |
variant_index_re = re.compile( | ||
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" | ||
) | ||
legacy_variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$") | ||
legacy_variant_index_re = re.compile( | ||
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.{variant}\.index\.json$" | ||
) | ||
|
||
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors` | ||
non_variant_file_re = re.compile( | ||
|
@@ -177,54 +199,68 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi | |
# `text_encoder/pytorch_model.bin.index.json` | ||
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json") | ||
|
||
if variant is not None: | ||
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None} | ||
variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None} | ||
variant_filenames = variant_weights | variant_indexes | ||
else: | ||
variant_filenames = set() | ||
def filter_for_compatible_extensions(filenames, ignore_patterns=None): | ||
if not ignore_patterns: | ||
return filenames | ||
|
||
# ignore patterns uses glob style patterns e.g *.safetensors but we're only | ||
# interested in the extension name | ||
return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)} | ||
|
||
def filter_with_regex(filenames, pattern_re): | ||
return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None} | ||
|
||
# Group files by component | ||
components = {} | ||
for filename in filenames: | ||
if not len(filename.split("/")) == 2: | ||
components.setdefault("", []).append(filename) | ||
continue | ||
|
||
non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None} | ||
non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None} | ||
non_variant_filenames = non_variant_weights | non_variant_indexes | ||
component, _ = filename.split("/") | ||
components.setdefault(component, []).append(filename) | ||
|
||
# all variant filenames will be used by default | ||
usable_filenames = set(variant_filenames) | ||
usable_filenames = set() | ||
variant_filenames = set() | ||
for component, component_filenames in components.items(): | ||
component_filenames = filter_for_compatible_extensions(component_filenames, ignore_patterns=ignore_patterns) | ||
|
||
component_variants = set() | ||
component_legacy_variants = set() | ||
component_non_variants = set() | ||
if variant is not None: | ||
component_variants = filter_with_regex(component_filenames, variant_file_re) | ||
component_variant_index_files = filter_with_regex(component_filenames, variant_index_re) | ||
|
||
component_legacy_variants = filter_with_regex(component_filenames, legacy_variant_file_re) | ||
component_legacy_variant_index_files = filter_with_regex(component_filenames, legacy_variant_index_re) | ||
|
||
if component_variants or component_legacy_variants: | ||
variant_filenames.update( | ||
component_variants | component_variant_index_files | ||
if component_variants | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh missed this. Great catch! Updated to account for this. |
||
else component_legacy_variants | component_legacy_variant_index_files | ||
) | ||
|
||
def convert_to_variant(filename): | ||
if "index" in filename: | ||
variant_filename = filename.replace("index", f"index.{variant}") | ||
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None: | ||
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}" | ||
else: | ||
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" | ||
return variant_filename | ||
component_non_variants = filter_with_regex(component_filenames, non_variant_file_re) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. awesome! |
||
component_variant_index_files = filter_with_regex(component_filenames, non_variant_index_re) | ||
|
||
def find_component(filename): | ||
if not len(filename.split("/")) == 2: | ||
return | ||
component = filename.split("/")[0] | ||
return component | ||
|
||
def has_sharded_variant(component, variant, variant_filenames): | ||
# If component exists check for sharded variant index filename | ||
# If component doesn't exist check main dir for sharded variant index filename | ||
component = component + "/" if component else "" | ||
variant_index_re = re.compile( | ||
rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" | ||
) | ||
return any(f for f in variant_filenames if variant_index_re.match(f) is not None) | ||
usable_filenames.update(component_non_variants | component_variant_index_files) | ||
|
||
for filename in non_variant_filenames: | ||
if convert_to_variant(filename) in variant_filenames: | ||
continue | ||
usable_filenames.update(variant_filenames) | ||
|
||
component = find_component(filename) | ||
# If a sharded variant exists skip adding to allowed patterns | ||
if has_sharded_variant(component, variant, variant_filenames): | ||
continue | ||
if len(variant_filenames) == 0 and variant is not None: | ||
error_message = f"You are trying to load model files of the `variant={variant}`, but no such modeling files are available. " | ||
raise ValueError(error_message) | ||
|
||
usable_filenames.add(filename) | ||
if len(variant_filenames) > 0 and usable_filenames != variant_filenames: | ||
logger.warning( | ||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n" | ||
f"[{', '.join(variant_filenames)}]\nLoaded non-{variant} filenames:\n" | ||
f"[{', '.join(usable_filenames - variant_filenames)}\nIf this behavior is not " | ||
f"expected, please check your folder structure." | ||
) | ||
|
||
return usable_filenames, variant_filenames | ||
|
||
|
@@ -922,18 +958,13 @@ def _get_custom_components_and_folders( | |
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'." | ||
) | ||
|
||
if len(variant_filenames) == 0 and variant is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need this in this function. The |
||
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." | ||
raise ValueError(error_message) | ||
|
||
return custom_components, folder_names | ||
|
||
|
||
def _get_ignore_patterns( | ||
passed_components, | ||
model_folder_names: List[str], | ||
model_filenames: List[str], | ||
variant_filenames: List[str], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
use_safetensors: bool, | ||
from_flax: bool, | ||
allow_pickle: bool, | ||
|
@@ -964,33 +995,13 @@ def _get_ignore_patterns( | |
if not use_onnx: | ||
ignore_patterns += ["*.onnx", "*.pb"] | ||
|
||
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} | ||
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} | ||
if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames: | ||
logger.warning( | ||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n" | ||
f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n" | ||
f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not " | ||
f"expected, please check your folder structure." | ||
) | ||
|
||
else: | ||
ignore_patterns = ["*.safetensors", "*.msgpack"] | ||
|
||
use_onnx = use_onnx if use_onnx is not None else is_onnx | ||
if not use_onnx: | ||
ignore_patterns += ["*.onnx", "*.pb"] | ||
|
||
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")} | ||
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")} | ||
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: | ||
logger.warning( | ||
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n" | ||
f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n" | ||
f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check " | ||
f"your folder structure." | ||
) | ||
|
||
return ignore_patterns | ||
|
||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.