18
18
Cluster object.
19
19
"""
20
20
21
- from dataclasses import dataclass , field
22
21
import pathlib
23
- import typing
24
22
import warnings
23
+ from dataclasses import dataclass , field , fields
24
+ from typing import Dict , List , Optional , Union , get_args , get_origin
25
25
26
26
dir = pathlib .Path (__file__ ).parent .parent .resolve ()
27
27
@@ -73,36 +73,37 @@ class ClusterConfiguration:
73
73
"""
74
74
75
75
name : str
76
- namespace : str = None
77
- head_info : list = field (default_factory = list )
78
- head_cpus : typing .Union [int , str ] = 2
79
- head_memory : typing .Union [int , str ] = 8
80
- head_gpus : int = None # Deprecating
81
- head_extended_resource_requests : typing .Dict [str , int ] = field (default_factory = dict )
82
- machine_types : list = field (default_factory = list ) # ["m4.xlarge", "g4dn.xlarge"]
83
- worker_cpu_requests : typing .Union [int , str ] = 1
84
- worker_cpu_limits : typing .Union [int , str ] = 1
85
- min_cpus : typing .Union [int , str ] = None # Deprecating
86
- max_cpus : typing .Union [int , str ] = None # Deprecating
76
+ namespace : Optional [str ] = None
77
+ head_info : List [str ] = field (default_factory = list )
78
+ head_cpus : Union [int , str ] = 2
79
+ head_memory : Union [int , str ] = 8
80
+ head_gpus : Optional [int ] = None # Deprecating
81
+ head_extended_resource_requests : Dict [str , int ] = field (default_factory = dict )
82
+ machine_types : List [str ] = field (
83
+ default_factory = list
84
+ ) # ["m4.xlarge", "g4dn.xlarge"]
85
+ worker_cpu_requests : Union [int , str ] = 1
86
+ worker_cpu_limits : Union [int , str ] = 1
87
+ min_cpus : Optional [Union [int , str ]] = None # Deprecating
88
+ max_cpus : Optional [Union [int , str ]] = None # Deprecating
87
89
num_workers : int = 1
88
- worker_memory_requests : typing . Union [int , str ] = 2
89
- worker_memory_limits : typing . Union [int , str ] = 2
90
- min_memory : typing . Union [int , str ] = None # Deprecating
91
- max_memory : typing . Union [int , str ] = None # Deprecating
92
- num_gpus : int = None # Deprecating
90
+ worker_memory_requests : Union [int , str ] = 2
91
+ worker_memory_limits : Union [int , str ] = 2
92
+ min_memory : Optional [ Union [int , str ] ] = None # Deprecating
93
+ max_memory : Optional [ Union [int , str ] ] = None # Deprecating
94
+ num_gpus : Optional [ int ] = None # Deprecating
93
95
template : str = f"{ dir } /templates/base-template.yaml"
94
96
appwrapper : bool = False
95
- envs : dict = field (default_factory = dict )
97
+ envs : Dict [ str , str ] = field (default_factory = dict )
96
98
image : str = ""
97
- image_pull_secrets : list = field (default_factory = list )
99
+ image_pull_secrets : List [ str ] = field (default_factory = list )
98
100
write_to_file : bool = False
99
101
verify_tls : bool = True
100
- labels : dict = field (default_factory = dict )
101
- worker_extended_resource_requests : typing .Dict [str , int ] = field (
102
- default_factory = dict
103
- )
104
- extended_resource_mapping : typing .Dict [str , str ] = field (default_factory = dict )
102
+ labels : Dict [str , str ] = field (default_factory = dict )
103
+ worker_extended_resource_requests : Dict [str , int ] = field (default_factory = dict )
104
+ extended_resource_mapping : Dict [str , str ] = field (default_factory = dict )
105
105
overwrite_default_resource_mapping : bool = False
106
+ local_queue : Optional [str ] = None
106
107
107
108
def __post_init__ (self ):
108
109
if not self .verify_tls :
@@ -120,6 +121,7 @@ def __post_init__(self):
120
121
self ._validate_extended_resource_requests (
121
122
self .worker_extended_resource_requests
122
123
)
124
+ self ._validate_types ()
123
125
124
126
def _combine_extended_resource_mapping (self ):
125
127
if overwritten := set (self .extended_resource_mapping .keys ()).intersection (
@@ -139,9 +141,7 @@ def _combine_extended_resource_mapping(self):
139
141
** self .extended_resource_mapping ,
140
142
}
141
143
142
- def _validate_extended_resource_requests (
143
- self , extended_resources : typing .Dict [str , int ]
144
- ):
144
+ def _validate_extended_resource_requests (self , extended_resources : Dict [str , int ]):
145
145
for k in extended_resources .keys ():
146
146
if k not in self .extended_resource_mapping .keys ():
147
147
raise ValueError (
@@ -206,4 +206,34 @@ def _memory_to_resource(self):
206
206
warnings .warn ("max_memory is being deprecated, use worker_memory_limits" )
207
207
self .worker_memory_limits = f"{ self .max_memory } G"
208
208
209
- local_queue : str = None
209
+ def _validate_types (self ):
210
+ """Validate the types of all fields in the ClusterConfiguration dataclass."""
211
+ for field_info in fields (self ):
212
+ value = getattr (self , field_info .name )
213
+ expected_type = field_info .type
214
+ if not self ._is_type (value , expected_type ):
215
+ raise TypeError (
216
+ f"'{ field_info .name } ' should be of type { expected_type } , got { type (value )} "
217
+ )
218
+
219
+ @staticmethod
220
+ def _is_type (value , expected_type ):
221
+ """Check if the value matches the expected type."""
222
+
223
+ def check_type (value , expected_type ):
224
+ origin_type = get_origin (expected_type )
225
+ args = get_args (expected_type )
226
+ if origin_type is Union :
227
+ return any (check_type (value , union_type ) for union_type in args )
228
+ if origin_type is list :
229
+ return all (check_type (elem , args [0 ]) for elem in value )
230
+ if origin_type is dict :
231
+ return all (
232
+ check_type (k , args [0 ]) and check_type (v , args [1 ])
233
+ for k , v in value .items ()
234
+ )
235
+ if origin_type is tuple :
236
+ return all (check_type (elem , etype ) for elem , etype in zip (value , args ))
237
+ return isinstance (value , expected_type )
238
+
239
+ return check_type (value , expected_type )
0 commit comments