Created
August 31, 2025 00:21
-
-
Save horothesun/c8b72bb178ef4b7a66545f7f1e19ffd0 to your computer and use it in GitHub Desktop.
TL2 STM experiments
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| //> using scala 3.7.2 | |
| //> using jvm temurin:21 | |
| // | |
| //> using dep org.typelevel::cats-core:2.13.0 | |
| //> using dep org.typelevel::cats-effect:3.6.3 | |
| // | |
| //> using test.dep org.scalameta::munit:1.1.1 | |
| //> using test.dep org.typelevel::munit-cats-effect:2.1.0 | |
| //> using test.dep org.scalameta::munit-scalacheck:1.1.0 | |
| //> using test.dep org.scalacheck::scalacheck:1.18.1 | |
| //> using test.dep org.typelevel::scalacheck-effect-munit:1.0.4 | |
| //> using test.dep org.typelevel::cats-effect-testkit:3.6.3 | |
| // | |
| import cats.effect._ | |
| import cats.syntax.all._ | |
| import scala.collection.mutable | |
| /** ---- TVar + Stamp ---- | |
| * Stamp encodes a version and optionally the token (Long) that owns the lock. | |
| */ | |
| private case class Stamp(version: Long, lockedBy: Option[Long]) | |
| final class TVar[A] private ( | |
| val value: Ref[IO, A], | |
| val stamp: Ref[IO, Stamp], | |
| val waiters: Ref[IO, List[Deferred[IO, Unit]]] | |
| ) | |
| object TVar { | |
| def of[A](a: A): IO[TVar[A]] = | |
| for { | |
| v <- Ref.of[IO, A](a) | |
| s <- Ref.of[IO, Stamp](Stamp(0L, None)) | |
| w <- Ref.of[IO, List[Deferred[IO, Unit]]](Nil) | |
| } yield new TVar[A](v, s, w) | |
| } | |
| /** ---- Existential wrappers (no Any in TRec) ---- | |
| * SomeTVar and SomeWrite encode heterogeneous collections safely. | |
| */ | |
| sealed trait SomeTVar { | |
| type A | |
| val tvar: TVar[A] | |
| } | |
| object SomeTVar { | |
| def apply[X](tv: TVar[X]): SomeTVar { type A = X } = | |
| new SomeTVar { type A = X; val tvar = tv } | |
| } | |
| sealed trait SomeWrite { | |
| type A | |
| val tvar: TVar[A] | |
| val value: A | |
| /** safe equality on TVar identity */ | |
| def sameTVar(other: SomeTVar): Boolean = | |
| (this.tvar eq other.tvar) | |
| } | |
| object SomeWrite { | |
| def apply[X](tv: TVar[X], v: X): SomeWrite { type A = X } = | |
| new SomeWrite { | |
| type A = X | |
| val tvar = tv | |
| val value = v | |
| } | |
| } | |
| /** ---- STM AST ---- */ | |
| sealed trait STM[+A] | |
| object STM { | |
| final case class Pure[A](a: A) extends STM[A] | |
| final case class Read[A](t: TVar[A]) extends STM[A] | |
| final case class Write[A](t: TVar[A], a: A) extends STM[Unit] | |
| final case class New[A](a: A) extends STM[TVar[A]] | |
| case object Retry extends STM[Nothing] | |
| final case class OrElse[A](left: STM[A], right: STM[A]) extends STM[A] | |
| final case class Bind[A,B](fa: STM[A], f: A => STM[B]) extends STM[B] | |
| def pure[A](a: A): STM[A] = Pure(a) | |
| def read[A](t: TVar[A]): STM[A] = Read(t) | |
| def write[A](t: TVar[A], a: A): STM[Unit] = Write(t,a) | |
| def newTVar[A](a: A): STM[TVar[A]] = New(a) | |
| val retry: STM[Nothing] = Retry | |
| def orElse[A](l: STM[A], r: STM[A]): STM[A] = OrElse(l,r) | |
| extension [A](s: STM[A]) | |
| def flatMap[B](f: A => STM[B]): STM[B] = Bind(s,f) | |
| def map[B](f: A => B): STM[B] = flatMap(a => Pure(f(a))) | |
| } | |
| /** ---- TRec without Any ---- */ | |
| private final class TRec( | |
| var readVersion: Long, | |
| val readSet: mutable.ArrayBuffer[SomeTVar], | |
| var writeSet: List[SomeWrite], // newest-first | |
| var frames: List[OrFrame[?]], | |
| val token: Long | |
| ) | |
| private final case class OrFrame[A](writeSnapshot: List[SomeWrite], alt: STM[A]) | |
| /** ---- STM runtime (TL2-style) ---- */ | |
| final class STMRuntime private ( | |
| clockRef: Ref[IO, Long], | |
| tokenCounter: Ref[IO, Long] | |
| ) { | |
| /** Create a new unique token id for each attempt */ | |
| private def nextToken: IO[Long] = | |
| tokenCounter.updateAndGet(_ + 1L) | |
| /** Public entry: run STM atomically */ | |
| def atomically[A](prog: STM[A]): IO[A] = | |
| attemptOnce(prog).flatMap { | |
| case Right(a) => IO.pure(a) | |
| case Left(gate) => gate.get *> atomically(prog) | |
| } | |
| /** Try one attempt. If top-level retry: return Left(gate). If commit succeeded: Right(result). | |
| * On abort, recursively start a fresh attempt (tail recursion via flatMap). | |
| */ | |
| private def attemptOnce[A](prog: STM[A]): IO[Either[Deferred[IO,Unit], A]] = | |
| for { | |
| tok <- nextToken | |
| rv <- clockRef.get | |
| trec = new TRec( | |
| readVersion = rv, | |
| readSet = mutable.ArrayBuffer.empty, | |
| writeSet = Nil, | |
| frames = Nil, | |
| token = tok | |
| ) | |
| res <- evaluate(prog, trec) | |
| out <- res match { | |
| case Left(g) => IO.pure(Left(g)) | |
| case Right(v) => | |
| commit(trec).flatMap { | |
| case true => IO.pure(Right(v)) | |
| case false => attemptOnce(prog) // abort -> fresh attempt | |
| } | |
| } | |
| } yield out | |
| /** Effectful evaluator (stack/continuation based) that builds read/write sets and handles orElse/retry. */ | |
| private def evaluate[A](prog: STM[A], trec: TRec): IO[Either[Deferred[IO,Unit], A]] = { | |
| import STM._ | |
| // continuation stack: functions Any => STM[Any], but we avoid casting here by storing as | |
| // value => STM[Any]; conversions will be done safely when invoking. | |
| val cont = mutable.ArrayBuffer.empty[Any => STM[Any]] | |
| def pushCont(f: Any => STM[Any]): Unit = cont.prepend(f) | |
| def popCont(): Option[Any => STM[Any]] = if cont.isEmpty then None else Some(cont.remove(0)) | |
| def step(cur: STM[Any]): IO[Either[Deferred[IO,Unit], Any]] = cur match { | |
| case Pure(a) => | |
| popCont() match { | |
| case None => IO.pure(Right(a)) | |
| case Some(f) => step(f(a)) | |
| } | |
| case Read(t) => | |
| // first check pending write-set for this TVar: if found, return its value (typed) | |
| findWriteFor[Any, Any](t.asInstanceOf[TVar[Any]], trec.writeSet) match { | |
| case Some(value) => | |
| step(Pure(value)) | |
| case None => | |
| // TL2-style read (effectful) | |
| readTVar(t.asInstanceOf[TVar[Any]], trec).flatMap { | |
| case Some(v) => step(Pure(v)) | |
| case None => IO.raiseError(new IllegalStateException("readTVar transient failure")) | |
| } | |
| } | |
| case Write(t, a) => | |
| // push some write (type-respecting) | |
| trec.writeSet = SomeWrite(t.asInstanceOf[TVar[Any]], a.asInstanceOf[Any]) :: trec.writeSet | |
| step(Pure(())) | |
| case New(a) => | |
| TVar.of(a).flatMap(tv => step(Pure(tv))) | |
| case Retry => | |
| trec.frames match { | |
| case OrFrame(snap, alt) :: rest => | |
| // rollback writes to snapshot and run alternate | |
| trec.writeSet = snap | |
| trec.frames = rest | |
| step(alt.asInstanceOf[STM[Any]]) | |
| case Nil => | |
| // top-level retry => subscribe and return Deferred gate | |
| subscribeAndSleep(trec).map(Left(_)) | |
| } | |
| case OrElse(l, r) => | |
| // snapshot current writeSet and push frame | |
| trec.frames = OrFrame(trec.writeSet, r) :: trec.frames | |
| step(l.asInstanceOf[STM[Any]]) | |
| case Bind(fa, f) => | |
| pushCont((a: Any) => f.asInstanceOf[Any => STM[Any]](a)) | |
| step(fa.asInstanceOf[STM[Any]]) | |
| } | |
| step(prog.asInstanceOf[STM[Any]]).map { | |
| case Left(g) => Left(g) | |
| case Right(a) => Right(a.asInstanceOf[A]) | |
| } | |
| } | |
| /** Find a pending write for a TVar[A] in writeSet and return value as A if present. | |
| * We must do a cast at this point, but it's safe: SomeWrite was constructed with that same TVar. | |
| */ | |
| private def findWriteFor[A, B](t: TVar[A], ws: List[SomeWrite]): Option[A] = { | |
| ws.collectFirst { | |
| case sw if (sw.tvar eq t) => | |
| // localized cast: sw.value has type sw.A which corresponds to sw.tvar.A, | |
| // and we tested identity of tvar, so cast to the requested A is safe. | |
| sw.value.asInstanceOf[A] | |
| } | |
| } | |
| /** TL2-style read: sample stamp/value/stamp, check <= rv and unlocked; else try timestamp extension. | |
| * Returns Some(value) on success, None to indicate a transient failure (caller may re-run attempt). | |
| */ | |
| private def readTVar[A](t: TVar[A], trec: TRec): IO[Option[A]] = { | |
| // read stamp, value, re-read stamp | |
| for { | |
| s1 <- t.stamp.get | |
| v <- t.value.get | |
| s2 <- t.stamp.get | |
| res <- { | |
| val lockedByOther = s1.lockedBy.exists(_ != trec.token) | |
| if (s1 == s2 && !lockedByOther && s2.version <= trec.readVersion) { | |
| trec.readSet += SomeTVar(t) | |
| IO.pure(Some(v)) | |
| } else { | |
| // attempt timestamp extension by reading clock and validating readSet | |
| for { | |
| latest <- clockRef.get | |
| ok <- validateReadSet(trec, latest) | |
| out <- if ok then | |
| IO { | |
| trec.readVersion = latest | |
| } *> { | |
| // re-check t's stamp & value under new rv | |
| for { | |
| s3 <- t.stamp.get | |
| vv <- t.value.get | |
| } yield { | |
| val lockedByOther3 = s3.lockedBy.exists(_ != trec.token) | |
| if (!lockedByOther3 && s3.version <= trec.readVersion) { | |
| trec.readSet += SomeTVar(t); Some(vv) | |
| } else None | |
| } | |
| } | |
| else IO.pure(None) | |
| } yield out | |
| } | |
| } | |
| } yield res | |
| } | |
| /** Validate readSet under candidate rv. Must run effectfully (no unsafeRunSync). */ | |
| private def validateReadSet(trec: TRec, rv: Long): IO[Boolean] = { | |
| trec.readSet.toList.traverse { stv => | |
| val tv = stv.tvar.asInstanceOf[TVar[Any]] | |
| tv.stamp.get.map { s => | |
| val lockedOk = s.lockedBy.forall(_ == trec.token) | |
| lockedOk && s.version <= rv | |
| } | |
| }.map(_.forall(identity)) | |
| } | |
| /** Commit: lock write TVars, validate readSet (locks count as valid), bump clock (unless read-only), | |
| * publish writes and wake waiters. Returns true on success, false on abort. | |
| */ | |
| private def commit(trec: TRec): IO[Boolean] = { | |
| val writesOldestFirst = trec.writeSet.reverse | |
| // unique TVars preserving order (oldest->newest), avoid duplicate lock attempts | |
| val uniq: List[TVar[Any]] = { | |
| val seen = mutable.HashSet.empty[TVar[Any]] | |
| val b = List.newBuilder[TVar[Any]] | |
| for ((sw: SomeWrite) <- writesOldestFirst) { | |
| val tv = sw.tvar.asInstanceOf[TVar[Any]] | |
| if (!seen.contains(tv)) { | |
| seen += tv; b += tv | |
| } | |
| } | |
| b.result() | |
| } | |
| // Try to acquire locks sequentially | |
| def acquireAll(seq: List[TVar[Any]]): IO[Boolean] = seq match { | |
| case Nil => IO.pure(true) | |
| case tv :: tail => | |
| // modify stamp atomically: only set lockedBy when currently None and version <= trec.readVersion | |
| tv.stamp.modify { | |
| case s @ Stamp(v, None) if v <= trec.readVersion => | |
| (Stamp(s.version, Some(trec.token)), true) | |
| case other => | |
| (other, false) | |
| }.flatMap { | |
| case true => acquireAll(tail) | |
| case false => | |
| // release any locks we might already hold, then fail | |
| releaseLocks(trec).as(false) | |
| } | |
| } | |
| def releaseLocks(t: TRec): IO[Unit] = { | |
| // For safety: release any TVar whose stamp.lockedBy == token | |
| val touched = trec.writeSet.map(_.tvar.asInstanceOf[TVar[Any]]).distinct | |
| touched.traverse_ { tv => | |
| tv.stamp.update(s => if (s.lockedBy.contains(t.token)) s.copy(lockedBy = None) else s) | |
| } | |
| } | |
| def validateUnderLocks: IO[Boolean] = | |
| trec.readSet.toList.traverse { stv => | |
| val tv = stv.tvar.asInstanceOf[TVar[Any]] | |
| tv.stamp.get.map { s => | |
| // valid if locked by us (our token) or unlocked and version <= readVersion | |
| s.lockedBy.forall(_ == trec.token) && s.version <= trec.readVersion | |
| } | |
| }.map(_.forall(identity)) | |
| def publishWrites: IO[Unit] = | |
| for { | |
| wv <- clockRef.updateAndGet(_ + 2L) | |
| // apply every SomeWrite in trec.writeSet (newest-first or any order is fine once locks held) | |
| _ <- trec.writeSet.traverse_ { sw => | |
| val tv = sw.tvar | |
| val v = sw.value | |
| for { | |
| _ <- tv.value.set(v) | |
| _ <- tv.stamp.set(Stamp(wv, None)) | |
| waiters <- tv.waiters.getAndSet(Nil) | |
| _ <- waiters.traverse_(d => d.complete(()).void) | |
| } yield () | |
| } | |
| } yield () | |
| for { | |
| // fast-path: read-only | |
| result <- if trec.writeSet.isEmpty then IO.pure(true) else | |
| for { | |
| acq <- acquireAll(uniq) | |
| ok <- if !acq then IO.pure(false) else | |
| for { | |
| valid <- validateUnderLocks | |
| res <- if !valid then releaseLocks(trec).as(false) else publishWrites.as(true) | |
| } yield res | |
| } yield ok | |
| } yield result | |
| } | |
| /** Subscribe to every TVar in readSet and return a gate Deferred to wait on */ | |
| private def subscribeAndSleep(trec: TRec): IO[Deferred[IO, Unit]] = | |
| for { | |
| gate <- Deferred[IO, Unit] | |
| _ <- trec.readSet.toList.traverse_ { stv => | |
| val tv = stv.tvar.asInstanceOf[TVar[Any]] | |
| tv.waiters.update(old => gate :: old) | |
| } | |
| } yield gate | |
| } | |
| object STMRuntime { | |
| /** Create runtime with fresh clock and token counter */ | |
| def create: IO[STMRuntime] = | |
| for { | |
| clock <- Ref.of[IO, Long](0L) | |
| tok <- Ref.of[IO, Long](0L) | |
| } yield new STMRuntime(clock, tok) | |
| } | |
| /** ---- Example helpers & simple test usage ---- */ | |
| object Examples: | |
| import STM.* | |
| def demo: IO[Unit] = for { | |
| rt <- STMRuntime.create | |
| t <- TVar.of(42) | |
| // a txn that increments and returns new value | |
| inc = for { | |
| x <- read(t) | |
| newVal = x + 1 | |
| _ <- write(t, newVal) | |
| } yield newVal | |
| // run many increments concurrently | |
| _ <- List.fill(10)(rt.atomically(inc)).parSequence | |
| finalVal <- rt.atomically(read(t)) | |
| _ <- IO.println(s"final = $finalVal") | |
| } yield () | |
| object Main extends IOApp.Simple: | |
| def run: IO[Unit] = Examples.demo |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment