Created
October 16, 2025 19:25
-
-
Save algebraic-dev/3204c1075268e695ba9fe84293c92d55 to your computer and use it in GitHub Desktop.
LZ algorithm in LEAN
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import Std.Data.HashMap | |
| namespace LZ | |
| /-- | |
| Encoder state for compression | |
| -/ | |
| structure EncodeState where | |
| /-- | |
| Dictionary mapping known sequences to their unique IDs for encoding | |
| -/ | |
| sequenceDict : Std.HashMap ByteArray Nat := ∅ | |
| /-- | |
| Set of single characters that need to be added to the dictionary (first occurrence) | |
| -/ | |
| pendingChars : Std.HashMap ByteArray Unit := ∅ | |
| /-- | |
| The current sequence being built during compression | |
| -/ | |
| currentSequence : ByteArray := .empty | |
| /-- | |
| Counter tracking when to expand the bit width (decrements each time a code is written) | |
| -/ | |
| bitsUntilExpand : Nat := 2 | |
| /-- | |
| Next available dictionary ID to assign to new sequences | |
| -/ | |
| dictSize : Nat := 3 | |
| /-- | |
| Current number of bits used to encode dictionary references | |
| -/ | |
| currentBitWidth : Nat := 2 | |
| /-- | |
| Accumulated output characters forming the compressed result | |
| -/ | |
| outputBuffer : ByteArray := ∅ | |
| /-- | |
| Temporary accumulator for building bit patterns before they're converted to characters | |
| -/ | |
| dataVal : Nat := 0 | |
| /-- | |
| Current position in the bit buffer (0 to bitsPerChar-1) | |
| -/ | |
| dataPosition : Nat := 0 | |
| deriving Inhabited | |
| namespace EncodeState | |
| /-- | |
| Write bits to the output buffer | |
| -/ | |
| private def writeBits (ctx : EncodeState) (value : Nat) (numBits : Nat) (bitsPerChar : Nat) (transform : Nat → UInt8) : EncodeState := | |
| let rec loop (i : Nat) (val : Nat) (ctx : EncodeState) : EncodeState := | |
| if i >= numBits then ctx | |
| else | |
| let dataVal := (ctx.dataVal <<< 1) ||| (val &&& 1) | |
| let ctx := | |
| if ctx.dataPosition == bitsPerChar - 1 then | |
| { ctx with | |
| dataPosition := 0, | |
| outputBuffer := ctx.outputBuffer.push (transform dataVal), | |
| dataVal := 0 | |
| } | |
| else | |
| { ctx with dataPosition := ctx.dataPosition + 1, dataVal := dataVal } | |
| loop (i + 1) (val >>> 1) ctx | |
| loop 0 value ctx | |
| /-- | |
| Write zero bits to the output buffer | |
| -/ | |
| private def writeZeroBits (ctx : EncodeState) (numBits : Nat) (bitsPerChar : Nat) (transform : Nat → UInt8) : EncodeState := | |
| let rec loop (i : Nat) (ctx : EncodeState) : EncodeState := | |
| if i >= numBits then ctx | |
| else | |
| let dataVal := ctx.dataVal <<< 1 | |
| let ctx := | |
| if ctx.dataPosition == bitsPerChar - 1 then | |
| { ctx with | |
| dataPosition := 0, | |
| outputBuffer := ctx.outputBuffer.push (transform dataVal), | |
| dataVal := 0 | |
| } | |
| else | |
| { ctx with dataPosition := ctx.dataPosition + 1, dataVal := dataVal } | |
| loop (i + 1) ctx | |
| loop 0 ctx | |
| /-- | |
| Check and update enlargeIn counter | |
| -/ | |
| private def checkEnlargeIn (ctx : EncodeState) : EncodeState := | |
| let newBitsUntilExpand := ctx.bitsUntilExpand - 1 | |
| if newBitsUntilExpand == 0 then | |
| { ctx with | |
| bitsUntilExpand := 2 ^ ctx.currentBitWidth, | |
| currentBitWidth := ctx.currentBitWidth + 1 | |
| } | |
| else | |
| { ctx with bitsUntilExpand := newBitsUntilExpand } | |
| /-- | |
| Process a new character (first occurrence) | |
| -/ | |
| private def processNewChar (ctx : EncodeState) (char : ByteArray) (bitsPerChar : Nat) (transform : Nat → UInt8) : EncodeState := | |
| let charByte : UInt8 := if char.size > 0 then char[0]! else 0 | |
| let charCode : Nat := charByte.toNat | |
| -- UInt8 is always < 256, so we always take the first branch | |
| let ctx := writeZeroBits ctx ctx.currentBitWidth bitsPerChar transform | |
| let ctx := writeBits ctx charCode 8 bitsPerChar transform | |
| let ctx := checkEnlargeIn ctx | |
| { ctx with pendingChars := ctx.pendingChars.erase char } | |
| /-- | |
| Process an existing word | |
| -/ | |
| private def processExistingWord (ctx : EncodeState) (word : ByteArray) (bitsPerChar : Nat) (transform : Nat → UInt8) : EncodeState := | |
| match ctx.sequenceDict.get? word with | |
| | some value => writeBits ctx value ctx.currentBitWidth bitsPerChar transform | |
| | none => ctx | |
| /-- | |
| Process a word (either new or existing) | |
| -/ | |
| private def processWord (ctx : EncodeState) (word : ByteArray) (bitsPerChar : Nat) (transform : Nat → UInt8) : EncodeState := | |
| let ctx := | |
| if ctx.pendingChars.contains word then | |
| processNewChar ctx word bitsPerChar transform | |
| else | |
| processExistingWord ctx word bitsPerChar transform | |
| checkEnlargeIn ctx | |
| /-- | |
| Process a single character during compression | |
| -/ | |
| private def processChar (ctx : EncodeState) (char : UInt8) (bitsPerChar : Nat) (transform : Nat → UInt8) : EncodeState := | |
| let charWord := ByteArray.mk #[char] | |
| let ctx := | |
| if !ctx.sequenceDict.contains charWord then | |
| { ctx with | |
| sequenceDict := ctx.sequenceDict.insert charWord ctx.dictSize, | |
| dictSize := ctx.dictSize + 1, | |
| pendingChars := ctx.pendingChars.insert charWord () | |
| } | |
| else | |
| ctx | |
| let currentWord := ctx.currentSequence ++ charWord | |
| if ctx.sequenceDict.contains currentWord then | |
| { ctx with currentSequence := currentWord } | |
| else | |
| let ctx := processWord ctx ctx.currentSequence bitsPerChar transform | |
| { ctx with | |
| sequenceDict := ctx.sequenceDict.insert currentWord ctx.dictSize, | |
| dictSize := ctx.dictSize + 1, | |
| currentSequence := charWord | |
| } | |
| /-- | |
| Flush remaining bits to output | |
| -/ | |
| private partial def flushRemainingBits (ctx : EncodeState) (bitsPerChar : Nat) (transform : Nat → UInt8) : EncodeState := | |
| let rec loop (ctx : EncodeState) : EncodeState := | |
| let dataVal := ctx.dataVal <<< 1 | |
| if ctx.dataPosition == bitsPerChar - 1 then | |
| { ctx with outputBuffer := ctx.outputBuffer.push (transform dataVal) } | |
| else | |
| loop { ctx with dataPosition := ctx.dataPosition + 1, dataVal := dataVal } | |
| loop ctx | |
| /-- | |
| Main compression loop | |
| -/ | |
| private def loopCompress (ba : ByteArray) (bitsPerChar : Nat) (transform : Nat → UInt8) (ctx : EncodeState) : EncodeState := | |
| let ctx := ba.foldl (fun ctx b => processChar ctx b bitsPerChar transform) ctx | |
| let ctx := if ctx.currentSequence.size > 0 then processWord ctx ctx.currentSequence bitsPerChar transform else ctx | |
| let ctx := writeBits ctx 2 ctx.currentBitWidth bitsPerChar transform | |
| flushRemainingBits ctx bitsPerChar transform | |
| end EncodeState | |
| /-- | |
| This function compresses a `String` using LZ algorithm. | |
| -/ | |
| def compress (uncompressed : String) (bitsPerChar : Nat) (transform : Nat → Char) : String := | |
| if uncompressed.isEmpty then | |
| "" | |
| else | |
| let ba := uncompressed.toUTF8 | |
| let finalState := EncodeState.loopCompress ba bitsPerChar (Char.toUInt8 ∘ transform) {} | |
| String.fromUTF8! finalState.outputBuffer | |
| /-- | |
| This function compresses a `String` using LZ algorithm to URI Component. | |
| -/ | |
| def compressToURIComponent (uncompressed : String) : String := | |
| compress uncompressed 6 ("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+-$".data[·]!) | |
| end LZ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment