Skip to content

Commit 886c66f

Browse files
committed
Add more type-hints
1 parent 22e2df9 commit 886c66f

File tree

1 file changed

+17
-19
lines changed

1 file changed

+17
-19
lines changed

adaptive/runner.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def _get_ncores(
140140
| concurrent.ThreadPoolExecutor
141141
| SequentialExecutor
142142
),
143-
):
143+
) -> int:
144144
"""Return the maximum number of cores that an executor can use."""
145145
if with_ipyparallel and isinstance(ex, ipyparallel.client.view.ViewExecutor):
146146
return len(ex.view)
@@ -237,7 +237,7 @@ class BaseRunner(metaclass=abc.ABCMeta):
237237

238238
def __init__(
239239
self,
240-
learner,
240+
learner: BaseLearner,
241241
goal: Callable[[BaseLearner], bool] | None = None,
242242
*,
243243
loss_goal: float | None = None,
@@ -257,7 +257,7 @@ def __init__(
257257
shutdown_executor: bool = False,
258258
retries: int = 0,
259259
raise_if_retries_exceeded: bool = True,
260-
allow_running_forever=False,
260+
allow_running_forever: bool = False,
261261
):
262262

263263
self.executor = _ensure_executor(executor)
@@ -301,7 +301,7 @@ def __init__(
301301
def _get_max_tasks(self) -> int:
302302
return self._max_tasks or _get_ncores(self.executor)
303303

304-
def _do_raise(self, e: Exception, pid: int):
304+
def _do_raise(self, e: Exception, pid: int) -> None:
305305
tb = self._tracebacks[pid]
306306
x = self._id_to_point[pid]
307307
raise RuntimeError(
@@ -331,7 +331,7 @@ def _ask(self, n: int) -> tuple[list[int], list[float]]:
331331
pids.append(pid)
332332
return pids, loss_improvements
333333

334-
def overhead(self):
334+
def overhead(self) -> float:
335335
"""Overhead of using Adaptive and the executor in percent.
336336
337337
This is measured as ``100 * (1 - t_function / t_elapsed)``.
@@ -424,8 +424,8 @@ def _cleanup(self) -> None:
424424
self.end_time = time.time()
425425

426426
@property
427-
def failed(self) -> set[Any]:
428-
"""Set of points that failed ``runner.retries`` times."""
427+
def failed(self) -> set[int]:
428+
"""Set of points ids that failed ``runner.retries`` times."""
429429
return set(self._tracebacks) - set(self._to_retry)
430430

431431
@abc.abstractmethod
@@ -533,7 +533,7 @@ class BlockingRunner(BaseRunner):
533533

534534
def __init__(
535535
self,
536-
learner,
536+
learner: BaseLearner,
537537
goal: Callable[[BaseLearner], bool] | None = None,
538538
*,
539539
loss_goal: float | None = None,
@@ -549,10 +549,10 @@ def __init__(
549549
| None
550550
) = None,
551551
ntasks: int | None = None,
552-
log=False,
553-
shutdown_executor=False,
554-
retries=0,
555-
raise_if_retries_exceeded=True,
552+
log: bool = False,
553+
shutdown_executor: bool = False,
554+
retries: int = 0,
555+
raise_if_retries_exceeded: bool = True,
556556
) -> None:
557557
if inspect.iscoroutinefunction(learner.function):
558558
raise ValueError("Coroutine functions can only be used with 'AsyncRunner'.")
@@ -597,7 +597,7 @@ def _run(self) -> None:
597597
self._process_futures(with_result)
598598
self._cleanup()
599599

600-
def elapsed_time(self):
600+
def elapsed_time(self) -> float:
601601
"""Return the total time elapsed since the runner
602602
was started."""
603603
if self.end_time is None:
@@ -699,7 +699,7 @@ class AsyncRunner(BaseRunner):
699699

700700
def __init__(
701701
self,
702-
learner,
702+
learner: BaseLearner,
703703
goal: Callable[[BaseLearner], bool] | None = None,
704704
*,
705705
loss_goal: float | None = None,
@@ -777,16 +777,14 @@ def __init__(
777777
"'adaptive.notebook_extension()'"
778778
)
779779

780-
def _submit(
781-
self, x: tuple[int, int] | int | tuple[float, float] | float
782-
) -> Task | Future:
780+
def _submit(self, x: Any) -> Task | Future:
783781
ioloop = self.ioloop
784782
if inspect.iscoroutinefunction(self.learner.function):
785783
return ioloop.create_task(self.learner.function(x))
786784
else:
787785
return ioloop.run_in_executor(self.executor, self.learner.function, x)
788786

789-
def status(self):
787+
def status(self) -> str:
790788
"""Return the runner status as a string.
791789
792790
The possible statuses are: running, cancelled, failed, and finished.
@@ -802,7 +800,7 @@ def status(self):
802800
else:
803801
return "finished"
804802

805-
def cancel(self):
803+
def cancel(self) -> None:
806804
"""Cancel the runner.
807805
808806
This is equivalent to calling ``runner.task.cancel()``.

0 commit comments

Comments
 (0)