矩阵乘法(分治法)

题目描述

设 A 和 B 是两个 n * n 阶矩阵,求它们的乘积矩阵 C。要求使用分治法。

输入格式

输入为 1+2×n×n 个数字,第一个数字与第二个数字之间以空格和换行隔开,其余数字之间以空格隔开,

第 1 个表示矩阵阶层 n,

第 2 个至第 n+1 个表示矩阵 A,

第 n+2 个至第 2n+1 个表示矩阵 B。

输出格式

输出为 n×n 个数字,表示乘积矩阵 C

样例 #1

样例输入 #1

1
2
3
1 1 1 1 2 3 2 3 4 5 7 8 2 3 2 1 2 9

样例输出 #1

1
8 12 19 12 19 39 20 31 58

提示

0≤n≤100,A、B 均为整数矩阵

问题分析

解法一:超级大暴力

首先我们可以使用最简单的矩阵乘法,由定义

两个矩阵对指定行号和列号的元素依次相乘的和作为输出矩阵对应的行号列号的元素。用数学语言表示, ,那么有 , 且

利用这种最简单的想法,我们很容易写出一个基于定义式的 暴力解法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include <algorithm>
#include <iostream>
#include <cstring>
#include <set>
#include <vector>
#include <map>
#include <cmath>
#include <queue>
#include <iomanip>
#include <stack>
#define debug(x) cerr << #x << " = " << x << endl
#define int long long
#define all(x) (x).begin(),(x).end()
#define endl ('\n')

using namespace std;

typedef pair<int, int> PII;
typedef long long LL;

const int N = 1e2 + 5, mod = 1e9 + 7;
const int INF = 0x3f3f3f3f;
//int dx[] = {1, 0, -1, 0}, dy[] = {0, -1, 0, 1};
//const double eps = 1e-6;

int n, m, k;
int a[N][N];
int b[N][N];
int c[N][N];
string s;

int get(int x, int y){
int ret = 0;
for (int i = 0; i < n; i ++)
ret += a[x][i] * b[i][y];
return ret;
}

void solve(){
cin >> n;
for (int i = 0; i < n * n; i ++ )
cin >> a[i / n][i % n];
for (int i = 0; i < n * n; i ++ )
cin >> b[i / n][i % n];
for (int i = 0; i < n; i ++){
for (int j = 0; j < n; j ++){
c[i][j] = get(i, j);
}
}
for (int i = 0; i < n * n; i ++ ){
cout << c[i / n][i % n] << " ";
}
}

signed main(){
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
solve();
return 0;
}

这样的超级大暴力代码能不能通过呢?

是可以的,因为时间放的很宽松,对于范围在 100 的数据竟然留了 1s,而一秒钟跑完 的数据只能说轻轻松松

解法二:普通分治乘法

参考 zqt 大佬的代码

对于矩阵 A,B 我们有

此时

但是如果采用这一种方式,那应当如何分割宽和高为奇数的矩阵的呢,如下所示

此时我们采取的解决方案为在外侧补上一圈的 0, 长度由原来的 length 变为 length + 1

1
2
3
4
5
6
7
8
9
10
11
12
for (int i = 0; i < length + 1; i++)
{
for (int j = 0; j < length + 1; j++)
{
// 补上外侧的 0
if (i == length || j == length)
{
at[i][j] = 0;
bt[i][j] = 0;
}
}
}

之后就变为长度为偶数的矩阵可以进行分块矩阵的分割,之后再递归的往下处理即可

分治的核心就在于此,将目标问题转换为更为简单的多个子问题,并拆分解决

1
2
3
4
5
6
7
8
9
10
11
12
int **t1 = MulM(a, b, l, ax11, ay11, bx11, by11, tx11, ty11);
int **t2 = MulM(a, b, l, ax12, ay12, bx21, by21, tx11, ty11);
int **t3 = MulM(a, b, l, ax11, ay11, bx12, by12, tx12, ty12);
int **t4 = MulM(a, b, l, ax12, ay12, bx22, by22, tx12, ty12);
int **t5 = MulM(a, b, l, ax21, ay21, bx11, by11, tx21, ty21);
int **t6 = MulM(a, b, l, ax22, ay22, bx21, by21, tx21, ty21);
int **t7 = MulM(a, b, l, ax21, ay21, bx12, by12, tx22, ty22);
int **t8 = MulM(a, b, l, ax22, ay22, bx22, by22, tx22, ty22);
AddM(t1, t2, temp, l, tx11, ty11);
AddM(t3, t4, temp, l, tx12, ty12);
AddM(t5, t6, temp, l, tx21, ty21);
AddM(t7, t8, temp, l, tx22, ty22);

按照这一思路实现,就能够正确的输出计算结果

但是提交到洛谷上之后发现,有很多测试点是 MLE (Memory limit exceeded)

造成这一情况的原因很可能是,由于开了太多的数组导致的,在本题中每一个子函数都创建了很多的数组(因此在使用完之后及时将内存进行释放就能够解决这一问题)

1
2
3
4
5
6
7
8
9
10
11
12
AddM(t1, t2, temp, l, tx11, ty11);
del(l, t1);
del(l, t2);
AddM(t3, t4, temp, l, tx12, ty12);
del(l, t3);
del(l, t4);
AddM(t5, t6, temp, l, tx21, ty21);
del(l, t5);
del(l, t6);
AddM(t7, t8, temp, l, tx22, ty22);
del(l, t7);
del(l, t8);

加上这一修改后即可通过该问题

复杂度分析,由于在该问题中,采用这样的分治方式,只是将代码转为递归形式的延续,并不能降低代码的复杂度

时间复杂度:

示例代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
#include <stdio.h>
#include <stdlib.h>

int n;
int **C;
void AddM(int **a, int **b, int **temp, int length, int ax, int ay, int bx, int by, int tx, int ty)
{
for (int i = 0; i < length; i++)
{
for (int j = 0; j < length; j++)
{
temp[tx + i][ty + j] = a[ax+i][ay+j] + b[bx+i][by+j];
}
}
}
void SubM(int **a, int **b, int **temp, int length, int ax, int ay, int bx, int by, int tx, int ty)
{
for (int i = 0; i < length; i++)
{
for (int j = 0; j < length; j++)
{
temp[tx + i][ty + j] = a[ax+i][ay+j] - b[bx+i][by+j];
}
}
}

// 用del将所有的函数都释放
void del(int n, int **a)
{ // 将矩阵空间释放
for (int i = 0; i < n; ++i)
{
free(a[i]);
}
free(a);
}

