Skip to content

Instantly share code, notes, and snippets.

@kory33
Last active December 22, 2024 07:33
Show Gist options
  • Select an option

  • Save kory33/2969c287a61432227cefb3ac619814d8 to your computer and use it in GitHub Desktop.

Select an option

Save kory33/2969c287a61432227cefb3ac619814d8 to your computer and use it in GitHub Desktop.
R-Tree packing
//> using options -deprecation -feature -unchecked -Xfatal-warnings -no-indent -Xkind-projector:underscores
//> using dep "org.typelevel::cats-core::2.12.0"
import cats.data.NonEmptyVector
import cats.syntax.all.*
import cats.instances.string
import cats.kernel.Semigroup
case class XZInt2(x: Int, z: Int)
case class Int3(x: Int, y: Int, z: Int)
case class Double3(x: Double, y: Double, z: Double)
case class BoundingBox(p1: Int3, p2: Int3) {
def min: Int3 = Int3(p1.x.min(p2.x), p1.y.min(p2.y), p1.z.min(p2.z))
def max: Int3 = Int3(p1.x.max(p2.x), p1.y.max(p2.y), p1.z.max(p2.z))
def center: Double3 =
Double3((p1.x + p2.x) / 2, (p1.y + p2.y) / 2, (p1.z + p2.z) / 2)
def xWidth: Int = (p2.x - p1.x).abs
def yWidth: Int = (p2.y - p1.y).abs
def zWidth: Int = (p2.z - p1.z).abs
def boxBoundingThisAnd(other: BoundingBox): BoundingBox = {
val thisMin = this.min
val thisMax = this.max
val otherMin = other.min
val otherMax = other.max
BoundingBox(
Int3(
thisMin.x.min(otherMin.x),
thisMin.y.min(otherMin.y),
thisMin.z.min(otherMin.z)
),
Int3(
thisMax.x.max(otherMax.x),
thisMax.y.max(otherMax.y),
thisMax.z.max(otherMax.z)
)
)
}
def contains(p: Int3): Boolean = {
val thisMax = this.max
val thisMin = this.min
thisMin.x <= p.x && p.x <= thisMax.x &&
thisMin.y <= p.y && p.y <= thisMax.y &&
thisMin.z <= p.z && p.z <= thisMax.z
}
}
enum RLikeTree[LeafData] {
case Node(mbr: BoundingBox, subtrees: NonEmptyVector[RLikeTree[LeafData]])
case Leaf(mbr: BoundingBox, data: LeafData)
}
extension [L](rlt: RLikeTree[L]) {
def mbr: BoundingBox = rlt match {
case RLikeTree.Node(mbr, _) => mbr
case RLikeTree.Leaf(mbr, _) => mbr
}
def isWellFormedRTree(degree: Int) = {
import scala.util.boundary, boundary.break
boundary {
def unionSubtreeMbrAndCheckContainment(
tree: RLikeTree[L]
): BoundingBox = tree match {
case RLikeTree.Node(mbr, subtrees) =>
val subtreeMbrs = subtrees.map(unionSubtreeMbrAndCheckContainment)
val unionedMbr = subtreeMbrs.reduce(_.boxBoundingThisAnd(_))
if (mbr == unionedMbr) mbr else boundary.break(false)
case RLikeTree.Leaf(mbr, _) => mbr
}
def uniformDepth(tree: RLikeTree[L]): Int = tree match {
case RLikeTree.Node(_, tiles) =>
val depths = tiles.map(t => uniformDepth(t))
if (depths.distinct.size == 1) depths.head + 1
else boundary.break(false)
case RLikeTree.Leaf(_, _) => 1
}
def checkDegree(tree: RLikeTree[L]): Boolean = tree match {
case RLikeTree.Node(_, tiles) =>
if (tiles.size <= degree) tiles.forall(checkDegree)
else boundary.break(false)
case RLikeTree.Leaf(_, _) => true
}
{ unionSubtreeMbrAndCheckContainment(rlt); true }
&& uniformDepth(rlt) > 1
&& checkDegree(rlt)
}
}
def height: Int = rlt match {
case RLikeTree.Node(_, tiles) => tiles.map(_.height).reduce(_ max _) + 1
case RLikeTree.Leaf(_, _) => 1
}
def findRegionsContaining(p: Int3): Vector[RLikeTree.Leaf[L]] = {
val result = new scala.collection.mutable.ArrayBuffer[RLikeTree.Leaf[L]]()
def traverse(tree: RLikeTree[L]): Unit = tree match {
case RLikeTree.Node(_, subtrees) =>
subtrees.toVector.foreach { subtree =>
if (subtree.mbr.contains(p)) {
traverse(subtree)
}
}
case leaf @ RLikeTree.Leaf(mbr, _) =>
if (mbr.contains(p)) {
result.append(leaf)
}
}
traverse(rlt)
result.toVector
}
}
extension [A](v: NonEmptyVector[A]) {
def chunked(size: Int): NonEmptyVector[NonEmptyVector[A]] = {
require(size > 0)
val (firstChunk, firstRest) = v.toVector.splitAt(size)
var chunks = NonEmptyVector.of(NonEmptyVector.fromVectorUnsafe(firstChunk))
var remaining = firstRest
while (remaining.size > 0) {
val (chunk, rest) = remaining.splitAt(size)
chunks = chunks.append(NonEmptyVector.fromVectorUnsafe(chunk))
remaining = rest
}
chunks
}
def splitInto(pieceCount: Int): NonEmptyVector[NonEmptyVector[A]] = {
require(pieceCount > 0)
val pieceSize = v.length / pieceCount
val remainder = v.length % pieceCount
val (firstChunk, firstRest) =
v.toVector.splitAt(if (remainder > 0) pieceSize + 1 else pieceSize)
var chunks = NonEmptyVector.of(NonEmptyVector.fromVectorUnsafe(firstChunk))
var remaining = firstRest
(1 until pieceCount).foreach { i =>
val (chunk, rest) =
remaining.splitAt(if (i < remainder) then pieceSize + 1 else pieceSize)
chunks = chunks.append(NonEmptyVector.fromVectorUnsafe(chunk))
remaining = rest
}
assert(remaining.isEmpty)
assert(chunks.length == pieceCount)
assert(chunks.map(_.length).toVector.distinct.size <= 2)
chunks
}
}
def strPack[D](
leaves: NonEmptyVector[(BoundingBox, D)],
maxInternalDegree: Int
): RLikeTree[D] = {
require(maxInternalDegree > 1)
def packLayer[D1](
children: NonEmptyVector[(BoundingBox, D1)]
): NonEmptyVector[(BoundingBox, NonEmptyVector[(BoundingBox, D1)])] = {
import Math.*
val childrenWithCenter = children.map { case (mbr, data) =>
(mbr.center, mbr, data)
}
// slab count in this layer
val P = ceil(children.size.toDouble / maxInternalDegree).toInt
val S_X = ceil(pow(P, 3.0 / 8.0)).toInt
val S_Z = ceil(pow(P, 3.0 / 8.0)).toInt
import scala.util.boundary, boundary.break
val slabs = boundary {
if (childrenWithCenter.size <= maxInternalDegree) {
boundary.break(NonEmptyVector.of(childrenWithCenter))
}
val slicedByX = childrenWithCenter.sortBy(_._1.x).splitInto(S_X)
val slicedByXZ = slicedByX.flatMap(slab =>
if (slab.size <= maxInternalDegree) {
NonEmptyVector.of(slab)
} else {
slab.sortBy(_._1.z).splitInto(S_Z)
}
)
slicedByXZ.flatMap(_.sortBy(_._1.y).chunked(maxInternalDegree))
}
slabs.map { slab =>
val mbrOfSlab = slab.map(_._2).reduce(_.boxBoundingThisAnd(_))
(mbrOfSlab, slab.map { case (_, mbr, data) => (mbr, data) })
}
}
var topLayerNodes: NonEmptyVector[RLikeTree.Node[D]] = packLayer(leaves).map {
case (mbr, leavesTile) =>
RLikeTree.Node(mbr, leavesTile.map(RLikeTree.Leaf(_, _)))
}
while (topLayerNodes.size > 1) {
topLayerNodes = packLayer(topLayerNodes.map(n => (n.mbr, n.subtrees))).map {
case (mbr, children) =>
RLikeTree.Node(mbr, children.map(RLikeTree.Node(_, _)))
}
}
// topLayerNodes.size <= 1 so topLayerNodes is a singleton
topLayerNodes.head
}.ensuring(r => r.isWellFormedRTree(maxInternalDegree))
def visualizeAsSvg(rtree: RLikeTree[Int3]): String = {
val stringBuilder = new StringBuilder()
val xOffset = -rtree.mbr.min.x
val zOffset = -rtree.mbr.min.z
stringBuilder.append(
s"<svg viewBox=\"0 0 ${xOffset + rtree.mbr.max.x} ${zOffset + rtree.mbr.max.z}\" xmlns=\"http://www.w3.org/2000/svg\">"
)
def traverse(tree: RLikeTree[Int3]): Unit = tree match {
case RLikeTree.Node(mbr, subtrees) =>
stringBuilder.append(
s"<rect x=\"${xOffset + mbr.min.x}\" y=\"${zOffset + mbr.min.z}\" width=\"${mbr.xWidth}\" height=\"${mbr.zWidth}\" fill-opacity=\"7%\" fill=\"blue\" stroke=\"black\" stroke-width=\"1\" />"
)
subtrees.toVector.foreach(traverse)
case RLikeTree.Leaf(mbr, data) =>
stringBuilder.append(
s"<rect x=\"${xOffset + mbr.min.x}\" y=\"${zOffset + mbr.min.z}\" width=\"${mbr.xWidth}\" height=\"${mbr.zWidth}\" fill-opacity=\"50%\" fill=\"red\" stroke=\"black\" stroke-width=\"1\" />"
)
}
traverse(rtree)
stringBuilder.append("</svg>")
stringBuilder.result()
}
def analyzeCosts(rtree: RLikeTree[Int3]): Unit = {
def project(rtree: RLikeTree[Int3]): RLikeTree[XZInt2] = rtree match {
case RLikeTree.Node(mbr, subtrees) =>
RLikeTree.Node(
BoundingBox(
Int3(mbr.min.x, -1, mbr.min.z),
Int3(mbr.max.x, 1, mbr.max.z)
),
subtrees.map(project)
)
case RLikeTree.Leaf(mbr, data) =>
RLikeTree.Leaf(
BoundingBox(
Int3(mbr.min.x, -1, mbr.min.z),
Int3(mbr.max.x, 1, mbr.max.z)
),
XZInt2(data.x, data.z)
)
}
val projected = project(rtree)
case class IntStatFromInt3(stat: Int, attainedAt: Int3)
case class Statistics(
dataPointCount: Int,
worstOverlappingIntermediateNodeCount: IntStatFromInt3,
worstRegionMembershipTestCount: IntStatFromInt3,
meanOverlappingIntermediateNodeCount: Double,
meanRegionMembershipTestCount: Double
) {
require(dataPointCount > 0)
}
given Semigroup[Statistics] with {
def combine(x: Statistics, y: Statistics): Statistics =
Statistics(
x.dataPointCount + y.dataPointCount,
if (
x.worstOverlappingIntermediateNodeCount.stat > y.worstOverlappingIntermediateNodeCount.stat
) then x.worstOverlappingIntermediateNodeCount
else y.worstOverlappingIntermediateNodeCount,
if (
x.worstRegionMembershipTestCount.stat > y.worstRegionMembershipTestCount.stat
) then x.worstRegionMembershipTestCount
else y.worstRegionMembershipTestCount,
(x.meanOverlappingIntermediateNodeCount * x.dataPointCount + y.meanOverlappingIntermediateNodeCount * y.dataPointCount) / (x.dataPointCount + y.dataPointCount),
(x.meanRegionMembershipTestCount * x.dataPointCount + y.meanRegionMembershipTestCount * y.dataPointCount) / (x.dataPointCount + y.dataPointCount)
)
}
def statsAtSinglePoint(p: Int3): Statistics = {
var overlappingIntermediateNodeCount: Int = 0
var regionMembershipTestCount: Int = 0
def traverse(tree: RLikeTree[Int3]): Unit = tree match {
case RLikeTree.Node(_, subtrees) =>
subtrees.toVector.foreach { subtree =>
if (subtree.mbr.contains(p)) {
overlappingIntermediateNodeCount += 1
traverse(subtree)
}
regionMembershipTestCount += 1
}
case RLikeTree.Leaf(_, _) =>
}
traverse(rtree)
Statistics(
dataPointCount = 1,
worstOverlappingIntermediateNodeCount = IntStatFromInt3(
stat = overlappingIntermediateNodeCount,
attainedAt = p
),
worstRegionMembershipTestCount = IntStatFromInt3(
stat = regionMembershipTestCount,
attainedAt = p
),
meanOverlappingIntermediateNodeCount =
overlappingIntermediateNodeCount.toDouble,
meanRegionMembershipTestCount = regionMembershipTestCount.toDouble
)
}
println("Tree height: " + rtree.height)
val sampledStats = Semigroup.combineAllOption {
(rtree.mbr.min.x to rtree.mbr.max.x).iterator.flatMap { x =>
(rtree.mbr.min.z to rtree.mbr.max.z).iterator.flatMap { z =>
val regionsCotaniningXZWhenProjected =
projected.findRegionsContaining(Int3(x, 0, z))
val bottomOfRegions =
regionsCotaniningXZWhenProjected.map {
case RLikeTree.Leaf(regionMbr, data) => regionMbr.min.y
}.distinct
bottomOfRegions.iterator.map { y =>
statsAtSinglePoint(Int3(x, y, z))
}
}
}
}.get
println("Sampled stats: " + sampledStats)
}
val lines = io.Source.fromFile("./chest-coords.txt").getLines()
val dataPoints = lines.flatMap { line =>
try {
val Array(x, y, z) = line.split(",").map(_.toInt)
Some(Int3(x, y, z))
} catch {
case _ => None
}
}.toVector
val regions = dataPoints.map { p =>
(
BoundingBox(
Int3(p.x - 15, p.y - 15, p.z - 15),
Int3(p.x + 15, p.y + 15, p.z + 15)
),
p
)
}
val rtree = strPack(NonEmptyVector.fromVectorUnsafe(regions), 3)
// analyze costs with varying maxInternalDegree
List(2, 3, 4, 5, 6, 7, 8, 9, 10).foreach { maxInternalDegree =>
println("-----")
println(
s"Costs analysis for strPack(maxInternalDegree = ${maxInternalDegree}):"
)
analyzeCosts(
strPack(NonEmptyVector.fromVectorUnsafe(regions), maxInternalDegree)
)
}
import java.nio.file.*
import java.nio.charset.StandardCharsets
Files.write(
Paths.get("./strpack-rtree.svg"),
visualizeAsSvg(rtree).getBytes(StandardCharsets.UTF_8)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment