Last active
March 12, 2025 21:08
-
-
Save tmck-code/0ec5d579884bb7779fc4e9474fd3e363 to your computer and use it in GitHub Desktop.
Python helpers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from functools import wraps | |
| from collections import Counter, namedtuple | |
| import pickle | |
| import operator | |
| import time, sys, os | |
| import pp | |
| Test = namedtuple('Test', 'args kwargs expected n') | |
| class NoExpectation: pass | |
| def set_function_module(func): | |
| 'Set the module of a function' | |
| if func.__module__ != '__main__': | |
| return | |
| module = sys.modules[func.__module__] | |
| if hasattr(module, '__file__'): | |
| # if the function is defined in the main module, set the module to the filename | |
| func.__module__ = os.path.basename(module.__file__).split('.')[0] | |
| else: | |
| # if the module is not a file, set the module to the current directory | |
| func.__module__ = os.path.basename(os.getcwd()) | |
| def timeit_func(func, args, kwargs, expected: object = NoExpectation, n: int = 10_000): | |
| 'Time a function with arguments and return the result, whether it is correct, and the times' | |
| times = Counter() | |
| # some functions may modify the input arguments, so a new copy is needed for every test | |
| # "pickle" is used instead of "deepcopy" as it's much faster | |
| margs = pickle.dumps(args) | |
| # ensure that the function module is meaningful (replace it if it's just "__main__") | |
| set_function_module(func) | |
| for i in range(n): | |
| try: | |
| start = time.time() | |
| func(*pickle.loads(margs), **kwargs) | |
| except Exception as e: | |
| pass | |
| finally: | |
| end = time.time() | |
| times[end-start] += 1 | |
| try: | |
| result = func(*pickle.loads(margs), **kwargs) | |
| except Exception as e: | |
| result = e | |
| return result, expected is NoExpectation or result == expected, times | |
| def _sum_times(times: Counter) -> float: | |
| 'sum the values*counts in a Counter' | |
| return sum(map(operator.mul, *zip(*times.items()))) | |
| TEST_STATUS = { | |
| False: pp.ps('fail', 'red'), | |
| True: pp.ps('pass', 'green'), | |
| } | |
| def _test(func, test, width=1): | |
| 'Run a timed test on a function and print the result' | |
| result, correct, times = timeit_func(func, *test) | |
| fail_sep, status_msg = '\n', '' | |
| if not correct: | |
| if os.get_terminal_size().columns >= 100: | |
| fail_sep = ' ' | |
| status_msg = pp.ps(f'{fail_sep}>> {result=}', 'yellow') | |
| print('{module:s}.{func_name:<{width}s}{status:s}, time: {times:.6f}{status_msg:s}'.format(**{ | |
| 'module': func.__module__, | |
| 'func_name': func.__name__ + ', ', | |
| 'times': _sum_times(times), | |
| 'status': TEST_STATUS[correct], | |
| 'status_msg': status_msg, | |
| 'width': width+2, | |
| })) | |
| def timeit(*args, n=10_000, **kwargs): | |
| 'Decorator to time a function' | |
| def decorator_with_args(func): | |
| @wraps(func) | |
| def wrapper(*args, **kwargs): | |
| _test(func, Test(args, kwargs, NoExpectation, n), width=len(func.__name__)) | |
| return wrapper | |
| return decorator_with_args | |
| def test(tests, funcs, n=10_000): | |
| 'Run a series of timed tests on a list of functions' | |
| width = max(len(func.__name__) for func in funcs) | |
| s = '' | |
| for test in tests: | |
| test = Test(*test, n=n) | |
| print('{s:s}{border:s}\n{n_s:s}: {n:,d}, {args_s:s}: {args:s}, {kwargs_s:s}: {kwargs:s}'.format(**{ | |
| 's': s, | |
| 'border': pp.ps('-'*60, 'yellow'), | |
| 'n_s': pp.ps('n', 'bold'), | |
| 'n': test.n, | |
| 'args_s': pp.ps('args', 'bold'), | |
| 'args': str(test.args), | |
| 'kwargs_s': pp.ps('kwargs', 'bold'), | |
| 'kwargs': str(test.kwargs), | |
| })) | |
| for func in funcs: | |
| _test(func, test, width) | |
| s = '\n' |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from itertools import islice, chain | |
| from more_itertools import value_chain | |
| %timeit ''.join(chain(islice(r, 10), ' ', islice(r, 10, len(r)-1))) | |
| # 117 μs ± 2.41 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each) | |
| %timeit ''.join([*list(islice(r, 10)), ' ', *list(islice(r, 10, len(r)-1))]) | |
| # 88.6 μs ± 3.1 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each) | |
| %timeit ''.join(value_chain(islice(r, 10), ' ', islice(r, 10, len(r)-1))) | |
| # 232 μs ± 18.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) | |
| a = ''.join(chain(islice(r, 10), ' ', islice(r, 10, len(r)-1))) | |
| b = ''.join([*list(islice(r, 10)), ' ', *list(islice(r, 10, len(r)-1))]) | |
| c = ''.join(value_chain(islice(r, 10), ' ', islice(r, 10, len(r)-1))) | |
| a == b == c | |
| # True |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| from itertools import zip_longest, islice, chain, tee, count, filterfalse | |
| from functools import wraps, partial | |
| import operator | |
| import sys, os | |
| import bench | |
| def listify(func): | |
| @wraps(func) | |
| def inner(*args, **kwargs): | |
| return [list(group) for group in func(*args, **kwargs)] | |
| return list(chain(map(list, func(*args, **kwargs)))) | |
| return inner | |
| def chunk_zip_longest(items, n): | |
| for group in zip_longest(*[iter(items)] * n): | |
| yield list(filter(None.__ne__, group)) | |
| p = partial(operator.is_, None) | |
| def chunk_zip_longest_p(items, n): | |
| for group in zip_longest(*[iter(items)] * n): | |
| yield list(filterfalse(p, group)) | |
| def chunk_slices(items, n): | |
| for i in range(0, len(items), n): | |
| yield items[i:i+n] | |
| def chunk_slices2(items, n): | |
| for i in range(n+1): | |
| yield items[i*n:(i+1)*n] | |
| def chunk_islice_list(items, n): | |
| for i in range(n+1): | |
| yield list(islice(items, i*n, (i+1)*n)) | |
| def chunk_islice_iter(items, n): | |
| items = iter(items) | |
| for i in range(items.__length_hint__()//n+1): | |
| yield islice(items, n) | |
| funcs = [ | |
| listify(chunk_zip_longest), | |
| listify(chunk_zip_longest_p), | |
| listify(chunk_slices), | |
| listify(chunk_slices2), | |
| listify(chunk_islice_list), | |
| listify(chunk_islice_iter), | |
| ] | |
| tests = [ | |
| ( (list(range(10)), 3), {}, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]], ), | |
| ( (range(10), 3), {}, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]], ), | |
| ( (iter(list(range(10))), 3), {}, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]], ), | |
| ( (iter(range(10)), 3), {}, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]], ), | |
| ] | |
| if __name__ == '__main__': | |
| bench.test(tests, funcs, n=100_000) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ''' | |
| A module to create loggers with custom handlers and a custom formatter. | |
| usage examples to initialise a logger: | |
| ```python | |
| # 1. initialise logger to stderr: | |
| logger = getLogger('my_logger', level=logging.DEBUG, print_stream=sys.stderr) | |
| # 2. initialise logger to file: | |
| logger = getLogger('my_logger', level=logging.DEBUG, filename='my_log.log') | |
| # 3. initialise logger to both stderr and file: | |
| logger = getLogger('my_logger', level=logging.DEBUG, print_stream=sys.stderr, filename='my_log.log') | |
| ``` | |
| usage examples to log messages: | |
| ```python | |
| logger.info('This is a basic info message') | |
| # {"timestamp": "2024-12-09T15:05:43.904417", "msg": "This is a basic info message", "data": {}} | |
| logger.info('This is an info message', {'key': 'value'}) | |
| # {"timestamp": "2024-12-09T15:05:43.904600", "msg": "This is an info message", "data": {"key": "value"}} | |
| logger.debug('This is a debug message', 'arg1', 'arg2', {'key': 'value'}) | |
| # {"timestamp": "2024-12-09T15:05:43.904749", "msg": "This is a debug message", "data": {"args": ["arg1", "arg2"], "key": "value"}} | |
| ``` | |
| ''' | |
| from dataclasses import asdict, dataclass, is_dataclass, field | |
| from datetime import datetime | |
| import io | |
| import json | |
| import logging | |
| import sys | |
| LOG_ROOT_NAME = 'root' | |
| def _json_default(obj: object) -> str: | |
| 'Default JSON serializer, supports most main class types' | |
| if isinstance(obj, str): return obj | |
| if is_dataclass(obj): return asdict(obj) | |
| if isinstance(obj, datetime): return obj.isoformat() | |
| if hasattr(obj, '__dict__'): return obj.__dict__ | |
| if hasattr(obj, '__name__'): return obj.__name__ | |
| if hasattr(obj, '__slots__'): return {k: getattr(obj, k) for k in obj.__slots__} | |
| if hasattr(obj, '_asdict'): return obj._asdict() | |
| return str(obj) | |
| class LogFormatter(logging.Formatter): | |
| 'Custom log formatter that formats log messages as JSON, aka "Structured Logging".' | |
| def format(self, record) -> str: | |
| 'Formats the log message as JSON.' | |
| args, kwargs = None, {} | |
| if isinstance(record.args, tuple): | |
| if len(record.args) == 1: | |
| args = record.args | |
| elif len(record.args) > 1: | |
| *args, kwargs = record.args | |
| elif isinstance(record.args, dict): | |
| kwargs = record.args | |
| record.msg = json.dumps( | |
| { | |
| 'timestamp': datetime.now().isoformat(), | |
| 'msg': record.msg, | |
| 'data': { | |
| **({'args': args} if args else {}), | |
| **(kwargs if kwargs else {}), | |
| } | |
| }, | |
| default=_json_default, | |
| ) | |
| record.args = () | |
| return super().format(record) | |
| def _getLogger(name: str, level: int = logging.CRITICAL, handlers: list[logging.Handler] = []) -> logging.Logger: | |
| ''' | |
| Creates a logger with the given name, level, and handlers. | |
| - If no handlers are provided, the logger will not output any logs. | |
| - This function requires the handlers to be initialized when passed as args. | |
| - the same log level is applied to all handlers. | |
| ''' | |
| # create the logger | |
| logger = logging.getLogger(LOG_ROOT_NAME) | |
| logger.setLevel(level) | |
| # close/remove any existing handlers | |
| while logger.handlers: | |
| for handler in logger.handlers: | |
| handler.close() | |
| logger.removeHandler(handler) | |
| # create the logger | |
| logger = logging.getLogger(f'{LOG_ROOT_NAME}.{name}') | |
| logger.setLevel(level) | |
| # close/remove any existing handlers | |
| while logger.handlers: | |
| for handler in logger.handlers: | |
| handler.close() | |
| logger.removeHandler(handler) | |
| # add the new handlers | |
| for handler in handlers: | |
| handler.setLevel(level) | |
| logger.addHandler(handler) | |
| if logger.handlers: | |
| # only set the first handler to use the custom formatter | |
| logger.handlers[0].setFormatter(LogFormatter()) | |
| return logger | |
| @dataclass | |
| class Logger: | |
| ''' | |
| A class to create a logger with custom handlers and a custom formatter. | |
| Logger(name, level, handlers, print_stream, filename) | |
| ''' | |
| name: str | |
| level: int = logging.CRITICAL | |
| print_stream: io.TextIOBase = field(default=None) | |
| filename: str = field(default=None) | |
| handlers: list[logging.Handler] = field(init=False, default_factory=list) | |
| def __post_init__(self, filename: str = None): | |
| if self.print_stream: | |
| self.handlers.append(logging.StreamHandler(self.print_stream)) | |
| if self.filename: | |
| self.handlers.append(logging.FileHandler(self.filename)) | |
| def getLogger(self) -> logging.Logger: | |
| return _getLogger(self.name, self.level, handlers=self.handlers) | |
| @property | |
| def logger(self) -> logging.Logger: | |
| return self.getLogger() | |
| def getLogger(name: str, level: int = logging.CRITICAL, print_stream: io.TextIOBase = None, filename: str = None) -> logging.Logger: | |
| ''' | |
| Creates a logger with the given name, level, and handlers. | |
| - if `print_stream` is provided, the logger will output logs to it. | |
| - if `filename` is provided, the logger will output logs to it. | |
| - if both are provided, the logger will output logs to both. | |
| - if neither are provided, the logger will not output any logs. | |
| ''' | |
| handlers = [] | |
| if print_stream: | |
| handlers.append(logging.StreamHandler(print_stream)) | |
| if filename: | |
| handlers.append(logging.FileHandler(filename)) | |
| return _getLogger(name, level, handlers) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from dataclasses import asdict, is_dataclass | |
| from datetime import datetime, timedelta | |
| import json | |
| import random | |
| from pygments import highlight, console | |
| from pygments.lexers import JsonLexer, OutputLexer | |
| from pygments.formatters import Terminal256Formatter | |
| from pygments.styles import get_style_by_name, get_all_styles | |
| STYLES = [ | |
| 'dracula', 'fruity', 'gruvbox-dark', 'gruvbox-light', 'lightbulb', 'material', 'one-dark', | |
| 'perldoc', 'native', 'tango', | |
| ] | |
| def _json_default(obj: object): | |
| 'Default JSON serializer, supports most main class types' | |
| if isinstance(obj, str): return obj | |
| if is_dataclass(obj): return asdict(obj) | |
| if isinstance(obj, datetime): return obj.isoformat() | |
| if hasattr(obj, '__dict__'): return obj.__dict__ | |
| if hasattr(obj, '__name__'): return obj.__name__ | |
| if hasattr(obj, '__slots__'): return {k: getattr(obj, k) for k in obj.__slots__} | |
| if hasattr(obj, '_asdict'): return obj._asdict() | |
| return str(obj) | |
| def ppd(d, indent=2, style='dracula', random_style=False): | |
| 'pretty-prints a dict' | |
| if random_style: | |
| style = random.choice(STYLES) | |
| print(highlight( | |
| code = json.dumps(d, indent=indent, default=_json_default), | |
| lexer = JsonLexer(), | |
| formatter = Terminal256Formatter(style=get_style_by_name(style)) | |
| ).strip()) | |
| def ppj(j, indent=2, style='dracula', random_style=False): | |
| 'pretty-prints a JSON string' | |
| ppd(json.loads(j), indent=indent, style=style, random_style=random_style) | |
| def ps(s, style='yellow', random_style=False): | |
| 'adds color to a string' | |
| if random_style: | |
| style = random.choice(console.dark_colors + console.light_colors) | |
| return console.colorize(style, s) | |
| def pps(s, style='yellow', random_style=False): | |
| 'pretty-prints a string' | |
| print(ps(s, style=style, random_style=random_style)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import json | |
| from pygments import highlight | |
| from pygments.lexers import JsonLexer | |
| from pygments.formatters import TerminalTrueColorFormatter as Formatter | |
| from pygments.styles import get_style_by_name | |
| def ppd(d, indent=2, style='material'): print(highlight(json.dumps(d, indent=indent, default=str), JsonLexer(), Formatter(style=get_style_by_name(style))).strip()) | |
| def ppj(j, indent=2, style='material'): ppd(json.loads(j), indent=indent, style=style) | |
| def pps(s, style='yellow'): print(console.colorize(style, s)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment