Skip to content

Instantly share code, notes, and snippets.

@kmizu
Created December 4, 2011 01:09
Show Gist options
  • Select an option

  • Save kmizu/1428717 to your computer and use it in GitHub Desktop.

Select an option

Save kmizu/1428717 to your computer and use it in GitHub Desktop.
An example of automatic differentiation in Scala. It is straightforward porting of http://www.kmonos.net/wlog/123.html#_2257111201
object AutomaticDifferentiation {
case class Dual[X](x: X, dx: X)
def dual(x: Double) = Dual(x, 1.0)
implicit val dualDoubleIsFractional: Fractional[Dual[Double]] = new Fractional[Dual[Double]] {
override def plus(x: Dual[Double], r: Dual[Double]) = Dual(x.x + r.x, x.dx * 1 + r.dx * 1)
override def minus(x: Dual[Double], r: Dual[Double]) = Dual(x.x - r.x, x.dx * 1 + r.dx * -1)
override def times(x: Dual[Double], r: Dual[Double]) = Dual(x.x * r.x, x.dx * r.x + r.dx * x.x)
override def div(x: Dual[Double], r: Dual[Double]) = Dual(x.x / r.x, x.dx * (1 / r.x) + r.dx * (- x.x / r.x / r.x))
override def fromInt(x: Int): Dual[Double] = Dual(x, 0.0)
override def negate(x: Dual[Double]): Dual[Double] = sys.error("not implemented")
override def toDouble(x: Dual[Double]): Double = sys.error("not implemented")
override def toFloat(x: Dual[Double]): Float = sys.error("not implemented")
override def toInt(x: Dual[Double]): Int = sys.error("not implemented")
override def toLong(x: Dual[Double]): Long = sys.error("not implemented")
override def compare(x: Dual[Double], y: Dual[Double]): Int = sys.error("not implemented")
}
//static dual var( double x ) { return dual(x, 1.0); }
def mySqrt[Real:Fractional](x: Real): Real = {
val N = implicitly[Fractional[Real]]
import N._
var y = one
val two = plus(one, one)
for(i <- 0 to 100) {
y = div(plus(y, div(x, y)), two)
}
y
}
def main(args: Array[String]) {
printf("%g%n", mySqrt(2.0));
printf("%g%n", mySqrt(3.0));
printf("%g%n", mySqrt(5.0));
printf("%g%n", mySqrt(0.25));
// --------------------------- //
printf("%s%n", mySqrt(dual(2.0)));
printf("%s%n", mySqrt(dual(3.0)));
printf("%s%n", mySqrt(dual(5.0)));
printf("%s%n", mySqrt(dual(0.25)));
}
}
$ scalac AutomaticDifferentiation.scala
$ scala AutomaticDifferentiation
1.41421
1.73205
2.23607
0.500000
Dual(1.414213562373095,0.35355339059327373)
Dual(1.7320508075688772,0.28867513459481287)
Dual(2.23606797749979,0.22360679774997896)
Dual(0.5,1.0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment