Last active
May 16, 2020 10:37
-
-
Save hawkrobe/4c04410f7142d642f1abea98de6be499 to your computer and use it in GitHub Desktop.
prior_oddities.wppl
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| var numRows = 4 | |
| var numCols = 6 | |
| var dims = [numRows, numCols] | |
| var conjunctions = ['0_2', '1_2', '1_0', '1_3'] | |
| var logit = function(p) { | |
| return ad.scalar.sub(ad.scalar.log(p), ad.scalar.log(ad.scalar.sub(1, p))); | |
| } | |
| var tensorSoftplus = function(x) { | |
| return T.log(T.add(T.exp(x), 1)); | |
| }; | |
| var sampleMatrix = function() { | |
| return sample(DiagCovGaussian({mu: zeros(dims), sigma: T.mul(ones(dims), 5)})); | |
| }; | |
| var getMatrixElement = function(row, col_conjunction, matrix) { | |
| var components = col_conjunction.split('_'); | |
| if(components.length == 1) { | |
| return T.get(matrix, row * numCols + _.toInteger(col_conjunction)); | |
| } else { | |
| return logit( | |
| ad.scalar.mul( | |
| ad.scalar.sigmoid(getMatrixElement(row, components[0], matrix)), | |
| ad.scalar.sigmoid(getMatrixElement(row, components[1], matrix)) | |
| )); | |
| } | |
| }; | |
| var normalizeRow = function(row, matrix) { | |
| return Infer({method: 'enumerate'}, function() { | |
| var conjunction = uniformDraw(conjunctions) | |
| factor(5 * getMatrixElement(row, conjunction, matrix)); | |
| return conjunction; | |
| }); | |
| } | |
| var expectedProb = function(row, conjunction) { | |
| return expectation(Infer({method: 'forward', samples: 50000}, function() { | |
| var matrix = sampleMatrix(); | |
| return Math.exp(normalizeRow(row, matrix).score(conjunction)); | |
| })) | |
| } | |
| display("P(conjunction 1):" + expectedProb(0,conjunctions[0])) | |
| display("P(conjunction 2):" + expectedProb(0,conjunctions[1])) | |
| display("P(conjunction 3):" + expectedProb(0,conjunctions[2])) | |
| display("P(conjunction 4):" + expectedProb(0,conjunctions[3])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment