Skip to content

Instantly share code, notes, and snippets.

@kodejuice
Last active January 26, 2026 14:34
Show Gist options
  • Select an option

  • Save kodejuice/49cac36be9b02b06bfdf586640af1f96 to your computer and use it in GitHub Desktop.

Select an option

Save kodejuice/49cac36be9b02b06bfdf586640af1f96 to your computer and use it in GitHub Desktop.
class gSST:
"""gSST allows efficient operations on sequences of multiple lengths N+M[0] to N+M[1]
made up of symbols from 0 to A-1, sorted by increasing entropy.
Entropy is a measure of the randomness or uncertainty in a dataset,
computed from the frequency distribution of the symbols in a sequence.
"""
# MAX_PARTITION_COUNT = 100_000
MAX_ITERATIONS = 9_000_000
def __init__(self, N: int, A: int, M: range, debug=True, sort=True) -> None:
"""
Initializes gSST.
Args:
N (int): The length of the sequences.
A (int): The number of distinct symbols.
M (range): Multiple shaping order interval, the range of sequences lengths to be considered.
Raises:
AssertionError: If the total number of partitions is greater than the maximum partition count.
"""
if not M:
M = range(1, N)
self.N = N
self.A = A
self.M = M
self.sequences_count = A ** N
self.set_size = sum(A ** i for i in M)
if debug:
print(self.set_size, 'set size')
# total = p(N, A)
# assert total < self.MAX_PARTITION_COUNT, \
# "Size too large (reduce N)"
# print([total, count_iterations(N, A, M)])
iterations_required = count_iterations(A, M)
if debug:
print('len', iterations_required)
assert iterations_required < self.MAX_ITERATIONS, \
"Iterations too large (%d) [reduce M or |A|]" % (iterations_required)
# Generate all integer partitions of numbers N+M[0] to N+M[1] with length A,
# partitions = generate_multiple_sorted_partitions(N, A, tuple(M))
partitions = generate_multiple_sorted_partitions_range(N, A, M, sort)
self.partitions = partitions
# return
# Create a prefix sum array of the sequence count for each partition
sequences_count_prefix_sum = [0] * len(partitions)
j = -1
set_cardinality = A**N
for i in range(len(partitions)):
prev = sequences_count_prefix_sum[i - 1] if i else 0
if prev >= set_cardinality and j == -1:
# this number of partitions have enough sequences to
# maintain bijection with our set A^N
j = i
self.min_partitions = j
sequences_count_prefix_sum[i] = prev + \
compute_sequences_count(partitions[i])
# print('len', len(partitions))
self.sequences_count_sum = sequences_count_prefix_sum
def nth_sequence(self, n):
"""
Returns the nth sequence from the set of all possible sequences, assuming they are sorted by increasing entropy.
Args:
n (int): The index of the sequence to retrieve.
Returns:
tuple: The nth sequence.
Raises:
AssertionError: If n is not within the valid range of [0, set_size).
"""
assert 0 <= n < self.set_size, "n must be less than set size and >=0"
# Binary search on the sequences count prefix sum array to find the corresponding partition.
index = 0
lo, hi = 0, len(self.sequences_count_sum) - 1
while lo <= hi:
mid = (lo + hi) >> 1
s = self.sequences_count_sum[mid]
if s <= n:
lo = mid + 1
if s > n:
index = mid
hi = mid - 1
# Update n to the index within the found partition.
if index > 0:
prev_index = index - 1 if index else 0
n -= self.sequences_count_sum[prev_index]
partition = self.partitions[index]
return get_nth_sequence(partition, n, self.A)
def sequence_index(self, seq):
"""
Returns the index of the given sequence in the set of all possible sequences sorted by increasing entropy.
Args:
seq (List[int]): The input sequence.
Returns:
int: The index of the input sequence in the sorted set.
Raises:
AssertionError: If the length of the input sequence is not equal to N, or if it contains invalid symbols.
"""
assert max(seq) <= self.A, "Invalid symbol found in sequence"
# index of given sequence in its own generating partition
sequence_index = compute_sequence_index(seq, self.A)
# Find the index of the partition corresponding to the sorted input sequence.
seq_sorted_partition = tuple(sorted(seq.count(symbol)
for symbol in range(self.A)))
partit_index = self.partitions.index(seq_sorted_partition)
# If the partition index is greater than 0, update the sequence index by adding
# the count of previous sequences.
if partit_index > 0:
prev_count = self.sequences_count_sum[partit_index - 1]
sequence_index += prev_count
return sequence_index
def g(self, sequence):
# assert len(sequence) == self.N, "Invalid sequence length"
assert len(sequence) <= self.M.stop, "Invalid sequence length"
assert max(sequence) <= self.A, "Invalid symbol found in sequence"
index = to_base10(sequence, self.A)
return self.nth_sequence(index)
def g_inv(self, sequence):
index = self.sequence_index(sequence)
r = to_baseN(index, self.A)
zeros = [0] * (self.N - len(r))
r = zeros + r # pad with 0
return r
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment