Strassen's algorithm to compute matrix multiplication
来源:互联网 发布:pc手机网站展示源码 编辑:程序博客网 时间:2024/06/07 13:55
//// main.cpp// Strassen//// Created by Longxiang Lyu on 5/24/16.// Copyright (c) 2016 Longxiang Lyu. All rights reserved.//#include <iostream>#include <vector>#include <string>#include <stdexcept>#include <math.h>using namespace std;void printMatrix(const vector<vector<int>> &matrix){ for (auto row : matrix) { for (auto elem : row) cout << elem << " "; cout << endl; }}void zeroPadding(vector<vector<int>> &matrix){ size_t sz = pow(2, (int)(sqrt(max(matrix.size(), matrix[0].size())) + 1)); matrix.resize(sz); for (size_t i = 0; i < sz; ++i) { if (!matrix[i].empty()) matrix[i].resize(sz); else matrix[i] = vector<int>(sz, 0); }}void sum(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &ret){ size_t sz = A.size(); for (int i = 0; i < sz; ++i) for (int j = 0; j < sz; ++j) ret[i][j] = (A[i][j] + B[i][j]);}void subtract(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &ret){ size_t sz = A.size(); // ret.clear(); // ret.resize(sz); for (int i = 0; i < sz; ++i) for (int j = 0; j < sz; ++j) ret[i][j] = (A[i][j] - B[i][j]);}void strassenHelper(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &ret){ if (A.size() == 1) { ret[0][0] = A[0][0] * B[0][0]; return; } size_t sz = A.size(); size_t new_sz = sz / 2; ret = vector<vector<int>>(sz, vector<int>(sz)); vector<vector<int>> a11(new_sz), a12(new_sz), a21(new_sz), a22(new_sz), b11(new_sz), b12(new_sz), b21(new_sz), b22(new_sz); for (int i = 0; i < new_sz; ++i) { for (int j = 0; j < new_sz; ++j) { a11[i].push_back(A[i][j]); a12[i].push_back(A[i][j + new_sz]); a21[i].push_back(A[i + new_sz][j]); a22[i].push_back(A[i + new_sz][j + new_sz]); b11[i].push_back(B[i][j]); b12[i].push_back(B[i][j + new_sz]); b21[i].push_back(B[i + new_sz][j]); b22[i].push_back(B[i + new_sz][j + new_sz]); } } vector<vector<int>> result1(new_sz, vector<int>(new_sz, 0)), result2(new_sz, vector<int>(new_sz, 0)); // p1 vector<vector<int>> p1(new_sz, vector<int>(new_sz, 0)); sum(a11, a22, result1); sum(b11, b22, result2); strassenHelper(result1, result2, p1); // p2 vector<vector<int>> p2(new_sz, vector<int>(new_sz, 0)); sum(a21, a22, result1); strassenHelper(result1, b11, p2); // p3 vector<vector<int>> p3(new_sz, vector<int>(new_sz, 0)); subtract(b12, b22, result2); strassenHelper(a11, result2, p3); // p4 vector<vector<int>> p4(new_sz, vector<int>(new_sz, 0)); subtract(b21, b11, result2); strassenHelper(a22, result2, p4); // p5 vector<vector<int>> p5(new_sz, vector<int>(new_sz, 0)); sum(a11, a12, result1); strassenHelper(result1, b22, p5); // p6 vector<vector<int>> p6(new_sz, vector<int>(new_sz, 0)); subtract(a21, a11, result1); sum(b11, b12, result2); strassenHelper(result1, result2, p6); // p7 vector<vector<int>> p7(new_sz, vector<int>(new_sz, 0)); subtract(a12, a22, result1); sum(b21, b22, result2); strassenHelper(result1, result2, p7); vector<vector<int>> c11(new_sz, vector<int>(new_sz, 0)); vector<vector<int>> c12(new_sz, vector<int>(new_sz, 0)); vector<vector<int>> c21(new_sz, vector<int>(new_sz, 0)); vector<vector<int>> c22(new_sz, vector<int>(new_sz, 0)); sum(p3, p5, c12); sum(p2, p4, c21); sum(p1, p4, result1); sum(result1, p7, result2); subtract(result2, p5, c11); sum(p1, p3, result1); sum(result1, p6, result2); subtract(result2, p2, c22); for (int i = 0; i < new_sz; ++i) { for (int j = 0; j < new_sz; ++j) { ret[i][j] = c11[i][j]; ret[i][j + new_sz] = c12[i][j]; ret[i + new_sz][j] = c21[i][j]; ret[i + new_sz][j + new_sz] = c22[i][j]; } } }void strassen(vector<vector<int>> &A, vector<vector<int>> &B, vector<vector<int>> &ret){ if (A.empty() || B.empty()) throw runtime_error("empty matrices"); if (A[0].size() != B.size()) throw runtime_error("A's col not equal B's row"); zeroPadding(A); zeroPadding(B); strassenHelper(A, B, ret);}int main(int argc, const char * argv[]) { vector<vector<int>> A{{1, 2, 0}, {1, 2, 3}, {1, 2, 3}}; vector<vector<int>> B{{1, 0, 1}, {1, 1, 1}, {2, 1, 1}}; vector<vector<int>> ret(2, vector<int>(2)); strassen(A, B, ret); printMatrix(ret); return 0;}
Reference:
https://martin-thoma.com/strassen-algorithm-in-python-java-cpp/
0 0
- Strassen's algorithm to compute matrix multiplication
- Strassen’s algorithm for matrix multiplication
- Strassen's Subcubic Matrix Multiplication Algorithm
- Implementation of Strassen’s Algorithm for Matrix Multiplication
- Strassen algorithm
- Strassen Algorithm
- Strassen Algorithm解析
- Matrix Multiplication
- Matrix Multiplication
- Matrix Multiplication
- Matrix Multiplication
- Matrix multiplication
- Matrix multiplication
- Strassen‘s 矩阵乘法
- Use python to implement Dijkstra's algorithm
- Hirschberg's algorithm to find string alignment
- Matrix Chain Multiplication
- zoj1094 Matrix Chain Multiplication
- MPI compile 设置 centos 7
- eclipse 安卓开发环境搭建
- Android应用开发Scroller详解及源码浅析
- 旋转木马3D环形特效
- LeetCode: Edit Distance
- Strassen's algorithm to compute matrix multiplication
- #define
- gnu autotools 相关技术资料
- MATLAB FDATool IIR数字滤波器设计
- 关于分布式事务、两阶段提交、一阶段提交、Best Efforts 1PC模式和事务补偿机制的研究
- 关于一场“信任危机”
- 分布式事务介绍
- leetcode #33 in cpp
- 勾股定理一日一证连载3