Skip to content

Commit abaa115

Browse files
Merge pull request #442 from JingyuanZhang/master
feat(core): support modelObj.params type ParamObject & publish @paddlejs/paddlejs-core@2.1.28
2 parents f68d889 + f2d3bc7 commit abaa115

File tree

3 files changed

+10
-5
lines changed

3 files changed

+10
-5
lines changed

packages/paddlejs-core/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@paddlejs/paddlejs-core",
3-
"version": "2.1.19",
3+
"version": "2.1.28",
44
"description": "",
55
"main": "lib/index",
66
"scripts": {

packages/paddlejs-core/src/commons/interface.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,13 @@ export interface FeedShape {
6363
fh: number;
6464
};
6565

66+
67+
export interface ParamObject {
68+
[key: string]: number;
69+
}
6670
interface ModelObj {
6771
model: Model;
68-
params: Float32Array
72+
params: Float32Array | ParamObject
6973
}
7074
export interface RunnerConfig {
7175
modelPath?: string;

packages/paddlejs-core/src/loader.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
*/
44

55
import env from './env';
6-
import { Model } from './commons/interface';
6+
import { Model, ParamObject } from './commons/interface';
77
import { traverseVars } from './commons/utils';
88

99
interface UrlConf {
@@ -126,14 +126,15 @@ export default class ModelLoader {
126126
});
127127
}
128128

129-
static allocateParamsVar(vars, allChunksData: Float32Array) {
129+
static allocateParamsVar(vars, allChunksData: Float32Array | ParamObject) {
130130
let marker = 0; // 读到哪个位置了
131131
let len; // 当前op长度
132+
const chunkData: number[] = Array.isArray(allChunksData) ? allChunksData : Object.values(allChunksData);
132133
traverseVars(vars, item => {
133134
len = item.shape.reduce((a, b) => a * b); // 长度为shape的乘积
134135
// 为了减少模型体积,模型转换工具不会导出非persistable的数据,这里只需要读取persistable的数据
135136
if (item.persistable) {
136-
item.data = allChunksData.slice(marker, marker + len);
137+
item.data = chunkData.slice(marker, marker + len);
137138
marker += len;
138139
}
139140
});

0 commit comments

Comments
 (0)