Skip to content

Instantly share code, notes, and snippets.

@horothesun
Created August 31, 2025 00:21
Show Gist options
  • Select an option

  • Save horothesun/c8b72bb178ef4b7a66545f7f1e19ffd0 to your computer and use it in GitHub Desktop.

Select an option

Save horothesun/c8b72bb178ef4b7a66545f7f1e19ffd0 to your computer and use it in GitHub Desktop.
TL2 STM experiments
//> using scala 3.7.2
//> using jvm temurin:21
//
//> using dep org.typelevel::cats-core:2.13.0
//> using dep org.typelevel::cats-effect:3.6.3
//
//> using test.dep org.scalameta::munit:1.1.1
//> using test.dep org.typelevel::munit-cats-effect:2.1.0
//> using test.dep org.scalameta::munit-scalacheck:1.1.0
//> using test.dep org.scalacheck::scalacheck:1.18.1
//> using test.dep org.typelevel::scalacheck-effect-munit:1.0.4
//> using test.dep org.typelevel::cats-effect-testkit:3.6.3
//
import cats.effect._
import cats.syntax.all._
import scala.collection.mutable
/** ---- TVar + Stamp ----
* Stamp encodes a version and optionally the token (Long) that owns the lock.
*/
private case class Stamp(version: Long, lockedBy: Option[Long])
final class TVar[A] private (
val value: Ref[IO, A],
val stamp: Ref[IO, Stamp],
val waiters: Ref[IO, List[Deferred[IO, Unit]]]
)
object TVar {
def of[A](a: A): IO[TVar[A]] =
for {
v <- Ref.of[IO, A](a)
s <- Ref.of[IO, Stamp](Stamp(0L, None))
w <- Ref.of[IO, List[Deferred[IO, Unit]]](Nil)
} yield new TVar[A](v, s, w)
}
/** ---- Existential wrappers (no Any in TRec) ----
* SomeTVar and SomeWrite encode heterogeneous collections safely.
*/
sealed trait SomeTVar {
type A
val tvar: TVar[A]
}
object SomeTVar {
def apply[X](tv: TVar[X]): SomeTVar { type A = X } =
new SomeTVar { type A = X; val tvar = tv }
}
sealed trait SomeWrite {
type A
val tvar: TVar[A]
val value: A
/** safe equality on TVar identity */
def sameTVar(other: SomeTVar): Boolean =
(this.tvar eq other.tvar)
}
object SomeWrite {
def apply[X](tv: TVar[X], v: X): SomeWrite { type A = X } =
new SomeWrite {
type A = X
val tvar = tv
val value = v
}
}
/** ---- STM AST ---- */
sealed trait STM[+A]
object STM {
final case class Pure[A](a: A) extends STM[A]
final case class Read[A](t: TVar[A]) extends STM[A]
final case class Write[A](t: TVar[A], a: A) extends STM[Unit]
final case class New[A](a: A) extends STM[TVar[A]]
case object Retry extends STM[Nothing]
final case class OrElse[A](left: STM[A], right: STM[A]) extends STM[A]
final case class Bind[A,B](fa: STM[A], f: A => STM[B]) extends STM[B]
def pure[A](a: A): STM[A] = Pure(a)
def read[A](t: TVar[A]): STM[A] = Read(t)
def write[A](t: TVar[A], a: A): STM[Unit] = Write(t,a)
def newTVar[A](a: A): STM[TVar[A]] = New(a)
val retry: STM[Nothing] = Retry
def orElse[A](l: STM[A], r: STM[A]): STM[A] = OrElse(l,r)
extension [A](s: STM[A])
def flatMap[B](f: A => STM[B]): STM[B] = Bind(s,f)
def map[B](f: A => B): STM[B] = flatMap(a => Pure(f(a)))
}
/** ---- TRec without Any ---- */
private final class TRec(
var readVersion: Long,
val readSet: mutable.ArrayBuffer[SomeTVar],
var writeSet: List[SomeWrite], // newest-first
var frames: List[OrFrame[?]],
val token: Long
)
private final case class OrFrame[A](writeSnapshot: List[SomeWrite], alt: STM[A])
/** ---- STM runtime (TL2-style) ---- */
final class STMRuntime private (
clockRef: Ref[IO, Long],
tokenCounter: Ref[IO, Long]
) {
/** Create a new unique token id for each attempt */
private def nextToken: IO[Long] =
tokenCounter.updateAndGet(_ + 1L)
/** Public entry: run STM atomically */
def atomically[A](prog: STM[A]): IO[A] =
attemptOnce(prog).flatMap {
case Right(a) => IO.pure(a)
case Left(gate) => gate.get *> atomically(prog)
}
/** Try one attempt. If top-level retry: return Left(gate). If commit succeeded: Right(result).
* On abort, recursively start a fresh attempt (tail recursion via flatMap).
*/
private def attemptOnce[A](prog: STM[A]): IO[Either[Deferred[IO,Unit], A]] =
for {
tok <- nextToken
rv <- clockRef.get
trec = new TRec(
readVersion = rv,
readSet = mutable.ArrayBuffer.empty,
writeSet = Nil,
frames = Nil,
token = tok
)
res <- evaluate(prog, trec)
out <- res match {
case Left(g) => IO.pure(Left(g))
case Right(v) =>
commit(trec).flatMap {
case true => IO.pure(Right(v))
case false => attemptOnce(prog) // abort -> fresh attempt
}
}
} yield out
/** Effectful evaluator (stack/continuation based) that builds read/write sets and handles orElse/retry. */
private def evaluate[A](prog: STM[A], trec: TRec): IO[Either[Deferred[IO,Unit], A]] = {
import STM._
// continuation stack: functions Any => STM[Any], but we avoid casting here by storing as
// value => STM[Any]; conversions will be done safely when invoking.
val cont = mutable.ArrayBuffer.empty[Any => STM[Any]]
def pushCont(f: Any => STM[Any]): Unit = cont.prepend(f)
def popCont(): Option[Any => STM[Any]] = if cont.isEmpty then None else Some(cont.remove(0))
def step(cur: STM[Any]): IO[Either[Deferred[IO,Unit], Any]] = cur match {
case Pure(a) =>
popCont() match {
case None => IO.pure(Right(a))
case Some(f) => step(f(a))
}
case Read(t) =>
// first check pending write-set for this TVar: if found, return its value (typed)
findWriteFor[Any, Any](t.asInstanceOf[TVar[Any]], trec.writeSet) match {
case Some(value) =>
step(Pure(value))
case None =>
// TL2-style read (effectful)
readTVar(t.asInstanceOf[TVar[Any]], trec).flatMap {
case Some(v) => step(Pure(v))
case None => IO.raiseError(new IllegalStateException("readTVar transient failure"))
}
}
case Write(t, a) =>
// push some write (type-respecting)
trec.writeSet = SomeWrite(t.asInstanceOf[TVar[Any]], a.asInstanceOf[Any]) :: trec.writeSet
step(Pure(()))
case New(a) =>
TVar.of(a).flatMap(tv => step(Pure(tv)))
case Retry =>
trec.frames match {
case OrFrame(snap, alt) :: rest =>
// rollback writes to snapshot and run alternate
trec.writeSet = snap
trec.frames = rest
step(alt.asInstanceOf[STM[Any]])
case Nil =>
// top-level retry => subscribe and return Deferred gate
subscribeAndSleep(trec).map(Left(_))
}
case OrElse(l, r) =>
// snapshot current writeSet and push frame
trec.frames = OrFrame(trec.writeSet, r) :: trec.frames
step(l.asInstanceOf[STM[Any]])
case Bind(fa, f) =>
pushCont((a: Any) => f.asInstanceOf[Any => STM[Any]](a))
step(fa.asInstanceOf[STM[Any]])
}
step(prog.asInstanceOf[STM[Any]]).map {
case Left(g) => Left(g)
case Right(a) => Right(a.asInstanceOf[A])
}
}
/** Find a pending write for a TVar[A] in writeSet and return value as A if present.
* We must do a cast at this point, but it's safe: SomeWrite was constructed with that same TVar.
*/
private def findWriteFor[A, B](t: TVar[A], ws: List[SomeWrite]): Option[A] = {
ws.collectFirst {
case sw if (sw.tvar eq t) =>
// localized cast: sw.value has type sw.A which corresponds to sw.tvar.A,
// and we tested identity of tvar, so cast to the requested A is safe.
sw.value.asInstanceOf[A]
}
}
/** TL2-style read: sample stamp/value/stamp, check <= rv and unlocked; else try timestamp extension.
* Returns Some(value) on success, None to indicate a transient failure (caller may re-run attempt).
*/
private def readTVar[A](t: TVar[A], trec: TRec): IO[Option[A]] = {
// read stamp, value, re-read stamp
for {
s1 <- t.stamp.get
v <- t.value.get
s2 <- t.stamp.get
res <- {
val lockedByOther = s1.lockedBy.exists(_ != trec.token)
if (s1 == s2 && !lockedByOther && s2.version <= trec.readVersion) {
trec.readSet += SomeTVar(t)
IO.pure(Some(v))
} else {
// attempt timestamp extension by reading clock and validating readSet
for {
latest <- clockRef.get
ok <- validateReadSet(trec, latest)
out <- if ok then
IO {
trec.readVersion = latest
} *> {
// re-check t's stamp & value under new rv
for {
s3 <- t.stamp.get
vv <- t.value.get
} yield {
val lockedByOther3 = s3.lockedBy.exists(_ != trec.token)
if (!lockedByOther3 && s3.version <= trec.readVersion) {
trec.readSet += SomeTVar(t); Some(vv)
} else None
}
}
else IO.pure(None)
} yield out
}
}
} yield res
}
/** Validate readSet under candidate rv. Must run effectfully (no unsafeRunSync). */
private def validateReadSet(trec: TRec, rv: Long): IO[Boolean] = {
trec.readSet.toList.traverse { stv =>
val tv = stv.tvar.asInstanceOf[TVar[Any]]
tv.stamp.get.map { s =>
val lockedOk = s.lockedBy.forall(_ == trec.token)
lockedOk && s.version <= rv
}
}.map(_.forall(identity))
}
/** Commit: lock write TVars, validate readSet (locks count as valid), bump clock (unless read-only),
* publish writes and wake waiters. Returns true on success, false on abort.
*/
private def commit(trec: TRec): IO[Boolean] = {
val writesOldestFirst = trec.writeSet.reverse
// unique TVars preserving order (oldest->newest), avoid duplicate lock attempts
val uniq: List[TVar[Any]] = {
val seen = mutable.HashSet.empty[TVar[Any]]
val b = List.newBuilder[TVar[Any]]
for ((sw: SomeWrite) <- writesOldestFirst) {
val tv = sw.tvar.asInstanceOf[TVar[Any]]
if (!seen.contains(tv)) {
seen += tv; b += tv
}
}
b.result()
}
// Try to acquire locks sequentially
def acquireAll(seq: List[TVar[Any]]): IO[Boolean] = seq match {
case Nil => IO.pure(true)
case tv :: tail =>
// modify stamp atomically: only set lockedBy when currently None and version <= trec.readVersion
tv.stamp.modify {
case s @ Stamp(v, None) if v <= trec.readVersion =>
(Stamp(s.version, Some(trec.token)), true)
case other =>
(other, false)
}.flatMap {
case true => acquireAll(tail)
case false =>
// release any locks we might already hold, then fail
releaseLocks(trec).as(false)
}
}
def releaseLocks(t: TRec): IO[Unit] = {
// For safety: release any TVar whose stamp.lockedBy == token
val touched = trec.writeSet.map(_.tvar.asInstanceOf[TVar[Any]]).distinct
touched.traverse_ { tv =>
tv.stamp.update(s => if (s.lockedBy.contains(t.token)) s.copy(lockedBy = None) else s)
}
}
def validateUnderLocks: IO[Boolean] =
trec.readSet.toList.traverse { stv =>
val tv = stv.tvar.asInstanceOf[TVar[Any]]
tv.stamp.get.map { s =>
// valid if locked by us (our token) or unlocked and version <= readVersion
s.lockedBy.forall(_ == trec.token) && s.version <= trec.readVersion
}
}.map(_.forall(identity))
def publishWrites: IO[Unit] =
for {
wv <- clockRef.updateAndGet(_ + 2L)
// apply every SomeWrite in trec.writeSet (newest-first or any order is fine once locks held)
_ <- trec.writeSet.traverse_ { sw =>
val tv = sw.tvar
val v = sw.value
for {
_ <- tv.value.set(v)
_ <- tv.stamp.set(Stamp(wv, None))
waiters <- tv.waiters.getAndSet(Nil)
_ <- waiters.traverse_(d => d.complete(()).void)
} yield ()
}
} yield ()
for {
// fast-path: read-only
result <- if trec.writeSet.isEmpty then IO.pure(true) else
for {
acq <- acquireAll(uniq)
ok <- if !acq then IO.pure(false) else
for {
valid <- validateUnderLocks
res <- if !valid then releaseLocks(trec).as(false) else publishWrites.as(true)
} yield res
} yield ok
} yield result
}
/** Subscribe to every TVar in readSet and return a gate Deferred to wait on */
private def subscribeAndSleep(trec: TRec): IO[Deferred[IO, Unit]] =
for {
gate <- Deferred[IO, Unit]
_ <- trec.readSet.toList.traverse_ { stv =>
val tv = stv.tvar.asInstanceOf[TVar[Any]]
tv.waiters.update(old => gate :: old)
}
} yield gate
}
object STMRuntime {
/** Create runtime with fresh clock and token counter */
def create: IO[STMRuntime] =
for {
clock <- Ref.of[IO, Long](0L)
tok <- Ref.of[IO, Long](0L)
} yield new STMRuntime(clock, tok)
}
/** ---- Example helpers & simple test usage ---- */
object Examples:
import STM.*
def demo: IO[Unit] = for {
rt <- STMRuntime.create
t <- TVar.of(42)
// a txn that increments and returns new value
inc = for {
x <- read(t)
newVal = x + 1
_ <- write(t, newVal)
} yield newVal
// run many increments concurrently
_ <- List.fill(10)(rt.atomically(inc)).parSequence
finalVal <- rt.atomically(read(t))
_ <- IO.println(s"final = $finalVal")
} yield ()
object Main extends IOApp.Simple:
def run: IO[Unit] = Examples.demo
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment