Last active
June 30, 2024 17:44
-
-
Save D4KU/c1e1be9485c41e6e4f21e66fbb3ec787 to your computer and use it in GitHub Desktop.
A tqdm subclass to highlight individual iterations in the progress bar
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from io import StringIO | |
| from math import ceil | |
| from tqdm import tqdm | |
| from colorama import Fore, Style, init | |
| # Replace the {bar} placeholder in the bar format with the given string | |
| def replace_bar(bar_format=None, bar=''): | |
| return '{l_bar}' + bar + '{r_bar}' \ | |
| if bar_format is None \ | |
| else bar_format.replace('{bar}', bar) | |
| # Estimate the char count of the bar meter | |
| def estimate_ncols_bar(pbar): | |
| # Find the length of the stats printout | |
| with StringIO() as strio: | |
| for _ in tqdm( | |
| iterable=range(1), | |
| desc=pbar.desc, | |
| total=pbar.total, | |
| file=strio, | |
| ncols=pbar.ncols, | |
| unit=pbar.unit, | |
| unit_scale=pbar.unit_scale, | |
| bar_format=replace_bar(pbar.bar_format), | |
| initial=0 if pbar.total is None else pbar.total - 1, | |
| postfix=pbar.postfix, | |
| unit_divisor=pbar.unit_divisor, | |
| ): | |
| pass | |
| # Subtract from total bar length | |
| return pbar.ncols - len(strio.getvalue().splitlines()[-1]) | |
| # Return a string to be inserted for the {bar} placeholder within | |
| # pbar.bar_format. Iterations passed as marks will be highlighted in the | |
| # given color. | |
| def marked_bar(pbar, marks, color, max_len): | |
| # Character with which to fill the bar | |
| char = '#' if pbar.ascii else '█' | |
| # Ratio of iterations to bar length | |
| # One character represents q iterations | |
| # One iteration is represented by 1/q characters | |
| q = pbar.total / max_len | |
| # Length the progress bar should have | |
| cur_len = ceil(pbar.n / q) | |
| # True if current part of bar will be marked | |
| marked = False | |
| # Construct the bar meter | |
| bar = '' | |
| for i in range(cur_len): | |
| # Indices the ith bar char represents | |
| ns = range(int(i * q), ceil((i + 1) * q)) | |
| mark = bool(marks.intersection(ns)) | |
| # # If range contains marked indices, color them | |
| if mark == marked: | |
| # keep current color (or no color) | |
| bar += char | |
| elif mark: | |
| # begin coloring characters | |
| bar += color + char | |
| else: | |
| # stop coloring characters | |
| bar += Style.RESET_ALL + char | |
| marked = mark | |
| if marked: | |
| bar += Style.RESET_ALL | |
| # Fill unprocessed iterations with space | |
| return bar + ' ' * (max_len - cur_len) | |
| class MarkedTqdm(tqdm): | |
| __slots__ = 'marks', 'color', 'ncols_bar' | |
| # *args passed on to base class | |
| # **kwargs passed on to base class | |
| # marks: Set of iterations to highlight | |
| # mark_colour: Color of markings | |
| # ncols_bar: Length of bar. Estimated if None. | |
| def __init__(self, *args, marks=set(), mark_colour=Fore.RED, | |
| ncols_bar=None, **kwargs): | |
| self.marks = None | |
| super().__init__(*args, **kwargs) | |
| if self.disable: | |
| return | |
| self.marks = marks | |
| self.color = mark_colour | |
| self.ncols_bar = estimate_ncols_bar(self) \ | |
| if ncols_bar is None else ncols_bar | |
| init() | |
| def display(self, *args, **kwargs): | |
| bar_format = self.bar_format | |
| # If total iterations are unknown, there is no bar meter to | |
| # manipulate | |
| if self.total and self.marks: | |
| bar = marked_bar(self, self.marks, self.color, self.ncols_bar) | |
| self.bar_format = replace_bar(bar_format, bar) | |
| super().display(*args, **kwargs) | |
| self.bar_format = bar_format |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment