@@ -46,6 +46,9 @@ def add_to_shard(i, filename):
46
46
needs_gpu_nvidia_medium = list (
47
47
filter (lambda x : get_needs_machine (x ) == "gpu.nvidia.large" , all_files ,)
48
48
)
49
+ needs_a10g = list (
50
+ filter (lambda x : get_needs_machine (x ) == "linux.g5.4xlarge.nvidia.gpu" , all_files ,)
51
+ )
49
52
for filename in needs_gpu_nvidia_small_multi :
50
53
# currently, the only job that uses gpu.nvidia.small.multi is the 0th worker,
51
54
# so we'll add all the jobs that need this machine to the 0th worker
@@ -56,6 +59,11 @@ def add_to_shard(i, filename):
56
59
# so we'll add all the jobs that need this machine to the 1st worker
57
60
add_to_shard (1 , filename )
58
61
all_other_files .remove (filename )
62
+ for filename in needs_a10g :
63
+ # currently, workers 2-5th use linux.g5.4xlarge.nvidia.gpu, so, arbitrarily,
64
+ # we'll add all the jobs that need this machine to the 5th worker
65
+ add_to_shard (5 , filename )
66
+ all_other_files .remove (filename )
59
67
60
68
sorted_files = sorted (all_other_files , key = get_duration , reverse = True ,)
61
69
0 commit comments