# What is Strassen's Multiplication

The traditional matrix multiplication algorithm has a time complexity of O(n^3), where n is the dimension of the matrix. Strassen's algorithm reduces this time complexity to O(n^log2(7)), which is approximately O(n^2.81). This makes Strassen's algorithm faster than the traditional algorithm for sufficiently large matrices.

**Strassen's algorithm**works by recursively dividing the matrices into submatrices of size n/2, and then computing the product of these submatrices using a set of seven matrix multiplications of size n/2. These matrix multiplications can be computed using only additions and subtractions, by cleverly combining the submatrices in a way that preserves the product.

While Strassen's algorithm has a better theoretical time complexity than the traditional algorithm, it has higher constant factors and requires more memory. In practice, the traditional algorithm is often faster for small matrices, but Strassen's algorithm can be faster for very large matrices.

## Who invented it?

**Strassen's algorithm**for matrix multiplication was invented by German mathematician Volker Strassen in 1969. At the time, Strassen was a PhD student at the University of California, Berkeley. His algorithm was published in a paper titled "Gaussian Elimination is not Optimal," which he presented at the 7th Annual ACM Symposium on Theory of Computing in 1969. The paper is widely regarded as a breakthrough in the field of algorithms and has had a significant impact on many areas of computer science.

## Pseudocode

```
function strassen_multiply(A, B):
n = size(A)
# Base case
if n == 1:
return A[0][0] * B[0][0]
# Divide A and B into submatrices
A11, A12, A21, A22 = split_matrix(A)
B11, B12, B21, B22 = split_matrix(B)
# Compute 7 matrix products recursively
P1 = strassen_multiply(A11 + A22, B11 + B22)
P2 = strassen_multiply(A21 + A22, B11)
P3 = strassen_multiply(A11, B12 - B22)
P4 = strassen_multiply(A22, B21 - B11)
P5 = strassen_multiply(A11 + A12, B22)
P6 = strassen_multiply(A21 - A11, B11 + B12)
P7 = strassen_multiply(A12 - A22, B21 + B22)
# Compute the resulting submatrices
C11 = P1 + P4 - P5 + P7
C12 = P3 + P5
C21 = P2 + P4
C22 = P1 - P2 + P3 + P6
# Combine the submatrices into a single matrix
return combine_matrices(C11, C12, C21, C22)
function split_matrix(A):
# Split matrix A into submatrices
n = size(A)
m = n // 2
A11 = A[:m, :m]
A12 = A[:m, m:]
A21 = A[m:, :m]
A22 = A[m:, m:]
return A11, A12, A21, A22
function combine_matrices(C11, C12, C21, C22):
# Combine the submatrices into a single matrix
n = len(C11)
m = 2 * n
C = zeros((m, m))
C[:n, :n] = C11
C[:n, n:] = C12
C[n:, :n] = C21
C[n:, n:] = C22
return C
```

- Define a function strassen_multiply that takes two matrices A and B as input and returns their product C.
- Get the size of the input matrices and check if they are 1x1 matrices. If so, compute their product and return it.
- Divide both matrices A and B into four submatrices A11, A12, A21, and A22 and B11, B12, B21, and B22.
- Compute seven matrix products recursively using these submatrices and the formulae specified in the algorithm.
- Combine the seven products to obtain the four quadrants of the result matrix C.
- Return the result matrix C.

## Sample Code

```
// C++ code snippet
typedef vector
```> matrix;
matrix strassen_multiply(matrix A, matrix B) {
int n = A.size();
if (n == 1) {
matrix C(1, vector(1));
C[0][0] = A[0][0] * B[0][0];
return C;
}
int m = n / 2;
matrix A11(m, vector(m));
matrix A12(m, vector(m));
matrix A21(m, vector(m));
matrix A22(m, vector(m));
matrix B11(m, vector(m));
matrix B12(m, vector(m));
matrix B21(m, vector(m));
matrix B22(m, vector(m));
for (int i = 0; i < m; i++) {
for (int j = 0; j < m; j++) {
A11[i][j] = A[i][j];
A12[i][j] = A[i][m+j];
A21[i][j] = A[m+i][j];
A22[i][j] = A[m+i][m+j];
B11[i][j] = B[i][j];
B12[i][j] = B[i][m+j];
B21[i][j] = B[m+i][j];
B22[i][j] = B[m+i][m+j];
}
}
matrix P1 = strassen_multiply(A11 + A22, B11 + B22);
matrix P2 = strassen_multiply(A21 + A22, B11);
matrix P3 = strassen_multiply(A11, B12 - B22);
matrix P4 = strassen_multiply(A22, B21 - B11);
matrix P5 = strassen_multiply(A11 + A12, B22);
matrix P6 = strassen_multiply(A21 - A11, B11 + B12);
matrix P7 = strassen_multiply(A12 - A22, B21 + B22);
matrix C11 = P1 + P4 - P5 + P7;
matrix C12 = P3 + P5;
matrix C21 = P2 + P4;
matrix C22 = P1 - P2 + P3 + P6;
matrix C(n, vector(n));
for (int i = 0; i < m; i++) {
for (int j = 0; j < m; j++) {
C[i][j] = C11[i][j];
C[i][m+j] = C12[i][j];
C[m+i][j] = C21[i][j];
C[m+i][m+j] = C22[i][j];
}
}
return C;
}
int main() {
matrix A = {{1, 2}, {3, 4}};
matrix B = {{5, 6}, {7, 8}};
matrix C = strassen_multiply(A, B);
for (int i = 0; i < C.size(); i++) {
for (int j = 0; j < C[i].size(); j++) {
cout << C[i][j] << " ";
}
cout << endl;
}
return 0;
}

