Skip to content

Commit 4157177

Browse files
authored
[training] Convert to ImageFolder script (#10664)
* [training] Convert to ImageFolder script * make
1 parent 18f7d1d commit 4157177

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

examples/dreambooth/README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -742,3 +742,29 @@ accelerate launch train_dreambooth.py \
742742
## Stable Diffusion XL
743743

744744
We support fine-tuning of the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with DreamBooth and LoRA via the `train_dreambooth_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).
745+
746+
## Dataset
747+
748+
We support 🤗 [Datasets](https://huggingface.co/docs/datasets/index), you can find a dataset on the [Hugging Face Hub](https://huggingface.co/datasets) or use your own.
749+
750+
The quickest way to get started with your custom dataset is 🤗 Datasets' [`ImageFolder`](https://huggingface.co/docs/datasets/image_dataset#imagefolder).
751+
752+
We need to create a file `metadata.jsonl` in the directory with our images:
753+
754+
```
755+
{"file_name": "01.jpg", "prompt": "prompt 01"}
756+
{"file_name": "02.jpg", "prompt": "prompt 02"}
757+
```
758+
759+
If we have a directory with image-text pairs e.g. `01.jpg` and `01.txt` then `convert_to_imagefolder.py` can create `metadata.jsonl`.
760+
761+
```sh
762+
python convert_to_imagefolder.py --path my_dataset/
763+
```
764+
765+
We use `--dataset_name` and `--caption_column` with training scripts.
766+
767+
```
768+
--dataset_name=my_dataset/
769+
--caption_column=prompt
770+
```
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import argparse
2+
import json
3+
import pathlib
4+
5+
6+
parser = argparse.ArgumentParser()
7+
parser.add_argument(
8+
"--path",
9+
type=str,
10+
required=True,
11+
help="Path to folder with image-text pairs.",
12+
)
13+
parser.add_argument("--caption_column", type=str, default="prompt", help="Name of caption column.")
14+
args = parser.parse_args()
15+
16+
path = pathlib.Path(args.path)
17+
if not path.exists():
18+
raise RuntimeError(f"`--path` '{args.path}' does not exist.")
19+
20+
all_files = list(path.glob("*"))
21+
captions = list(path.glob("*.txt"))
22+
images = set(all_files) - set(captions)
23+
images = {image.stem: image for image in images}
24+
caption_image = {caption: images.get(caption.stem) for caption in captions if images.get(caption.stem)}
25+
26+
metadata = path.joinpath("metadata.jsonl")
27+
28+
with metadata.open("w", encoding="utf-8") as f:
29+
for caption, image in caption_image.items():
30+
caption_text = caption.read_text(encoding="utf-8")
31+
json.dump({"file_name": image.name, args.caption_column: caption_text}, f)
32+
f.write("\n")

0 commit comments

Comments
 (0)