Skip to content

Instantly share code, notes, and snippets.

@Samu31Nd
Last active May 7, 2023 21:49
Show Gist options
  • Select an option

  • Save Samu31Nd/dd7bdbb11a38c2d84d6b3368e479df0f to your computer and use it in GitHub Desktop.

Select an option

Save Samu31Nd/dd7bdbb11a38c2d84d6b3368e479df0f to your computer and use it in GitHub Desktop.
Square Matrix Multiply w/Strassen Method
import java.util.Random;
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
System.out.print("Insert the length of the matrix: ");
int n = scanner.nextInt();
// int [][] m1 = StrassenUtils.randomMatrix(n);
// int [][] m2 = StrassenUtils.randomMatrix(n);
int [][] m1 = {{2,0},{1,3}};
int [][] m2 = {{-1,-1},{5,6}};
System.out.println("\nMatrix 1 and 2:");
StrassenUtils.showMatrix(m1,m2);
int [][] C = StrassenUtils.squareMatrixMultiplyRecursive(m1,m2);
System.out.println("\nMatrix C: ");
StrassenUtils.showMatrix(C);
}
}
class StrassenUtils {
static public int[][] randomMatrix(int n){
int [][] x = new int[n][n];
for (int p = 0; p < n; p++)
for (int j = 0; j < n; j++)
x[p][j] = (int) (new Random().nextFloat(51)); //from 0 to (51 - 1)
return x;
}
static public void showMatrix(int[][]m){
for (int []a : m){
for (int b : a) System.out.print(b + "\t");
System.out.println(" ");
}
System.out.println(" ");
}
static public void showMatrix(int[][]m, int[][]n){
int length = m.length;
for (int i = 0; i < length; i++){
for (int a : m[i]) System.out.print(a + "\t");
System.out.print("\t\t");
for (int a : n[i]) System.out.print(a + "\t");
System.out.println(" ");
}
}
static public int[][] squareMatrixMultiply(int[][]A, int[][]B){
int n = A.length;
int[][] C = new int[n][n];
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++){
C[i][j] = 0;
for (int k = 0; k < n; k++) C[i][j] = C[i][j] + A[i][k] * B[k][j];
}
return C;
}
static public int[][] squareMatrixMultiplyRecursive(int[][]A, int[][]B){
int n = A.length;
int[][] C = new int[n][n];
if(n==1) C[0][0] = A[0][0] * B[0][0];
else{
int [][] A11 = partition(A,1);
int [][] A12 = partition(A,2);
int [][] A21 = partition(A,3);
int [][] A22 = partition(A,4);
int [][] B11 = partition(B,1);
int [][] B12 = partition(B,2);
int [][] B21 = partition(B,3);
int [][] B22 = partition(B,4);
int [][] C11 = addition(squareMatrixMultiplyRecursive(A11,B11), squareMatrixMultiplyRecursive(A12,B21));
int [][] C12 = addition(squareMatrixMultiplyRecursive(A11,B12), squareMatrixMultiplyRecursive(A12,B22));
int [][] C21 = addition(squareMatrixMultiplyRecursive(A21,B11), squareMatrixMultiplyRecursive(A22,B21));
int [][] C22 = addition(squareMatrixMultiplyRecursive(A21,B12), squareMatrixMultiplyRecursive(A22,B22));
C = merge(C11,C12,C21,C22);
}
return C;
}
static private int[][] merge(int[][]C1, int[][]C2, int[][]C3, int[][]C4){
int n = C1.length*2;
int [][] C = new int[n][n];
int n2 = n/2;
for(int i = 0; i < n2; i++)
for(int j = 0; j < n2; j++) C[i][j] = C1[i][j];
for(int i = 0; i < n2; i++)
for(int j = 0; j < n2; j++) C[i+n2][j] = C2[i][j];
for(int i = 0; i < n2; i++)
for(int j = 0; j < n2; j++) C[i][j+n2] = C3[i][j];
for(int i = 0; i < n2; i++)
for(int j = 0; j < n2; j++) C[i+n2][j+n2] = C4[i][j];
return C;
}
static private int[][] partition(int[][]M, int mode){
int n = M.length/2;
int[][] newM = new int[n][n];
switch (mode) {
case 1 -> {
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
newM[i][j] = M[i][j];
}
case 2 -> {
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
newM[i][j] = M[i + n][j];
}
case 3 -> {
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
newM[i][j] = M[i][j + n];
}
case 4 -> {
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
newM[i][j] = M[i + n][j + n];
}
}
return newM;
}
static private int[][] addition(int[][]A, int[][]B){
int n = A.length;
int[][] C = new int[n][n];
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
C[i][j] = A[i][j] + B [i][j];
return C;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment