Skip to content

Instantly share code, notes, and snippets.

@rflamary
Last active April 14, 2021 14:28
Show Gist options
  • Select an option

  • Save rflamary/c5493846620df878a7c02f439ffe1d37 to your computer and use it in GitHub Desktop.

Select an option

Save rflamary/c5493846620df878a7c02f439ffe1d37 to your computer and use it in GitHub Desktop.
API POT test
# OT solution classes
class OTResults(): # class for exact ot and small pre-computed problems
def __init__(self,duals,loss,loss_linear,primal=None,log=None):
self.duals=duals # dual solutions
self.primal=primal # primal solution
self.log=log# log stuff from the solver
self.loss=loss# total loss (with entropy if reg)
self.loss_linear=loss_linear# linar part of the loss
class OTLazyResults(OTResults):
def __init__(self,duals,loss,loss_linear,reg_type='entropic',log=None):
self.duals=duals # dual variables
self.log=log # log stuff from the solver
self.loss=loss# total loss (with entropy if reg)
self.loss_linear=loss_linear# linar partt of the loss
def get_primal(i=None,j=None):
# return full primal by defult or submatrix if i/j are interger or lists of intergers or slices
return P
class OTLazyResultsSamples():
#stuff
# classical OT solvers
def ot.solve(C,a=None,b=None,reg_param=0,reg_type='none',n_iter_max=100000,strop_thr=1e-7,method='auto')
# solve exact OT by default
# solve generic regularization ('enropic','l2','entropic+group lasso') (reg_type can be a function)
# default a and b are uniform
# stuff
if lazy_sol:
return OTLazyResults((alpha,beta),loss,loss_linear,log=log)
else:
return OTResults((alpha,beta),loss,loss_linear,primal=P,log=log))
def ot.unbalanced_solve(C,a=None,b=None,reg_param=0,reg_type='entropic',n_iter_max=10000,stop_thr=1e-7,method='auto')
# solve unbalanced OT
def ot.partial_solve(C,a=None,b=None,reg_param=0,n_iter_max=10000,stop_thr=1e-7,method='auto')
# solve partial OT
# OT on empirical distributions
def ot.solve_samples(xa,xb,a=None,b=None,metric='sqeuclidean',reg_param=0,reg_type='entropic',n_iter_max=1000000)
# metric can be float for lp loss
# use solvers from jean for large sizes
return OTLazyResults((alpha,beta),loss,loss_linear,primal=P,log=log)
# OT on grids
def ot.solve_grid(a=None,b=None,metric='sqeuclidean',reg_param=0,reg_type='entropic',n_iter_max=1000000)
# a and b are grids (1D, 3D, 3D) with same support (convolutional stuff)
return OTLazyResults((alpha,beta),loss,loss_linear,primal=P,log=log)
# example of code
# exact ot
P=ot.solve(C,a,b).primal
wass=ot.solve(C,a,b).loss
# sinkhorn ot (uniform weights)
P=ot.solve(C,reg_param=1,reg_type='entropic').primal
sinkhorn_loss=ot.solve(C,reg_param=1,reg_type='entropic').loss
# compute ot matrix component for lazy
t.solve_samples(xs,xt,reg_param=1,reg_type='entropic').get_primal(0,1)
@jeanfeydy
Copy link

Hi @rflamary,

That's great, thanks a lot!
I list a few remarks below - I assume that it is easier to discuss them instead of directly writing a fork:

  1. Object results are the way to go indeed. If SciPy does it too, there is no hesitation to have :-)
    Don't you think that OTResult (without an "s", instead of OTResults) would be a bit more idiomatic?

  2. Most attributes will be "virtual" and rely on getter/setter methods.
    I understand that the @property decorator is the Pythonic way of implementing this?

  3. As far as I can tell, referring to the transport plan as a "primal" and the potentials/scaling vectors as "dual" solutions is a bit ambiguous and complex. These names rely on pre-requisite knowledge of Kantorovitch duality to be understood, which is too much to ask from most users. I would also say that the natural "dual" variables are the "prices" in the log-domain, but you may see things differently?
    In any case, relying on plan, potentials and scalings attributes is probably the most explicit things to do?

  4. Instead of creating a separate OTLazyResults class, don't you think that we could define a lazy_plan attribute for OTResult?
    By default, plan would always return a straightforward NumPy/PyTorch array, while lazy_plan would return a symbolic object (or raise a RuntimeError/NotImplementedError in cases where it does not make sense).

  5. We could catch out-of-memory errors in plan and advise users to use lazy_plan instead - if possible?

  6. Please also note that by this summer, we should have added support for indexing in KeOps LazyTensors: this would make the get_primal(i=..., j=...) redundant as users would just type result.lazy_plan[i, j].

  7. I would tend to merge the "unbalanced" and "partial" OT problems with the "exact" ones: they have the same input/output and will be handled by the exact same solvers (at least in the Sinkhorn case). Apart from the long (but factored!) docstrings, do you see a case where we really don't want to do this? For unsupported/untested configurations, we may just throw a NotImplementedError?

  8. However, I totally agree that making separate APIs for histograms, point clouds and grids makes a lot of sense. ot.solve, ot.solve_samples and ot.solve_grid all seem like good names to me.

  9. Out of curiosity, what types of regularization do you intend to support? Just vanilla OT (=none), Sinkhorn (=entropic) and quadratic?

  10. I agree that having a reg_param which is homogeneous to the cost function is overall simpler and more coherent. Do you think that it is a better name than e.g. reg_strength? (I have no idea.)

  11. Would some redundant parameters be OK for you? For instance, specifying regularization strengths through blur and reach scales (which are homogeneous to the point coordinates, not the cost function) could make sense in many settings.

  12. Likewise, the number of iterations could be specified through a n_iter parameter and/or a tolerance.

  13. Do you want to use a linter for POT? I found black to be very easy to use and suitable for the job - this is what we now use for both KeOps and GeomLoss.

These types of "laundry lists of discussion points" may seem a bit discouraging... But of course, I am 100% supportive :-D

Best regards,
And see you soon,
Jean

P.S.: We'll discuss the unbalanced/barycenter solvers with Hicham and Thibault tomorrow. We'll keep you updated on these progresses!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment