From 06e1d3119a3a0adfdb40767ea520f2c079ea09f0 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 8 Apr 2025 16:31:45 -0400 Subject: [PATCH 1/3] convert : write tensors in parallel --- convert_hf_to_gguf.py | 11 +++-- gguf-py/gguf/gguf_writer.py | 98 +++++++++++++++++++++++++++++++------ gguf-py/gguf/lazy.py | 5 ++ 3 files changed, 95 insertions(+), 19 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 9549900206b48..e2855aabd63c8 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -73,7 +73,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, use_temp_file: bool = False, eager: bool = False, metadata_override: Path | None = None, model_name: str | None = None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, - small_first_shard: bool = False, hparams: dict[str, Any] | None = None): + small_first_shard: bool = False, hparams: dict[str, Any] | None = None, thread_count: int = 2): if type(self) is Model: raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") @@ -109,7 +109,8 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, # Configure GGUF Writer self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, - split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard) + split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard, + thread_count=thread_count) @classmethod def __init_subclass__(cls): @@ -5470,6 +5471,10 @@ def parse_args() -> argparse.Namespace: "--print-supported-models", action="store_true", help="Print the supported models" ) + parser.add_argument( + "-t", "--threads", type=int, default=2, + help="Number of threads to use when writing the tensors. Make sure you have enough RAM for at least THREADS of the biggest tensors in the model when setting this.", + ) args = parser.parse_args() if not args.print_supported_models and args.model is None: @@ -5554,7 +5559,7 @@ def main() -> None: metadata_override=args.metadata, model_name=args.model_name, split_max_tensors=args.split_max_tensors, split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run, - small_first_shard=args.no_tensor_first_split) + small_first_shard=args.no_tensor_first_split, thread_count=args.threads) if args.vocab_only: logger.info("Exporting model vocab...") diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 485550aad6da4..889d9fdfe6bea 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -5,10 +5,12 @@ import shutil import struct import tempfile +import threading from dataclasses import dataclass from enum import Enum, auto from math import prod from pathlib import Path +from queue import Empty, Queue from io import BufferedWriter from typing import IO, Any, Sequence, Mapping from string import ascii_letters, digits @@ -60,8 +62,31 @@ class WriterState(Enum): WEIGHTS = auto() +@dataclass +class TensorWriteInfo: + filename: Path + offset: int + post_pad: int + tensor: np.ndarray + bar: Any | None + + def write_chunk(self, open_files: dict[Path, BufferedWriter]): + if self.filename not in open_files: + open_files[self.filename] = open(self.filename, "r+b") + f = open_files[self.filename] + + f.seek(self.offset) + f.write(self.tensor.data) + if self.post_pad > 0: + f.write(bytes([0] * self.post_pad)) + if self.bar is not None: + self.bar.update(self.tensor.nbytes) + + class GGUFWriter: fout: list[BufferedWriter] | None + filenames: list[Path] | None + thread_count: int path: Path | None temp_file: tempfile.SpooledTemporaryFile[bytes] | None tensors: list[dict[str, TensorInfo]] @@ -83,7 +108,8 @@ class GGUFWriter: def __init__( self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE, - split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False + split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False, + thread_count: int = 2, ): self.fout = None self.path = Path(path) if path else None @@ -98,6 +124,7 @@ def __init__( self.split_max_size = split_max_size self.dry_run = dry_run self.small_first_shard = small_first_shard + self.thread_count = thread_count logger.info("gguf: This GGUF file is for {0} Endian only".format( "Big" if self.endianess == GGUFEndian.BIG else "Little", )) @@ -173,6 +200,7 @@ def open_output_file(self, path: Path | None = None) -> None: if self.path is not None: filenames = self.print_plan() + self.filenames = filenames self.fout = [open(filename, "wb") for filename in filenames] self.state = WriterState.EMPTY @@ -424,40 +452,78 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None: self.write_ti_data_to_file() assert self.fout is not None + assert self.filenames is not None for fout in self.fout: self.write_padding(fout, fout.tell()) if self.temp_file is None: - shard_bar = None bar = None + # Distribute writing the tensors between multiple threads + tensor_queue: Queue[TensorWriteInfo] = Queue() + + offsets: list[int] = [fout.tell() for fout in self.fout] if progress: + # TODO: add back the shard bar to show which shard is being written when single-threaded from tqdm import tqdm total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values()) - if len(self.fout) > 1: - shard_bar = tqdm(desc=f"Shard (0/{len(self.fout)})", total=None, unit="byte", unit_scale=True) bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) - for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)): - if shard_bar is not None: - shard_bar.set_description(f"Shard ({i + 1}/{len(self.fout)})") - total = sum(ti.nbytes for ti in tensors.values()) - shard_bar.reset(total=(total if total > 0 else None)) + for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)): + offset = offsets[i] # relying on the fact that Python dicts preserve insertion order (since 3.7) for ti in tensors.values(): assert ti.tensor is not None # can only iterate once over the tensors assert ti.tensor.nbytes == ti.nbytes - ti.tensor.tofile(fout) - if shard_bar is not None: - shard_bar.update(ti.nbytes) - if bar is not None: - bar.update(ti.nbytes) - self.write_padding(fout, ti.nbytes) - ti.tensor = None + start_offset = offset + nbytes = ti.tensor.nbytes + offset = self.ggml_pad(start_offset + nbytes, self.data_alignment) + padding = offset - (start_offset + nbytes) + tensor_queue.put( + TensorWriteInfo( + filename=filename, + offset=start_offset, + post_pad=padding, + tensor=ti.tensor, + bar=bar, + ) + ) + ti.tensor = None # avoid keeping a reference to written tensors + + # Write tensors in parallel + # TODO: total tensor size limit for the running threads + def write_tensors_from_thread(queue: Queue[TensorWriteInfo]): + open_files: dict[Path, BufferedWriter] = {} + try: + while t := queue.get_nowait(): + t.write_chunk(open_files) + del t + queue.task_done() + except Empty: + pass + + for f in open_files.values(): + f.close() + + threads = [ + threading.Thread(target=write_tensors_from_thread, args=(tensor_queue,)) + for _ in range(self.thread_count) + ] + + for t in threads: + t.start() + + # NOTE: thread joining has weird interactions with KeyboardInterrupt, + # so waiting for the queue to be "done" first. + tensor_queue.join() + + for t in threads: + t.join() + else: self.temp_file.seek(0) diff --git a/gguf-py/gguf/lazy.py b/gguf-py/gguf/lazy.py index f9bcadae0224b..e01b5b050b788 100644 --- a/gguf-py/gguf/lazy.py +++ b/gguf-py/gguf/lazy.py @@ -220,4 +220,9 @@ def tofile(self, *args, **kwargs): eager = LazyNumpyTensor.to_eager(self) return eager.tofile(*args, **kwargs) + @property + def data(self): + eager = LazyNumpyTensor.to_eager(self) + return eager.data + # TODO: __array_function__ From d8bab9efa18a48724e9db8803c0a99c1a6ce866b Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 8 Apr 2025 21:55:15 -0400 Subject: [PATCH 2/3] gguf-py : add more clarifying comments for multi-thread writes --- gguf-py/gguf/gguf_writer.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 889d9fdfe6bea..db8ad4f055985 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -63,14 +63,17 @@ class WriterState(Enum): @dataclass -class TensorWriteInfo: +class ThreadedTensorWriteInfo: filename: Path offset: int post_pad: int tensor: np.ndarray - bar: Any | None + bar: Any | None # optional tqdm progress bar def write_chunk(self, open_files: dict[Path, BufferedWriter]): + # This is called from a thread pool, + # and each thread should have its own file handle per output file + # so that they can have different seek locations. if self.filename not in open_files: open_files[self.filename] = open(self.filename, "r+b") f = open_files[self.filename] @@ -460,8 +463,9 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None: if self.temp_file is None: bar = None # Distribute writing the tensors between multiple threads - tensor_queue: Queue[TensorWriteInfo] = Queue() + tensor_queue: Queue[ThreadedTensorWriteInfo] = Queue() + # Initial file offsets before writing the tensor data offsets: list[int] = [fout.tell() for fout in self.fout] if progress: @@ -472,6 +476,7 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None: bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) + # Fill the tensor queue with all the pending tensor writes for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)): offset = offsets[i] @@ -484,7 +489,7 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None: offset = self.ggml_pad(start_offset + nbytes, self.data_alignment) padding = offset - (start_offset + nbytes) tensor_queue.put( - TensorWriteInfo( + ThreadedTensorWriteInfo( filename=filename, offset=start_offset, post_pad=padding, @@ -496,12 +501,13 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None: # Write tensors in parallel # TODO: total tensor size limit for the running threads - def write_tensors_from_thread(queue: Queue[TensorWriteInfo]): + def write_tensors_from_thread(queue: Queue[ThreadedTensorWriteInfo]): + # Opening the files only once per thread open_files: dict[Path, BufferedWriter] = {} try: - while t := queue.get_nowait(): - t.write_chunk(open_files) - del t + while tensor := queue.get_nowait(): + tensor.write_chunk(open_files) + del tensor queue.task_done() except Empty: pass From 3fe362fe497ff6040d206c5228b181ec2e977024 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 12 Apr 2025 00:00:51 -0400 Subject: [PATCH 3/3] gguf-py : use ThreadPoolExecutor when writing tensors - gguf-py : handle (limited) retries for remote tensors --- gguf-py/gguf/gguf_writer.py | 146 +++++++++++++++++++++--------------- gguf-py/gguf/utility.py | 50 +++++++++--- 2 files changed, 126 insertions(+), 70 deletions(-) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index db8ad4f055985..ea283c57fabcd 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -10,10 +10,10 @@ from enum import Enum, auto from math import prod from pathlib import Path -from queue import Empty, Queue from io import BufferedWriter from typing import IO, Any, Sequence, Mapping from string import ascii_letters, digits +from concurrent.futures import FIRST_EXCEPTION, Future, ThreadPoolExecutor, wait import numpy as np @@ -62,20 +62,49 @@ class WriterState(Enum): WEIGHTS = auto() +# To close files which were opened in thread-local context +# Necessary because ThreadPoolExecutor doesn't allow setting a custom finalizer +# ref: https://github.com/python/cpython/issues/89502 +class _ThreadedOpenFiles: + files: dict[Path, BufferedWriter] + + def __init__(self): + self.files = {} + + def __del__(self): + for file in self.files.values(): + file.close() + + def __getitem__(self, key: Path, /) -> BufferedWriter: + if key not in self.files: + self.files[key] = open(key, "r+b") + return self.files[key] + + @classmethod + def init_thread_local(cls, local_data): + local_data.open_files = _ThreadedOpenFiles() + + +# Exit quickly instead of waiting +class _InterruptibleThreadPoolExecutor(ThreadPoolExecutor): + def __exit__(self, exc_type, exc_val, exc_tb) -> bool | None: + del exc_type, exc_val, exc_tb + self.shutdown(wait=False, cancel_futures=True) + return False + + @dataclass -class ThreadedTensorWriteInfo: +class _ThreadedTensorWriteInfo: filename: Path offset: int post_pad: int tensor: np.ndarray bar: Any | None # optional tqdm progress bar - def write_chunk(self, open_files: dict[Path, BufferedWriter]): + def write_chunk(self, open_files: _ThreadedOpenFiles): # This is called from a thread pool, # and each thread should have its own file handle per output file # so that they can have different seek locations. - if self.filename not in open_files: - open_files[self.filename] = open(self.filename, "r+b") f = open_files[self.filename] f.seek(self.offset) @@ -462,9 +491,6 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None: if self.temp_file is None: bar = None - # Distribute writing the tensors between multiple threads - tensor_queue: Queue[ThreadedTensorWriteInfo] = Queue() - # Initial file offsets before writing the tensor data offsets: list[int] = [fout.tell() for fout in self.fout] @@ -476,60 +502,58 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None: bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) - # Fill the tensor queue with all the pending tensor writes - for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)): - offset = offsets[i] - - # relying on the fact that Python dicts preserve insertion order (since 3.7) - for ti in tensors.values(): - assert ti.tensor is not None # can only iterate once over the tensors - assert ti.tensor.nbytes == ti.nbytes - start_offset = offset - nbytes = ti.tensor.nbytes - offset = self.ggml_pad(start_offset + nbytes, self.data_alignment) - padding = offset - (start_offset + nbytes) - tensor_queue.put( - ThreadedTensorWriteInfo( - filename=filename, - offset=start_offset, - post_pad=padding, - tensor=ti.tensor, - bar=bar, + # Allow opening the files only once per worker + local_data = threading.local() + + # Unit of work + def thread_write_tensor(tensor: _ThreadedTensorWriteInfo): + tensor.write_chunk(local_data.open_files) + + with _InterruptibleThreadPoolExecutor( + max_workers=self.thread_count, + initializer=_ThreadedOpenFiles.init_thread_local, + initargs=(local_data,), + ) as executor: + + futures: list[Future] = [] + + # Fill the tensor queue with all the pending tensor writes + for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)): + offset = offsets[i] + + # relying on the fact that Python dicts preserve insertion order (since 3.7) + for ti in tensors.values(): + assert ti.tensor is not None # can only iterate once over the tensors + assert ti.tensor.nbytes == ti.nbytes + start_offset = offset + nbytes = ti.tensor.nbytes + offset = self.ggml_pad(start_offset + nbytes, self.data_alignment) + padding = offset - (start_offset + nbytes) + futures.append( + executor.submit( + thread_write_tensor, + _ThreadedTensorWriteInfo( + filename=filename, + offset=start_offset, + post_pad=padding, + tensor=ti.tensor, + bar=bar, + ), + ) ) - ) - ti.tensor = None # avoid keeping a reference to written tensors - - # Write tensors in parallel - # TODO: total tensor size limit for the running threads - def write_tensors_from_thread(queue: Queue[ThreadedTensorWriteInfo]): - # Opening the files only once per thread - open_files: dict[Path, BufferedWriter] = {} - try: - while tensor := queue.get_nowait(): - tensor.write_chunk(open_files) - del tensor - queue.task_done() - except Empty: - pass - - for f in open_files.values(): - f.close() - - threads = [ - threading.Thread(target=write_tensors_from_thread, args=(tensor_queue,)) - for _ in range(self.thread_count) - ] - - for t in threads: - t.start() - - # NOTE: thread joining has weird interactions with KeyboardInterrupt, - # so waiting for the queue to be "done" first. - tensor_queue.join() - - for t in threads: - t.join() - + ti.tensor = None # avoid keeping a reference to written tensors + + # FIXME: there's still some weird behavior with KeyboardInterrupt + # not being able to interrupt a future mid-execution + done, not_done = wait(futures, return_when=FIRST_EXCEPTION) + exc = None + if any(f for f in done + if not f.cancelled() and (exc := f.exception()) is not None): + raise RuntimeError("Error writing tensors") from exc + elif len(not_done) != 0: + raise RuntimeError("Not all tensors were written") + + del local_data else: self.temp_file.seek(0) diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index e5251aef8c832..0734b9f25d2ac 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -5,6 +5,14 @@ import os import json +import time +import logging + +import requests +from urllib.parse import urlparse + + +logger = logging.getLogger(__name__) def fill_templated_filename(filename: str, output_type: str | None) -> str: @@ -75,6 +83,7 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st @dataclass class RemoteTensor: + name: str dtype: str shape: tuple[int, ...] offset_start: int @@ -82,9 +91,30 @@ class RemoteTensor: url: str def data(self) -> bytearray: - # TODO: handle request errors (maybe with limited retries?) - # NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable - data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size)) + data = None + MAX_RETRIES = 8 + for i in range(MAX_RETRIES): + try: + # NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable + data = bytearray( + SafetensorRemote.get_data_by_range( + url=self.url, start=self.offset_start, size=self.size + ) + ) + except ( + requests.exceptions.ChunkedEncodingError, + requests.exceptions.ContentDecodingError, + requests.exceptions.ConnectionError, + ) as e: + if i == MAX_RETRIES - 1: + raise RuntimeError(f"Failed to download tensor {self.name}") from e + logger.warning(f"Retry ({i + 1}/{MAX_RETRIES}) downloading tensor {self.name} because of {e}") + time.sleep(2 * i + 1) # 1 3 5 7 9 11 13 + continue + + if data is None: + raise RuntimeError(f"Failed to download tensor {self.name}") + return data @@ -169,7 +199,14 @@ def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]: offset_start_relative, offset_end_relative = meta["data_offsets"] size = offset_end_relative - offset_start_relative offset_start = data_start_offset + offset_start_relative - res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url) + res[name] = RemoteTensor( + name=name, + dtype=dtype, + shape=tuple(shape), + offset_start=offset_start, + size=size, + url=url, + ) except KeyError as e: raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}") @@ -217,8 +254,6 @@ def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes: Get raw byte data from a remote file by range. If size is not specified, it will read the entire file. """ - import requests - from urllib.parse import urlparse parsed_url = urlparse(url) if not parsed_url.scheme or not parsed_url.netloc: @@ -239,9 +274,6 @@ def check_file_exist(cls, url: str) -> bool: Check if a file exists at the given URL. Returns True if the file exists, False otherwise. """ - import requests - from urllib.parse import urlparse - parsed_url = urlparse(url) if not parsed_url.scheme or not parsed_url.netloc: raise ValueError(f"Invalid URL: {url}")