Skip to content

Commit 4e6a427

Browse files
committed
Rewrite CheckpointLoaderPix2pix using axios instead of xhr.
1 parent 35a003e commit 4e6a427

File tree

1 file changed

+53
-54
lines changed

1 file changed

+53
-54
lines changed
Lines changed: 53 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,67 @@
1-
/* eslint max-len: "off" */
2-
31
import * as tf from '@tensorflow/tfjs';
2+
import axios from 'axios';
43

4+
/**
5+
* Pix2Pix loads data from a '.pict' file.
6+
* File contains the properties (name and tensor shape) for each variable
7+
* and a huge array of numbers for all of the variables.
8+
* Numbers must be assigned to the correct variable.
9+
*/
510
export default class CheckpointLoaderPix2pix {
11+
/**
12+
* @param {string} urlPath
13+
*/
614
constructor(urlPath) {
15+
/**
16+
* @type {string}
17+
*/
718
this.urlPath = urlPath;
819
}
920

10-
getAllVariables() {
11-
return new Promise((resolve, reject) => {
12-
const weightsCache = {};
13-
if (this.urlPath in weightsCache) {
14-
resolve(weightsCache[this.urlPath]);
15-
return;
16-
}
17-
18-
const xhr = new XMLHttpRequest();
19-
xhr.open('GET', this.urlPath, true);
20-
xhr.responseType = 'arraybuffer';
21-
xhr.onload = () => {
22-
if (xhr.status !== 200) {
23-
reject(new Error('missing model'));
24-
return;
25-
}
26-
const buf = xhr.response;
27-
if (!buf) {
28-
reject(new Error('invalid arraybuffer'));
29-
return;
30-
}
21+
async getAllVariables() {
22+
// Load the file as an ArrayBuffer.
23+
const response = await axios.get(this.urlPath, { responseType: 'arraybuffer' })
24+
.catch(error => {
25+
throw new Error(`No model found. Failed with error ${error}`);
26+
});
27+
/** @type {ArrayBuffer} */
28+
const buf = response.data;
3129

32-
const parts = [];
33-
let offset = 0;
34-
while (offset < buf.byteLength) {
35-
const b = new Uint8Array(buf.slice(offset, offset + 4));
36-
offset += 4;
37-
const len = (b[0] << 24) + (b[1] << 16) + (b[2] << 8) + b[3]; // eslint-disable-line no-bitwise
38-
parts.push(buf.slice(offset, offset + len));
39-
offset += len;
40-
}
30+
// Break data into three parts: shapes, index, and encoded.
31+
/** @type {ArrayBuffer[]} */
32+
const parts = [];
33+
let offset = 0;
34+
while (offset < buf.byteLength) {
35+
const b = new Uint8Array(buf.slice(offset, offset + 4));
36+
offset += 4;
37+
const len = (b[0] << 24) + (b[1] << 16) + (b[2] << 8) + b[3]; // eslint-disable-line no-bitwise
38+
parts.push(buf.slice(offset, offset + len));
39+
offset += len;
40+
}
4141

42-
const shapes = JSON.parse((new TextDecoder('utf8')).decode(parts[0]));
43-
const index = new Float32Array(parts[1]);
44-
const encoded = new Uint8Array(parts[2]);
42+
/** @type {Array<{ name: string, shape: number[] }>} */
43+
const shapes = JSON.parse((new TextDecoder('utf8')).decode(parts[0]));
44+
const index = new Float32Array(parts[1]);
45+
const encoded = new Uint8Array(parts[2]);
4546

46-
// decode using index
47-
const arr = new Float32Array(encoded.length);
48-
for (let i = 0; i < arr.length; i += 1) {
49-
arr[i] = index[encoded[i]];
50-
}
47+
// Dictionary of variables by name.
48+
/** @type {Record<string, tf.Tensor>} */
49+
const weights = {};
5150

52-
const weights = {};
53-
offset = 0;
54-
for (let i = 0; i < shapes.length; i += 1) {
55-
const { shape } = shapes[i];
56-
const size = shape.reduce((total, num) => total * num);
57-
const values = arr.slice(offset, offset + size);
58-
const tfarr = tf.tensor1d(values, 'float32');
59-
weights[shapes[i].name] = tfarr.reshape(shape);
60-
offset += size;
61-
}
62-
weightsCache[this.urlPath] = weights;
63-
resolve(weights);
64-
};
65-
xhr.send(null);
51+
// Create a tensor for each shape.
52+
offset = 0;
53+
shapes.forEach(({ shape, name }) => {
54+
const size = shape.reduce((total, num) => total * num);
55+
// Get the raw data.
56+
const raw = encoded.slice(offset, offset + size);
57+
// Decode using index.
58+
const values = new Float32Array(raw.length);
59+
raw.forEach((value, i) => {
60+
values[i] = index[value];
61+
});
62+
weights[name] = tf.tensor(values, shape, 'float32');
63+
offset += size;
6664
});
65+
return weights;
6766
}
6867
}

0 commit comments

Comments
 (0)