Skip to content

Commit 8264de5

Browse files
authored
Merge branch 'master' into ci/add-python-3-12-to-testing
2 parents 9e66ba2 + a8f0ed5 commit 8264de5

File tree

14 files changed

+68
-41
lines changed

14 files changed

+68
-41
lines changed

.circleci/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ jobs:
4646
command: |
4747
python -m pip install --user --upgrade --progress-bar off pip
4848
python -m pip install --user -e .
49-
python -m pip install --user --upgrade --no-cache-dir --progress-bar off -r requirements.txt
49+
python -m pip install --user --upgrade --no-cache-dir --progress-bar off -r requirements_all.txt
5050
python -m pip install --user --upgrade --progress-bar off -r docs/requirements.txt
5151
python -m pip install --user --upgrade --progress-bar off ipython sphinx-gallery memory_profiler
5252
# python -m pip install --user --upgrade --progress-bar off ipython "https://api.github.com/repos/sphinx-gallery/sphinx-gallery/zipball/master" memory_profiler

.github/workflows/build_doc.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
- name: Get Python running
2525
run: |
2626
python -m pip install --user --upgrade --progress-bar off pip
27-
python -m pip install --user --upgrade --progress-bar off -r requirements.txt
27+
python -m pip install --user --upgrade --progress-bar off -r requirements_all.txt
2828
python -m pip install --user --upgrade --progress-bar off -r docs/requirements.txt
2929
python -m pip install --user --upgrade --progress-bar off ipython "https://api.github.com/repos/sphinx-gallery/sphinx-gallery/zipball/master" memory_profiler
3030
python -m pip install --user -e .

.github/workflows/build_tests.yml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ jobs:
3535
pip install -e .
3636
- name: Install dependencies
3737
run: |
38-
python -m pip install --upgrade pip
39-
pip install -r requirements.txt
38+
python -m pip install --upgrade pip setuptools
39+
pip install -r requirements_all.txt
4040
pip install pytest pytest-cov
4141
- name: Run tests
4242
run: |
@@ -55,8 +55,8 @@ jobs:
5555
python-version: "3.x"
5656
- name: Install dependencies
5757
run: |
58-
python -m pip install --upgrade pip
59-
pip install flake8
58+
python -m pip install --upgrade pip setuptools
59+
pip install flake8
6060
- name: Lint with flake8
6161
run: |
6262
# stop the build if there are Python syntax errors or undefined names
@@ -76,7 +76,7 @@ jobs:
7676
python-version: "3.12"
7777
- name: Install dependencies
7878
run: |
79-
python -m pip install --upgrade pip
79+
python -m pip install --upgrade pip setuptools
8080
pip install pytest pytest-cov
8181
- name: Install POT
8282
run: |
@@ -106,9 +106,9 @@ jobs:
106106
pip install -e .
107107
- name: Install dependencies
108108
run: |
109-
python -m pip install --upgrade pip
110-
pip install -r requirements.txt
111-
pip install pytest
109+
python -m pip install --upgrade pip setuptools
110+
pip install -r requirements_all.txt
111+
pip install pytest
112112
- name: Run tests
113113
run: |
114114
python -m pytest --durations=20 -v test/ ot/ --color=yes

.github/workflows/build_tests_cuda.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ jobs:
1515
- uses: actions/checkout@v4
1616
- name: Install POT
1717
run: |
18-
python3.10 -m pip install --ignore-installed -e .
18+
python3.10 -m pip install --upgrade pip setuptools
19+
python3.10 -m pip install --ignore-installed -e .
1920
- name: Run tests
2021
run: |
2122
python3.10 -m pytest --durations=20 -v test/ ot/ --doctest-modules --color=yes --ignore=test/test_dr.py --ignore=ot.dr --ignore=ot.plot

.github/workflows/build_wheels.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626

2727
- name: Install dependencies
2828
run: |
29-
python -m pip install --upgrade pip
29+
python -m pip install --upgrade pip setuptools
3030
3131
- name: Install cibuildwheel
3232
run: |
@@ -61,7 +61,7 @@ jobs:
6161

6262
- name: Install dependencies
6363
run: |
64-
python -m pip install --upgrade pip
64+
python -m pip install --upgrade pip setuptools
6565
6666
- name: Install cibuildwheel
6767
run: |

