Skip to content

Commit fe8a7f0

Browse files
committed
Label normalization performs copy only when necessary
1 parent d969b24 commit fe8a7f0

File tree

3 files changed

+26
-15
lines changed

3 files changed

+26
-15
lines changed

ot/da.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,7 @@ def transform_labels(self, ys=None):
644644
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)
645645

646646
# compute propagated labels
647-
labels = label_normalization(nx.copy(ys))
647+
labels = label_normalization(ys)
648648
masks = labels_to_masks(labels, nx=nx, type_as=transp)
649649
transp_ys = nx.dot(masks.T, transp)
650650

@@ -739,7 +739,7 @@ def inverse_transform_labels(self, yt=None):
739739
transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)
740740

741741
# compute propagated samples
742-
labels = label_normalization(nx.copy(yt))
742+
labels = label_normalization(yt)
743743
masks = labels_to_masks(labels, nx=nx, type_as=transp)
744744
transp_ys = nx.dot(masks.T, transp.T)
745745

@@ -2126,7 +2126,7 @@ def transform_labels(self, ys=None):
21262126
type_as=ys[0]
21272127
)
21282128
for i in range(len(ys)):
2129-
ysTemp = label_normalization(nx.copy(ys[i]))
2129+
ysTemp = label_normalization(ys[i])
21302130
classes = nx.unique(ysTemp)
21312131
n = len(classes)
21322132
ns = len(ysTemp)
@@ -2169,7 +2169,7 @@ def inverse_transform_labels(self, yt=None):
21692169
# check the necessary inputs parameters are here
21702170
if check_params(yt=yt):
21712171
transp_ys = []
2172-
ytTemp = label_normalization(nx.copy(yt))
2172+
ytTemp = label_normalization(yt)
21732173
classes = nx.unique(ytTemp)
21742174
n = len(classes)
21752175
D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_[0])

ot/utils.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def is_all_finite(*args):
390390
return all(not nx.any(~nx.isfinite(arg)) for arg in args)
391391

392392

393-
def label_normalization(y, start=0):
393+
def label_normalization(y, start=0, nx=None):
394394
r""" Transform labels to start at a given value
395395
396396
Parameters
@@ -399,31 +399,31 @@ def label_normalization(y, start=0):
399399
The vector of labels to be normalized.
400400
start : int
401401
Desired value for the smallest label in :math:`\mathbf{y}` (default=0)
402+
nx : Backend, optional
403+
Backend to perform computations on. If omitted, the backend defaults to that of `y`.
402404
403405
Returns
404406
-------
405407
y : array-like, shape (`n1`, )
406408
The input vector of labels normalized according to given start value.
407409
"""
408-
nx = get_backend(y)
409-
410+
if nx is None:
411+
nx = get_backend(y)
410412
diff = nx.min(nx.unique(y)) - start
411-
if diff != 0:
412-
y -= diff
413-
return y
413+
return y if diff == 0 else (y - diff)
414414

415415

416-
def labels_to_masks(y, nx=None, type_as=None):
417-
r"""Transforms (n_samples,) vector of labels into a (n_samples, n_labels) matrix of masks.
416+
def labels_to_masks(y, type_as=None, nx=None):
417+
r"""Transforms (n_samples,) vector of labels into a (n_samples, n_labels) matrix of masks.
418418
419419
Parameters
420420
----------
421421
y : array-like, shape (n_samples, )
422422
The vector of labels.
423-
nx : Backend, optional
424-
Backend to perform computations on. If omitted, the backend defaults to that of `y`.
425423
type_as : array_like
426424
Array of the same type of the expected output.
425+
nx : Backend, optional
426+
Backend to perform computations on. If omitted, the backend defaults to that of `y`.
427427
428428
Returns
429429
-------

test/test_utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ def test_lowrank_LazyTensor(nx):
585585
np.testing.assert_allclose(nx.to_numpy(T[:]), nx.to_numpy(T0))
586586

587587

588-
def test_label_to_mask_helper(nx):
588+
def test_labels_to_mask_helper(nx):
589589
y = np.array([1, 0, 2, 2, 1])
590590
out = np.array([
591591
[0, 1, 0],
@@ -597,3 +597,14 @@ def test_label_to_mask_helper(nx):
597597
y = nx.from_numpy(y)
598598
masks = ot.utils.labels_to_masks(y)
599599
np.testing.assert_array_equal(out, masks)
600+
601+
602+
def test_label_normalization(nx):
603+
y = nx.from_numpy(np.arange(5) + 1)
604+
out = np.arange(5)
605+
# labels are shifted
606+
y_normalized = ot.utils.label_normalization(y)
607+
np.testing.assert_array_equal(out, y_normalized)
608+
# labels are shifted but the shift if expected
609+
y_normalized_start = ot.utils.label_normalization(y, start=1)
610+
np.testing.assert_array_equal(y, y_normalized_start)

0 commit comments

Comments
 (0)