⏩ 线性代数与空间解析几何 上
#include <iostream>
#include <string>
#include <functional>
#include <chrono>
#include <vector>
using namespace std;
double timeit(std::function<void()> test) {
auto start = std::chrono::system_clock::now();
auto stop = std::chrono::system_clock::now();
std::chrono::duration<double, std::milli> time = stop - start;
return time.count();
class matrix {
size_t m, n;
std::vector<int64_t> e;
explicit matrix(size_t m, size_t n): m(m), n(n), e(m * n) {}
void random() {
for (size_t i = 0; i < m; ++i)
for (size_t j = 0; j < n; ++j)
at(i, j) = rand();
matrix(matrix const& that) = default;
matrix(matrix&&) = default;
matrix& operator=(matrix const& that) = default;
matrix& operator=(matrix&& that) = default;
int64_t& at(size_t i, size_t j) { return e[i * m + j]; }
int64_t const& at(size_t i, size_t j) const { return e[i * m + j]; }
matrix const& operator+=(matrix const& that) {
assert(m == that.m && n == that.n);
for (size_t i = 0; i < m; ++i) for (size_t j = 0; j < n; ++j) at(i, j) += that.at(i, j);
return *this;
matrix operator+(matrix that) const { return that += *this; }
matrix const& operator-=(matrix const& that) {
assert(m == that.m && n == that.n);
for (size_t i = 0; i < m; ++i) for (size_t j = 0; j < n; ++j) at(i, j) -= that.at(i, j);
return *this;
matrix operator-(matrix const& that) const { matrix copy(*this); copy -= that; return copy; }
matrix submat(size_t z, size_t w, size_t u, size_t v) const {
matrix sub(u, v);
for (size_t i = 0; i < u; ++i)
for (size_t j = 0; j < v; ++j)
sub.at(i, j) = at(z + i, w + j);
return sub;
matrix mul_plain(matrix const& that) const {
size_t p = that.m;
assert(n == p);
matrix product(m, that.n);
for (size_t i = 0; i < m; ++i)
for (size_t j = 0; j < n; ++j) {
for (size_t k = 0; k < p; ++k)
product.at(i, j) += at(i, k) * that.at(k, j);
return product;
matrix mul_constant(matrix const& that) const {
size_t p = that.m;
assert (n == p);
matrix product(m, that.n);
for (size_t i = 0; i < m; ++i)
for (size_t k = 0; k < p; ++k) {
int64_t r = at(i, k);
for (size_t j = 0; j < n; ++j)
product.at(i, j) += r * that.at(k, j);
return product;
matrix mul_Strassen(matrix const& that) const {
assert(m == that.m && n == that.n);
if (m == 1 && n == 1) {
matrix m(1, 1);
m.at(0, 0) = at(0, 0) * that.at(0, 0);
return m;
matrix A11 = submat(0, 0, m / 2, n / 2);
matrix A12 = submat(0, n / 2, m / 2, n - n / 2);
matrix A21 = submat(m / 2, 0, m - m / 2, n / 2);
matrix A22 = submat(m / 2, n / 2, m - m / 2, n - n / 2);
matrix B11 = that.submat(0, 0, m / 2, n / 2);
matrix B12 = that.submat(0, n / 2, m / 2, n - n / 2);
matrix B21 = that.submat(m / 2, 0, m - m / 2, n / 2);
matrix B22 = that.submat(m / 2, n / 2, m - m / 2, n - n / 2);
matrix S1 = B12 - B22;
matrix S2 = A11 + A12;
matrix S3 = A21 + A22;
matrix S4 = B21 - B11;
matrix S5 = A11 + A22;
matrix S6 = B11 + B22;
matrix S7 = A12 - A22;
matrix S8 = B21 + B22;
matrix S9 = A11 - A21;
matrix S10= B11 + B12;
matrix P1 = A11.mul_Strassen(S1);
matrix P2 = S2.mul_Strassen(B22);
matrix P3 = S3.mul_Strassen(B11);
matrix P4 = A22.mul_Strassen(S4);
matrix P5 = S5.mul_Strassen(S6);
matrix P6 = S7.mul_Strassen(S8);
matrix P7 = S9.mul_Strassen(S10);
matrix C11 = P5 + P4 - P2 + P6;
matrix C12 = P1 + P2;
matrix C21 = P3 + P4;
matrix C22 = P5 + P1 - P3 - P7;
matrix r(m, n);
for (size_t i = 0; i < C11.m; ++i)
for (size_t j = 0; j < C11.n; ++j)
r.at(i, j) = C11.at(i, j);
for (size_t i = 0; i < C12.m; ++i)
for (size_t j = 0; j < C12.n; ++j)
r.at(i, C11.n + j) = C12.at(i, j);
for (size_t i = 0; i < C21.m; ++i)
for (size_t j = 0; j < C21.n; ++j)
r.at(C11.m + i, j) = C21.at(i, j);
for (size_t i = 0; i < C22.m; ++i)
for (size_t j = 0; j < C22.n; ++j)
r.at(C11.m + i, C11.n + j) = C22.at(i, j);
return r;
bool operator==(matrix const& that) const {
return m == that.m && n == that.n && e == that.e;
void verify(matrix (matrix::* std)(matrix const&) const, matrix (matrix::* test)(matrix const&) const, const char* impl) {
for (int i = 0; i < 1000; ++i) {
matrix m1(16, 16), m2(16, 16);
m1.random(); m2.random();
if ((m1.*std)(m2) != (m1.*test)(m2)) {
fprintf(stderr, "Invalid implementation: %s\n", impl);
fprintf(stdout, "Valid implementation: %s\n", impl);
double timeit(matrix (matrix::* test)(matrix const&) const, int size, int times, const char* impl, double base = 0) {
double time = timeit([=] {
for (int i = 0; i < times; ++i) {
matrix m1(size, size), m2(size, size);
m1.random(); m2.random();
printf("'%8s' algorithm took %.2f ms to complete %d times %d x %d matrices multiplication", impl, time, times, size, size);
if (base && base != time) {
printf(", %.2f%% ", abs(base - time) / base * 100);
if (base > time) printf("faster");
else printf("slower");
return time;
int main() {
double base1 = timeit(&matrix::mul_plain, 16, 10000, "plain");
double base2 = timeit(&matrix::mul_plain, 1024, 1, "plain");
verify(&matrix::mul_plain, &matrix::mul_constant, "constant");
timeit(&matrix::mul_constant, 16, 10000, "constant", base1);
timeit(&matrix::mul_constant, 1024, 1, "constant", base2);
verify(&matrix::mul_plain, &matrix::mul_Strassen, "Strassen");
timeit(&matrix::mul_Strassen, 16, 10000, "Strassen", base1);
timeit(&matrix::mul_Strassen, 1024, 1, "Strassen", base2);
' plain' algorithm took 107.22 ms to complete 10000 times 16 x 16 matrices multiplication
' plain' algorithm took 3825.46 ms to complete 1 times 1024 x 1024 matrices multiplication
Valid implementation: constant
'constant' algorithm took 116.40 ms to complete 10000 times 16 x 16 matrices multiplication, 8.55% slower
'constant' algorithm took 487.24 ms to complete 1 times 1024 x 1024 matrices multiplication, 87.26% faster
Valid implementation: Strassen
'Strassen' algorithm took 61000.17 ms to complete 10000 times 16 x 16 matrices multiplication, 56790.40% slower
'Strassen' algorithm took 576375.25 ms to complete 1 times 1024 x 1024 matrices multiplication, 14966.81% slower
算法就是基础的矩阵乘法。 -
for (size_t j = 0; j < n; ++j)
product.at(i, j) += r * that.at(k, j);
变化的变量是 j
是算法复杂度为 的算法,它把一个大矩阵的乘法分解成了七个子矩阵线性组合的乘法。推导过程和复杂度分析在这里。
数学家和算法工程师看了拍案叫绝,软件工程师看了直摇头——这个常数实在是太大了。实际上,比 时间复杂度更优的算法,几乎都没有朴素的 跑得快,因此它们的研究仅仅只能停留在数学和算法的研究中,不能用于实际的工程中。
由于年少无知,错把 SIMD 的优化也放到了不允许并行的题解里,出题人也没分清,特此说明。
既然转换为了向量的问题,就可以使用向量化的方法可以大大提高计算的速度。下面用 1024 x 1024 的 double 矩阵为例。
#include <iostream>
#include <string>
#include <functional>
#include <chrono>
#include <vector>
#include <cstring>
using namespace std;
double timeit(std::function<void()> test) {
auto start = std::chrono::system_clock::now();
auto stop = std::chrono::system_clock::now();
std::chrono::duration<double, std::milli> time = stop - start;
return time.count();
#pragma GCC target("sse,sse2,sse3,ssse3,sse4.1,sse4.2,avx,avx2,tune=native")
#include <immintrin.h>
#include <emmintrin.h>
struct vec {
constexpr static int N = 256;
__m256d a[N];
double& operator[](int x) {
return ((double*) a)[x];
const double& operator[](int x) const {
return ((double*) a)[x];
double dot(const vec& x) const {
__m256d sum = _mm256_set1_pd(0.0);
for (int i = 0; i < N; i++)
sum = _mm256_add_pd(sum, _mm256_mul_pd(a[i], x.a[i]));
sum = _mm256_hadd_pd(sum, sum);
__m128d sum_high = _mm256_extractf128_pd(sum, 1);
__m128d result = _mm_add_pd(sum_high, _mm256_castpd256_pd128(sum));
return ((double*) &result)[0];
class matrix {
constexpr static size_t m = 1024, n = 1024;
vector<double> e;
explicit matrix(): e(m * n) {}
void random() {
for (size_t i = 0; i < m; ++i)
for (size_t j = 0; j < n; ++j)
at(i, j) = rand();
matrix(matrix const& that) = default;
matrix(matrix&&) = default;
matrix& operator=(matrix const& that) = default;
matrix& operator=(matrix&& that) = default;
double& at(size_t i, size_t j) { return e[i * m + j]; }
double const& at(size_t i, size_t j) const { return e[i * m + j]; }
matrix mul_plain(matrix const& that) const {
size_t p = that.m;
matrix product;
for (size_t i = 0; i < m; ++i)
for (size_t j = 0; j < n; ++j) {
for (size_t k = 0; k < p; ++k)
product.at(i, j) += at(i, k) * that.at(k, j);
return product;
matrix mul_vectorized(matrix const& that) const {
vector<vec> lines(m), columns(n);
memcpy(lines.data(), e.data(), sizeof(double) * m * n);
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
columns[j][i] = that.at(i, j);
matrix r;
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
r.at(i, j) = lines[i].dot(columns[j]);
return r;
double timeit(matrix (matrix::* test)(matrix const&) const, int times, const char* impl, double base = 0) {
double time = timeit([=] {
for (int i = 0; i < times; ++i) {
matrix m1, m2;
m1.random(); m2.random();
printf("'%10s' algorithm took %.2f ms to complete %d times 1024 x 1024 matrices multiplication", impl, time, times);
if (base && base != time) {
printf(", %.2f%% ", abs(base - time) / base * 100);
if (base > time) printf("faster");
else printf("slower");
return time;
int main() {
double base = timeit(&matrix::mul_plain, 1, "plain");
timeit(&matrix::mul_vectorized, 1, "vectorized", base);
' plain' algorithm took 3405.35 ms to complete 1 times 1024 x 1024 matrices multiplication
'vectorized' algorithm took 223.44 ms to complete 1 times 1024 x 1024 matrices multiplication, 93.44% faster
向量化让 CPU 可以在一条指令里处理多个数据,从而大大提高计算效率。