Skip to content

Instantly share code, notes, and snippets.

@dutta-alankar
Created January 29, 2026 12:55
Show Gist options
  • Select an option

  • Save dutta-alankar/b01a7b06ca0ac6be5c19e763eea3b4c9 to your computer and use it in GitHub Desktop.

Select an option

Save dutta-alankar/b01a7b06ca0ac6be5c19e763eea3b4c9 to your computer and use it in GitHub Desktop.
Parallel sorts a distributed array using MPI (works only when total processor count is 2^N)
#!/bin/python3
# -*- coding: utf-8 -*-
"""
Created on Sat Jan 29 13:15:48 2026
@author: alankar.
Usage: time python parallel-sort.py
"""
from mpi4py import MPI
import numpy as np
def parallel_quicksort(data, comm):
rank = comm.Get_rank()
size = comm.Get_size()
# Base case: if the communicator has only one process, sort locally
if size <= 1:
return np.sort(data)
# 1. Pivot selection (using the median of the first process)
pivot = None
if rank == 0:
pivot = np.median(data) if len(data) > 0 else 0
pivot = comm.bcast(pivot, root=0)
# 2. Local partitioning
less_than_pivot = data[data <= pivot]
greater_than_pivot = data[data > pivot]
# 3. Split the communicator in half
mid = size // 2
color = 0 if rank < mid else 1
new_comm = comm.Split(color, rank)
# 4. Exchange data between paired processes
if color == 0:
# Lower half sends "greater" data to upper half, receives "lesser"
# partner is rank + mid
send_data = greater_than_pivot
recv_data = comm.sendrecv(send_data, dest=rank + mid)
new_data = np.concatenate((less_than_pivot, recv_data))
else:
# Upper half sends "lesser" data to lower half, receives "greater"
# partner is rank - mid
send_data = less_than_pivot
recv_data = comm.sendrecv(send_data, dest=rank - mid)
new_data = np.concatenate((greater_than_pivot, recv_data))
# 5. Recursive call on the new smaller communicator
return parallel_quicksort(new_data, new_comm)
# --- Execution ---
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
# Initial data distribution (random data on each core)
local_data = np.random.randint(0, 1000, size=10)
print(rank, local_data, flush=True)
sorted_local_data = parallel_quicksort(local_data, comm)
# Gather results to Rank 0 to see the final sorted array
all_sorted_parts = comm.gather(sorted_local_data, root=0)
if rank == 0:
final_array = np.concatenate(all_sorted_parts)
print(f"Global Sorted Array: {final_array}", flush=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment