Last active
April 14, 2021 14:28
-
-
Save rflamary/c5493846620df878a7c02f439ffe1d37 to your computer and use it in GitHub Desktop.
API POT test
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # OT solution classes | |
| class OTResults(): # class for exact ot and small pre-computed problems | |
| def __init__(self,duals,loss,loss_linear,primal=None,log=None): | |
| self.duals=duals # dual solutions | |
| self.primal=primal # primal solution | |
| self.log=log# log stuff from the solver | |
| self.loss=loss# total loss (with entropy if reg) | |
| self.loss_linear=loss_linear# linar part of the loss | |
| class OTLazyResults(OTResults): | |
| def __init__(self,duals,loss,loss_linear,reg_type='entropic',log=None): | |
| self.duals=duals # dual variables | |
| self.log=log # log stuff from the solver | |
| self.loss=loss# total loss (with entropy if reg) | |
| self.loss_linear=loss_linear# linar partt of the loss | |
| def get_primal(i=None,j=None): | |
| # return full primal by defult or submatrix if i/j are interger or lists of intergers or slices | |
| return P | |
| class OTLazyResultsSamples(): | |
| #stuff | |
| # classical OT solvers | |
| def ot.solve(C,a=None,b=None,reg_param=0,reg_type='none',n_iter_max=100000,strop_thr=1e-7,method='auto') | |
| # solve exact OT by default | |
| # solve generic regularization ('enropic','l2','entropic+group lasso') (reg_type can be a function) | |
| # default a and b are uniform | |
| # stuff | |
| if lazy_sol: | |
| return OTLazyResults((alpha,beta),loss,loss_linear,log=log) | |
| else: | |
| return OTResults((alpha,beta),loss,loss_linear,primal=P,log=log)) | |
| def ot.unbalanced_solve(C,a=None,b=None,reg_param=0,reg_type='entropic',n_iter_max=10000,stop_thr=1e-7,method='auto') | |
| # solve unbalanced OT | |
| def ot.partial_solve(C,a=None,b=None,reg_param=0,n_iter_max=10000,stop_thr=1e-7,method='auto') | |
| # solve partial OT | |
| # OT on empirical distributions | |
| def ot.solve_samples(xa,xb,a=None,b=None,metric='sqeuclidean',reg_param=0,reg_type='entropic',n_iter_max=1000000) | |
| # metric can be float for lp loss | |
| # use solvers from jean for large sizes | |
| return OTLazyResults((alpha,beta),loss,loss_linear,primal=P,log=log) | |
| # OT on grids | |
| def ot.solve_grid(a=None,b=None,metric='sqeuclidean',reg_param=0,reg_type='entropic',n_iter_max=1000000) | |
| # a and b are grids (1D, 3D, 3D) with same support (convolutional stuff) | |
| return OTLazyResults((alpha,beta),loss,loss_linear,primal=P,log=log) | |
| # example of code | |
| # exact ot | |
| P=ot.solve(C,a,b).primal | |
| wass=ot.solve(C,a,b).loss | |
| # sinkhorn ot (uniform weights) | |
| P=ot.solve(C,reg_param=1,reg_type='entropic').primal | |
| sinkhorn_loss=ot.solve(C,reg_param=1,reg_type='entropic').loss | |
| # compute ot matrix component for lazy | |
| t.solve_samples(xs,xt,reg_param=1,reg_type='entropic').get_primal(0,1) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi @rflamary,
That's great, thanks a lot!
I list a few remarks below - I assume that it is easier to discuss them instead of directly writing a fork:
Object results are the way to go indeed. If SciPy does it too, there is no hesitation to have :-)
Don't you think that
OTResult(without an "s", instead ofOTResults) would be a bit more idiomatic?Most attributes will be "virtual" and rely on getter/setter methods.
I understand that the
@propertydecorator is the Pythonic way of implementing this?As far as I can tell, referring to the transport plan as a "primal" and the potentials/scaling vectors as "dual" solutions is a bit ambiguous and complex. These names rely on pre-requisite knowledge of Kantorovitch duality to be understood, which is too much to ask from most users. I would also say that the natural "dual" variables are the "prices" in the log-domain, but you may see things differently?
In any case, relying on
plan,potentialsandscalingsattributes is probably the most explicit things to do?Instead of creating a separate
OTLazyResultsclass, don't you think that we could define alazy_planattribute forOTResult?By default,
planwould always return a straightforward NumPy/PyTorch array, whilelazy_planwould return a symbolic object (or raise aRuntimeError/NotImplementedErrorin cases where it does not make sense).We could catch out-of-memory errors in
planand advise users to uselazy_planinstead - if possible?Please also note that by this summer, we should have added support for indexing in KeOps LazyTensors: this would make the
get_primal(i=..., j=...)redundant as users would just typeresult.lazy_plan[i, j].I would tend to merge the "unbalanced" and "partial" OT problems with the "exact" ones: they have the same input/output and will be handled by the exact same solvers (at least in the Sinkhorn case). Apart from the long (but factored!) docstrings, do you see a case where we really don't want to do this? For unsupported/untested configurations, we may just throw a
NotImplementedError?However, I totally agree that making separate APIs for histograms, point clouds and grids makes a lot of sense.
ot.solve,ot.solve_samplesandot.solve_gridall seem like good names to me.Out of curiosity, what types of regularization do you intend to support? Just vanilla OT (=none), Sinkhorn (=entropic) and quadratic?
I agree that having a
reg_paramwhich is homogeneous to the cost function is overall simpler and more coherent. Do you think that it is a better name than e.g.reg_strength? (I have no idea.)Would some redundant parameters be OK for you? For instance, specifying regularization strengths through
blurandreachscales (which are homogeneous to the point coordinates, not the cost function) could make sense in many settings.Likewise, the number of iterations could be specified through a
n_iterparameter and/or a tolerance.Do you want to use a linter for POT? I found black to be very easy to use and suitable for the job - this is what we now use for both KeOps and GeomLoss.
These types of "laundry lists of discussion points" may seem a bit discouraging... But of course, I am 100% supportive :-D
Best regards,
And see you soon,
Jean
P.S.: We'll discuss the unbalanced/barycenter solvers with Hicham and Thibault tomorrow. We'll keep you updated on these progresses!