Last active
January 29, 2026 12:53
-
-
Save dutta-alankar/39ec994926f378c8ac4d4fe99feff2b7 to your computer and use it in GitHub Desktop.
Parallel sorting of a distributed array using MPI
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on Sat Jan 29 13:34:48 2026 | |
| @author: alankar. | |
| Usage: time python parallel-sort-arb.py | |
| """ | |
| from mpi4py import MPI | |
| import numpy as np | |
| def sample_sort(local_data, comm): | |
| rank = comm.Get_rank() | |
| size = comm.Get_size() | |
| # 1. Local Sort | |
| local_data.sort() | |
| # 2. Select Samples (Regular sampling) | |
| # Pick 'size' samples from the local data | |
| indices = np.linspace(0, len(local_data) - 1, size, dtype=int) | |
| local_samples = local_data[indices] if len(local_data) > 0 else np.array([]) | |
| # 3. Gather samples at Rank 0 to find global splitters | |
| all_samples = comm.gather(local_samples, root=0) | |
| splitters = None | |
| if rank == 0: | |
| # Flatten samples, sort them, and pick size-1 splitters | |
| flat_samples = np.concatenate(all_samples) | |
| flat_samples.sort() | |
| # Pick splitters that divide the samples into 'size' groups | |
| s_indices = np.linspace(0, len(flat_samples) - 1, size + 1, dtype=int)[1:-1] | |
| splitters = flat_samples[s_indices] | |
| # 4. Broadcast splitters to all processes | |
| splitters = comm.bcast(splitters, root=0) | |
| # 5. Partition local data based on splitters | |
| buckets = [] | |
| # Find where each element belongs using searchsorted | |
| indices = np.searchsorted(splitters, local_data) | |
| for i in range(size): | |
| buckets.append(local_data[indices == i]) | |
| # 6. Redistribute data (All-to-All) | |
| # Every core 'i' sends buckets[j] to core 'j' | |
| received_buckets = comm.alltoall(buckets) | |
| # 7. Final Local Sort of the received data | |
| final_local_data = np.concatenate(received_buckets) | |
| final_local_data.sort() | |
| return final_local_data | |
| # --- Execution --- | |
| comm = MPI.COMM_WORLD | |
| rank = comm.Get_rank() | |
| # Generate different sized random arrays on each core | |
| local_data = np.random.randint(0, 5000, size=np.random.randint(10, 20)) | |
| print(rank, local_data, flush=True) | |
| # Display the global array | |
| all_data = comm.gather(local_data, root=0) | |
| if rank == 0: | |
| final_input_data = np.concatenate(all_data) | |
| print(f"Total elements to be sorted: {len(final_input_data)}", flush=True) | |
| print(f"Global Input Array: {final_input_data}", flush=True) | |
| sorted_part = sample_sort(local_data, comm) | |
| comm.Barrier() | |
| if rank==0: | |
| print("Sort complete!", flush=True) | |
| comm.Barrier() | |
| print(rank, sorted_part, flush=True) | |
| comm.Barrier() | |
| # Final check: Gather all to Rank 0 to verify global order | |
| all_results = comm.gather(sorted_part, root=0) | |
| if rank == 0: | |
| final_array = np.concatenate(all_results) | |
| print(f"Total elements sorted: {len(final_array)}", flush=True) | |
| print(f"Is sorted? {np.all(np.diff(final_array) >= 0)}", flush=True) | |
| print(f"Global Sorted Array: {final_array}", flush=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment