Skip to main content

⏩ 线性代数与空间解析几何 上

题目

不使用并行计算的前提下,使矩阵乘法的运算速度变快。

题解

代码

#include <iostream>
#include <string>
#include <functional>
#include <chrono>
#include <vector>
using namespace std;


double timeit(std::function<void()> test) {
auto start = std::chrono::system_clock::now();
test();
auto stop = std::chrono::system_clock::now();
std::chrono::duration<double, std::milli> time = stop - start;
return time.count();
}

class matrix {
size_t m, n;
std::vector<int64_t> e;
public:
explicit matrix(size_t m, size_t n): m(m), n(n), e(m * n) {}
void random() {
for (size_t i = 0; i < m; ++i)
for (size_t j = 0; j < n; ++j)
at(i, j) = rand();
}
matrix(matrix const& that) = default;
matrix(matrix&&) = default;
matrix& operator=(matrix const& that) = default;
matrix& operator=(matrix&& that) = default;
int64_t& at(size_t i, size_t j) { return e[i * m + j]; }
int64_t const& at(size_t i, size_t j) const { return e[i * m + j]; }

matrix const& operator+=(matrix const& that) {
assert(m == that.m && n == that.n);
for (size_t i = 0; i < m; ++i) for (size_t j = 0; j < n; ++j) at(i, j) += that.at(i, j);
return *this;
}

matrix operator+(matrix that) const { return that += *this; }
matrix const& operator-=(matrix const& that) {
assert(m == that.m && n == that.n);
for (size_t i = 0; i < m; ++i) for (size_t j = 0; j < n; ++j) at(i, j) -= that.at(i, j);
return *this;
}
matrix operator-(matrix const& that) const { matrix copy(*this); copy -= that; return copy; }

matrix submat(size_t z, size_t w, size_t u, size_t v) const {
matrix sub(u, v);
for (size_t i = 0; i < u; ++i)
for (size_t j = 0; j < v; ++j)
sub.at(i, j) = at(z + i, w + j);
return sub;
}

matrix mul_plain(matrix const& that) const {
size_t p = that.m;
assert(n == p);
matrix product(m, that.n);
for (size_t i = 0; i < m; ++i)
for (size_t j = 0; j < n; ++j) {
for (size_t k = 0; k < p; ++k)
product.at(i, j) += at(i, k) * that.at(k, j);
}
return product;
}

matrix mul_constant(matrix const& that) const {
size_t p = that.m;
assert (n == p);
matrix product(m, that.n);
for (size_t i = 0; i < m; ++i)
for (size_t k = 0; k < p; ++k) {
int64_t r = at(i, k);
for (size_t j = 0; j < n; ++j)
product.at(i, j) += r * that.at(k, j);
}
return product;
}

matrix mul_Strassen(matrix const& that) const {
assert(m == that.m && n == that.n);
if (m == 1 && n == 1) {
matrix m(1, 1);
m.at(0, 0) = at(0, 0) * that.at(0, 0);
return m;
}
matrix A11 = submat(0, 0, m / 2, n / 2);
matrix A12 = submat(0, n / 2, m / 2, n - n / 2);
matrix A21 = submat(m / 2, 0, m - m / 2, n / 2);
matrix A22 = submat(m / 2, n / 2, m - m / 2, n - n / 2);
matrix B11 = that.submat(0, 0, m / 2, n / 2);
matrix B12 = that.submat(0, n / 2, m / 2, n - n / 2);
matrix B21 = that.submat(m / 2, 0, m - m / 2, n / 2);
matrix B22 = that.submat(m / 2, n / 2, m - m / 2, n - n / 2);
matrix S1 = B12 - B22;
matrix S2 = A11 + A12;
matrix S3 = A21 + A22;
matrix S4 = B21 - B11;
matrix S5 = A11 + A22;
matrix S6 = B11 + B22;
matrix S7 = A12 - A22;
matrix S8 = B21 + B22;
matrix S9 = A11 - A21;
matrix S10= B11 + B12;
matrix P1 = A11.mul_Strassen(S1);
matrix P2 = S2.mul_Strassen(B22);
matrix P3 = S3.mul_Strassen(B11);
matrix P4 = A22.mul_Strassen(S4);
matrix P5 = S5.mul_Strassen(S6);
matrix P6 = S7.mul_Strassen(S8);
matrix P7 = S9.mul_Strassen(S10);
matrix C11 = P5 + P4 - P2 + P6;
matrix C12 = P1 + P2;
matrix C21 = P3 + P4;
matrix C22 = P5 + P1 - P3 - P7;
matrix r(m, n);
for (size_t i = 0; i < C11.m; ++i)
for (size_t j = 0; j < C11.n; ++j)
r.at(i, j) = C11.at(i, j);
for (size_t i = 0; i < C12.m; ++i)
for (size_t j = 0; j < C12.n; ++j)
r.at(i, C11.n + j) = C12.at(i, j);
for (size_t i = 0; i < C21.m; ++i)
for (size_t j = 0; j < C21.n; ++j)
r.at(C11.m + i, j) = C21.at(i, j);
for (size_t i = 0; i < C22.m; ++i)
for (size_t j = 0; j < C22.n; ++j)
r.at(C11.m + i, C11.n + j) = C22.at(i, j);
return r;
}

bool operator==(matrix const& that) const {
return m == that.m && n == that.n && e == that.e;
}
};

void verify(matrix (matrix::* std)(matrix const&) const, matrix (matrix::* test)(matrix const&) const, const char* impl) {
srand(1);
for (int i = 0; i < 1000; ++i) {
matrix m1(16, 16), m2(16, 16);
m1.random(); m2.random();
if ((m1.*std)(m2) != (m1.*test)(m2)) {
fprintf(stderr, "Invalid implementation: %s\n", impl);
exit(-1);
}
}
fprintf(stdout, "Valid implementation: %s\n", impl);
}

double timeit(matrix (matrix::* test)(matrix const&) const, int size, int times, const char* impl, double base = 0) {
double time = timeit([=] {
srand(1);
for (int i = 0; i < times; ++i) {
matrix m1(size, size), m2(size, size);
m1.random(); m2.random();
(m1.*test)(m2);
}
});
printf("'%8s' algorithm took %.2f ms to complete %d times %d x %d matrices multiplication", impl, time, times, size, size);
if (base && base != time) {
printf(", %.2f%% ", abs(base - time) / base * 100);
if (base > time) printf("faster");
else printf("slower");
}
puts("");
return time;
}


int main() {

double base1 = timeit(&matrix::mul_plain, 16, 10000, "plain");
double base2 = timeit(&matrix::mul_plain, 1024, 1, "plain");

verify(&matrix::mul_plain, &matrix::mul_constant, "constant");
timeit(&matrix::mul_constant, 16, 10000, "constant", base1);
timeit(&matrix::mul_constant, 1024, 1, "constant", base2);

verify(&matrix::mul_plain, &matrix::mul_Strassen, "Strassen");
timeit(&matrix::mul_Strassen, 16, 10000, "Strassen", base1);
timeit(&matrix::mul_Strassen, 1024, 1, "Strassen", base2);
}

输出:

'   plain' algorithm took 107.22 ms to complete 10000 times 16 x 16 matrices multiplication
' plain' algorithm took 3825.46 ms to complete 1 times 1024 x 1024 matrices multiplication
Valid implementation: constant
'constant' algorithm took 116.40 ms to complete 10000 times 16 x 16 matrices multiplication, 8.55% slower
'constant' algorithm took 487.24 ms to complete 1 times 1024 x 1024 matrices multiplication, 87.26% faster
Valid implementation: Strassen
'Strassen' algorithm took 61000.17 ms to complete 10000 times 16 x 16 matrices multiplication, 56790.40% slower
'Strassen' algorithm took 576375.25 ms to complete 1 times 1024 x 1024 matrices multiplication, 14966.81% slower

分析

  • plain 算法就是基础的矩阵乘法。

  • constant 是常数优化之后的矩阵乘法。在矩阵较小的时候与前者速度差别不大,但是当矩阵较大时速度会有非常明显地提升。

其原理在于:重新排列了循环的顺序,最内层循环

for (size_t j = 0; j < n; ++j) 
product.at(i, j) += r * that.at(k, j);

变化的变量是 j ,(由于我的矩阵按行展开存储)使得每次寻址都与上次寻址的内存非常近,大大提高了空间局部性,从而实现常数的优化。

  • Strassen 是算法复杂度为 Θ(nlog7)\Theta(n^{\log7}) 的算法,它把一个大矩阵的乘法分解成了七个子矩阵线性组合的乘法。推导过程和复杂度分析在这里

数学家和算法工程师看了拍案叫绝,软件工程师看了直摇头——这个常数实在是太大了。实际上,比 Θ(n3)\Theta(n^3) 时间复杂度更优的算法,几乎都没有朴素的 Θ(n3)\Theta(n^3) 跑得快,因此它们的研究仅仅只能停留在数学和算法的研究中,不能用于实际的工程中。

