Strassen's algorithm to compute matrix multiplication

来源:互联网 发布:pc手机网站展示源码 编辑:程序博客网 时间:2024/06/07 13:55
////  main.cpp//  Strassen////  Created by Longxiang Lyu on 5/24/16.//  Copyright (c) 2016 Longxiang Lyu. All rights reserved.//#include <iostream>#include <vector>#include <string>#include <stdexcept>#include <math.h>using namespace std;void printMatrix(const vector<vector<int>> &matrix){    for (auto row : matrix)    {        for (auto elem : row)            cout << elem << " ";        cout << endl;    }}void zeroPadding(vector<vector<int>> &matrix){    size_t sz = pow(2, (int)(sqrt(max(matrix.size(), matrix[0].size())) + 1));    matrix.resize(sz);    for (size_t i = 0; i < sz; ++i)    {        if (!matrix[i].empty())            matrix[i].resize(sz);        else            matrix[i] = vector<int>(sz, 0);    }}void sum(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &ret){    size_t sz = A.size();    for (int i = 0; i < sz; ++i)        for (int j = 0; j < sz; ++j)            ret[i][j] = (A[i][j] + B[i][j]);}void subtract(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &ret){    size_t sz = A.size();    // ret.clear();    // ret.resize(sz);    for (int i = 0; i < sz; ++i)        for (int j = 0; j < sz; ++j)            ret[i][j] = (A[i][j] - B[i][j]);}void strassenHelper(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &ret){    if (A.size() == 1)    {        ret[0][0] = A[0][0] * B[0][0];        return;    }    size_t sz = A.size();    size_t new_sz = sz / 2;    ret = vector<vector<int>>(sz, vector<int>(sz));    vector<vector<int>> a11(new_sz), a12(new_sz), a21(new_sz), a22(new_sz), b11(new_sz), b12(new_sz), b21(new_sz), b22(new_sz);    for (int i = 0; i < new_sz; ++i)    {        for (int j = 0; j < new_sz; ++j)        {            a11[i].push_back(A[i][j]);            a12[i].push_back(A[i][j + new_sz]);            a21[i].push_back(A[i + new_sz][j]);            a22[i].push_back(A[i + new_sz][j + new_sz]);                        b11[i].push_back(B[i][j]);            b12[i].push_back(B[i][j + new_sz]);            b21[i].push_back(B[i + new_sz][j]);            b22[i].push_back(B[i + new_sz][j + new_sz]);        }    }    vector<vector<int>> result1(new_sz, vector<int>(new_sz, 0)), result2(new_sz, vector<int>(new_sz, 0));    // p1    vector<vector<int>> p1(new_sz, vector<int>(new_sz, 0));    sum(a11, a22, result1);    sum(b11, b22, result2);    strassenHelper(result1, result2, p1);        // p2    vector<vector<int>> p2(new_sz, vector<int>(new_sz, 0));    sum(a21, a22, result1);    strassenHelper(result1, b11, p2);        // p3    vector<vector<int>> p3(new_sz, vector<int>(new_sz, 0));    subtract(b12, b22, result2);    strassenHelper(a11, result2, p3);        // p4    vector<vector<int>> p4(new_sz, vector<int>(new_sz, 0));    subtract(b21, b11, result2);    strassenHelper(a22, result2, p4);        // p5    vector<vector<int>> p5(new_sz, vector<int>(new_sz, 0));    sum(a11, a12, result1);    strassenHelper(result1, b22, p5);        // p6    vector<vector<int>> p6(new_sz, vector<int>(new_sz, 0));    subtract(a21, a11, result1);    sum(b11, b12, result2);    strassenHelper(result1, result2, p6);        // p7    vector<vector<int>> p7(new_sz, vector<int>(new_sz, 0));    subtract(a12, a22, result1);    sum(b21, b22, result2);    strassenHelper(result1, result2, p7);        vector<vector<int>> c11(new_sz, vector<int>(new_sz, 0));    vector<vector<int>> c12(new_sz, vector<int>(new_sz, 0));    vector<vector<int>> c21(new_sz, vector<int>(new_sz, 0));    vector<vector<int>> c22(new_sz, vector<int>(new_sz, 0));    sum(p3, p5, c12);    sum(p2, p4, c21);        sum(p1, p4, result1);    sum(result1, p7, result2);    subtract(result2, p5, c11);        sum(p1, p3, result1);    sum(result1, p6, result2);    subtract(result2, p2, c22);        for (int i = 0; i < new_sz; ++i)    {        for (int j = 0; j < new_sz; ++j)        {            ret[i][j] = c11[i][j];            ret[i][j + new_sz] = c12[i][j];            ret[i + new_sz][j] = c21[i][j];            ret[i + new_sz][j + new_sz] = c22[i][j];        }    }    }void strassen(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &ret){    if (A.empty() || B.empty())        throw runtime_error("empty matrices");    if (A[0].size() != B.size())        throw runtime_error("A's col not equal B's row");    zeroPadding(A);    zeroPadding(B);    strassenHelper(A, B, ret);}int main(int argc, const char * argv[]) {    vector<vector<int>> A{{1, 2, 0}, {1, 2, 3}, {1, 2, 3}};    vector<vector<int>> B{{1, 0, 1}, {1, 1, 1}, {2, 1, 1}};    vector<vector<int>> ret(2, vector<int>(2));    strassen(A, B, ret);    printMatrix(ret);    return 0;}

Reference:

https://martin-thoma.com/strassen-algorithm-in-python-java-cpp/

0 0