Skip to content

Instantly share code, notes, and snippets.

@dutta-alankar
Last active January 29, 2026 12:53
Show Gist options
  • Select an option

  • Save dutta-alankar/39ec994926f378c8ac4d4fe99feff2b7 to your computer and use it in GitHub Desktop.

Select an option

Save dutta-alankar/39ec994926f378c8ac4d4fe99feff2b7 to your computer and use it in GitHub Desktop.
Parallel sorting of a distributed array using MPI
#!/bin/python3
# -*- coding: utf-8 -*-
"""
Created on Sat Jan 29 13:34:48 2026
@author: alankar.
Usage: time python parallel-sort-arb.py
"""
from mpi4py import MPI
import numpy as np
def sample_sort(local_data, comm):
rank = comm.Get_rank()
size = comm.Get_size()
# 1. Local Sort
local_data.sort()
# 2. Select Samples (Regular sampling)
# Pick 'size' samples from the local data
indices = np.linspace(0, len(local_data) - 1, size, dtype=int)
local_samples = local_data[indices] if len(local_data) > 0 else np.array([])
# 3. Gather samples at Rank 0 to find global splitters
all_samples = comm.gather(local_samples, root=0)
splitters = None
if rank == 0:
# Flatten samples, sort them, and pick size-1 splitters
flat_samples = np.concatenate(all_samples)
flat_samples.sort()
# Pick splitters that divide the samples into 'size' groups
s_indices = np.linspace(0, len(flat_samples) - 1, size + 1, dtype=int)[1:-1]
splitters = flat_samples[s_indices]
# 4. Broadcast splitters to all processes
splitters = comm.bcast(splitters, root=0)
# 5. Partition local data based on splitters
buckets = []
# Find where each element belongs using searchsorted
indices = np.searchsorted(splitters, local_data)
for i in range(size):
buckets.append(local_data[indices == i])
# 6. Redistribute data (All-to-All)
# Every core 'i' sends buckets[j] to core 'j'
received_buckets = comm.alltoall(buckets)
# 7. Final Local Sort of the received data
final_local_data = np.concatenate(received_buckets)
final_local_data.sort()
return final_local_data
# --- Execution ---
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
# Generate different sized random arrays on each core
local_data = np.random.randint(0, 5000, size=np.random.randint(10, 20))
print(rank, local_data, flush=True)
# Display the global array
all_data = comm.gather(local_data, root=0)
if rank == 0:
final_input_data = np.concatenate(all_data)
print(f"Total elements to be sorted: {len(final_input_data)}", flush=True)
print(f"Global Input Array: {final_input_data}", flush=True)
sorted_part = sample_sort(local_data, comm)
comm.Barrier()
if rank==0:
print("Sort complete!", flush=True)
comm.Barrier()
print(rank, sorted_part, flush=True)
comm.Barrier()
# Final check: Gather all to Rank 0 to verify global order
all_results = comm.gather(sorted_part, root=0)
if rank == 0:
final_array = np.concatenate(all_results)
print(f"Total elements sorted: {len(final_array)}", flush=True)
print(f"Is sorted? {np.all(np.diff(final_array) >= 0)}", flush=True)
print(f"Global Sorted Array: {final_array}", flush=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment