Skip to content

Commit a9c5098

Browse files
authored
Merge pull request #1361 from lindapaiste/fix/xhr-to-axios
Rewrite XHR code using Axios
2 parents 3c1f132 + 4e6a427 commit a9c5098

File tree

3 files changed

+152
-136
lines changed

3 files changed

+152
-136
lines changed

src/CVAE/index.js

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
*/
1111

1212
import * as tf from '@tensorflow/tfjs';
13+
import axios from "axios";
1314
import callCallback from '../utils/callcallback';
1415
import p5Utils from '../utils/p5Utils';
1516

@@ -28,13 +29,11 @@ class Cvae {
2829
this.ready = false;
2930
this.model = {};
3031
this.latentDim = tf.randomUniform([1, 16]);
31-
this.modelPath = modelPath;
32-
this.modelPathPrefix = '';
3332

34-
this.jsonLoader().then(val => {
35-
this.modelPathPrefix = this.modelPath.split('manifest.json')[0]
36-
this.ready = callCallback(this.loadCVAEModel(this.modelPathPrefix+val.model), callback);
37-
this.labels = val.labels;
33+
const [modelPathPrefix] = modelPath.split('manifest.json');
34+
axios.get(modelPath).then(({ data }) => {
35+
this.ready = callCallback(this.loadCVAEModel(modelPathPrefix + data.model), callback);
36+
this.labels = data.labels;
3837
// get an array full of zero with the length of labels [0, 0, 0 ...]
3938
this.labelVector = Array(this.labels.length+1).fill(0);
4039
});
@@ -114,21 +113,6 @@ class Cvae {
114113
return { src, raws, image };
115114
}
116115

117-
async jsonLoader() {
118-
return new Promise((resolve, reject) => {
119-
const xhr = new XMLHttpRequest();
120-
xhr.open('GET', this.modelPath);
121-
122-
xhr.onload = () => {
123-
const json = JSON.parse(xhr.responseText);
124-
resolve(json);
125-
};
126-
xhr.onerror = (error) => {
127-
reject(error);
128-
};
129-
xhr.send();
130-
});
131-
}
132116
}
133117

134118
const CVAE = (model, callback) => new Cvae(model, callback);

src/utils/checkpointLoader.js

Lines changed: 94 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -4,86 +4,119 @@
44
// https://opensource.org/licenses/MIT
55

66
import * as tf from '@tensorflow/tfjs';
7+
import axios from 'axios';
78

89
const MANIFEST_FILE = 'manifest.json';
910

11+
/**
12+
* @typedef {Record<string, { filename: string, shape: Array<number> }>} Manifest
13+
*/
14+
/**
15+
* Loads all of the variables of a model from a directory
16+
* which contains a `manifest.json` file and individual variable data files.
17+
* The `manifest.json` contains the `filename` and `shape` for each data file.
18+
*
19+
* @class
20+
* @property {string} urlPath
21+
* @property {Manifest} [checkpointManifest]
22+
* @property {Record<string, tf.Tensor>} variables
23+
*/
1024
export default class CheckpointLoader {
25+
/**
26+
* @param {string} urlPath - the directory URL
27+
*/
1128
constructor(urlPath) {
12-
this.urlPath = urlPath;
13-
if (this.urlPath.charAt(this.urlPath.length - 1) !== '/') {
14-
this.urlPath += '/';
15-
}
29+
this.urlPath = urlPath.endsWith('/') ? urlPath : `${urlPath}/`;
30+
this.variables = {};
1631
}
1732

33+
/**
34+
* @private
35+
* Executes the request to load the manifest.json file.
36+
*
37+
* @return {Promise<Manifest>}
38+
*/
1839
async loadManifest() {
19-
return new Promise((resolve, reject) => {
20-
const xhr = new XMLHttpRequest();
21-
xhr.open('GET', this.urlPath + MANIFEST_FILE);
22-
23-
xhr.onload = () => {
24-
this.checkpointManifest = JSON.parse(xhr.responseText);
25-
resolve();
26-
};
27-
xhr.onerror = (error) => {
28-
reject();
29-
throw new Error(`${MANIFEST_FILE} not found at ${this.urlPath}. ${error}`);
30-
};
31-
xhr.send();
32-
});
40+
try {
41+
const response = await axios.get(this.urlPath + MANIFEST_FILE);
42+
return response.data;
43+
} catch (error) {
44+
throw new Error(`${MANIFEST_FILE} not found at ${this.urlPath}. ${error}`);
45+
}
3346
}
3447

48+
/**
49+
* @private
50+
* Executes the request to load the file for a variable.
51+
*
52+
* @param {string} varName
53+
* @return {Promise<tf.Tensor>}
54+
*/
55+
async loadVariable(varName) {
56+
const manifest = await this.getCheckpointManifest();
57+
if (!(varName in manifest)) {
58+
throw new Error(`Cannot load non-existent variable ${varName}`);
59+
}
60+
const { filename, shape } = manifest[varName];
61+
const url = this.urlPath + filename;
62+
try {
63+
const response = await axios.get(url, { responseType: 'arraybuffer' });
64+
const values = new Float32Array(response.data);
65+
return tf.tensor(values, shape);
66+
} catch (error) {
67+
throw new Error(`Error loading variable ${varName} from URL ${url}: ${error}`);
68+
}
69+
}
3570

71+
/**
72+
* @public
73+
* Lazy-load the contents of the manifest.json file.
74+
*
75+
* @return {Promise<Manifest>}
76+
*/
3677
async getCheckpointManifest() {
37-
if (this.checkpointManifest == null) {
38-
await this.loadManifest();
78+
if (!this.checkpointManifest) {
79+
this.checkpointManifest = await this.loadManifest();
3980
}
4081
return this.checkpointManifest;
4182
}
4283

84+
/**
85+
* @public
86+
* Get the property names for each variable in the manifest.
87+
*
88+
* @return {Promise<string[]>}
89+
*/
90+
async getKeys() {
91+
const manifest = await this.getCheckpointManifest();
92+
return Object.keys(manifest);
93+
}
94+
95+
/**
96+
* @public
97+
* Get a dictionary with the tensors for all variables in the manifest.
98+
*
99+
* @return {Promise<Record<string, tf.Tensor>>}
100+
*/
43101
async getAllVariables() {
44-
if (this.variables != null) {
45-
return Promise.resolve(this.variables);
46-
}
47-
await this.getCheckpointManifest();
48-
const variableNames = Object.keys(this.checkpointManifest);
102+
// Ensure that all keys are loaded and then return the dictionary.
103+
const variableNames = await this.getKeys();
49104
const variablePromises = variableNames.map(v => this.getVariable(v));
50-
return Promise.all(variablePromises).then((variables) => {
51-
this.variables = {};
52-
for (let i = 0; i < variables.length; i += 1) {
53-
this.variables[variableNames[i]] = variables[i];
54-
}
55-
return this.variables;
56-
});
105+
await Promise.all(variablePromises);
106+
return this.variables;
57107
}
58-
getVariable(varName) {
59-
if (!(varName in this.checkpointManifest)) {
60-
throw new Error(`Cannot load non-existent variable ${varName}`);
61-
}
62-
const variableRequestPromiseMethod = (resolve) => {
63-
const xhr = new XMLHttpRequest();
64-
xhr.responseType = 'arraybuffer';
65-
const fname = this.checkpointManifest[varName].filename;
66-
xhr.open('GET', this.urlPath + fname);
67-
xhr.onload = () => {
68-
if (xhr.status === 404) {
69-
throw new Error(`Not found variable ${varName}`);
70-
}
71-
const values = new Float32Array(xhr.response);
72-
const tensor = tf.tensor(values, this.checkpointManifest[varName].shape);
73-
resolve(tensor);
74-
};
75-
xhr.onerror = (error) => {
76-
throw new Error(`Could not fetch variable ${varName}: ${error}`);
77-
};
78-
xhr.send();
79-
};
80-
if (this.checkpointManifest == null) {
81-
return new Promise((resolve) => {
82-
this.loadManifest().then(() => {
83-
new Promise(variableRequestPromiseMethod).then(resolve);
84-
});
85-
});
108+
109+
/**
110+
* @public
111+
* Access a single variable from its key. Will load only if not previously loaded.
112+
*
113+
* @param {string} varName
114+
* @return {Promise<tf.Tensor>}
115+
*/
116+
async getVariable(varName) {
117+
if (!this.variables[varName]) {
118+
this.variables[varName] = await this.loadVariable(varName);
86119
}
87-
return new Promise(variableRequestPromiseMethod);
120+
return this.variables[varName];
88121
}
89122
}
Lines changed: 53 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,67 @@
1-
/* eslint max-len: "off" */
2-
31
import * as tf from '@tensorflow/tfjs';
2+
import axios from 'axios';
43

4+
/**
5+
* Pix2Pix loads data from a '.pict' file.
6+
* File contains the properties (name and tensor shape) for each variable
7+
* and a huge array of numbers for all of the variables.
8+
* Numbers must be assigned to the correct variable.
9+
*/
510
export default class CheckpointLoaderPix2pix {
11+
/**
12+
* @param {string} urlPath
13+
*/
614
constructor(urlPath) {
15+
/**
16+
* @type {string}
17+
*/
718
this.urlPath = urlPath;
819
}
920

10-
getAllVariables() {
11-
return new Promise((resolve, reject) => {
12-
const weightsCache = {};
13-
if (this.urlPath in weightsCache) {
14-
resolve(weightsCache[this.urlPath]);
15-
return;
16-
}
17-
18-
const xhr = new XMLHttpRequest();
19-
xhr.open('GET', this.urlPath, true);
20-
xhr.responseType = 'arraybuffer';
21-
xhr.onload = () => {
22-
if (xhr.status !== 200) {
23-
reject(new Error('missing model'));
24-
return;
25-
}
26-
const buf = xhr.response;
27-
if (!buf) {
28-
reject(new Error('invalid arraybuffer'));
29-
return;
30-
}
21+
async getAllVariables() {
22+
// Load the file as an ArrayBuffer.
23+
const response = await axios.get(this.urlPath, { responseType: 'arraybuffer' })
24+
.catch(error => {
25+
throw new Error(`No model found. Failed with error ${error}`);
26+
});
27+
/** @type {ArrayBuffer} */
28+
const buf = response.data;
3129

32-
const parts = [];
33-
let offset = 0;
34-
while (offset < buf.byteLength) {
35-
const b = new Uint8Array(buf.slice(offset, offset + 4));
36-
offset += 4;
37-
const len = (b[0] << 24) + (b[1] << 16) + (b[2] << 8) + b[3]; // eslint-disable-line no-bitwise
38-
parts.push(buf.slice(offset, offset + len));
39-
offset += len;
40-
}
30+
// Break data into three parts: shapes, index, and encoded.
31+
/** @type {ArrayBuffer[]} */
32+
const parts = [];
33+
let offset = 0;
34+
while (offset < buf.byteLength) {
35+
const b = new Uint8Array(buf.slice(offset, offset + 4));
36+
offset += 4;
37+
const len = (b[0] << 24) + (b[1] << 16) + (b[2] << 8) + b[3]; // eslint-disable-line no-bitwise
38+
parts.push(buf.slice(offset, offset + len));
39+
offset += len;
40+
}
4141

42-
const shapes = JSON.parse((new TextDecoder('utf8')).decode(parts[0]));
43-
const index = new Float32Array(parts[1]);
44-
const encoded = new Uint8Array(parts[2]);
42+
/** @type {Array<{ name: string, shape: number[] }>} */
43+
const shapes = JSON.parse((new TextDecoder('utf8')).decode(parts[0]));
44+
const index = new Float32Array(parts[1]);
45+
const encoded = new Uint8Array(parts[2]);
4546

46-
// decode using index
47-
const arr = new Float32Array(encoded.length);
48-
for (let i = 0; i < arr.length; i += 1) {
49-
arr[i] = index[encoded[i]];
50-
}
47+
// Dictionary of variables by name.
48+
/** @type {Record<string, tf.Tensor>} */
49+
const weights = {};
5150

52-
const weights = {};
53-
offset = 0;
54-
for (let i = 0; i < shapes.length; i += 1) {
55-
const { shape } = shapes[i];
56-
const size = shape.reduce((total, num) => total * num);
57-
const values = arr.slice(offset, offset + size);
58-
const tfarr = tf.tensor1d(values, 'float32');
59-
weights[shapes[i].name] = tfarr.reshape(shape);
60-
offset += size;
61-
}
62-
weightsCache[this.urlPath] = weights;
63-
resolve(weights);
64-
};
65-
xhr.send(null);
51+
// Create a tensor for each shape.
52+
offset = 0;
53+
shapes.forEach(({ shape, name }) => {
54+
const size = shape.reduce((total, num) => total * num);
55+
// Get the raw data.
56+
const raw = encoded.slice(offset, offset + size);
57+
// Decode using index.
58+
const values = new Float32Array(raw.length);
59+
raw.forEach((value, i) => {
60+
values[i] = index[value];
61+
});
62+
weights[name] = tf.tensor(values, shape, 'float32');
63+
offset += size;
6664
});
65+
return weights;
6766
}
6867
}

0 commit comments

Comments
 (0)