Skip to content

[WIP] low rank sinkhorn, solve_sample, OTResultLazy + test functions #542

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from

Conversation

laudavid
Copy link
Contributor

Hello Rémi,

Here are the modifications we made to the repo:

  • a new lowrank.py file with the following functions: LR-Dykstra and lowrank_sinkhorn
  • a solve_sample function in solvers.py (not complete yet)
  • a OTResultLazy class in utils.py

We also added test functions (and a new test_lowrank.py file).

Don't hesitate to give us feedback.

Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, a few comments about the code


################################## LR-DYSKTRA ALGORITHM ##########################################

def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a bit more documentation with reeference to the paper and to the algorithm number in the paper please

q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_p

# POT backend
eps1, eps2, eps3, p1, p2 = list_to_array(eps1, eps2, eps3, p1, p2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needs LR_dykstra is used inside POT function and will never be interfaced: i will receive arrays not lists

eps1, eps2, eps3, p1, p2 = list_to_array(eps1, eps2, eps3, p1, p2)
q3_1, q3_2, v1_, v2_, q1, q2 = list_to_array(q3_1, q3_2, v1_, v2_, q1, q2)

nx = get_backend(eps1, eps2, eps3, p1, p2, q3_1, q3_2, v1_, v2_, q1, q2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be done only when nx is not given to the function and set to None, else use the new from the calling function (you need to add nx as a parameter to the function)


ns, nt = X_s.shape[0], X_t.shape[0]
if a is None:
a = nx.from_numpy(unif(ns), type_as=X_s)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
a = nx.from_numpy(unif(ns), type_as=X_s)
a = unif(ns, type_as=X_s)

if a is None:
a = nx.from_numpy(unif(ns), type_as=X_s)
if b is None:
b = nx.from_numpy(unif(nt), type_as=X_s)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

b = nx.from_numpy(unif(nt), type_as=X_s)

# Compute cost matrix
M = dist(X_s,X_t, metric=metric)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

M should never be computed as an ntimes n marix It shoud be stired as afctorized version (see discussion on the paper with D=AB) and used only as factorized version later (when computing dot ).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in fcat I'm OK with implementin only squared_euclidena and raise NotImplemented for other metrics until we have afficent function for obtainin low rank factorization of the metrics

C_t_Q = nx.dot(M.T,Q)
diag_g = (1/g)[:,None]

eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adda efw more comments with references to equations and alg line nume rin the paper please

## Function isn't complete, still work in progress for reg == 0 / reg is None case (and unbalanced cases)


from .utils import unif, list_to_array, dist, OTResultLazy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those should be at the top (move it when nearing merge)

potentials = None
lazy_plan = None

if max_iter is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move those before calling empirical_sinkhorn function those values are relevant to this function

@rflamary
Copy link
Collaborator

rflamary commented Nov 9, 2023

OK i'm closing this PR sice the new #568 and #563 are more current and better split

@rflamary rflamary closed this Nov 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants