Skip to content

Commit 9988181

Browse files
Merge pull request #219 from JingyuanZhang/master
feat(core&webgl): support origin shape imgfeed op and support LINEAR …
2 parents e032be2 + 1228591 commit 9988181

File tree

12 files changed

+132
-60
lines changed

12 files changed

+132
-60
lines changed

packages/paddlejs-backend-webgl/src/ops/index.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@ import elementwise_pow from './shader/elementwise_pow';
4848
import elementwise_sub from './shader/elementwise_sub';
4949
import cast from './shader/cast';
5050

51-
import { pack_out, nhwc_2_nchw, unpacked_2_packed, packed_2_unpacked, feedPost } from './shader/custom';
51+
import {
52+
imgFeed, pack_out, nhwc_2_nchw, unpacked_2_packed,
53+
packed_2_unpacked, feedPost
54+
} from './shader/custom';
5255

5356

5457
const ops = {
@@ -109,7 +112,8 @@ const ops = {
109112
shuffle_channel,
110113
pack_out,
111114
nhwc_2_nchw,
112-
feedPost
115+
feedPost,
116+
imgFeed
113117
};
114118
export {
115119
ops

packages/paddlejs-backend-webgl/src/ops/shader/custom/feedPost.ts

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,33 @@ function mainFunc(
66
{},
77
{ mean = [0, 0, 0], std = [1, 1, 1] }
88
) {
9-
109
return `
1110
// start函数
1211
void main(void) {
1312
ivec4 oPos = getOutputTensorPos();
14-
float o = getValueFromTensorPos_origin(oPos.r, oPos.g, oPos.b, oPos.a);
15-
int c = oPos.g;
16-
o = float(o / 255.0);
13+
float res = 0.0;
14+
int c1 = int(mod(float(oPos[1]), 4.0));
15+
int c = oPos[1];
16+
vec4 o = getValueFromTensorPosPacking_origin(oPos[0], c / 4, oPos[2], oPos[3]) / 255.0;
17+
18+
if (c1 == 0) {
19+
res = o.r;
20+
} else if (c1 == 1) {
21+
res = o.g;
22+
} else if (c1 == 2) {
23+
res = o.b;
24+
} else if (c1 == 3) {
25+
res = o.a;
26+
}
27+
1728
if (c == 0) {
18-
o = (o - float(${mean[0]})) / float(${std[0]});
29+
res = (res - float(${mean[0]})) / float(${std[0]});
1930
} else if (c == 1) {
20-
o = (o - float(${mean[1]})) / float(${std[1]});
31+
res = (res - float(${mean[1]})) / float(${std[1]});
2132
} else if (c == 2) {
22-
o = (o - float(${mean[2]})) / float(${std[2]});
33+
res = (res - float(${mean[2]})) / float(${std[2]});
2334
}
24-
setOutput(float(o));
35+
setOutput(float(res));
2536
}
2637
`;
2738
}
@@ -32,6 +43,6 @@ export default {
3243
'std'
3344
],
3445
textureFuncConf: {
35-
origin: ['getValueFromTensorPos']
46+
origin: ['getValueFromTensorPosPacking']
3647
}
3748
};

packages/paddlejs-backend-webgl/src/ops/shader/custom/imgFeed.ts

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,21 @@
22
* @file feed post process
33
*/
44

5-
function mainFunc(
6-
{},
7-
{}
8-
) {
5+
function mainFunc() {
96

107
return `
11-
// start函数
8+
129
void main(void) {
13-
ivec4 oPos = getOutputTensorPos();
14-
vec4 o = getValueFromTensorPosPacking_origin(oPos.r, oPos.g, oPos.b, oPos.a);
15-
setPackedOutput(o / 255.0);
10+
vec2 outCoord = vCoord.xy;
11+
vec4 counter = TEXTURE2D(texture_origin, vCoord.xy);
12+
setPackedOutput(counter);
1613
}
1714
`;
1815
}
1916
export default {
2017
mainFunc,
2118
params: [],
2219
textureFuncConf: {
23-
origin: ['getValueFromTensorPosPacking']
20+
origin: []
2421
}
2522
};

packages/paddlejs-backend-webgl/src/ops/shader/custom/index.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ import pack_out from './pack_out';
33
import unpacked_2_packed from './unpacked_2_packed';
44
import packed_2_unpacked from './packed_2_unpacked';
55
import feedPost from './feedPost';
6+
import imgFeed from './imgFeed';
67

78
export {
9+
imgFeed,
810
feedPost,
911
nhwc_2_nchw,
1012
pack_out,

packages/paddlejs-backend-webgl/src/webgl/WebGLTexture.ts

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,16 @@ export class GLTexture {
204204
}
205205

206206
public static genOutputTexture(gl, textureConf, outTensor, isFinalOp): WebGLTexture {
207+
const {
208+
interpType,
209+
width_texture,
210+
height_texture
211+
} = outTensor;
207212
// 生成output的texture缓存
208213
const texture = gl.createTexture();
209214
gl.bindTexture(gl.TEXTURE_2D, texture);
210-
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
211-
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
215+
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, interpType === 'LINEAR' ? gl.LINEAR : gl.NEAREST);
216+
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, interpType === 'LINEAR' ? gl.LINEAR : gl.NEAREST);
212217
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
213218
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
214219

@@ -231,12 +236,13 @@ export class GLTexture {
231236
: gl.UNSIGNED_BYTE
232237
: null;
233238

239+
234240
gl.texImage2D(
235241
gl.TEXTURE_2D, // Target, matches bind above.
236242
0, // Level of detail.
237243
internalFormat, // Internal format.
238-
outTensor.width_texture,
239-
outTensor.height_texture,
244+
width_texture,
245+
height_texture,
240246
0, // Always 0 in OpenGL ES.
241247
gl.RGBA, // Format for each pixel.
242248
isFinalOp ? textureTypeForReadPixel : textureType, // Data type for each chanel.

packages/paddlejs-core/src/commons/interface.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ export interface ModelOp {
66
'sub-attrs'?: OpAttrs[];
77
inputs: OpInputs;
88
outputs: OpOutputs;
9+
isPacked?: boolean;
910
}
1011

1112
export interface ModelVar {
@@ -14,6 +15,7 @@ export interface ModelVar {
1415
data?: number[] | Float32Array;
1516
persistable?: boolean;
1617
tensorName?: string;
18+
interpType?: string;
1719
total?: number;
1820
runtime?: number;
1921
}

packages/paddlejs-core/src/mediaProcessor.ts

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ export default class MediaProcessor {
1313
mean: number[] = [0, 0, 0];
1414
std: number[] = [1, 1, 1];
1515
bgr: boolean = false;
16-
result: Float32Array | number[] = [];
1716
pixelWidth: number = 1;
1817
pixelHeight: number = 1;
1918
inputFeed: InputFeed[] = [];
@@ -45,17 +44,13 @@ export default class MediaProcessor {
4544
targetShape: [1, fc, fh, fw]
4645
};
4746

48-
if (this.result.length === 0) {
49-
const [, c, h, w] = params.targetShape;
50-
// 计算确定targetShape所需Float32Array占用空间
51-
this.result = new Float32Array(h * w * c);
52-
}
5347
return this.fromPixels(input, params) || [];
5448
}
5549

5650
fromPixels(pixels, opt): InputFeed[] {
5751
let data: ImageData | number[] | Float32Array = [];
5852
const imageDataInfo = {
53+
gapFillWith: opt.gapFillWith,
5954
dx: 0,
6055
dy: 0,
6156
dWidth: opt.targetSize.width,
@@ -76,21 +71,22 @@ export default class MediaProcessor {
7671
this.pixelWidth = pixels.width;
7772
this.pixelHeight = pixels.height;
7873

79-
this.fitToTargetSize(pixels, opt);
80-
data = this.getImageData(imageDataInfo);
8174

75+
this.fitToTargetSize(pixels, imageDataInfo, env.get('webgl_feed_process'));
76+
data = this.getImageData(imageDataInfo);
8277
// process imageData in webgl
8378
if (env.get('webgl_feed_process')) {
8479
data = Float32Array.from((data as ImageData).data);
8580
return [{
8681
data,
87-
shape: [1, 4, opt.targetShape[2], opt.targetShape[3]],
82+
shape: [1, 1, imageDataInfo.dHeight, imageDataInfo.dWidth],
8883
name: 'image',
8984
persistable: true
9085
}] as InputFeed[];
9186
}
9287

93-
data = this.allReshapeToRGB(data, opt) as number[];
88+
89+
data = this.allReshapeToRGB(data, opt) as Float32Array;
9490
return [{
9591
data,
9692
shape: opt.targetShape || opt.shape,
@@ -113,7 +109,7 @@ export default class MediaProcessor {
113109
const { mean, std, targetShape, bgr } = opt;
114110
const [, c, h, w] = targetShape;
115111
const data = imageData.data || imageData;
116-
const result = this.result;
112+
const result = new Float32Array(h * w * c);
117113
let offset = 0;
118114
// 将数据映射为0~1, 1:映射为-1~1之间
119115
const normalizeType = 0;
@@ -138,30 +134,47 @@ export default class MediaProcessor {
138134
/**
139135
* 缩放成目标尺寸并居中
140136
*/
141-
fitToTargetSize(image, params) {
137+
fitToTargetSize(image, imageDataInfo, inGPU = false) {
142138
// 目标尺寸
143-
const targetWidth = params.targetSize.width;
144-
const targetHeight = params.targetSize.height;
145-
this.targetContext.canvas.width = targetWidth;
146-
this.targetContext.canvas.height = targetHeight;
147-
this.targetContext.fillStyle = params.gapFillWith;
148-
this.targetContext.fillRect(0, 0, targetHeight, targetWidth);
139+
const targetWidth = imageDataInfo.dWidth;
140+
const targetHeight = imageDataInfo.dHeight;
141+
142+
let canvasWidth = inGPU ? this.pixelWidth : targetWidth;
143+
let canvasHeight = inGPU ? this.pixelHeight : targetHeight;
149144
// 缩放后的宽高
150-
let sw = targetWidth;
151-
let sh = targetHeight;
145+
let sw = inGPU ? this.pixelWidth : targetWidth;
146+
let sh = inGPU ? this.pixelHeight : targetHeight;
152147
let x = 0;
153148
let y = 0;
154149
// target的长宽比大些 就把原图的高变成target那么高
155150
if (targetWidth / targetHeight * this.pixelHeight / this.pixelWidth >= 1) {
156-
sw = Math.round(sh * this.pixelWidth / this.pixelHeight);
157-
x = Math.floor((targetWidth - sw) / 2);
151+
if (inGPU) {
152+
canvasWidth = Math.round(sh * targetWidth / targetHeight);
153+
x = Math.floor((canvasWidth - sw) / 2);
154+
}
155+
else {
156+
sw = Math.round(sh * this.pixelWidth / this.pixelHeight);
157+
x = Math.floor((targetWidth - sw) / 2);
158+
}
158159
}
159160
// target的长宽比小些 就把原图的宽变成target那么宽
160161
else {
161-
sh = Math.round(sw * this.pixelHeight / this.pixelWidth);
162-
y = Math.floor((targetHeight - sh) / 2);
162+
if (inGPU) {
163+
canvasHeight = Math.round(sw * targetHeight / targetWidth);
164+
y = Math.floor((canvasHeight - sh) / 2);
165+
}
166+
else {
167+
sh = Math.round(sw * this.pixelHeight / this.pixelWidth);
168+
y = Math.floor((targetHeight - sh) / 2);
169+
}
163170
}
164171

172+
imageDataInfo.dWidth = canvasWidth;
173+
imageDataInfo.dHeight = canvasHeight;
174+
this.targetContext.canvas.width = canvasWidth;
175+
this.targetContext.canvas.height = canvasHeight;
176+
this.targetContext.fillStyle = imageDataInfo.gapFillWith;
177+
this.targetContext.fillRect(0, 0, canvasHeight, canvasWidth);
165178
this.targetContext.drawImage(image, x, y, sw, sh);
166179
}
167180

packages/paddlejs-core/src/opFactory/opDataBuilder.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ export default class OpData {
152152
outTensor,
153153
inputTensors,
154154
shaderParams: this.fShaderParams[index],
155-
runtime: index
155+
runtime: index,
156+
isFinalOp: this.isFinalOp
156157
}));
157158
}
158159

@@ -200,6 +201,7 @@ export default class OpData {
200201
shape: data.shape,
201202
data: data.data || null,
202203
persistable: data.persistable || false,
204+
interpType: data.interpType || 'NEAREST',
203205
isPacked: this.isPackedOp || false,
204206
binding: index,
205207
noLayout: GLOBALS.backendInstance?.noLayout

packages/paddlejs-core/src/opFactory/opExecutor.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@ export default class OpExecutor {
1717
isPacked: boolean = false;
1818
finish: boolean = false;
1919

20-
constructor(op: ModelOp, idx: number, isPacked: boolean | undefined = false) {
20+
constructor(op: ModelOp, idx: number) {
2121
const {
2222
inputs,
2323
outputs,
2424
attrs = {},
25-
type
25+
type,
26+
isPacked = false
2627
} = op;
2728

2829
this.id = `${type}_${+new Date()}_${idx}`;
@@ -31,11 +32,10 @@ export default class OpExecutor {
3132
this.attrs = attrs;
3233
this.subAttrs = op['sub-attrs'] || [];
3334
this.type = type;
34-
this.isPacked = isPacked || false;
35+
this.isPacked = isPacked;
3536
this.finish = false;
3637
this.next = '';
3738
this.opData = null;
38-
this.isPacked = false;
3939
}
4040

4141
get inputsName() {

packages/paddlejs-core/src/opFactory/tensor.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ interface TensorParams {
1111
shape: number[];
1212
data: Float32Array | number[] | null;
1313
persistable: boolean;
14+
interpType?: string;
1415
isPacked?: boolean;
1516
binding?: number;
1617
noLayout?: boolean;
@@ -28,6 +29,7 @@ export default class Tensor {
2829
exceedMax: boolean = false;
2930
data: Float32Array | number[] | null = null;
3031
persistable: boolean = false;
32+
interpType: string = 'NEAREST';
3133

3234
constructor(opts: TensorParams) {
3335
this.opts = opts;
@@ -36,6 +38,7 @@ export default class Tensor {
3638
// 设置tensor名字
3739
this.name = opts.name;
3840
this.persistable = opts.persistable || false;
41+
this.interpType = opts.interpType || 'NEAREST';
3942
// 设置 tensorId
4043
this.tensorId = opts.type;
4144
// 保留 model 原生 shape 长度

0 commit comments

Comments
 (0)