Skip to content

Commit 34508ed

Browse files
authored
Merge pull request #460 from JingyuanZhang/master
feat(webgl): add new op pool2d_avg_adaptive
2 parents 7d613cb + e22f789 commit 34508ed

File tree

4 files changed

+59
-0
lines changed

4 files changed

+59
-0
lines changed

packages/paddlejs-backend-webgl/src/ops/index.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import conv2d_depthwise from './shader/conv2d_depthwise';
1010
import depthwise_conv2d from './shader/depthwise_conv2d';
1111
import conv2d_elementwise_add from './shader/conv2d_elementwise_add';
1212
import pool2d from './shader/pool2d';
13+
import pool2d_avg_adaptive from './shader/pool2d_avg_adaptive';
1314
import pool2d_max from './shader/pool2d_max';
1415
import pool2d_winograd from './shader/pool2d_winograd';
1516
import elementwise_add from './shader/elementwise_add';
@@ -117,6 +118,7 @@ const ops = {
117118
rnn_cell,
118119
rnn_origin,
119120
pool2d_avg,
121+
pool2d_avg_adaptive,
120122
prelu: dynamic('prelu'),
121123
relu6: dynamic('relu6'),
122124
leakyRelu: dynamic('leakyRelu'),

packages/paddlejs-backend-webgl/src/ops/shader/pool2d.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ export default {
6464
origin: ['getValueFromTensorPos']
6565
},
6666
behaviors: [
67+
'isAdaptiveAvg',
6768
'isMax',
6869
'setPacked',
6970
'setAdaptive',
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/**
2+
* @file pool2d_avg_adaptive
3+
* @desc https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/AdaptiveAvgPool2D_cn.html#adaptiveavgpool2d
4+
*/
5+
6+
function mainFunc(
7+
{ origin },
8+
{ strides = [], paddings = [], ksize }
9+
) {
10+
const [stride_v = 1, stride_h = 1] = strides;
11+
const [padTop = 0, padLeft = 0] = paddings;
12+
const [ksize_y, ksize_x] = ksize;
13+
const H = origin.height_shape;
14+
const W = origin.width_shape;
15+
return `
16+
// start函数
17+
void main(void) {
18+
float res = 0.0;
19+
// 获取output的坐标
20+
ivec4 out_pos = getOutputTensorPos();
21+
int i = out_pos[2] * ${stride_v} - ${padTop};
22+
int j = out_pos[3] * ${stride_h} - ${padLeft};
23+
int hstart = int(floor(float(i) * float(${H}) / float(${ksize_y})));
24+
int hend = int(ceil(float(i + 1) * float(${H}) / float(${ksize_y})));
25+
int wstart = int(floor(float(j) * float(${W}) / float(${ksize_x})));
26+
int wend = int(ceil(float(j + 1) * float(${W}) / float(${ksize_x})));
27+
for (int fy = hstart; fy < hend; fy++) {
28+
for (int fx = wstart; fx < wend; fx++) {
29+
float curr = getValueFromTensorPos_origin(out_pos[0], out_pos[1], fy, fx);
30+
res += curr;
31+
}
32+
}
33+
int count_pool = (hend - hstart) * (wend - wstart);
34+
res = res / float(count_pool);
35+
setOutput(res);
36+
}
37+
`;
38+
}
39+
export default {
40+
mainFunc,
41+
textureFuncConf: {
42+
origin: ['getValueFromTensorPos']
43+
},
44+
behaviors: []
45+
};

packages/paddlejs-core/src/opFactory/opBehaviors.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,17 @@ const behaviors : Behaviors = {
9292
}
9393
},
9494

95+
isAdaptiveAvg() {
96+
const {
97+
adaptive,
98+
pooling_type
99+
} = this.processedAttrs;
100+
101+
if (adaptive && pooling_type === 'avg') {
102+
this.name += '_avg_adaptive';
103+
}
104+
},
105+
95106
isMax() {
96107
const type = this.processedAttrs['pooling_type'] === 'max' ? 1 : 0;
97108
this.processedAttrs['pooling_type'] = type;

0 commit comments

Comments
 (0)