Created
March 21, 2018 13:20
-
-
Save jeanfeydy/155e0f2bd5bb9b1dd8a2a4f6e7109530 to your computer and use it in GitHub Desktop.
Reproduces issue #174 in Tensor Comprehensions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch.autograd import Variable | |
| import tensor_comprehensions as tc | |
| use_cuda = torch.cuda.is_available() | |
| dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor | |
| # Straightforward convolution with a filter of size Xp = 2*X-1 | |
| convolution_1D_lang = """ | |
| def convolution_1D(float(X) A, float(Xp) K) -> (B) { | |
| B(x) +=! A(xp) * K(X-1+x-xp) where xp in 0:X | |
| } | |
| """ | |
| # Safe, step-by-step computation | |
| log_convolution_1D_lang = """ | |
| def log_convolution_1D(float(X) log_A, float(Xp) log_K) -> (maxExp,myExp,sumExp,LSE,log_B) { | |
| maxExp(x) max=! log_A(xp) + log_K(X-1+x-xp) | |
| myExp(x,xp) = exp( log_A(xp) + log_K(X-1+x-xp) - maxExp(x) ) | |
| sumExp(x) +=! myExp(x,xp) | |
| LSE(x) = log(sumExp(x)) | |
| log_B(x) = maxExp(x) + LSE(x) | |
| } | |
| """ | |
| # Remove the "useless" variable myExp to get a *linear memory footprint*. | |
| # I just collapse the lines "myExp(x,xp) = ...", "sumExp +=! ..." | |
| # into a single statement. | |
| # | |
| # This simplification is critical as in practice | |
| # (when we implement "log_convolution_3D", etc.), | |
| # `A_log` is a (256,256,256) 3D tensor and | |
| # `log_K` is a (511) vector (if we compute separable log-convolutions). | |
| # | |
| # Creating a full (256,256,256,511) array "myExp" is thus intractable... | |
| log_convolution_1Db_lang = """ | |
| def log_convolution_1Db(float(X) log_A, float(Xp) log_K) -> (maxExp,sumExp,LSE,log_B) { | |
| maxExp(x) max=! log_A(xp) + log_K(X-1+x-xp) | |
| sumExp(x) +=! exp( log_A(xp) + log_K(X-1+x-xp) - maxExp(x) ) | |
| LSE(x) = log(sumExp(x)) | |
| log_B(x) = maxExp(x) + LSE(x) | |
| } | |
| """ | |
| # Let's process our TC strings | |
| convolution_1D = tc.define(convolution_1D_lang, name="convolution_1D") | |
| log_convolution_1D = tc.define(log_convolution_1D_lang, name="log_convolution_1D") | |
| log_convolution_1Db = tc.define(log_convolution_1Db_lang, name="log_convolution_1Db") | |
| # And test our routines on some data | |
| # For the sake of simplicity, we use a dirac signal | |
| # A = [0,0,0,1,0,0,0,0,0,0] | |
| X = 10 | |
| log_A = -1000 * torch.ones(X).type(dtype) | |
| log_A[3] = 0 | |
| log_A = Variable(log_A) | |
| # We'll convolve A with a Gaussian kernel K of std=2 | |
| sigma = Variable(torch.Tensor([2]).type(dtype)) | |
| C = Variable( torch.arange(-X+1,X).type(dtype)) | |
| log_K = -(C/sigma)**2 | |
| B = convolution_1D( log_A.exp(), log_K.exp() ) | |
| log_B = log_convolution_1D( log_A, log_K )[-1] | |
| log_Bb = log_convolution_1Db( log_A, log_K )[-1] | |
| print("A :\n", log_A.exp().view(1,-1), "\n") | |
| print("K :\n", log_K.exp().view(1,-1), "\n") | |
| print("B = A ★ K :\n", B.view(1,-1), "\n") | |
| print("log(B) :\n", log_B.view(1,-1), "\n") | |
| print("log(B) - buggy :\n", log_Bb.view(1,-1), "\n") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment