-
Notifications
You must be signed in to change notification settings - Fork 524
[WIP] low rank sinkhorn, solve_sample, OTResultLazy + test functions #542
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, a few comments about the code
|
||
################################## LR-DYSKTRA ALGORITHM ########################################## | ||
|
||
def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a bit more documentation with reeference to the paper and to the algorithm number in the paper please
q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_p | ||
|
||
# POT backend | ||
eps1, eps2, eps3, p1, p2 = list_to_array(eps1, eps2, eps3, p1, p2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not needs LR_dykstra is used inside POT function and will never be interfaced: i will receive arrays not lists
eps1, eps2, eps3, p1, p2 = list_to_array(eps1, eps2, eps3, p1, p2) | ||
q3_1, q3_2, v1_, v2_, q1, q2 = list_to_array(q3_1, q3_2, v1_, v2_, q1, q2) | ||
|
||
nx = get_backend(eps1, eps2, eps3, p1, p2, q3_1, q3_2, v1_, v2_, q1, q2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be done only when nx is not given to the function and set to None, else use the new from the calling function (you need to add nx as a parameter to the function)
|
||
ns, nt = X_s.shape[0], X_t.shape[0] | ||
if a is None: | ||
a = nx.from_numpy(unif(ns), type_as=X_s) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a = nx.from_numpy(unif(ns), type_as=X_s) | |
a = unif(ns, type_as=X_s) |
if a is None: | ||
a = nx.from_numpy(unif(ns), type_as=X_s) | ||
if b is None: | ||
b = nx.from_numpy(unif(nt), type_as=X_s) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
b = nx.from_numpy(unif(nt), type_as=X_s) | ||
|
||
# Compute cost matrix | ||
M = dist(X_s,X_t, metric=metric) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
M should never be computed as an ntimes n marix It shoud be stired as afctorized version (see discussion on the paper with D=AB) and used only as factorized version later (when computing dot ).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in fcat I'm OK with implementin only squared_euclidena and raise NotImplemented for other metrics until we have afficent function for obtainin low rank factorization of the metrics
C_t_Q = nx.dot(M.T,Q) | ||
diag_g = (1/g)[:,None] | ||
|
||
eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adda efw more comments with references to equations and alg line nume rin the paper please
## Function isn't complete, still work in progress for reg == 0 / reg is None case (and unbalanced cases) | ||
|
||
|
||
from .utils import unif, list_to_array, dist, OTResultLazy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
those should be at the top (move it when nearing merge)
potentials = None | ||
lazy_plan = None | ||
|
||
if max_iter is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move those before calling empirical_sinkhorn function those values are relevant to this function
Hello Rémi,
Here are the modifications we made to the repo:
We also added test functions (and a new test_lowrank.py file).
Don't hesitate to give us feedback.