Last active
October 14, 2024 11:58
-
-
Save alexgian/f02ba4f81685809bc9dba9a96a8784a1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #lang racket | |
| (require malt) | |
| ; hypers needed: might as well declare them all here | |
| (declare-hyper α) | |
| (declare-hyper revs) | |
| (declare-hyper batch-size) | |
| (declare-hyper μ) | |
| (declare-hyper β) | |
| ; given data | |
| (define plane-xs #( #(1 2.05) #(1 3) #(2 2) #(2 3.91) #(3 6.13) #(4 8.09))) | |
| (define plane-ys (tensor 13.99 15.99 18 22.4 30.2 37.94)) | |
| ; final gradient-descent, as given on p 141:38 TLL | |
| (define (gradient-descent_final inflate deflate update) | |
| (λ(objf θ) | |
| (let ((f (λ(Theta) | |
| (map update | |
| Theta | |
| (gradient-of objf (map deflate Theta)))))) | |
| (map deflate (revise f revs (map inflate θ)))))) | |
| ; stabilizer | |
| (define ϵ 1e-08) | |
| ; p 169:27-29 | |
| (define (rms-u P g) | |
| (let* ((r (smooth β (ref P 1) (sqr g))) | |
| (α (/ α (+ (sqrt r) ϵ)))) | |
| (list (- (ref P 0) (* α g)) r))) | |
| (define (rms-i p) (list p (zeroes p))) | |
| (define (rms-d p) (ref p 0)) | |
| ; Using either of the following two 'gradient-descent' lines does not work! | |
| ; further, the error messages are different (diffent hypers flagged as unset!) | |
| (define rms-gradient-descent | |
| ; (gradient-descent_final rms-i rms-d rms-u)) | |
| (gradient-descent rms-i rms-d rms-u)) | |
| (define (try-plane a-gradient-descent a-revs an-α) | |
| (with-hypers ((revs a-revs) (α an-α) (batch-size 4)) | |
| (a-gradient-descent | |
| (sampling-obj | |
| (l2-loss plane) plane-xs plane-ys) | |
| (list (tensor 0.0 0.0) 0.0)))) | |
| #| | |
| ; testing with this, as per 170:33 | |
| (with-hypers ((β 0.9)) | |
| (try-plane rms-gradient-descent 3000 0.01)) | |
| |# | |
| ; errors: | |
| #| | |
| . . zero?: contract violation | |
| expected: number? | |
| given: 'unset-hyper-revs | |
| or | |
| . . zero?: contract violation | |
| expected: number? | |
| given: 'unset-hyper-batch-size | |
| |# |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment