Last active
January 26, 2026 14:34
-
-
Save kodejuice/49cac36be9b02b06bfdf586640af1f96 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class gSST: | |
| """gSST allows efficient operations on sequences of multiple lengths N+M[0] to N+M[1] | |
| made up of symbols from 0 to A-1, sorted by increasing entropy. | |
| Entropy is a measure of the randomness or uncertainty in a dataset, | |
| computed from the frequency distribution of the symbols in a sequence. | |
| """ | |
| # MAX_PARTITION_COUNT = 100_000 | |
| MAX_ITERATIONS = 9_000_000 | |
| def __init__(self, N: int, A: int, M: range, debug=True, sort=True) -> None: | |
| """ | |
| Initializes gSST. | |
| Args: | |
| N (int): The length of the sequences. | |
| A (int): The number of distinct symbols. | |
| M (range): Multiple shaping order interval, the range of sequences lengths to be considered. | |
| Raises: | |
| AssertionError: If the total number of partitions is greater than the maximum partition count. | |
| """ | |
| if not M: | |
| M = range(1, N) | |
| self.N = N | |
| self.A = A | |
| self.M = M | |
| self.sequences_count = A ** N | |
| self.set_size = sum(A ** i for i in M) | |
| if debug: | |
| print(self.set_size, 'set size') | |
| # total = p(N, A) | |
| # assert total < self.MAX_PARTITION_COUNT, \ | |
| # "Size too large (reduce N)" | |
| # print([total, count_iterations(N, A, M)]) | |
| iterations_required = count_iterations(A, M) | |
| if debug: | |
| print('len', iterations_required) | |
| assert iterations_required < self.MAX_ITERATIONS, \ | |
| "Iterations too large (%d) [reduce M or |A|]" % (iterations_required) | |
| # Generate all integer partitions of numbers N+M[0] to N+M[1] with length A, | |
| # partitions = generate_multiple_sorted_partitions(N, A, tuple(M)) | |
| partitions = generate_multiple_sorted_partitions_range(N, A, M, sort) | |
| self.partitions = partitions | |
| # return | |
| # Create a prefix sum array of the sequence count for each partition | |
| sequences_count_prefix_sum = [0] * len(partitions) | |
| j = -1 | |
| set_cardinality = A**N | |
| for i in range(len(partitions)): | |
| prev = sequences_count_prefix_sum[i - 1] if i else 0 | |
| if prev >= set_cardinality and j == -1: | |
| # this number of partitions have enough sequences to | |
| # maintain bijection with our set A^N | |
| j = i | |
| self.min_partitions = j | |
| sequences_count_prefix_sum[i] = prev + \ | |
| compute_sequences_count(partitions[i]) | |
| # print('len', len(partitions)) | |
| self.sequences_count_sum = sequences_count_prefix_sum | |
| def nth_sequence(self, n): | |
| """ | |
| Returns the nth sequence from the set of all possible sequences, assuming they are sorted by increasing entropy. | |
| Args: | |
| n (int): The index of the sequence to retrieve. | |
| Returns: | |
| tuple: The nth sequence. | |
| Raises: | |
| AssertionError: If n is not within the valid range of [0, set_size). | |
| """ | |
| assert 0 <= n < self.set_size, "n must be less than set size and >=0" | |
| # Binary search on the sequences count prefix sum array to find the corresponding partition. | |
| index = 0 | |
| lo, hi = 0, len(self.sequences_count_sum) - 1 | |
| while lo <= hi: | |
| mid = (lo + hi) >> 1 | |
| s = self.sequences_count_sum[mid] | |
| if s <= n: | |
| lo = mid + 1 | |
| if s > n: | |
| index = mid | |
| hi = mid - 1 | |
| # Update n to the index within the found partition. | |
| if index > 0: | |
| prev_index = index - 1 if index else 0 | |
| n -= self.sequences_count_sum[prev_index] | |
| partition = self.partitions[index] | |
| return get_nth_sequence(partition, n, self.A) | |
| def sequence_index(self, seq): | |
| """ | |
| Returns the index of the given sequence in the set of all possible sequences sorted by increasing entropy. | |
| Args: | |
| seq (List[int]): The input sequence. | |
| Returns: | |
| int: The index of the input sequence in the sorted set. | |
| Raises: | |
| AssertionError: If the length of the input sequence is not equal to N, or if it contains invalid symbols. | |
| """ | |
| assert max(seq) <= self.A, "Invalid symbol found in sequence" | |
| # index of given sequence in its own generating partition | |
| sequence_index = compute_sequence_index(seq, self.A) | |
| # Find the index of the partition corresponding to the sorted input sequence. | |
| seq_sorted_partition = tuple(sorted(seq.count(symbol) | |
| for symbol in range(self.A))) | |
| partit_index = self.partitions.index(seq_sorted_partition) | |
| # If the partition index is greater than 0, update the sequence index by adding | |
| # the count of previous sequences. | |
| if partit_index > 0: | |
| prev_count = self.sequences_count_sum[partit_index - 1] | |
| sequence_index += prev_count | |
| return sequence_index | |
| def g(self, sequence): | |
| # assert len(sequence) == self.N, "Invalid sequence length" | |
| assert len(sequence) <= self.M.stop, "Invalid sequence length" | |
| assert max(sequence) <= self.A, "Invalid symbol found in sequence" | |
| index = to_base10(sequence, self.A) | |
| return self.nth_sequence(index) | |
| def g_inv(self, sequence): | |
| index = self.sequence_index(sequence) | |
| r = to_baseN(index, self.A) | |
| zeros = [0] * (self.N - len(r)) | |
| r = zeros + r # pad with 0 | |
| return r |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment