From 014784c5cc99afa2c284bdda1aa62c9ac4455ad0 Mon Sep 17 00:00:00 2001 From: Loubna Ben Allal <44069155+loubnabnl@users.noreply.github.com> Date: Tue, 19 Sep 2023 17:38:17 +0200 Subject: [PATCH] update filtering and decontamination --- preprocessing/filtering.py | 133 ++++++++++++++++++++++++++----------- 1 file changed, 95 insertions(+), 38 deletions(-) diff --git a/preprocessing/filtering.py b/preprocessing/filtering.py index 3d7d7a9..1966512 100644 --- a/preprocessing/filtering.py +++ b/preprocessing/filtering.py @@ -23,7 +23,7 @@ ALL_FILTERS = ["basic", "basic_per_extension", "stars", "comments", "fertility", "xml", "html", "large_and_small_files"] THRESHOLDS_FERTILITY = {"python": 2.5, "java": 2.9, "javascript": 2.6} - +LANG = "language" class MultiChoice: def __init__(self, choices): self.choices = choices @@ -63,7 +63,7 @@ def parse_args(): def get_comments_ratio(examples): """Get ratio of comments to code in each example. Requires a language argument""" ratio_list = [] - for code, language in zip(examples["content"], examples["lang"]): + for code, language in zip(examples["content"], examples[LANG]): ratio_list.append(get_nl_ratio(code, language.lower())) return {"nl_ratio": ratio_list} @@ -89,6 +89,17 @@ def basic_filters(example): return False return True +def add_stats(example): + """Add extra stats: + - size of text, mean and max line length of file + - % alphanumeric characters + - extracts file extension""" + size = len(example["content"]) + line_lengths = [len(line) for line in example["content"].splitlines()] + alpha_frac = np.mean([c.isalnum() for c in example["content"]]) + ext = example["path"].split(".")[-1] + return {"size": size, "avg_line_length": np.mean(line_lengths), "max_line_length": max(line_lengths), "alphanum_fraction": alpha_frac, "ext": ext} + def basic_filters_per_extension(example, ext_to_filter): """Filter files based on line length and % alphanumeric characters. @@ -97,7 +108,7 @@ def basic_filters_per_extension(example, ext_to_filter): # extension `None` is an empty string in the csv try: (include, line_max, line_mean, alphanum_frac, alphabetic_frac) = ext_to_filter[(language_format_from_dataset( - example["lang"]), example["ext"] if example["ext"] is not None else "" + example[LANG]), example["ext"] if example["ext"] is not None else "" )] except KeyError as e: # Some extensions are not in the csv. This happens for dockerfiles. @@ -187,7 +198,7 @@ def char_token_ratio(examples, tokenizer): def filter_tokenizer(examples): """Filter files based on char to token ratio""" values = [] - for ratio, lang in zip(examples["fertility_ratio"], examples["lang"]): + for ratio, lang in zip(examples["fertility_ratio"], examples[LANG]): if ratio < THRESHOLDS_FERTILITY[lang.lower()]: values.append(False) else: @@ -202,7 +213,7 @@ def filter_xml(example): def filter_html(example): """Filter HTML files based on displayed text VS code ratio""" - assert example["lang"] == "HTML", "Filter is only for html examples" + assert example[LANG] == "HTML", "Filter is only for html examples" html = example["content"] try: soup = BeautifulSoup(html, features="html.parser") @@ -226,6 +237,8 @@ def filter_large_and_small_files(example): def get_size_text(example): return {"size": len(example["content"])} +def get_ext(example): + return {"ext": example["path"].split(".")[-1]} LICENSE_COLUMNS = ['max_stars_repo_licenses', 'max_issues_repo_licenses', 'max_forks_repo_licenses'] def fix_license_cols(example): @@ -234,6 +247,7 @@ def fix_license_cols(example): return example + if __name__ == "__main__": args = parse_args() print(f"Selected filters: {args.filters}") @@ -258,20 +272,29 @@ def fix_license_cols(example): # Load dataset t_start = time.time() logger.info(f" ===== Loading {args.dataset_name} and subset {args.subset}=====") + # assert out_path/data doesn't exists + import os + if os.path.exists(f"{args.out_path}/data"): + raise ValueError(f"Output path already exists: {args.out_path}/data delete if before filtering") + dataset = load_dataset( - args.dataset_name, split=args.split, data_dir=args.subset, use_auth_token=True, num_proc=args.num_workers + args.dataset_name, split=args.split, use_auth_token=True, num_proc=rgs.num_workers ) logger.info(f"Dataset loaded in {time.time() - t_start:.2f} seconds") logger.info(f"Dataset: {dataset}") if "size" not in dataset.column_names: - logger.info("Add text size column") - dataset = dataset.map(get_size_text) + logger.info("Add text size column, ext and line stats") + dataset = dataset.map(add_stats, num_proc=args.num_workers) if args.fix_license_columns: dataset = dataset.map(fix_license_cols, num_proc=args.num_workers) logger.info( - f"Dataset size before any filtering: {len(dataset)} examples, {sum(dataset['size']) / 1e9:.2f} GB" + f"Dataset size before any filtering: {len(dataset)} examples, {sum(dataset['size']) / 1e9:.2f} GB and columns: {dataset.column_names}" + ) + # filter non permissive data + dataset = dataset.filter(lambda x: x["license_type"] != "non_permissive") + logger.info( + f"Dataset size after non permissive filtering: {len(dataset)} examples, {sum(dataset['size']) / 1e9:.2f} GB" ) - # Run pre-processing if needed if "stars" in filters: logger.info(f"===== Processing dataset to add proper stars column=====") @@ -335,6 +358,8 @@ def fix_license_cols(example): elif filter == "basic_per_extension": assert args.per_extension_filter_csv is not None language = language_format_from_data_dir(args.subset.split("/")[-1]) if args.subset is not None else None + language = "python" + logger.info("selected language: ", language) logger.info( f"===== Language: {language}. Basic filtering with line_max, avg_line, alphanum_frac and alphabetic_frac given by : {args.per_extension_filter_csv} =====" ) @@ -536,6 +561,65 @@ def fix_license_cols(example): ) dataset = ds + + # Run decontamination + if args.run_decontamination: + logger.info( + f"===== Running decontamination =====" + ) + import sys + import os + sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir)) + from decontamination.benchmark_data import FILTER_OUT + + FILTER_OUT.pop('apps_docstrings', None) + FILTER_OUT.pop('gsm8k_questions', None) + logger.info(f"FILTER OUT Benchmarks: {FILTER_OUT.keys()}") + def decontaminate(samples, filter_out=FILTER_OUT): + """ + filter_out: Dict[str, List[str]] mapping from benchmark name to list of strings that need to be + filtered-out. + Return a list where each element is True if the corresponding file should be included in the dataset. + Otherwise, the element is False. + """ + output = [] + + for content in samples["content"]: + content = content.lower() + matched = False + for benchmark, substrings in filter_out.items(): + for substring in substrings: + if substring.lower() in content: + matched = True + break + if matched: + break + # we keep files that are not matched + output.append(not matched) + + return output + + old_size = len(dataset) + old_size_gb = sum(dataset["size"]) + dataset = dataset.filter(decontaminate, batched=True, batch_size=10_000, num_proc=64) + filtered_size_gb = sum(dataset["size"]) + logger.info( + f"Removed {old_size - len(dataset)} files from {old_size} (i.e {(old_size - len(dataset)) * 100 / old_size}%)" + ) + logger.info( + f"Dataset size after decontamination: {len(dataset)} examples, {filtered_size_gb / 1e9:.2f} GB" + ) + + if args.add_metadata: + from add_content_with_meta import content_with_meta + + logger.info("===== Adding content with metadata =====") + dataset = dataset.map( + content_with_meta, + remove_columns=["content"], + num_proc=args.num_workers, + ) + # Save dataset logger.info( f"Final dataset has {len(dataset)} samples and {sum(dataset['size']) / 1e9:.2f} GB of code" @@ -548,7 +632,7 @@ def fix_license_cols(example): dataset.push_to_hub(args.remote_repo) else: print( - f"Saving the dataset in manual shards in a clone of {args.hub_username + args.remote_repo}" + f"Saving the dataset in manual shards in a clone of {args.hub_username}/{args.remote_repo}" ) try: save_manual_shards( @@ -557,30 +641,3 @@ def fix_license_cols(example): logger.info(f"Dataset successfully saved at {args.out_path}/{args.subset} in {time.time() - t_start:.2f} seconds") except FileExistsError: logger.warning(f"Output dir already exists at {args.out_path}/{args.subset}. Will not save filtered data") - - # Run decontamination - if args.run_decontamination: - import sys - import os - sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir)) - from decontamination.find_substrings import SubstringFilterer - - output_dir_decontaminated = f"{args.out_path}_decontaminate/{args.subset}" - - filterer = SubstringFilterer( - output_dir=output_dir_decontaminated, - cached_decontamination_dir=None, # no previous cached run - split_languages=False, - cache_retrieval_key="", - data_dir=output_dir_decontaminated - ) - - filtered = filterer.run(dataset, args.num_workers, args.batch_size) - - filtered_size_gb = sum(filtered["size"]) - logger.info( - f"Removed {len(dataset) - len(filtered)} / {len(dataset)} files" - ) - logger.info( - f"Dataset size after decontamination: {len(filtered)} examples, {filtered_size_gb / 1e9:.2f} GB" - )