10
10
from enum import Enum , auto
11
11
from math import prod
12
12
from pathlib import Path
13
- from queue import Empty , Queue
14
13
from io import BufferedWriter
15
14
from typing import IO , Any , Sequence , Mapping
16
15
from string import ascii_letters , digits
16
+ from concurrent .futures import FIRST_EXCEPTION , Future , ThreadPoolExecutor , wait
17
17
18
18
import numpy as np
19
19
@@ -62,20 +62,49 @@ class WriterState(Enum):
62
62
WEIGHTS = auto ()
63
63
64
64
65
+ # To close files which were opened in thread-local context
66
+ # Necessary because ThreadPoolExecutor doesn't allow setting a custom finalizer
67
+ # ref: https://github.com/python/cpython/issues/89502
68
+ class _ThreadedOpenFiles :
69
+ files : dict [Path , BufferedWriter ]
70
+
71
+ def __init__ (self ):
72
+ self .files = {}
73
+
74
+ def __del__ (self ):
75
+ for file in self .files .values ():
76
+ file .close ()
77
+
78
+ def __getitem__ (self , key : Path , / ) -> BufferedWriter :
79
+ if key not in self .files :
80
+ self .files [key ] = open (key , "r+b" )
81
+ return self .files [key ]
82
+
83
+ @classmethod
84
+ def init_thread_local (cls , local_data ):
85
+ local_data .open_files = _ThreadedOpenFiles ()
86
+
87
+
88
+ # Exit quickly instead of waiting
89
+ class _InterruptibleThreadPoolExecutor (ThreadPoolExecutor ):
90
+ def __exit__ (self , exc_type , exc_val , exc_tb ) -> bool | None :
91
+ del exc_type , exc_val , exc_tb
92
+ self .shutdown (wait = False , cancel_futures = True )
93
+ return False
94
+
95
+
65
96
@dataclass
66
- class ThreadedTensorWriteInfo :
97
+ class _ThreadedTensorWriteInfo :
67
98
filename : Path
68
99
offset : int
69
100
post_pad : int
70
101
tensor : np .ndarray
71
102
bar : Any | None # optional tqdm progress bar
72
103
73
- def write_chunk (self , open_files : dict [ Path , BufferedWriter ] ):
104
+ def write_chunk (self , open_files : _ThreadedOpenFiles ):
74
105
# This is called from a thread pool,
75
106
# and each thread should have its own file handle per output file
76
107
# so that they can have different seek locations.
77
- if self .filename not in open_files :
78
- open_files [self .filename ] = open (self .filename , "r+b" )
79
108
f = open_files [self .filename ]
80
109
81
110
f .seek (self .offset )
@@ -462,9 +491,6 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
462
491
463
492
if self .temp_file is None :
464
493
bar = None
465
- # Distribute writing the tensors between multiple threads
466
- tensor_queue : Queue [ThreadedTensorWriteInfo ] = Queue ()
467
-
468
494
# Initial file offsets before writing the tensor data
469
495
offsets : list [int ] = [fout .tell () for fout in self .fout ]
470
496
@@ -476,60 +502,58 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
476
502
477
503
bar = tqdm (desc = "Writing" , total = total_bytes , unit = "byte" , unit_scale = True )
478
504
479
- # Fill the tensor queue with all the pending tensor writes
480
- for i , (filename , tensors ) in enumerate (zip (self .filenames , self .tensors )):
481
- offset = offsets [i ]
482
-
483
- # relying on the fact that Python dicts preserve insertion order (since 3.7)
484
- for ti in tensors .values ():
485
- assert ti .tensor is not None # can only iterate once over the tensors
486
- assert ti .tensor .nbytes == ti .nbytes
487
- start_offset = offset
488
- nbytes = ti .tensor .nbytes
489
- offset = self .ggml_pad (start_offset + nbytes , self .data_alignment )
490
- padding = offset - (start_offset + nbytes )
491
- tensor_queue .put (
492
- ThreadedTensorWriteInfo (
493
- filename = filename ,
494
- offset = start_offset ,
495
- post_pad = padding ,
496
- tensor = ti .tensor ,
497
- bar = bar ,
505
+ # Allow opening the files only once per worker
506
+ local_data = threading .local ()
507
+
508
+ # Unit of work
509
+ def thread_write_tensor (tensor : _ThreadedTensorWriteInfo ):
510
+ tensor .write_chunk (local_data .open_files )
511
+
512
+ with _InterruptibleThreadPoolExecutor (
513
+ max_workers = self .thread_count ,
514
+ initializer = _ThreadedOpenFiles .init_thread_local ,
515
+ initargs = (local_data ,),
516
+ ) as executor :
517
+
518
+ futures : list [Future ] = []
519
+
520
+ # Fill the tensor queue with all the pending tensor writes
521
+ for i , (filename , tensors ) in enumerate (zip (self .filenames , self .tensors )):
522
+ offset = offsets [i ]
523
+
524
+ # relying on the fact that Python dicts preserve insertion order (since 3.7)
525
+ for ti in tensors .values ():
526
+ assert ti .tensor is not None # can only iterate once over the tensors
527
+ assert ti .tensor .nbytes == ti .nbytes
528
+ start_offset = offset
529
+ nbytes = ti .tensor .nbytes
530
+ offset = self .ggml_pad (start_offset + nbytes , self .data_alignment )
531
+ padding = offset - (start_offset + nbytes )
532
+ futures .append (
533
+ executor .submit (
534
+ thread_write_tensor ,
535
+ _ThreadedTensorWriteInfo (
536
+ filename = filename ,
537
+ offset = start_offset ,
538
+ post_pad = padding ,
539
+ tensor = ti .tensor ,
540
+ bar = bar ,
541
+ ),
542
+ )
498
543
)
499
- )
500
- ti .tensor = None # avoid keeping a reference to written tensors
501
-
502
- # Write tensors in parallel
503
- # TODO: total tensor size limit for the running threads
504
- def write_tensors_from_thread (queue : Queue [ThreadedTensorWriteInfo ]):
505
- # Opening the files only once per thread
506
- open_files : dict [Path , BufferedWriter ] = {}
507
- try :
508
- while tensor := queue .get_nowait ():
509
- tensor .write_chunk (open_files )
510
- del tensor
511
- queue .task_done ()
512
- except Empty :
513
- pass
514
-
515
- for f in open_files .values ():
516
- f .close ()
517
-
518
- threads = [
519
- threading .Thread (target = write_tensors_from_thread , args = (tensor_queue ,))
520
- for _ in range (self .thread_count )
521
- ]
522
-
523
- for t in threads :
524
- t .start ()
525
-
526
- # NOTE: thread joining has weird interactions with KeyboardInterrupt,
527
- # so waiting for the queue to be "done" first.
528
- tensor_queue .join ()
529
-
530
- for t in threads :
531
- t .join ()
532
-
544
+ ti .tensor = None # avoid keeping a reference to written tensors
545
+
546
+ # FIXME: there's still some weird behavior with KeyboardInterrupt
547
+ # not being able to interrupt a future mid-execution
548
+ done , not_done = wait (futures , return_when = FIRST_EXCEPTION )
549
+ exc = None
550
+ if any (f for f in done
551
+ if not f .cancelled () and (exc := f .exception ()) is not None ):
552
+ raise RuntimeError ("Error writing tensors" ) from exc
553
+ elif len (not_done ) != 0 :
554
+ raise RuntimeError ("Not all tensors were written" )
555
+
556
+ del local_data
533
557
else :
534
558
self .temp_file .seek (0 )
535
559
0 commit comments