```
# Python code snippet
def strassen_multiply(A, B):
n = len(A)
if n == 1:
return [[A[0][0] * B[0][0]]]
m = n // 2
A11 = [row[:m] for row in A[:m]]
A12 = [row[m:] for row in A[:m]]
A21 = [row[:m] for row in A[m:]]
A22 = [row[m:] for row in A[m:]]
B11 = [row[:m] for row in B[:m]]
B12 = [row[m:] for row in B[:m]]
B21 = [row[:m] for row in B[m:]]
B22 = [row[m:] for row in B[m:]]
P1 = strassen_multiply(add(A11, A22), add(B11, B22))
P2 = strassen_multiply(add(A21, A22), B11)
P3 = strassen_multiply(A11, subtract(B12, B22))
P4 = strassen_multiply(A22, subtract(B21, B11))
P5 = strassen_multiply(add(A11, A12), B22)
P6 = strassen_multiply(subtract(A21, A11), add(B11, B12))
P7 = strassen_multiply(subtract(A12, A22), add(B21, B22))
C11 = subtract(add(add(P1, P4), P7), P5)
C12 = add(P3, P5)
C21 = add(P2, P4)
C22 = subtract(add(add(P1, P3), P6), P2)
C = [[0] * n for _ in range(n)]
for i in range(m):
for j in range(m):
C[i][j] = C11[i][j]
C[i][m+j] = C12[i][j]
C[m+i][j] = C21[i][j]
C[m+i][m+j] = C22[i][j]
return C
def add(A, B):
return [[A[i][j] + B[i][j] for j in range(len(A[0]))] for i in range(len(A))]
def subtract(A, B):
return [[A[i][j] - B[i][j] for j in range(len(A[0]))] for i in range(len(A))]
# Example usage:
A = [[1, 2], [3, 4]]
B = [[5, 6], [7, 8]]
C = strassen_multiply(A, B)
print(C) # prints [[19, 22], [43, 50]]
```

```
public class StrassenMatrixMultiplication {
public static int[][] strassenMultiply(int[][] A, int[][] B) {
int n = A.length;
if (n == 1) {
return new int[][] {{A[0][0] * B[0][0]}};
}
int m = n / 2;
int[][] A11 = submatrix(A, 0, m, 0, m);
int[][] A12 = submatrix(A, 0, m, m, n);
int[][] A21 = submatrix(A, m, n, 0, m);
int[][] A22 = submatrix(A, m, n, m, n);
int[][] B11 = submatrix(B, 0, m, 0, m);
int[][] B12 = submatrix(B, 0, m, m, n);
int[][] B21 = submatrix(B, m, n, 0, m);
int[][] B22 = submatrix(B, m, n, m, n);
int[][] P1 = strassenMultiply(add(A11, A22), add(B11, B22));
int[][] P2 = strassenMultiply(add(A21, A22), B11);
int[][] P3 = strassenMultiply(A11, subtract(B12, B22));
int[][] P4 = strassenMultiply(A22, subtract(B21, B11));
int[][] P5 = strassenMultiply(add(A11, A12), B22);
int[][] P6 = strassenMultiply(subtract(A21, A11), add(B11, B12));
int[][] P7 = strassenMultiply(subtract(A12, A22), add(B21, B22));
int[][] C11 = subtract(add(add(P1, P4), P7), P5);
int[][] C12 = add(P3, P5);
int[][] C21 = add(P2, P4);
int[][] C22 = subtract(add(add(P1, P3), P6), P2);
int[][] C = new int[n][n];
for (int i = 0; i < m; i++) {
for (int j = 0; j < m; j++) {
C[i][j] = C11[i][j];
C[i][m+j] = C12[i][j];
C[m+i][j] = C21[i][j];
C[m+i][m+j] = C22[i][j];
}
}
return C;
}
private static int[][] add(int[][] A, int[][] B) {
int n = A.length;
int[][] C = new int[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
C[i][j] = A[i][j] + B[i][j];
}
}
return C;
}
private static int[][] subtract(int[][] A, int[][] B) {
int n = A.length;
int[][] C = new int[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
C[i][j] = A[i][j] - B[i][j];
}
}
return C;
```

## Time and Space Complexity

- The time complexity of Strassen's algorithm is O(n^log2(7)), which is approximately O(n^2.81). This is better than the traditional matrix multiplication algorithm's time complexity of O(n^3) for large matrices. However, the algorithm's actual performance is slower for small matrices due to the overhead of recursive function calls.
- The space complexity of Strassen's algorithm is O(n^2) due to the need to store multiple intermediate matrices. This is more than the space complexity of the traditional matrix multiplication algorithm, which is O(n^2) as well, but it can still be more space-efficient for very large matrices because it reduces the number of arithmetic operations required.

## Advantages

- Time complexity: Strassen's algorithm has a lower theoretical time complexity than traditional matrix multiplication for large matrices, making it more efficient in those cases.
- Asymptotic space complexity: Strassen's algorithm also has a lower asymptotic space complexity than traditional matrix multiplication, which makes it more space-efficient for very large matrices.

## Disadvantages

- Overhead: Strassen's algorithm involves more overhead than traditional matrix multiplication due to the need to divide matrices into smaller sub-matrices and the recursive calls to itself. This can make it slower for smaller matrices, where the overhead can outweigh the benefits of the lower time complexity.
- Approximate arithmetic: Strassen's algorithm uses floating-point arithmetic and truncation of the result, which can lead to errors in the result when matrices have large entries. This can be a disadvantage for some applications where exact results are required.
- Constant factors: The constant factors involved in Strassen's algorithm can make it slower in practice than the traditional matrix multiplication algorithm for matrices of moderate size, despite its lower theoretical time complexity.