.github/workflows/build_wheels_weekly.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525

2626
- name: Install dependencies
2727
run: |
28-
python -m pip install --upgrade pip
28+
python -m pip install --upgrade pip setuptools
2929
3030
- name: Install cibuildwheel
3131
run: |

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,12 @@ or get the very latest version by running:
113113
pip install -U https://github.com/PythonOT/POT/archive/master.zip # with --user for user install (no root)
114114
```
115115

116+
Optional dependencies may be installed with
117+
```console
118+
pip install POT[all]
119+
```
120+
Note that this installs `cvxopt`, which is licensed under GPL 3.0. Alternatively, if you cannot use GPL-licensed software, the specific optional dependencies may be installed individually, or per-submodule. The available optional installations are `backend-jax, backend-tf, backend-torch, cvxopt, dr, gnn, all`.
121+
116122
#### Anaconda installation with conda-forge
117123

118124
If you use the Anaconda python distribution, POT is available in [conda-forge](https://conda-forge.org). To install it and the required dependencies:

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
+ New general unbalanced solvers for `ot.solve` and BFGS solver and illustrative example (PR #620)
1111
+ Add gradient computation with envelope theorem to sinkhorn solver of `ot.solve` with `grad='envelope'` (PR #605).
1212
+ Added support for [Low rank Gromov-Wasserstein](https://proceedings.mlr.press/v162/scetbon22b/scetbon22b.pdf) with `ot.gromov.lowrank_gromov_wasserstein_samples` (PR #614)
13+
+ Optional dependencies may now be installed with `pip install POT[all]` The specific backends or submodules' dependencies may also be installed individually. The pip options are: `backend-jax, backend-tf, backend-torch, cvxopt, dr, gnn, all`. The installation of the `cupy` backend should be done with conda.
1314

1415
#### Closed issues
1516
- Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596)

ot/da.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,10 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
376376
elif sim == 'knn':
377377
if sim_param is None:
378378
sim_param = 3
379-
380-
from sklearn.neighbors import kneighbors_graph
379+
try:
380+
from sklearn.neighbors import kneighbors_graph
381+
except ImportError:
382+
raise ValueError('scikit-learn must be installed to use knn similarity. Install with `$pip install scikit-learn`.')
381383

382384
sS = nx.from_numpy(kneighbors_graph(
383385
X=nx.to_numpy(xs), n_neighbors=int(sim_param)

ot/dr.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717
# License: MIT License
1818

1919
from scipy import linalg
20-
import autograd.numpy as np
21-
from sklearn.decomposition import PCA
22-
23-
import pymanopt
24-
import pymanopt.manifolds
25-
import pymanopt.optimizers
20+
try:
21+
import autograd.numpy as np
22+
from sklearn.decomposition import PCA
23+
24+
import pymanopt
25+
import pymanopt.manifolds
26+
import pymanopt.optimizers
27+
except ImportError:
28+
raise ImportError("Missing dependency for ot.dr. Requires autograd, pymanopt, scikit-learn. You can install with install with 'pip install POT[dr]', or 'conda install autograd pymanopt scikit-learn'")
2629

2730
from .bregman import sinkhorn as sinkhorn_bregman
2831
from .utils import dist as dist_utils, check_random_state

ot/gromov/_gw.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@ def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric
4343
4444
Where :
4545
46-
- :math:`\mathbf{C_1}`: Metric cost matrix in the source space
47-
- :math:`\mathbf{C_2}`: Metric cost matrix in the target space
48-
- :math:`\mathbf{p}`: distribution in the source space
49-
- :math:`\mathbf{q}`: distribution in the target space
50-
- `L`: loss function to account for the misfit between the similarity matrices
46+
- :math:`\mathbf{C_1}`: Metric cost matrix in the source space.
47+
- :math:`\mathbf{C_2}`: Metric cost matrix in the target space.
48+
- :math:`\mathbf{p}`: Distribution in the source space.
49+
- :math:`\mathbf{q}`: Distribution in the target space.
50+
- `L`: Loss function to account for the misfit between the similarity matrices.
5151
5252
.. note:: This function is backend-compatible and will work on arrays
5353
from all compatible backends. But the algorithm uses the C++ CPU backend
@@ -62,39 +62,39 @@ def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric
6262
Parameters
6363
----------
6464
C1 : array-like, shape (ns, ns)
65-
Metric cost matrix in the source space
65+
Metric cost matrix in the source space.
6666
C2 : array-like, shape (nt, nt)
67-
Metric cost matrix in the target space
67+
Metric cost matrix in the target space.
6868
p : array-like, shape (ns,), optional
6969
Distribution in the source space.
7070
If let to its default value None, uniform distribution is taken.
7171
q : array-like, shape (nt,), optional
7272
Distribution in the target space.
7373
If let to its default value None, uniform distribution is taken.
7474
loss_fun : str, optional
75-
loss function used for the solver either 'square_loss' or 'kl_loss'
75+
Loss function used for the solver either 'square_loss' or 'kl_loss'.
7676
symmetric : bool, optional
7777
Either C1 and C2 are to be assumed symmetric or not.
7878
If let to its default None value, a symmetry test will be conducted.
7979
Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric).
8080
verbose : bool, optional
81-
Print information along iterations
81+
Print information along iterations.
8282
log : bool, optional
83-
record log if True
83+
Record log if True.
8484
armijo : bool, optional
85-
If True the step of the line-search is found via an armijo research. Else closed form is used.
86-
If there are convergence issues use False.
85+
If True, the step of the line-search is found via an armijo search. Else closed form is used.
86+
If there are convergence issues, use False.
8787
G0: array-like, shape (ns,nt), optional
88-
If None the initial transport plan of the solver is pq^T.
88+
If None, the initial transport plan of the solver is pq^T.
8989
Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
9090
max_iter : int, optional
91-
Max number of iterations
91+
Max number of iterations.
9292
tol_rel : float, optional
93-
Stop threshold on relative error (>0)
93+
Stop threshold on relative error (>0).
9494
tol_abs : float, optional
95-
Stop threshold on absolute error (>0)
95+
Stop threshold on absolute error (>0).
9696
**kwargs : dict
97-
parameters can be directly passed to the ot.optim.cg solver
97+
Parameters can be directly passed to the ot.optim.cg solver.
9898
9999
Returns
100100
-------
@@ -175,7 +175,7 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
175175

176176
if not nx.is_floating_point(C10):
177177
warnings.warn(
178-
"Input structure matrix consists of integer. The transport plan will be "
178+
"Input structure matrix consists of integers. The transport plan will be "
179179
"casted accordingly, possibly resulting in a loss of precision. "
180180
"If this behaviour is unwanted, please make sure your input "
181181
"structure matrix consists of floating point elements.",

ot/helpers/openmp_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import textwrap
1010
import subprocess
1111

12-
from distutils.errors import CompileError, LinkError
12+
from setuptools.errors import CompileError, LinkError
1313

1414
from pre_build_helpers import compile_test_program
1515

File renamed without changes.

setup.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@
4646
sdk_path = subprocess.check_output(['xcrun', '--show-sdk-path'])
4747
os.environ['CFLAGS'] = '-isysroot "{}"'.format(sdk_path.rstrip().decode("utf-8"))
4848

49+
with open('requirements_all.txt') as f:
50+
optional_requirements = f.read().splitlines()
51+
4952
setup(
5053
name='POT',
5154
version=__version__,
@@ -70,6 +73,17 @@
7073
scripts=[],
7174
data_files=[],
7275
install_requires=["numpy>=1.16", "scipy>=1.6"],
76+
extras_require={
77+
'backend-numpy': [], # in requirements.
78+
'backend-jax': ['jax<=0.4.24', 'jaxlib<=0.4.24'],
79+
'backend-cupy': [], # should be installed with conda, not pip, or figure out what CUDA version above.
80+
'backend-tf': ['tensorflow'],
81+
'backend-torch': ['torch'],
82+
'cvxopt': ['cvxopt'], # on it's own to prevent accidental GPL violations
83+
'dr': ['scikit-learn', 'pymanopt', 'autograd'],
84+
'gnn': ['torch', 'torch_geometric'],
85+
'all': optional_requirements
86+
},
7387
python_requires=">=3.7",
7488
classifiers=[
7589
'Development Status :: 5 - Production/Stable',

0 commit comments

Comments
 (0)