Skip to content

Commit a615ec4

Browse files
committed
Merge remote-tracking branch 'origin/master' into typing-sequence-learner
2 parents cc8be65 + 50fae43 commit a615ec4

15 files changed

+119
-25
lines changed

adaptive/learner/average_learner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ def __init__(
7474
self.sum_f: Real = 0.0
7575
self.sum_f_sq: Real = 0.0
7676

77+
def new(self) -> AverageLearner:
78+
"""Create a copy of `~adaptive.AverageLearner` without the data."""
79+
return AverageLearner(self.function, self.atol, self.rtol, self.min_npoints)
80+
7781
@property
7882
def n_requested(self) -> int:
7983
return self.npoints + len(self.pending_points)

adaptive/learner/average_learner1D.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,20 @@ def __init__(
125125
# {xii: error[xii]/min(_distances[xi], _distances[xii], ...}
126126
self.rescaled_error: dict[Real, float] = decreasing_dict()
127127

128+
def new(self) -> AverageLearner1D:
129+
"""Create a copy of `~adaptive.AverageLearner1D` without the data."""
130+
return AverageLearner1D(
131+
self.function,
132+
self.bounds,
133+
self.loss_per_interval,
134+
self.delta,
135+
self.alpha,
136+
self.neighbor_sampling,
137+
self.min_samples,
138+
self.max_samples,
139+
self.min_error,
140+
)
141+
128142
@property
129143
def nsamples(self) -> int:
130144
"""Returns the total number of samples"""

adaptive/learner/balancing_learner.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import itertools
24
from collections import defaultdict
35
from collections.abc import Iterable
@@ -96,6 +98,14 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
9698

9799
self.strategy = strategy
98100

101+
def new(self) -> BalancingLearner:
102+
"""Create a new `BalancingLearner` with the same parameters."""
103+
return BalancingLearner(
104+
[learner.new() for learner in self.learners],
105+
cdims=self._cdims_default,
106+
strategy=self.strategy,
107+
)
108+
99109
@property
100110
def data(self):
101111
data = {}

adaptive/learner/base_learner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,11 @@ def _get_data(self):
149149
def _set_data(self):
150150
pass
151151

152+
@abc.abstractmethod
153+
def new(self):
154+
"""Return a new learner with the same function and parameters."""
155+
pass
156+
152157
def copy_from(self, other):
153158
"""Copy over the data from another learner.
154159

adaptive/learner/data_saver.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ def __init__(self, learner, arg_picker):
4545
self.function = learner.function
4646
self.arg_picker = arg_picker
4747

48+
def new(self) -> DataSaver:
49+
"""Return a new `DataSaver` with the same `arg_picker` and `learner`."""
50+
return DataSaver(self.learner.new(), self.arg_picker)
51+
4852
def __getattr__(self, attr):
4953
return getattr(self.learner, attr)
5054

adaptive/learner/integrator_learner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Based on an adaptive quadrature algorithm by Pedro Gonnet
2+
from __future__ import annotations
23

34
import sys
45
from collections import defaultdict
@@ -381,6 +382,10 @@ def __init__(self, function, bounds, tol):
381382
self.add_ival(ival)
382383
self.first_ival = ival
383384

385+
def new(self) -> IntegratorLearner:
386+
"""Create a copy of `~adaptive.Learner2D` without the data."""
387+
return IntegratorLearner(self.function, self.bounds, self.tol)
388+
384389
@property
385390
def approximating_intervals(self):
386391
return self.first_ival.done_leaves

adaptive/learner/learner1D.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@
2222
partial_function_from_dataframe,
2323
)
2424

25+
try:
26+
from typing import TypeAlias
27+
except ImportError:
28+
# Remove this when we drop support for Python 3.9
29+
from typing_extensions import TypeAlias
30+
2531
try:
2632
import pandas
2733

@@ -33,21 +39,21 @@
3339
# -- types --
3440

3541
# Commonly used types
36-
Interval = Union[Tuple[float, float], Tuple[float, float, int]]
37-
NeighborsType = Dict[float, List[Union[float, None]]]
42+
Interval: TypeAlias = Union[Tuple[float, float], Tuple[float, float, int]]
43+
NeighborsType: TypeAlias = Dict[float, List[Union[float, None]]]
3844

3945
# Types for loss_per_interval functions
40-
NoneFloat = Union[Float, None]
41-
NoneArray = Union[np.ndarray, None]
42-
XsType0 = Tuple[Float, Float]
43-
YsType0 = Union[Tuple[Float, Float], Tuple[np.ndarray, np.ndarray]]
44-
XsType1 = Tuple[NoneFloat, NoneFloat, NoneFloat, NoneFloat]
45-
YsType1 = Union[
46+
NoneFloat: TypeAlias = Union[Float, None]
47+
NoneArray: TypeAlias = Union[np.ndarray, None]
48+
XsType0: TypeAlias = Tuple[Float, Float]
49+
YsType0: TypeAlias = Union[Tuple[Float, Float], Tuple[np.ndarray, np.ndarray]]
50+
XsType1: TypeAlias = Tuple[NoneFloat, NoneFloat, NoneFloat, NoneFloat]
51+
YsType1: TypeAlias = Union[
4652
Tuple[NoneFloat, NoneFloat, NoneFloat, NoneFloat],
4753
Tuple[NoneArray, NoneArray, NoneArray, NoneArray],
4854
]
49-
XsTypeN = Tuple[NoneFloat, ...]
50-
YsTypeN = Union[Tuple[NoneFloat, ...], Tuple[NoneArray, ...]]
55+
XsTypeN: TypeAlias = Tuple[NoneFloat, ...]
56+
YsTypeN: TypeAlias = Union[Tuple[NoneFloat, ...], Tuple[NoneArray, ...]]
5157

5258

5359
__all__ = [
@@ -303,11 +309,15 @@ def __init__(
303309
# The precision in 'x' below which we set losses to 0.
304310
self._dx_eps = 2 * max(np.abs(bounds)) * np.finfo(float).eps
305311

306-
self.bounds = list(bounds)
312+
self.bounds = tuple(bounds)
307313
self.__missing_bounds = set(self.bounds) # cache of missing bounds
308314

309315
self._vdim: int | None = None
310316

317+
def new(self) -> Learner1D:
318+
"""Create a copy of `~adaptive.Learner1D` without the data."""
319+
return Learner1D(self.function, self.bounds, self.loss_per_interval)
320+
311321
@property
312322
def vdim(self) -> int:
313323
"""Length of the output of ``learner.function``.

adaptive/learner/learner2D.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,9 @@ def __init__(self, function, bounds, loss_per_triangle=None):
384384

385385
self.stack_size = 10
386386

387+
def new(self) -> Learner2D:
388+
return Learner2D(self.function, self.bounds, self.loss_per_triangle)
389+
387390
@property
388391
def xy_scale(self):
389392
xy_scale = self._xy_scale

adaptive/learner/learnerND.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,10 @@ def __init__(self, func, bounds, loss_per_simplex=None):
376376
# _pop_highest_existing_simplex
377377
self._simplex_queue = SortedKeyList(key=_simplex_evaluation_priority)
378378

379+
def new(self) -> LearnerND:
380+
"""Create a new learner with the same function and bounds."""
381+
return LearnerND(self.function, self.bounds, self.loss_per_simplex)
382+
379383
@property
380384
def npoints(self):
381385
"""Number of evaluated points."""

adaptive/learner/sequence_learner.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ def __init__(self, function, sequence):
8686
self.data = SortedDict()
8787
self.pending_points = set()
8888

89+
def new(self) -> SequenceLearner:
90+
"""Return a new `~adaptive.SequenceLearner` without the data."""
91+
return SequenceLearner(self._original_function, self.sequence)
92+
8993
def ask(
9094
self, n: int, tell_pending: bool = True
9195
) -> tuple[list[PointType], list[float]]:
@@ -186,7 +190,7 @@ def to_dataframe(
186190
df.attrs["inputs"] = [index_name]
187191
df.attrs["output"] = y_name
188192
if with_default_function_args:
189-
assign_defaults(self.function, df, function_prefix)
193+
assign_defaults(self._original_function, df, function_prefix)
190194
return df
191195

192196
def load_dataframe(
@@ -223,7 +227,7 @@ def load_dataframe(
223227
self.tell_many(df[[index_name, x_name]].values, df[y_name].values)
224228
if with_default_function_args:
225229
self.function = partial_function_from_dataframe(
226-
self.function, df, function_prefix
230+
self._original_function, df, function_prefix
227231
)
228232

229233
def _get_data(self) -> dict[int, Any]:

adaptive/learner/skopt_learner.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import collections
24

35
import numpy as np
@@ -27,8 +29,13 @@ def __init__(self, function, **kwargs):
2729
self.function = function
2830
self.pending_points = set()
2931
self.data = collections.OrderedDict()
32+
self._kwargs = kwargs
3033
super().__init__(**kwargs)
3134

35+
def new(self) -> SKOptLearner:
36+
"""Return a new `~adaptive.SKOptLearner` without the data."""
37+
return SKOptLearner(self.function, **self._kwargs)
38+
3239
def tell(self, x, y, fit=True):
3340
if isinstance(x, collections.abc.Iterable):
3441
self.pending_points.discard(tuple(x))

adaptive/tests/test_learners.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def test_adding_existing_data_is_idempotent(learner_type, f, learner_kwargs):
294294
"""
295295
f = generate_random_parametrization(f)
296296
learner = learner_type(f, **learner_kwargs)
297-
control = learner_type(f, **learner_kwargs)
297+
control = learner.new()
298298
if learner_type in (Learner1D, AverageLearner1D):
299299
learner._recompute_losses_factor = 1
300300
control._recompute_losses_factor = 1
@@ -345,7 +345,7 @@ def test_adding_non_chosen_data(learner_type, f, learner_kwargs):
345345
# XXX: learner, control and bounds are not defined
346346
f = generate_random_parametrization(f)
347347
learner = learner_type(f, **learner_kwargs)
348-
control = learner_type(f, **learner_kwargs)
348+
control = learner.new()
349349

350350
if learner_type is Learner2D:
351351
# If the stack_size is bigger then the number of points added,
@@ -395,7 +395,7 @@ def test_point_adding_order_is_irrelevant(learner_type, f, learner_kwargs):
395395
"""
396396
f = generate_random_parametrization(f)
397397
learner = learner_type(f, **learner_kwargs)
398-
control = learner_type(f, **learner_kwargs)
398+
control = learner.new()
399399

400400
if learner_type in (Learner1D, AverageLearner1D):
401401
learner._recompute_losses_factor = 1
@@ -581,7 +581,7 @@ def test_balancing_learner(learner_type, f, learner_kwargs):
581581
def test_saving(learner_type, f, learner_kwargs):
582582
f = generate_random_parametrization(f)
583583
learner = learner_type(f, **learner_kwargs)
584-
control = learner_type(f, **learner_kwargs)
584+
control = learner.new()
585585
if learner_type in (Learner1D, AverageLearner1D):
586586
learner._recompute_losses_factor = 1
587587
control._recompute_losses_factor = 1
@@ -614,7 +614,7 @@ def test_saving(learner_type, f, learner_kwargs):
614614
def test_saving_of_balancing_learner(learner_type, f, learner_kwargs):
615615
f = generate_random_parametrization(f)
616616
learner = BalancingLearner([learner_type(f, **learner_kwargs)])
617-
control = BalancingLearner([learner_type(f, **learner_kwargs)])
617+
control = learner.new()
618618

619619
if learner_type in (Learner1D, AverageLearner1D):
620620
for l, c in zip(learner.learners, control.learners):
@@ -654,7 +654,7 @@ def test_saving_with_datasaver(learner_type, f, learner_kwargs):
654654
g = lambda x: {"y": f(x), "t": random.random()} # noqa: E731
655655
arg_picker = operator.itemgetter("y")
656656
learner = DataSaver(learner_type(g, **learner_kwargs), arg_picker)
657-
control = DataSaver(learner_type(g, **learner_kwargs), arg_picker)
657+
control = learner.new()
658658

659659
if learner_type in (Learner1D, AverageLearner1D):
660660
learner.learner._recompute_losses_factor = 1
@@ -742,7 +742,7 @@ def test_to_dataframe(learner_type, f, learner_kwargs):
742742
assert len(df) == learner.npoints
743743

744744
# Add points from the DataFrame to a new empty learner
745-
learner2 = learner_type(learner.function, **learner_kwargs)
745+
learner2 = learner.new()
746746
learner2.load_dataframe(df, **kw)
747747
assert learner2.npoints == learner.npoints
748748

@@ -787,8 +787,7 @@ def test_to_dataframe(learner_type, f, learner_kwargs):
787787
assert len(df) == data_saver.npoints
788788

789789
# Test loading from a DataFrame into a new DataSaver
790-
learner2 = learner_type(learner.function, **learner_kwargs)
791-
data_saver2 = DataSaver(learner2, operator.itemgetter("result"))
790+
data_saver2 = data_saver.new()
792791
data_saver2.load_dataframe(df, **kw)
793792
assert data_saver2.extra_data.keys() == data_saver.extra_data.keys()
794793
assert all(

adaptive/types.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
import numpy as np
44

5-
Float = Union[float, np.float_]
6-
Int = Union[int, np.int_]
7-
Real = Union[Float, Int]
5+
try:
6+
from typing import TypeAlias
7+
except ImportError:
8+
# Remove this when we drop support for Python 3.9
9+
from typing_extensions import TypeAlias
10+
11+
Float: TypeAlias = Union[float, np.float_]
12+
Int: TypeAlias = Union[int, np.int_]
13+
Real: TypeAlias = Union[Float, Int]

adaptive/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import inspect
55
import os
66
import pickle
7+
import warnings
78
from contextlib import contextmanager
89
from itertools import product
910

@@ -135,4 +136,19 @@ def partial_function_from_dataframe(function, df, function_prefix: str = "functi
135136
kwargs[k] = v
136137
if not kwargs:
137138
return function
139+
140+
sig = inspect.signature(function)
141+
for k, v in kwargs.items():
142+
if k not in sig.parameters:
143+
raise ValueError(
144+
f"The DataFrame contains a default parameter"
145+
f" ({k}={v}) but the function does not have that parameter."
146+
)
147+
default = sig.parameters[k].default
148+
if default != inspect._empty and kwargs[k] != default:
149+
warnings.warn(
150+
f"The DataFrame contains a default parameter"
151+
f" ({k}={v}) but the function already has a default ({k}={default})."
152+
" The DataFrame's value will be used."
153+
)
138154
return functools.partial(function, **kwargs)

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ def get_version_and_cmdclass(package_name):
3131
"cloudpickle",
3232
"loky >= 2.9",
3333
]
34+
if sys.version_info < (3, 10):
35+
install_requires.append("typing_extensions")
3436

3537
extras_require = {
3638
"notebook": [
@@ -78,6 +80,7 @@ def get_version_and_cmdclass(package_name):
7880
"Programming Language :: Python :: 3.7",
7981
"Programming Language :: Python :: 3.8",
8082
"Programming Language :: Python :: 3.9",
83+
"Programming Language :: Python :: 3.10",
8184
],
8285
packages=find_packages("."),
8386
install_requires=install_requires,

0 commit comments

Comments
 (0)