Skip to content

Commit 3fe362f

Browse files
committed
gguf-py : use ThreadPoolExecutor when writing tensors
- gguf-py : handle (limited) retries for remote tensors
1 parent d7db159 commit 3fe362f

File tree

2 files changed

+126
-70
lines changed

2 files changed

+126
-70
lines changed

gguf-py/gguf/gguf_writer.py

Lines changed: 85 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
from enum import Enum, auto
1111
from math import prod
1212
from pathlib import Path
13-
from queue import Empty, Queue
1413
from io import BufferedWriter
1514
from typing import IO, Any, Sequence, Mapping
1615
from string import ascii_letters, digits
16+
from concurrent.futures import FIRST_EXCEPTION, Future, ThreadPoolExecutor, wait
1717

1818
import numpy as np
1919

@@ -62,20 +62,49 @@ class WriterState(Enum):
6262
WEIGHTS = auto()
6363

6464

65+
# To close files which were opened in thread-local context
66+
# Necessary because ThreadPoolExecutor doesn't allow setting a custom finalizer
67+
# ref: https://github.com/python/cpython/issues/89502
68+
class _ThreadedOpenFiles:
69+
files: dict[Path, BufferedWriter]
70+
71+
def __init__(self):
72+
self.files = {}
73+
74+
def __del__(self):
75+
for file in self.files.values():
76+
file.close()
77+
78+
def __getitem__(self, key: Path, /) -> BufferedWriter:
79+
if key not in self.files:
80+
self.files[key] = open(key, "r+b")
81+
return self.files[key]
82+
83+
@classmethod
84+
def init_thread_local(cls, local_data):
85+
local_data.open_files = _ThreadedOpenFiles()
86+
87+
88+
# Exit quickly instead of waiting
89+
class _InterruptibleThreadPoolExecutor(ThreadPoolExecutor):
90+
def __exit__(self, exc_type, exc_val, exc_tb) -> bool | None:
91+
del exc_type, exc_val, exc_tb
92+
self.shutdown(wait=False, cancel_futures=True)
93+
return False
94+
95+
6596
@dataclass
66-
class ThreadedTensorWriteInfo:
97+
class _ThreadedTensorWriteInfo:
6798
filename: Path
6899
offset: int
69100
post_pad: int
70101
tensor: np.ndarray
71102
bar: Any | None # optional tqdm progress bar
72103

73-
def write_chunk(self, open_files: dict[Path, BufferedWriter]):
104+
def write_chunk(self, open_files: _ThreadedOpenFiles):
74105
# This is called from a thread pool,
75106
# and each thread should have its own file handle per output file
76107
# so that they can have different seek locations.
77-
if self.filename not in open_files:
78-
open_files[self.filename] = open(self.filename, "r+b")
79108
f = open_files[self.filename]
80109

81110
f.seek(self.offset)
@@ -462,9 +491,6 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
462491

463492
if self.temp_file is None:
464493
bar = None
465-
# Distribute writing the tensors between multiple threads
466-
tensor_queue: Queue[ThreadedTensorWriteInfo] = Queue()
467-
468494
# Initial file offsets before writing the tensor data
469495
offsets: list[int] = [fout.tell() for fout in self.fout]
470496

@@ -476,60 +502,58 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
476502

477503
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
478504