// 返回了当前矩阵的运算结果
int **MulM(int **a, int **b, int length, int ax, int ay, int bx, int by)
{
// 申请一个临时的矩阵,保存这一次a*b的结果
int **temp;
// 分治的最底层
if (length == 1)
{
// 申请一个1*1的
temp = (int **)malloc(sizeof(int *) * 1);
// 用temp[0]
temp[0] = (int *)malloc(sizeof(int) * 1);
temp[0][0] = a[ax][ay] * b[bx][by];
return temp;
}
// 否则的话,把矩阵分成四块
// 分A
int l;
if (length % 2 == 0)
{
l = length / 2;
temp = (int **)malloc(sizeof(int *) * length);
for (int i = 0; i < length; i++)
{
temp[i] = (int *)malloc(sizeof(int) * length);
}
int ax11 = ax;
int ay11 = ay;
int ax12 = ax;
int ay12 = ay + l;
int ax21 = ax + l;
int ay21 = ay;
int ax22 = ax + l;
int ay22 = ay + l;
// 分B
int bx11 = bx;
int by11 = by;
int bx12 = bx;
int by12 = by + l;
int bx21 = bx + l;
int by21 = by;
int bx22 = bx + l;
int by22 = by + l;
// 分temp,他就是从0开始分的
int tx11 = 0;
int ty11 = 0;
int tx12 = 0;
int ty12 = l;
int tx21 = l;
int ty21 = 0;
int tx22 = l;
int ty22 = l;
// 7次
//开一个t,来存每次相加减的结果
int **t;
int **t2;
t = (int **)malloc(sizeof(int *) * l);
t2 = (int **)malloc(sizeof(int *) * l);
for (int i = 0; i < l; i++)
{
t[i] = (int *)malloc(sizeof(int) * l);
t2[i] = (int *)malloc(sizeof(int) * l);
}
//A11*(B12-B22)
SubM( b , b , t , l, bx12 , by12 , bx22 , by22 , 0, 0);
int **m1 = MulM(a, t, l, ax11, ay11, 0, 0);
//(A11+A12)*B22
AddM( a , a , t , l, ax11 , ay11 , ax12 , ay12 ,0, 0);
int **m2 = MulM(t, b, l, 0, 0, bx22,by22);
//(A21+A22)*B11
AddM( a , a , t , l, ax21 , ay21 , ax22 , ay22 ,0, 0);
int **m3 = MulM(t, b, l, 0, 0 , bx11, by11);
//A22(B21-B11)
SubM( b , b , t , l, bx21 , by21 , bx11 , by11 , 0, 0);
int **m4 = MulM(a, t, l, ax22, ay22 , 0, 0);
//(A11+A22)(B11+B22)
AddM( a , a , t , l, ax11 , ay11 , ax22 , ay22 , 0, 0);
AddM( b , b , t2 , l, bx11 , by11 , bx22 , by22 , 0, 0);
int **m5 = MulM(t, t2, l, 0, 0 , 0, 0);
//(A12-A22)(B21+B22)
SubM( a , a , t , l, ax12 , ay12 , ax22 , ay22 , 0, 0);
AddM( b , b , t2 , l, bx21 , by21 , bx22 , by22 , 0, 0);
int **m6 = MulM(t, t2, l, 0, 0 , 0, 0);
//(A11-A21)(B11+B12)
SubM( a , a , t , l, ax11 , ay11 , ax21 , ay21 , 0, 0);
AddM( b , b , t2 , l, bx11 , by11 , bx12 , by12 , 0, 0);
int **m7 = MulM(t, t2, l, 0, 0 , 0, 0);
//C11=M5+M4-M2+M6
AddM( m5 , m4 , temp , l, 0 , 0 , 0 , 0 , tx11,ty11);
SubM( temp , m2 , temp , l, tx11,ty11 , 0 , 0 , tx11,ty11);
AddM( temp , m6 , temp , l, tx11,ty11 , 0 , 0 , tx11,ty11);
del(l,m6);
//C12=M1+M2
AddM( m1 , m2 , temp , l, 0 , 0 , 0 , 0 , tx12,ty12);
del(l,m2);
//C21=M3+M4
AddM( m3 , m4 , temp , l, 0 , 0 , 0 , 0 , tx21,ty21);
del(l,m4);
//C22=M5+M1-M3-M7
AddM( m5 , m1 , temp , l, 0 , 0 , 0 , 0 , tx22,ty22);
SubM( temp , m3 , temp , l, tx22,ty22, 0 , 0 , tx22,ty22);
SubM( temp , m7 , temp , l, tx22,ty22, 0 , 0 , tx22,ty22);
del(l,m1);del(l,m3);del(l,m5);del(l,m7);
return temp;
}
else
{
l = length / 2 + 1;
// 如果是奇数,则分开的时候
// 要把最后一行、最后一列补0,合起来的时候,再把0去掉
// 那么这里最简单的方法就是开新的a,b
int **at;
int **bt;
at = (int **)malloc(sizeof(int *) * (length + 1));
bt = (int **)malloc(sizeof(int *) * (length + 1));
temp = (int **)malloc(sizeof(int *) * (length + 1));
for (int i = 0; i < length + 1; i++)
{
temp[i] = (int *)malloc(sizeof(int) * (length + 1));
at[i] = (int *)malloc(sizeof(int) * (length + 1));
bt[i] = (int *)malloc(sizeof(int) * (length + 1));
}
// 赋值
for (int i = 0; i < length + 1; i++)
{
for (int j = 0; j < length + 1; j++)
{
if (i == length || j == length)
{
at[i][j] = 0;
bt[i][j] = 0;
}
else
{
at[i][j] = a[ax + i][ay + j];
bt[i][j] = b[bx + i][by + j];
}
}
}
// ab都从00开始,进行分块,也存在temp里面,不影响
int ax11 = 0;
int ay11 = 0;
int ax12 = 0;
int ay12 = l;
int ax21 = l;
int ay21 = 0;
int ax22 = l;
int ay22 = l;
// 分B
int bx11 = 0;
int by11 = 0;
int bx12 = 0;
int by12 = l;
int bx21 = l;
int by21 = 0;
int bx22 = l;
int by22 = l;
// 分temp,他就是从0开始分的
int tx11 = 0;
int ty11 = 0;
int tx12 = 0;
int ty12 = l;
int tx21 = l;
int ty21 = 0;
int tx22 = l;
int ty22 = l;
// 分别计算
int **t;
int **t2;
t = (int **)malloc(sizeof(int *) * l);
t2 = (int **)malloc(sizeof(int *) * l);
for (int i = 0; i < l; i++)
{
t[i] = (int *)malloc(sizeof(int) * l);
t2[i] = (int *)malloc(sizeof(int) * l);
}
//A11*(B12-B22)
SubM( bt , bt , t , l, bx12 , by12 , bx22 , by22 , 0, 0);
int **m1 = MulM(at, t, l, ax11, ay11, 0, 0);
//(A11+A12)*B22
AddM( at , at , t , l, ax11 , ay11 , ax12 , ay12 ,0, 0);
int **m2 = MulM(t, bt, l, 0, 0, bx22,by22);
//(A21+A22)*B11
AddM( at , at , t , l, ax21 , ay21 , ax22 , ay22 ,0, 0);
int **m3 = MulM(t, bt, l, 0, 0 , bx11, by11);
//A22(B21-B11)
SubM( bt , bt , t , l, bx21 , by21 , bx11 , by11 , 0, 0);
int **m4 = MulM(at, t, l, ax22, ay22 , 0, 0);
//(A11+A22)(B11+B22)
AddM( at , at , t , l, ax11 , ay11 , ax22 , ay22 , 0, 0);
AddM( bt , bt , t2 , l, bx11 , by11 , bx22 , by22 , 0, 0);
int **m5 = MulM(t, t2, l, 0, 0 , 0, 0);
//(A12-A22)(B21+B22)
SubM( at , at , t , l, ax12 , ay12 , ax22 , ay22 , 0, 0);
AddM( bt , bt , t2 , l, bx21 , by21 , bx22 , by22 , 0, 0);
int **m6 = MulM(t, t2, l, 0, 0 , 0, 0);
//(A11-A21)(B11+B12)
SubM( at , at , t , l, ax11 , ay11 , ax21 , ay21 , 0, 0);
AddM( bt , bt , t2 , l, bx11 , by11 , bx12 , by12 , 0, 0);
int **m7 = MulM(t, t2, l, 0, 0 , 0, 0);
//C11=M5+M4-M2+M6
AddM( m5 , m4 , temp , l, 0 , 0 , 0 , 0 , tx11,ty11);
SubM( temp , m2 , temp , l, tx11,ty11 , 0 , 0 , tx11,ty11);
AddM( temp , m6 , temp , l, tx11,ty11 , 0 , 0 , tx11,ty11);
del(l,m6);
//C12=M1+M2
AddM( m1 , m2 , temp , l, 0 , 0 , 0 , 0 , tx12,ty12);
del(l,m2);
//C21=M3+M4
AddM( m3 , m4 , temp , l, 0 , 0 , 0 , 0 , tx21,ty21);
del(l,m4);
//C22=M5+M1-M3-M7
AddM( m5 , m1 , temp , l, 0 , 0 , 0 , 0 , tx22,ty22);
SubM( temp , m3 , temp , l, tx22,ty22, 0 , 0 , tx22,ty22);
SubM( temp , m7 , temp , l, tx22,ty22, 0 , 0 , tx22,ty22);
del(l,m1);del(l,m3);del(l,m5);del(l,m7);
// 返回的话,都要把最后一行最后一列应该去掉
// 但是length的话,本来就是不包括的
del(length,at);del(length,bt);
return temp;
}
}

int main()
{
// 读入
scanf("%d", &n);
int **A;
int **B;
// 多开一个
A = (int **)malloc(sizeof(int *) * (n + 1));
B = (int **)malloc(sizeof(int *) * (n + 1));
for (int i = 0; i < n + 1; i++)
{
A[i] = (int *)malloc(sizeof(int) * (n + 1));
B[i] = (int *)malloc(sizeof(int) * (n + 1));
}
for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
scanf("%d", &A[i][j]);
}
for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
scanf("%d", &B[i][j]);
}
C = MulM(A, B, n, 0, 0, 0, 0);
for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
printf("%d ", C[i][j]);
}
}

解法三:斯特拉森 (Strassen) 矩阵乘法

1969 年,Volker Strassen 提出了第一个算法时间复杂度低于 矩阵乘法算法,算法复杂度为 。性能上才有很大的优势,可以减少很多乘法计算。

设 A 和 B 是两个 n 的矩阵,其中 n 可以写成 。将 A 和 B 分别等分成 4 个小矩阵,此时如果把 A 和 B 都当成 矩阵来看,每个元素就是一个 矩阵,而矩阵 A 和 B 的乘积就可以写成

其中利用斯特拉森方法得到 7 个小矩阵,分别定义为:

矩阵 可以通过 7 次矩阵乘法,6 次矩阵加法和 4 次矩阵减法计算得出,前述 4 个小矩阵 可以由矩阵 通过 6 次矩阵加法和 2 次矩阵减法得出,方法如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
int **t1 = init(n / 2);
sub(n / 2, B12, B22, t1);
int **M1 = init(n / 2);
strassen(n / 2, A11, t1, M1);
del(n / 2, t1);

int **t2 = init(n / 2);
add(n / 2, A11, A12, t2);
int **M2 = init(n / 2);
strassen(n / 2, t2, B22, M2);
del(n / 2, t2);

int **t3 = init(n / 2);
add(n / 2, A21, A22, t3);
int **M3 = init(n / 2);
strassen(n / 2, t3, B11, M3);
del(n / 2, t3);

int **t4 = init(n / 2);
sub(n / 2, B21, B11, t4);
int **M4 = init(n / 2);
strassen(n / 2, A22, t4, M4);
del(n / 2, t4);

int **t5 = init(n / 2);
int **t6 = init(n / 2);
add(n / 2, A11, A22, t5);
add(n / 2, B11, B22, t6);
int **M5 = init(n / 2);
strassen(n / 2, t5, t6, M5);
del(n / 2, t5);
del(n / 2, t6);

int **t7 = init(n / 2);
int **t8 = init(n / 2);
sub(n / 2, A12, A22, t7);
add(n / 2, B21, B22, t8);
int **M6 = init(n / 2);
strassen(n / 2, t7, t8, M6);
del(n / 2, t7);
del(n / 2, t8);

int **t9 = init(n / 2);
int **t10 = init(n / 2);
sub(n / 2, A11, A21, t9);
add(n / 2, B11, B12, t10);
int **M7 = init(n / 2);
strassen(n / 2, t9, t10, M7);

那么我们再来考虑奇数的处理方式 实际上,还有第二种思路进行队奇数进行补全

因为问题中每一次处理都是将原本的问题规模缩减一半,即变为原来的

所以如果只想进行一次补全,之后解决的子问题长度全是偶数,我们就要保证最开始的起始问题规模为

如 {2, 4, 8, 16, 32, 64, 128}

由于该问题中数值最大为 100,所以问题研究的大小最大到 128 即可

算数据范围:

1
2
3
4
5
6
7
8
9
10
11
12
int next_power_of_two(unsigned int n) {
// 如果 n 已经是 2 的幂次方,则直接返回 n
if ((n & (n - 1)) == 0) {
return n;
}

// 将 n 转换为二进制后,将最高位后面的所有位都设为 1
while ((n & (n - 1)) != 0) {
n = n & (n - 1);
}
return n << 1;
}

补 0:

1
2
3
4
5
6
7
8
9
10
11
for (int i = 1; i < N; i++)
{
for (int j = 1; j < N; j++)
{
if(i >= n || j >= n)
{
A[i][j] = 0;
B[i][j] = 0;
}
}
}

最终完整的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
#include <iostream>
#include <cstring>
#include <cstdlib>

int **init(int n)
{
int **a = (int **)malloc(sizeof(int *) * n);
for (int i = 0; i < n; ++i)
{
a[i] = (int *)malloc(sizeof(int) * n);
}
return a;
}

void del(int n, int **a)
{
for (int i = 0; i < n; ++i)
{
free(a[i]);
}
free(a);
}

void add(int n, int **A, int **B, int **C)
{
for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
{
C[i][j] = A[i][j] + B[i][j];
}
}
}

void sub(int n, int **A, int **B, int **C)
{
for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
{
C[i][j] = A[i][j] - B[i][j];
}
}
}

void strassen(int n, int **A, int **B, int **C)
{

int **A11 = init(n / 2);
int **A12 = init(n / 2);
int **A21 = init(n / 2);
int **A22 = init(n / 2);
int **B11 = init(n / 2);
int **B12 = init(n / 2);
int **B21 = init(n / 2);
int **B22 = init(n / 2);

if (n == 2)
{
for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
{
C[i][j] = 0;
for (int k = 0; k < n; k++)
{
C[i][j] += A[i][k] * B[k][j];
}
}
}
return;
}
for (int i = 0; i < n / 2; i++)
{
for (int j = 0; j < n / 2; j++)
{
A11[i][j] = A[i][j];
A12[i][j] = A[i][j + n / 2];
A21[i][j] = A[i + n / 2][j];
A22[i][j] = A[i + n / 2][j + n / 2];

B11[i][j] = B[i][j];
B12[i][j] = B[i][j + n / 2];
B21[i][j] = B[i + n / 2][j];
B22[i][j] = B[i + n / 2][j + n / 2];
}
}
int **t1 = init(n / 2);
sub(n / 2, B12, B22, t1);
int **M1 = init(n / 2);
strassen(n / 2, A11, t1, M1);
del(n / 2, t1);

int **t2 = init(n / 2);
add(n / 2, A11, A12, t2);
int **M2 = init(n / 2);
strassen(n / 2, t2, B22, M2);
del(n / 2, t2);

int **t3 = init(n / 2);
add(n / 2, A21, A22, t3);
int **M3 = init(n / 2);
strassen(n / 2, t3, B11, M3);
del(n / 2, t3);

int **t4 = init(n / 2);
sub(n / 2, B21, B11, t4);
int **M4 = init(n / 2);
strassen(n / 2, A22, t4, M4);
del(n / 2, t4);

int **t5 = init(n / 2);
int **t6 = init(n / 2);
add(n / 2, A11, A22, t5);
add(n / 2, B11, B22, t6);
int **M5 = init(n / 2);
strassen(n / 2, t5, t6, M5);
del(n / 2, t5);
del(n / 2, t6);

int **t7 = init(n / 2);
int **t8 = init(n / 2);
sub(n / 2, A12, A22, t7);
add(n / 2, B21, B22, t8);
int **M6 = init(n / 2);
strassen(n / 2, t7, t8, M6);
del(n / 2, t7);
del(n / 2, t8);

int **t9 = init(n / 2);
int **t10 = init(n / 2);
sub(n / 2, A11, A21, t9);
add(n / 2, B11, B12, t10);
int **M7 = init(n / 2);
strassen(n / 2, t9, t10, M7);
del(n / 2, t9);
del(n / 2, t10);

del(n / 2, A11);
del(n / 2, A12);
del(n / 2, A21);
del(n / 2, A22);
del(n / 2, B11);
del(n / 2, B12);
del(n / 2, B21);
del(n / 2, B22);

int **t11 = init(n / 2);
int **t12 = init(n / 2);
add(n / 2, M5, M4, t11);
sub(n / 2, t11, M2, t12);

int **C11 = init(n / 2);
int **C12 = init(n / 2);
int **C21 = init(n / 2);
int **C22 = init(n / 2);
add(n / 2, t12, M6, C11);
del(n / 2, t11);
del(n / 2, t12);

add(n / 2, M1, M2, C12);

add(n / 2, M3, M4, C21);

int **t13 = init(n / 2);
int **t14 = init(n / 2);
add(n / 2, M5, M1, t13);
sub(n / 2, t13, M3, t14);
sub(n / 2, t14, M7, C22);
del(n / 2, t13);
del(n / 2, t14);

del(n / 2, M1);
del(n / 2, M2);
del(n / 2, M3);
del(n / 2, M4);
del(n / 2, M5);
del(n / 2, M6);
del(n / 2, M7);

for (int i = 0; i < n / 2; i++)
{
for (int j = 0; j < n / 2; j++)
{
C[i][j] = C11[i][j];
C[i + n / 2][j] = C21[i][j];
C[i][j + n / 2] = C12[i][j];
C[i + n / 2][j + n / 2] = C22[i][j];
}
}
del(n / 2, C11);
del(n / 2, C12);
del(n / 2, C21);
del(n / 2, C22);
}

int next_power_of_two(int n) {
if ((n & (n - 1)) == 0) {
return n;
}

while ((n & (n - 1)) != 0) {
n = n & (n - 1);
}
return n << 1;
}

signed main()
{
int n;
scanf("%d", &n);
int N;
if (n >= 0 && n <= 2)
{
N = 2;
}
else
{
N = next_power_of_two(n);
}
int **A = init(N);
int **B = init(N);
int **C = init(N);

for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
{
scanf("%d", &A[i][j]);
}
}
for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
{
scanf("%d", &B[i][j]);
}
}
for (int i = 1; i < N; i++)
{
for (int j = 1; j < N; j++)
{
if(i >= n || j >= n)
{
A[i][j] = 0;
B[i][j] = 0;
}
}
}
strassen(N, A, B, C);
for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
{
printf("%d ", C[i][j]);
}
}
}

可惜的是,虽然在算法的时间复杂度上,strassen 是有优势的,但是考虑到数据范围较小仅为 100,时间复杂度的优势体现不出来,反而是 malloc 函数的常数过大,导致总体的时间更高。

最为讽刺的是,速度最快的竟然是直接暴力相乘。

如果在实际的生产中 可以设置一个数据范围,比如 300 当大于要相乘的矩阵数据范围大于 300 时,再采用 strassen 算法进行计算

具体参考 alibaba 的开源算法库中 MNN 中关于 Strassen 算法实现https://github.com/alibaba/MNN

Reference:

[1] https://zhuanlan.zhihu.com/p/268392799

[2] https://zhuanlan.zhihu.com/p/78657463