不过这个算法的优势其实也有所体现——擅长大矩阵。它的核心就是分治的思想:大矩阵化为小矩阵。矩阵小时,相较于朴素算法慢了56790.40%,但是当矩阵较大时,就只慢了14966.81%,可见矩阵足够大的时候,后者也许能反超前者。(不过那可能太大了,至少在个人电脑上可以实验的数据范围里是看不到这一刻了)

由于年少无知,错把 SIMD 的优化也放到了不允许并行的题解里,出题人也没分清,特此说明。

向量化

先复习一下线性代数与空间解析几何:

A=(α1Tα2TαnT)B=(β1,β2,,βn)C=(cij)=AB=(αiTβj)A=\left(\begin{array}{}\alpha_1^T\\\alpha_2^T\\\vdots\\\alpha_n^T\end{array}\right)\\ B=(\beta_1,\beta_2,\cdots,\beta_n)\\ C=(c_{ij})=AB=(\alpha_i^T\beta_j)

也就是说,矩阵的乘法可以看成左矩阵的行向量和右矩阵的列向量的内积!

既然转换为了向量的问题,就可以使用向量化的方法可以大大提高计算的速度。下面用 1024 x 1024 的 double 矩阵为例。

#include <iostream>
#include <string>
#include <functional>
#include <chrono>
#include <vector>
#include <cstring>
using namespace std;


double timeit(std::function<void()> test) {
auto start = std::chrono::system_clock::now();
test();
auto stop = std::chrono::system_clock::now();
std::chrono::duration<double, std::milli> time = stop - start;
return time.count();
}

#pragma GCC target("sse,sse2,sse3,ssse3,sse4.1,sse4.2,avx,avx2,tune=native")
#include <immintrin.h>
#include <emmintrin.h>

struct vec {
constexpr static int N = 256;
__m256d a[N];
double& operator[](int x) {
return ((double*) a)[x];
}
const double& operator[](int x) const {
return ((double*) a)[x];
}
double dot(const vec& x) const {
__m256d sum = _mm256_set1_pd(0.0);
for (int i = 0; i < N; i++)
sum = _mm256_add_pd(sum, _mm256_mul_pd(a[i], x.a[i]));
sum = _mm256_hadd_pd(sum, sum);
__m128d sum_high = _mm256_extractf128_pd(sum, 1);
__m128d result = _mm_add_pd(sum_high, _mm256_castpd256_pd128(sum));
return ((double*) &result)[0];
}
};

class matrix {
constexpr static size_t m = 1024, n = 1024;
vector<double> e;
public:
explicit matrix(): e(m * n) {}
void random() {
for (size_t i = 0; i < m; ++i)
for (size_t j = 0; j < n; ++j)
at(i, j) = rand();
}
matrix(matrix const& that) = default;
matrix(matrix&&) = default;
matrix& operator=(matrix const& that) = default;
matrix& operator=(matrix&& that) = default;
double& at(size_t i, size_t j) { return e[i * m + j]; }
double const& at(size_t i, size_t j) const { return e[i * m + j]; }

matrix mul_plain(matrix const& that) const {
size_t p = that.m;
matrix product;
for (size_t i = 0; i < m; ++i)
for (size_t j = 0; j < n; ++j) {
for (size_t k = 0; k < p; ++k)
product.at(i, j) += at(i, k) * that.at(k, j);
}
return product;
}

matrix mul_vectorized(matrix const& that) const {
vector<vec> lines(m), columns(n);
memcpy(lines.data(), e.data(), sizeof(double) * m * n);
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
columns[j][i] = that.at(i, j);
}
}
matrix r;
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
r.at(i, j) = lines[i].dot(columns[j]);
}
}
return r;
}
};

double timeit(matrix (matrix::* test)(matrix const&) const, int times, const char* impl, double base = 0) {
double time = timeit([=] {
srand(1);
for (int i = 0; i < times; ++i) {
matrix m1, m2;
m1.random(); m2.random();
(m1.*test)(m2);
}
});
printf("'%10s' algorithm took %.2f ms to complete %d times 1024 x 1024 matrices multiplication", impl, time, times);
if (base && base != time) {
printf(", %.2f%% ", abs(base - time) / base * 100);
if (base > time) printf("faster");
else printf("slower");
}
puts("");
return time;
}


int main() {
double base = timeit(&matrix::mul_plain, 1, "plain");
timeit(&matrix::mul_vectorized, 1, "vectorized", base);
}

输出:

'     plain' algorithm took 3405.35 ms to complete 1 times 1024 x 1024 matrices multiplication
'vectorized' algorithm took 223.44 ms to complete 1 times 1024 x 1024 matrices multiplication, 93.44% faster

向量化让 CPU 可以在一条指令里处理多个数据,从而大大提高计算效率。