479-
# Fill the tensor queue with all the pending tensor writes
480-
for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)):
481-
offset = offsets[i]
482-
483-
# relying on the fact that Python dicts preserve insertion order (since 3.7)
484-
for ti in tensors.values():
485-
assert ti.tensor is not None # can only iterate once over the tensors
486-
assert ti.tensor.nbytes == ti.nbytes
487-
start_offset = offset
488-
nbytes = ti.tensor.nbytes
489-
offset = self.ggml_pad(start_offset + nbytes, self.data_alignment)
490-
padding = offset - (start_offset + nbytes)
491-
tensor_queue.put(
492-
ThreadedTensorWriteInfo(
493-
filename=filename,
494-
offset=start_offset,
495-
post_pad=padding,
496-
tensor=ti.tensor,
497-
bar=bar,
505+
# Allow opening the files only once per worker
506+
local_data = threading.local()
507+
508+
# Unit of work
509+
def thread_write_tensor(tensor: _ThreadedTensorWriteInfo):
510+
tensor.write_chunk(local_data.open_files)
511+
512+
with _InterruptibleThreadPoolExecutor(
513+
max_workers=self.thread_count,
514+
initializer=_ThreadedOpenFiles.init_thread_local,
515+
initargs=(local_data,),
516+
) as executor:
517+
518+
futures: list[Future] = []
519+
520+
# Fill the tensor queue with all the pending tensor writes
521+
for i, (filename, tensors) in enumerate(zip(self.filenames, self.tensors)):
522+
offset = offsets[i]
523+
524+
# relying on the fact that Python dicts preserve insertion order (since 3.7)
525+
for ti in tensors.values():
526+
assert ti.tensor is not None # can only iterate once over the tensors
527+
assert ti.tensor.nbytes == ti.nbytes
528+
start_offset = offset
529+
nbytes = ti.tensor.nbytes
530+
offset = self.ggml_pad(start_offset + nbytes, self.data_alignment)
531+
padding = offset - (start_offset + nbytes)
532+
futures.append(
533+
executor.submit(
534+
thread_write_tensor,
535+
_ThreadedTensorWriteInfo(
536+
filename=filename,
537+
offset=start_offset,
538+
post_pad=padding,
539+
tensor=ti.tensor,
540+
bar=bar,
541+
),
542+
)
498543
)
499-
)
500-
ti.tensor = None # avoid keeping a reference to written tensors
501-
502-
# Write tensors in parallel
503-
# TODO: total tensor size limit for the running threads
504-
def write_tensors_from_thread(queue: Queue[ThreadedTensorWriteInfo]):
505-
# Opening the files only once per thread
506-
open_files: dict[Path, BufferedWriter] = {}
507-
try:
508-
while tensor := queue.get_nowait():
509-
tensor.write_chunk(open_files)
510-
del tensor
511-
queue.task_done()
512-
except Empty:
513-
pass
514-
515-
for f in open_files.values():
516-
f.close()
517-
518-
threads = [
519-
threading.Thread(target=write_tensors_from_thread, args=(tensor_queue,))
520-
for _ in range(self.thread_count)
521-
]
522-
523-
for t in threads:
524-
t.start()
525-
526-
# NOTE: thread joining has weird interactions with KeyboardInterrupt,
527-
# so waiting for the queue to be "done" first.
528-
tensor_queue.join()
529-
530-
for t in threads:
531-
t.join()
532-
544+
ti.tensor = None # avoid keeping a reference to written tensors
545+
546+
# FIXME: there's still some weird behavior with KeyboardInterrupt
547+
# not being able to interrupt a future mid-execution
548+
done, not_done = wait(futures, return_when=FIRST_EXCEPTION)
549+
exc = None
550+
if any(f for f in done
551+
if not f.cancelled() and (exc := f.exception()) is not None):
552+
raise RuntimeError("Error writing tensors") from exc
553+
elif len(not_done) != 0:
554+
raise RuntimeError("Not all tensors were written")
555+
556+
del local_data
533557
else:
534558
self.temp_file.seek(0)
535559

gguf-py/gguf/utility.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@
55

66
import os
77
import json
8+
import time
9+
import logging
10+
11+
import requests
12+
from urllib.parse import urlparse
13+
14+
15+
logger = logging.getLogger(__name__)
816

917

1018
def fill_templated_filename(filename: str, output_type: str | None) -> str:
@@ -75,16 +83,38 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st
7583

7684
@dataclass
7785
class RemoteTensor:
86+
name: str
7887
dtype: str
7988
shape: tuple[int, ...]
8089
offset_start: int
8190
size: int
8291
url: str
8392

8493
def data(self) -> bytearray:
85-
# TODO: handle request errors (maybe with limited retries?)
86-
# NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
87-
data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size))
94+
data = None
95+
MAX_RETRIES = 8
96+
for i in range(MAX_RETRIES):
97+
try:
98+
# NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
99+
data = bytearray(
100+
SafetensorRemote.get_data_by_range(
101+
url=self.url, start=self.offset_start, size=self.size
102+
)
103+
)
104+
except (
105+
requests.exceptions.ChunkedEncodingError,
106+
requests.exceptions.ContentDecodingError,
107+
requests.exceptions.ConnectionError,
108+
) as e:
109+
if i == MAX_RETRIES - 1:
110+
raise RuntimeError(f"Failed to download tensor {self.name}") from e
111+
logger.warning(f"Retry ({i + 1}/{MAX_RETRIES}) downloading tensor {self.name} because of {e}")
112+
time.sleep(2 * i + 1) # 1 3 5 7 9 11 13
113+
continue
114+
115+
if data is None:
116+
raise RuntimeError(f"Failed to download tensor {self.name}")
117+
88118
return data
89119

90120

@@ -169,7 +199,14 @@ def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]:
169199
offset_start_relative, offset_end_relative = meta["data_offsets"]
170200
size = offset_end_relative - offset_start_relative
171201
offset_start = data_start_offset + offset_start_relative
172-
res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
202+
res[name] = RemoteTensor(
203+
name=name,
204+
dtype=dtype,
205+
shape=tuple(shape),
206+
offset_start=offset_start,
207+
size=size,
208+
url=url,
209+
)
173210
except KeyError as e:
174211
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
175212

@@ -217,8 +254,6 @@ def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes:
217254
Get raw byte data from a remote file by range.
218255
If size is not specified, it will read the entire file.
219256
"""
220-
import requests
221-
from urllib.parse import urlparse
222257

223258
parsed_url = urlparse(url)
224259
if not parsed_url.scheme or not parsed_url.netloc:
@@ -239,9 +274,6 @@ def check_file_exist(cls, url: str) -> bool:
239274
Check if a file exists at the given URL.
240275
Returns True if the file exists, False otherwise.
241276
"""
242-
import requests
243-
from urllib.parse import urlparse
244-
245277
parsed_url = urlparse(url)
246278
if not parsed_url.scheme or not parsed_url.netloc:
247279
raise ValueError(f"Invalid URL: {url}")

0 commit comments

Comments
 (0)