|
4 | 4 | // https://opensource.org/licenses/MIT |
5 | 5 |
|
6 | 6 | import * as tf from '@tensorflow/tfjs'; |
| 7 | +import axios from 'axios'; |
7 | 8 |
|
8 | 9 | const MANIFEST_FILE = 'manifest.json'; |
9 | 10 |
|
| 11 | +/** |
| 12 | + * @typedef {Record<string, { filename: string, shape: Array<number> }>} Manifest |
| 13 | + */ |
| 14 | +/** |
| 15 | + * Loads all of the variables of a model from a directory |
| 16 | + * which contains a `manifest.json` file and individual variable data files. |
| 17 | + * The `manifest.json` contains the `filename` and `shape` for each data file. |
| 18 | + * |
| 19 | + * @class |
| 20 | + * @property {string} urlPath |
| 21 | + * @property {Manifest} [checkpointManifest] |
| 22 | + * @property {Record<string, tf.Tensor>} variables |
| 23 | + */ |
10 | 24 | export default class CheckpointLoader { |
| 25 | + /** |
| 26 | + * @param {string} urlPath - the directory URL |
| 27 | + */ |
11 | 28 | constructor(urlPath) { |
12 | | - this.urlPath = urlPath; |
13 | | - if (this.urlPath.charAt(this.urlPath.length - 1) !== '/') { |
14 | | - this.urlPath += '/'; |
15 | | - } |
| 29 | + this.urlPath = urlPath.endsWith('/') ? urlPath : `${urlPath}/`; |
| 30 | + this.variables = {}; |
16 | 31 | } |
17 | 32 |
|
| 33 | + /** |
| 34 | + * @private |
| 35 | + * Executes the request to load the manifest.json file. |
| 36 | + * |
| 37 | + * @return {Promise<Manifest>} |
| 38 | + */ |
18 | 39 | async loadManifest() { |
19 | | - return new Promise((resolve, reject) => { |
20 | | - const xhr = new XMLHttpRequest(); |
21 | | - xhr.open('GET', this.urlPath + MANIFEST_FILE); |
22 | | - |
23 | | - xhr.onload = () => { |
24 | | - this.checkpointManifest = JSON.parse(xhr.responseText); |
25 | | - resolve(); |
26 | | - }; |
27 | | - xhr.onerror = (error) => { |
28 | | - reject(); |
29 | | - throw new Error(`${MANIFEST_FILE} not found at ${this.urlPath}. ${error}`); |
30 | | - }; |
31 | | - xhr.send(); |
32 | | - }); |
| 40 | + try { |
| 41 | + const response = await axios.get(this.urlPath + MANIFEST_FILE); |
| 42 | + return response.data; |
| 43 | + } catch (error) { |
| 44 | + throw new Error(`${MANIFEST_FILE} not found at ${this.urlPath}. ${error}`); |
| 45 | + } |
33 | 46 | } |
34 | 47 |
|
| 48 | + /** |
| 49 | + * @private |
| 50 | + * Executes the request to load the file for a variable. |
| 51 | + * |
| 52 | + * @param {string} varName |
| 53 | + * @return {Promise<tf.Tensor>} |
| 54 | + */ |
| 55 | + async loadVariable(varName) { |
| 56 | + const manifest = await this.getCheckpointManifest(); |
| 57 | + if (!(varName in manifest)) { |
| 58 | + throw new Error(`Cannot load non-existent variable ${varName}`); |
| 59 | + } |
| 60 | + const { filename, shape } = manifest[varName]; |
| 61 | + const url = this.urlPath + filename; |
| 62 | + try { |
| 63 | + const response = await axios.get(url, { responseType: 'arraybuffer' }); |
| 64 | + const values = new Float32Array(response.data); |
| 65 | + return tf.tensor(values, shape); |
| 66 | + } catch (error) { |
| 67 | + throw new Error(`Error loading variable ${varName} from URL ${url}: ${error}`); |
| 68 | + } |
| 69 | + } |
35 | 70 |
|
| 71 | + /** |
| 72 | + * @public |
| 73 | + * Lazy-load the contents of the manifest.json file. |
| 74 | + * |
| 75 | + * @return {Promise<Manifest>} |
| 76 | + */ |
36 | 77 | async getCheckpointManifest() { |
37 | | - if (this.checkpointManifest == null) { |
38 | | - await this.loadManifest(); |
| 78 | + if (!this.checkpointManifest) { |
| 79 | + this.checkpointManifest = await this.loadManifest(); |
39 | 80 | } |
40 | 81 | return this.checkpointManifest; |
41 | 82 | } |
42 | 83 |
|
| 84 | + /** |
| 85 | + * @public |
| 86 | + * Get the property names for each variable in the manifest. |
| 87 | + * |
| 88 | + * @return {Promise<string[]>} |
| 89 | + */ |
| 90 | + async getKeys() { |
| 91 | + const manifest = await this.getCheckpointManifest(); |
| 92 | + return Object.keys(manifest); |
| 93 | + } |
| 94 | + |
| 95 | + /** |
| 96 | + * @public |
| 97 | + * Get a dictionary with the tensors for all variables in the manifest. |
| 98 | + * |
| 99 | + * @return {Promise<Record<string, tf.Tensor>>} |
| 100 | + */ |
43 | 101 | async getAllVariables() { |
44 | | - if (this.variables != null) { |
45 | | - return Promise.resolve(this.variables); |
46 | | - } |
47 | | - await this.getCheckpointManifest(); |
48 | | - const variableNames = Object.keys(this.checkpointManifest); |
| 102 | + // Ensure that all keys are loaded and then return the dictionary. |
| 103 | + const variableNames = await this.getKeys(); |
49 | 104 | const variablePromises = variableNames.map(v => this.getVariable(v)); |
50 | | - return Promise.all(variablePromises).then((variables) => { |
51 | | - this.variables = {}; |
52 | | - for (let i = 0; i < variables.length; i += 1) { |
53 | | - this.variables[variableNames[i]] = variables[i]; |
54 | | - } |
55 | | - return this.variables; |
56 | | - }); |
| 105 | + await Promise.all(variablePromises); |
| 106 | + return this.variables; |
57 | 107 | } |
58 | | - getVariable(varName) { |
59 | | - if (!(varName in this.checkpointManifest)) { |
60 | | - throw new Error(`Cannot load non-existent variable ${varName}`); |
61 | | - } |
62 | | - const variableRequestPromiseMethod = (resolve) => { |
63 | | - const xhr = new XMLHttpRequest(); |
64 | | - xhr.responseType = 'arraybuffer'; |
65 | | - const fname = this.checkpointManifest[varName].filename; |
66 | | - xhr.open('GET', this.urlPath + fname); |
67 | | - xhr.onload = () => { |
68 | | - if (xhr.status === 404) { |
69 | | - throw new Error(`Not found variable ${varName}`); |
70 | | - } |
71 | | - const values = new Float32Array(xhr.response); |
72 | | - const tensor = tf.tensor(values, this.checkpointManifest[varName].shape); |
73 | | - resolve(tensor); |
74 | | - }; |
75 | | - xhr.onerror = (error) => { |
76 | | - throw new Error(`Could not fetch variable ${varName}: ${error}`); |
77 | | - }; |
78 | | - xhr.send(); |
79 | | - }; |
80 | | - if (this.checkpointManifest == null) { |
81 | | - return new Promise((resolve) => { |
82 | | - this.loadManifest().then(() => { |
83 | | - new Promise(variableRequestPromiseMethod).then(resolve); |
84 | | - }); |
85 | | - }); |
| 108 | + |
| 109 | + /** |
| 110 | + * @public |
| 111 | + * Access a single variable from its key. Will load only if not previously loaded. |
| 112 | + * |
| 113 | + * @param {string} varName |
| 114 | + * @return {Promise<tf.Tensor>} |
| 115 | + */ |
| 116 | + async getVariable(varName) { |
| 117 | + if (!this.variables[varName]) { |
| 118 | + this.variables[varName] = await this.loadVariable(varName); |
86 | 119 | } |
87 | | - return new Promise(variableRequestPromiseMethod); |
| 120 | + return this.variables[varName]; |
88 | 121 | } |
89 | 122 | } |
0 commit comments