-
-
Save Magnus167/8dec0cef789481b8e53d934850b4d609 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def pipeline_sankey_fig_same_rank_terminals(df, show_reached=True): | |
| """ | |
| Plotly Sankey for a monotonic pipeline where multiple stages at the same rank | |
| "end together" as parallel terminal outcomes. | |
| Input df columns: | |
| - stage_rank (int): 0..N in order (can repeat) | |
| - stage (str): stage name | |
| - count (int): users CURRENTLY at that stage | |
| Assumption: | |
| Users at rank N have passed all ranks < N. | |
| Params: | |
| show_reached (bool): | |
| If True, spine labels show: | |
| "RankLabel (current_total=..., reached=...)" | |
| else: | |
| "RankLabel (current_total=...)" | |
| Returns: | |
| plotly.graph_objects.Figure | |
| """ | |
| import plotly.graph_objects as go | |
| import pandas as pd | |
| d = df.sort_values(["stage_rank", "stage"]).reset_index(drop=True) | |
| # group stages by rank | |
| grouped = d.groupby("stage_rank", sort=True) | |
| ranks = list(grouped.groups.keys()) | |
| n_ranks = len(ranks) | |
| last_rank = ranks[-1] | |
| # total current users per rank (sum of variants) | |
| current_total = grouped["count"].sum().reindex(ranks).tolist() | |
| # cumulative "reached at least this rank" by rank | |
| reached = [] | |
| running = 0 | |
| for c in reversed(current_total): | |
| running += c | |
| reached.append(running) | |
| reached = list(reversed(reached)) | |
| # build spine node labels (one per rank) | |
| spine_nodes = [] | |
| for r, cur, rec in zip(ranks, current_total, reached): | |
| stages_at_r = grouped.get_group(r)["stage"].tolist() | |
| rank_label = stages_at_r[0] if len(stages_at_r) == 1 else f"Rank {r}" | |
| if show_reached: | |
| spine_nodes.append(f"{rank_label} (current_total={cur}, reached={rec})") | |
| else: | |
| spine_nodes.append(f"{rank_label} (current_total={cur})") | |
| nodes = spine_nodes[:] | |
| sources, targets, values = [], [], [] | |
| # helper: add node if not exists | |
| def add_node(label): | |
| if label not in nodes: | |
| nodes.append(label) | |
| return nodes.index(label) | |
| # flows between ranks + terminal branches | |
| for i, r in enumerate(ranks): | |
| spine_i = i # index in spine_nodes == same as first n_ranks entries in nodes | |
| stages_at_r = grouped.get_group(r)[["stage", "count"]].values.tolist() | |
| multi = len(stages_at_r) > 1 | |
| is_last = (r == last_rank) | |
| if not is_last: | |
| # forward flow to next rank (everyone who reached next) | |
| forward = reached[i + 1] | |
| sources.append(spine_i) | |
| targets.append(i + 1) | |
| values.append(forward) | |
| if multi or is_last: | |
| # branch to specific terminal outcomes at this rank | |
| for stage_name, c in stages_at_r: | |
| term_label = f"{stage_name} ({int(c)})" | |
| term_idx = add_node(term_label) | |
| sources.append(spine_i) | |
| targets.append(term_idx) | |
| values.append(int(c)) | |
| else: | |
| # single-stage, non-last rank: generic drop | |
| stage_name, c = stages_at_r[0] | |
| drop_label = f"Dropped at {stage_name} ({int(c)})" | |
| drop_idx = add_node(drop_label) | |
| sources.append(spine_i) | |
| targets.append(drop_idx) | |
| values.append(int(c)) | |
| # if last rank had a single stage, add a unified end node | |
| last_group = grouped.get_group(last_rank) | |
| if len(last_group) == 1: | |
| last_stage_name = last_group["stage"].iloc[0] | |
| last_c = int(last_group["count"].iloc[0]) | |
| end_label = f"Ended at {last_stage_name} ({last_c})" | |
| end_idx = add_node(end_label) | |
| sources.append(n_ranks - 1) | |
| targets.append(end_idx) | |
| values.append(last_c) | |
| # ---- layout: spine left->right, all terminals aligned right ---- | |
| spine_x = [0.7 * i / (n_ranks - 1) if n_ranks > 1 else 0.0 for i in range(n_ranks)] | |
| spine_y = [i / (n_ranks - 1) if n_ranks > 1 else 0.5 for i in range(n_ranks)] | |
| term_count = len(nodes) - n_ranks | |
| term_x = [1.0] * term_count | |
| term_y = [i / (term_count - 1) if term_count > 1 else 0.5 for i in range(term_count)] | |
| x = spine_x + term_x | |
| y = spine_y + term_y | |
| fig = go.Figure(go.Sankey( | |
| arrangement="fixed", | |
| node=dict( | |
| label=nodes, | |
| x=x, y=y, | |
| pad=18, | |
| thickness=18, | |
| line=dict(width=0.5) | |
| ), | |
| link=dict(source=sources, target=targets, value=values) | |
| )) | |
| fig.update_layout( | |
| title="Pipeline Sankey (same-rank terminals)", | |
| height=650, | |
| font_size=12 | |
| ) | |
| return fig |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment