Skip to content

Commit e33999d

Browse files
Merge pull request #449 from JingyuanZhang/master
feat(webgl): support conv2d fused hard_swish
2 parents abaa115 + ee6d9e6 commit e33999d

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

packages/paddlejs-backend-webgl/package-lock.json

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packages/paddlejs-backend-webgl/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@paddlejs/paddlejs-backend-webgl",
3-
"version": "1.2.7",
3+
"version": "1.2.8",
44
"description": "",
55
"main": "lib/index",
66
"scripts": {

packages/paddlejs-backend-webgl/src/ops/shader/conv2d.ts

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@ function mainFunc(
1313
fuse_relu,
1414
filter_nearest_vec4,
1515
filter_remainder_vec4,
16-
act_type,
17-
padding_algorithm = ''
16+
act_type = '',
17+
padding_algorithm = '',
18+
hard_swish_offset = 3.0,
19+
hard_swish_scale = 6.0,
20+
hard_swish_threshold = 6.0
1821
}
1922
) {
2023
const [stride_v = 1, stride_h = 1] = strides;
@@ -119,6 +122,12 @@ function mainFunc(
119122
else if (${act_type === 'relu6'}) {
120123
res = min(max(0.0, res), 6.0);
121124
}
125+
else if (${act_type === 'hard_swish'}) {
126+
res = res * min(
127+
max(0.0, res + float(${hard_swish_offset})),
128+
float(${hard_swish_threshold})
129+
) / float(${hard_swish_scale});
130+
}
122131
123132
setOutput(res);
124133
}

0 commit comments

Comments
 (0)