Created
July 12, 2025 20:34
-
-
Save IntendedConsequence/3624d147875d8eb4fa5f404df0e86f36 to your computer and use it in GitHub Desktop.
Shows weights in .safetensors in a simple table (for windows)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| # /// script | |
| # requires-python = ">=3.10" | |
| # dependencies = [ | |
| # "imgui[sdl2]", | |
| # "pysdl2-dll", | |
| # "opencv-python-headless", | |
| # ] | |
| # /// | |
| from dataclasses import dataclass, field | |
| import json | |
| from pathlib import Path | |
| import struct | |
| from imgui.integrations.sdl2 import SDL2Renderer | |
| #from testwindow import show_test_window | |
| from sdl2 import ( | |
| SDL_Event, | |
| SDL_GetTicks, | |
| SDL_GetTicks64, | |
| SDL_PollEvent, | |
| SDL_QUIT, | |
| SDL_DROPFILE, | |
| SDL_Delay, | |
| SDL_GL_SwapWindow, | |
| SDL_GL_DeleteContext, | |
| SDL_DestroyWindow, | |
| SDL_Quit, | |
| SDL_Init, | |
| SDL_INIT_EVERYTHING, | |
| SDL_GetError, | |
| SDL_GL_SetAttribute, | |
| SDL_GL_DOUBLEBUFFER, | |
| SDL_GL_DEPTH_SIZE, | |
| SDL_GL_STENCIL_SIZE, | |
| SDL_GL_ACCELERATED_VISUAL, | |
| SDL_GL_MULTISAMPLEBUFFERS, | |
| SDL_GL_MULTISAMPLESAMPLES, | |
| SDL_GL_CONTEXT_FLAGS, | |
| SDL_GL_CONTEXT_FORWARD_COMPATIBLE_FLAG, | |
| SDL_GL_CONTEXT_MAJOR_VERSION, | |
| SDL_GL_CONTEXT_MINOR_VERSION, | |
| SDL_GL_CONTEXT_PROFILE_MASK, | |
| SDL_GL_CONTEXT_PROFILE_CORE, | |
| SDL_SetHint, | |
| SDL_HINT_MAC_CTRL_CLICK_EMULATE_RIGHT_CLICK, | |
| SDL_HINT_VIDEO_HIGHDPI_DISABLED, | |
| SDL_CreateWindow, | |
| SDL_WINDOWPOS_CENTERED, | |
| SDL_WINDOW_OPENGL, | |
| SDL_WINDOW_RESIZABLE, | |
| SDL_GL_CreateContext, | |
| SDL_GL_MakeCurrent, | |
| SDL_GL_SetSwapInterval, | |
| ) | |
| # from sdl2 import * | |
| import OpenGL.GL as gl | |
| import ctypes | |
| import imgui | |
| import sys | |
| import os | |
| import cv2 | |
| import numpy as np | |
| import socket | |
| from multiprocessing.connection import Client, Connection | |
| CreateWaitableTimerExW = ctypes.WinDLL("kernel32").CreateWaitableTimerExW | |
| SetWaitableTimer = ctypes.WinDLL("kernel32").SetWaitableTimer | |
| WaitForSingleObject = ctypes.WinDLL("kernel32").WaitForSingleObject | |
| CloseHandle = ctypes.WinDLL("kernel32").CloseHandle | |
| CREATE_WAITABLE_TIMER_HIGH_RESOLUTION = 0x00000002 | |
| TIMER_ALL_ACCESS = 0x1F0003 | |
| import _winapi | |
| from _winapi import INFINITE | |
| class LARGE_INTEGER(ctypes.Structure): | |
| _fields_ = [("QuadPart", ctypes.c_longlong)] | |
| def wait_hr(duration_ms): | |
| handle: ctypes.wintypes.HANDLE = CreateWaitableTimerExW( | |
| 0, 0, CREATE_WAITABLE_TIMER_HIGH_RESOLUTION, TIMER_ALL_ACCESS | |
| ) | |
| duration_in_100s_of_nanoseconds = duration_ms * 10000.0 | |
| # NOTE: duration is an absolute timestamp in FILETIME UTC format. Negative idicates relative duration! | |
| dur = LARGE_INTEGER(int(-duration_in_100s_of_nanoseconds)) | |
| SetWaitableTimer(handle, ctypes.byref(dur), 0, 0, 0, 0) | |
| WaitForSingleObject(handle, INFINITE) | |
| CloseHandle(handle) | |
| import logging | |
| # Configure the basic logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| filename='safetensors_ui.log' | |
| ) | |
| # Create a logger | |
| logger = logging.getLogger(__name__) | |
| # Use the logger | |
| # logger.info('This is an info message') | |
| # logger.warning('This is a warning message') | |
| # logger.error('This is an error message') | |
| def show_help_marker(desc): | |
| imgui.text_disabled("(?)") | |
| if imgui.is_item_hovered(): | |
| imgui.begin_tooltip() | |
| imgui.push_text_wrap_pos(imgui.get_font_size() * 35.0) | |
| imgui.text_unformatted(desc) | |
| imgui.pop_text_wrap_pos() | |
| imgui.end_tooltip() | |
| extra_keys = [ | |
| 'ss_dataset_dirs', | |
| 'ss_bucket_info', | |
| 'ss_tag_frequency', | |
| 'ot_config' | |
| ] | |
| def loadsftmeta(path): | |
| with open(path, 'rb') as f: | |
| meta_length = struct.unpack('<Q', f.read(8))[0] | |
| json_utf8 = f.read(meta_length) | |
| json_unicode = json_utf8.decode('utf8') | |
| meta = json.loads(json_unicode) | |
| return meta | |
| def process_meta(meta_dict): | |
| meta = meta_dict.get('__metadata__', {}) | |
| for key in extra_keys: | |
| if key in meta: | |
| meta[key] = json.loads(meta[key]) | |
| return meta | |
| def print_onetrainer_stats(meta_dict, meta): | |
| if 'ot_revision' in meta: | |
| print(f"OneTrainer:\n revision {meta['ot_revision']}") | |
| print(f" branch '{meta['ot_branch']}'") | |
| TARGET_FPS = 180 | |
| def main(): | |
| extra_fields_with_json = {} | |
| extra_fields_with_json_str = "" | |
| weight_keys = "" | |
| meta_dict = {} | |
| def load_meta(path): | |
| nonlocal extra_fields_with_json | |
| nonlocal extra_fields_with_json_str | |
| nonlocal weight_keys | |
| nonlocal meta_dict | |
| meta_dict = loadsftmeta(path) | |
| extra_fields_with_json = process_meta(meta_dict) | |
| extra_fields_with_json_str = json.dumps(extra_fields_with_json, indent=2, sort_keys=True) | |
| weight_keys = '\n'.join(k for k in sorted(meta_dict.keys()) if k != '__metadata__') | |
| if len(sys.argv) > 1: | |
| path = Path(sys.argv[1]) | |
| load_meta(path) | |
| window, gl_context = impl_pysdl2_init() | |
| imgui.create_context() | |
| impl = SDL2Renderer(window) | |
| show_custom_window = True | |
| running = True | |
| event = SDL_Event() | |
| current_time = SDL_GetTicks() / 1000.0 | |
| last_frame_timestamp = 0 | |
| while running: | |
| current_time = SDL_GetTicks() / 1000.0 | |
| while SDL_PollEvent(ctypes.byref(event)) != 0: | |
| if event.type == SDL_QUIT: | |
| running = False | |
| break | |
| elif event.type == SDL_DROPFILE: | |
| charstar = event.drop.file | |
| path = charstar.decode("utf-8") | |
| load_meta(path) | |
| impl.process_event(event) | |
| impl.process_inputs() | |
| imgui.new_frame() | |
| if imgui.begin_main_menu_bar(): | |
| if imgui.begin_menu("File", True): | |
| clicked_quit, selected_quit = imgui.menu_item( | |
| "Quit", "Cmd+Q", False, True | |
| ) | |
| if clicked_quit: | |
| sys.exit(0) | |
| imgui.end_menu() | |
| imgui.end_main_menu_bar() | |
| # show_test_window() | |
| if show_custom_window: | |
| is_expand, show_custom_window = imgui.begin("Custom window", True) | |
| if is_expand: | |
| wk = [k for k in meta_dict if k != "__metadata__"] | |
| if wk: | |
| flags = imgui.TABLE_SIZING_STRETCH_PROP | imgui.TABLE_SCROLL_Y | imgui.TABLE_SCROLL_X | |
| flags |= imgui.TABLE_ROW_BACKGROUND | |
| flags |= imgui.TABLE_BORDERS | |
| # When using ScrollX or ScrollY we need to specify a size for our table container! | |
| # Otherwise by default the table will fit all available space, like a BeginChild() call. | |
| # ImVec2 outer_size = ImVec2(0.0f, 0.0f); | |
| if imgui.begin_table("weights", 3, flags, 0, 0): | |
| imgui.table_setup_scroll_freeze(0, 1) | |
| imgui.table_setup_column("name") | |
| imgui.table_setup_column("dtype") | |
| imgui.table_setup_column("shape") | |
| imgui.table_headers_row() | |
| for k in wk: | |
| imgui.table_next_row() | |
| imgui.table_set_column_index(0) | |
| imgui.text(k) | |
| imgui.table_set_column_index(1) | |
| imgui.text(meta_dict[k]["dtype"]) | |
| imgui.table_set_column_index(2) | |
| imgui.text(str(meta_dict[k]["shape"])) | |
| imgui.end_table() | |
| imgui.end() | |
| is_expand, _ = imgui.begin("__metadata__", True) | |
| if is_expand: | |
| for k,v in extra_fields_with_json.items(): | |
| imgui.text(k) | |
| imgui.indent() | |
| imgui.text_unformatted(str(v)) | |
| imgui.unindent() | |
| imgui.end() | |
| gl.glClearColor(0.5, 0.5, 0.7, 1) | |
| gl.glClear(gl.GL_COLOR_BUFFER_BIT) | |
| imgui.render() | |
| impl.render(imgui.get_draw_data()) | |
| if True: | |
| SDL_Delay(10) | |
| else: | |
| time_since_last_frame_ms = SDL_GetTicks64() - last_frame_timestamp | |
| time_to_wait_ms = (1000.0 / TARGET_FPS) - (time_since_last_frame_ms) - 1 | |
| if time_to_wait_ms > 0: | |
| wait_hr(time_to_wait_ms) | |
| SDL_GL_SwapWindow(window) | |
| last_frame_timestamp = SDL_GetTicks64() | |
| impl.shutdown() | |
| SDL_GL_DeleteContext(gl_context) | |
| SDL_DestroyWindow(window) | |
| SDL_Quit() | |
| def impl_pysdl2_init(): | |
| width, height = 1280, 960 | |
| window_name = "safetensors inspector" | |
| if SDL_Init(SDL_INIT_EVERYTHING) < 0: | |
| print( | |
| "Error: SDL could not initialize! SDL Error: " | |
| + SDL_GetError().decode("utf-8") | |
| ) | |
| sys.exit(1) | |
| SDL_GL_SetAttribute(SDL_GL_DOUBLEBUFFER, 1) | |
| SDL_GL_SetAttribute(SDL_GL_DEPTH_SIZE, 24) | |
| SDL_GL_SetAttribute(SDL_GL_STENCIL_SIZE, 8) | |
| SDL_GL_SetAttribute(SDL_GL_ACCELERATED_VISUAL, 1) | |
| SDL_GL_SetAttribute(SDL_GL_MULTISAMPLEBUFFERS, 1) | |
| SDL_GL_SetAttribute(SDL_GL_MULTISAMPLESAMPLES, 8) | |
| SDL_GL_SetAttribute(SDL_GL_CONTEXT_FLAGS, SDL_GL_CONTEXT_FORWARD_COMPATIBLE_FLAG) | |
| SDL_GL_SetAttribute(SDL_GL_CONTEXT_MAJOR_VERSION, 4) | |
| SDL_GL_SetAttribute(SDL_GL_CONTEXT_MINOR_VERSION, 1) | |
| SDL_GL_SetAttribute(SDL_GL_CONTEXT_PROFILE_MASK, SDL_GL_CONTEXT_PROFILE_CORE) | |
| SDL_SetHint(SDL_HINT_MAC_CTRL_CLICK_EMULATE_RIGHT_CLICK, b"1") | |
| SDL_SetHint(SDL_HINT_VIDEO_HIGHDPI_DISABLED, b"1") | |
| window = SDL_CreateWindow( | |
| window_name.encode("utf-8"), | |
| SDL_WINDOWPOS_CENTERED, | |
| SDL_WINDOWPOS_CENTERED, | |
| width, | |
| height, | |
| SDL_WINDOW_OPENGL | SDL_WINDOW_RESIZABLE, | |
| ) | |
| if window is None: | |
| print( | |
| "Error: Window could not be created! SDL Error: " | |
| + SDL_GetError().decode("utf-8") | |
| ) | |
| sys.exit(1) | |
| gl_context = SDL_GL_CreateContext(window) | |
| if gl_context is None: | |
| print( | |
| "Error: Cannot create OpenGL Context! SDL Error: " | |
| + SDL_GetError().decode("utf-8") | |
| ) | |
| sys.exit(1) | |
| SDL_GL_MakeCurrent(window, gl_context) | |
| if SDL_GL_SetSwapInterval(-1) < 0: | |
| print( | |
| "Warning: Unable to set VSync! SDL Error: " + SDL_GetError().decode("utf-8") | |
| ) | |
| sys.exit(1) | |
| return window, gl_context | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment