@@ -140,7 +140,7 @@ def _get_ncores(
140
140
| concurrent .ThreadPoolExecutor
141
141
| SequentialExecutor
142
142
),
143
- ):
143
+ ) -> int :
144
144
"""Return the maximum number of cores that an executor can use."""
145
145
if with_ipyparallel and isinstance (ex , ipyparallel .client .view .ViewExecutor ):
146
146
return len (ex .view )
@@ -237,7 +237,7 @@ class BaseRunner(metaclass=abc.ABCMeta):
237
237
238
238
def __init__ (
239
239
self ,
240
- learner ,
240
+ learner : BaseLearner ,
241
241
goal : Callable [[BaseLearner ], bool ] | None = None ,
242
242
* ,
243
243
loss_goal : float | None = None ,
@@ -257,7 +257,7 @@ def __init__(
257
257
shutdown_executor : bool = False ,
258
258
retries : int = 0 ,
259
259
raise_if_retries_exceeded : bool = True ,
260
- allow_running_forever = False ,
260
+ allow_running_forever : bool = False ,
261
261
):
262
262
263
263
self .executor = _ensure_executor (executor )
@@ -301,7 +301,7 @@ def __init__(
301
301
def _get_max_tasks (self ) -> int :
302
302
return self ._max_tasks or _get_ncores (self .executor )
303
303
304
- def _do_raise (self , e : Exception , pid : int ):
304
+ def _do_raise (self , e : Exception , pid : int ) -> None :
305
305
tb = self ._tracebacks [pid ]
306
306
x = self ._id_to_point [pid ]
307
307
raise RuntimeError (
@@ -331,7 +331,7 @@ def _ask(self, n: int) -> tuple[list[int], list[float]]:
331
331
pids .append (pid )
332
332
return pids , loss_improvements
333
333
334
- def overhead (self ):
334
+ def overhead (self ) -> float :
335
335
"""Overhead of using Adaptive and the executor in percent.
336
336
337
337
This is measured as ``100 * (1 - t_function / t_elapsed)``.
@@ -424,8 +424,8 @@ def _cleanup(self) -> None:
424
424
self .end_time = time .time ()
425
425
426
426
@property
427
- def failed (self ) -> set [Any ]:
428
- """Set of points that failed ``runner.retries`` times."""
427
+ def failed (self ) -> set [int ]:
428
+ """Set of points ids that failed ``runner.retries`` times."""
429
429
return set (self ._tracebacks ) - set (self ._to_retry )
430
430
431
431
@abc .abstractmethod
@@ -533,7 +533,7 @@ class BlockingRunner(BaseRunner):
533
533
534
534
def __init__ (
535
535
self ,
536
- learner ,
536
+ learner : BaseLearner ,
537
537
goal : Callable [[BaseLearner ], bool ] | None = None ,
538
538
* ,
539
539
loss_goal : float | None = None ,
@@ -549,10 +549,10 @@ def __init__(
549
549
| None
550
550
) = None ,
551
551
ntasks : int | None = None ,
552
- log = False ,
553
- shutdown_executor = False ,
554
- retries = 0 ,
555
- raise_if_retries_exceeded = True ,
552
+ log : bool = False ,
553
+ shutdown_executor : bool = False ,
554
+ retries : int = 0 ,
555
+ raise_if_retries_exceeded : bool = True ,
556
556
) -> None :
557
557
if inspect .iscoroutinefunction (learner .function ):
558
558
raise ValueError ("Coroutine functions can only be used with 'AsyncRunner'." )
@@ -597,7 +597,7 @@ def _run(self) -> None:
597
597
self ._process_futures (with_result )
598
598
self ._cleanup ()
599
599
600
- def elapsed_time (self ):
600
+ def elapsed_time (self ) -> float :
601
601
"""Return the total time elapsed since the runner
602
602
was started."""
603
603
if self .end_time is None :
@@ -699,7 +699,7 @@ class AsyncRunner(BaseRunner):
699
699
700
700
def __init__ (
701
701
self ,
702
- learner ,
702
+ learner : BaseLearner ,
703
703
goal : Callable [[BaseLearner ], bool ] | None = None ,
704
704
* ,
705
705
loss_goal : float | None = None ,
@@ -777,16 +777,14 @@ def __init__(
777
777
"'adaptive.notebook_extension()'"
778
778
)
779
779
780
- def _submit (
781
- self , x : tuple [int , int ] | int | tuple [float , float ] | float
782
- ) -> Task | Future :
780
+ def _submit (self , x : Any ) -> Task | Future :
783
781
ioloop = self .ioloop
784
782
if inspect .iscoroutinefunction (self .learner .function ):
785
783
return ioloop .create_task (self .learner .function (x ))
786
784
else :
787
785
return ioloop .run_in_executor (self .executor , self .learner .function , x )
788
786
789
- def status (self ):
787
+ def status (self ) -> str :
790
788
"""Return the runner status as a string.
791
789
792
790
The possible statuses are: running, cancelled, failed, and finished.
@@ -802,7 +800,7 @@ def status(self):
802
800
else :
803
801
return "finished"
804
802
805
- def cancel (self ):
803
+ def cancel (self ) -> None :
806
804
"""Cancel the runner.
807
805
808
806
This is equivalent to calling ``runner.task.cancel()``.
0 commit comments