What is Strassen's Multiplication
The traditional matrix multiplication algorithm has a time complexity of O(n^3), where n is the dimension of the matrix. Strassen's algorithm reduces this time complexity to O(n^log2(7)), which is approximately O(n^2.81). This makes Strassen's algorithm faster than the traditional algorithm for sufficiently large matrices.
Strassen's algorithm works by recursively dividing the matrices into submatrices of size n/2, and then computing the product of these submatrices using a set of seven matrix multiplications of size n/2. These matrix multiplications can be computed using only additions and subtractions, by cleverly combining the submatrices in a way that preserves the product.
While Strassen's algorithm has a better theoretical time complexity than the traditional algorithm, it has higher constant factors and requires more memory. In practice, the traditional algorithm is often faster for small matrices, but Strassen's algorithm can be faster for very large matrices.
Who invented it?
Strassen's algorithm for matrix multiplication was invented by German mathematician Volker Strassen in 1969. At the time, Strassen was a PhD student at the University of California, Berkeley. His algorithm was published in a paper titled "Gaussian Elimination is not Optimal," which he presented at the 7th Annual ACM Symposium on Theory of Computing in 1969. The paper is widely regarded as a breakthrough in the field of algorithms and has had a significant impact on many areas of computer science.
Pseudocode
function strassen_multiply(A, B):
n = size(A)
# Base case
if n == 1:
return A[0][0] * B[0][0]
# Divide A and B into submatrices
A11, A12, A21, A22 = split_matrix(A)
B11, B12, B21, B22 = split_matrix(B)
# Compute 7 matrix products recursively
P1 = strassen_multiply(A11 + A22, B11 + B22)
P2 = strassen_multiply(A21 + A22, B11)
P3 = strassen_multiply(A11, B12 - B22)
P4 = strassen_multiply(A22, B21 - B11)
P5 = strassen_multiply(A11 + A12, B22)
P6 = strassen_multiply(A21 - A11, B11 + B12)
P7 = strassen_multiply(A12 - A22, B21 + B22)
# Compute the resulting submatrices
C11 = P1 + P4 - P5 + P7
C12 = P3 + P5
C21 = P2 + P4
C22 = P1 - P2 + P3 + P6
# Combine the submatrices into a single matrix
return combine_matrices(C11, C12, C21, C22)
function split_matrix(A):
# Split matrix A into submatrices
n = size(A)
m = n // 2
A11 = A[:m, :m]
A12 = A[:m, m:]
A21 = A[m:, :m]
A22 = A[m:, m:]
return A11, A12, A21, A22
function combine_matrices(C11, C12, C21, C22):
# Combine the submatrices into a single matrix
n = len(C11)
m = 2 * n
C = zeros((m, m))
C[:n, :n] = C11
C[:n, n:] = C12
C[n:, :n] = C21
C[n:, n:] = C22
return C
- Define a function strassen_multiply that takes two matrices A and B as input and returns their product C.
- Get the size of the input matrices and check if they are 1x1 matrices. If so, compute their product and return it.
- Divide both matrices A and B into four submatrices A11, A12, A21, and A22 and B11, B12, B21, and B22.
- Compute seven matrix products recursively using these submatrices and the formulae specified in the algorithm.
- Combine the seven products to obtain the four quadrants of the result matrix C.
- Return the result matrix C.
Sample Code
// C++ code snippet
typedef vector> matrix;
matrix strassen_multiply(matrix A, matrix B) {
int n = A.size();
if (n == 1) {
matrix C(1, vector(1));
C[0][0] = A[0][0] * B[0][0];
return C;
}
int m = n / 2;
matrix A11(m, vector(m));
matrix A12(m, vector(m));
matrix A21(m, vector(m));
matrix A22(m, vector(m));
matrix B11(m, vector(m));
matrix B12(m, vector(m));
matrix B21(m, vector(m));
matrix B22(m, vector(m));
for (int i = 0; i < m; i++) {
for (int j = 0; j < m; j++) {
A11[i][j] = A[i][j];
A12[i][j] = A[i][m+j];
A21[i][j] = A[m+i][j];
A22[i][j] = A[m+i][m+j];
B11[i][j] = B[i][j];
B12[i][j] = B[i][m+j];
B21[i][j] = B[m+i][j];
B22[i][j] = B[m+i][m+j];
}
}
matrix P1 = strassen_multiply(A11 + A22, B11 + B22);
matrix P2 = strassen_multiply(A21 + A22, B11);
matrix P3 = strassen_multiply(A11, B12 - B22);
matrix P4 = strassen_multiply(A22, B21 - B11);
matrix P5 = strassen_multiply(A11 + A12, B22);
matrix P6 = strassen_multiply(A21 - A11, B11 + B12);
matrix P7 = strassen_multiply(A12 - A22, B21 + B22);
matrix C11 = P1 + P4 - P5 + P7;
matrix C12 = P3 + P5;
matrix C21 = P2 + P4;
matrix C22 = P1 - P2 + P3 + P6;
matrix C(n, vector(n));
for (int i = 0; i < m; i++) {
for (int j = 0; j < m; j++) {
C[i][j] = C11[i][j];
C[i][m+j] = C12[i][j];
C[m+i][j] = C21[i][j];
C[m+i][m+j] = C22[i][j];
}
}
return C;
}
int main() {
matrix A = {{1, 2}, {3, 4}};
matrix B = {{5, 6}, {7, 8}};
matrix C = strassen_multiply(A, B);
for (int i = 0; i < C.size(); i++) {
for (int j = 0; j < C[i].size(); j++) {
cout << C[i][j] << " ";
}
cout << endl;
}
return 0;
}
# Python code snippet
def strassen_multiply(A, B):
n = len(A)
if n == 1:
return [[A[0][0] * B[0][0]]]
m = n // 2
A11 = [row[:m] for row in A[:m]]
A12 = [row[m:] for row in A[:m]]
A21 = [row[:m] for row in A[m:]]
A22 = [row[m:] for row in A[m:]]
B11 = [row[:m] for row in B[:m]]
B12 = [row[m:] for row in B[:m]]
B21 = [row[:m] for row in B[m:]]
B22 = [row[m:] for row in B[m:]]
P1 = strassen_multiply(add(A11, A22), add(B11, B22))
P2 = strassen_multiply(add(A21, A22), B11)
P3 = strassen_multiply(A11, subtract(B12, B22))
P4 = strassen_multiply(A22, subtract(B21, B11))
P5 = strassen_multiply(add(A11, A12), B22)
P6 = strassen_multiply(subtract(A21, A11), add(B11, B12))
P7 = strassen_multiply(subtract(A12, A22), add(B21, B22))
C11 = subtract(add(add(P1, P4), P7), P5)
C12 = add(P3, P5)
C21 = add(P2, P4)
C22 = subtract(add(add(P1, P3), P6), P2)
C = [[0] * n for _ in range(n)]
for i in range(m):
for j in range(m):
C[i][j] = C11[i][j]
C[i][m+j] = C12[i][j]
C[m+i][j] = C21[i][j]
C[m+i][m+j] = C22[i][j]
return C
def add(A, B):
return [[A[i][j] + B[i][j] for j in range(len(A[0]))] for i in range(len(A))]
def subtract(A, B):
return [[A[i][j] - B[i][j] for j in range(len(A[0]))] for i in range(len(A))]
# Example usage:
A = [[1, 2], [3, 4]]
B = [[5, 6], [7, 8]]
C = strassen_multiply(A, B)
print(C) # prints [[19, 22], [43, 50]]
public class StrassenMatrixMultiplication {
public static int[][] strassenMultiply(int[][] A, int[][] B) {
int n = A.length;
if (n == 1) {
return new int[][] {{A[0][0] * B[0][0]}};
}
int m = n / 2;
int[][] A11 = submatrix(A, 0, m, 0, m);
int[][] A12 = submatrix(A, 0, m, m, n);
int[][] A21 = submatrix(A, m, n, 0, m);
int[][] A22 = submatrix(A, m, n, m, n);
int[][] B11 = submatrix(B, 0, m, 0, m);
int[][] B12 = submatrix(B, 0, m, m, n);
int[][] B21 = submatrix(B, m, n, 0, m);
int[][] B22 = submatrix(B, m, n, m, n);
int[][] P1 = strassenMultiply(add(A11, A22), add(B11, B22));
int[][] P2 = strassenMultiply(add(A21, A22), B11);
int[][] P3 = strassenMultiply(A11, subtract(B12, B22));
int[][] P4 = strassenMultiply(A22, subtract(B21, B11));
int[][] P5 = strassenMultiply(add(A11, A12), B22);
int[][] P6 = strassenMultiply(subtract(A21, A11), add(B11, B12));
int[][] P7 = strassenMultiply(subtract(A12, A22), add(B21, B22));
int[][] C11 = subtract(add(add(P1, P4), P7), P5);
int[][] C12 = add(P3, P5);
int[][] C21 = add(P2, P4);
int[][] C22 = subtract(add(add(P1, P3), P6), P2);
int[][] C = new int[n][n];
for (int i = 0; i < m; i++) {
for (int j = 0; j < m; j++) {
C[i][j] = C11[i][j];
C[i][m+j] = C12[i][j];
C[m+i][j] = C21[i][j];
C[m+i][m+j] = C22[i][j];
}
}
return C;
}
private static int[][] add(int[][] A, int[][] B) {
int n = A.length;
int[][] C = new int[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
C[i][j] = A[i][j] + B[i][j];
}
}
return C;
}
private static int[][] subtract(int[][] A, int[][] B) {
int n = A.length;
int[][] C = new int[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
C[i][j] = A[i][j] - B[i][j];
}
}
return C;