Skip to content

Upgrading tfjs #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@ node_modules
dist
yarn-error.log
.rpt2_cache
.cache
example/.cache
20 changes: 10 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ await mlClassifier.train(imageData, {
onTrainBegin: () => {
console.log('training begins');
},
onBatchEnd: (batch: any,logs: any) => {
onBatchEnd: (batch, logs) => {
console.log('Loss is: ' + logs.loss.toFixed(5));
}
},
Expand Down Expand Up @@ -232,7 +232,7 @@ Nothing.
```
import MLClassifier from 'ml-classifier';
const mlClassifier = new MLClassifier();
mlClassifier.addData(images, labels, DataType.TRAIN);
mlClassifier.addData(images, labels, 'train');
mlClassifier.train({
callbacks: {
onTrainBegin: () => {
Expand All @@ -259,9 +259,9 @@ mlClassifier.train({
```
import MLClassifier from 'ml-classifier';
const mlClassifier = new MLClassifier();
mlClassifier.addData(images, labels, DataType.TRAIN);
mlClassifier.addData(images, labels, 'train');
mlClassifier.train();
mlClassifier.addData(evaluationImages, labels, DataType.EVALUATE);
mlClassifier.addData(evaluationImages, labels, 'evaluate');
mlClassifier.evaluate();
```

Expand All @@ -282,7 +282,7 @@ mlClassifier.evaluate();
```
import MLClassifier from 'ml-classifier';
const mlClassifier = new MLClassifier();
mlClassifier.addData(images, labels, DataType.TRAIN);
mlClassifier.addData(images, labels, 'train');
mlClassifier.train();
mlClassifier.predict(imageToPredict);
```
Expand All @@ -304,7 +304,7 @@ mlClassifier.predict(imageToPredict);
```
import MLClassifier from 'ml-classifier';
const mlClassifier = new MLClassifier();
mlClassifier.addData(images, labels, DataType.TRAIN);
mlClassifier.addData(images, labels, 'train');
mlClassifier.train();
mlClassifier.save(('path-to-save');
```
Expand All @@ -324,7 +324,7 @@ mlClassifier.save(('path-to-save');
```
import MLClassifier from 'ml-classifier';
const mlClassifier = new MLClassifier();
mlClassifier.addData(images, labels, DataType.TRAIN);
mlClassifier.addData(images, labels, 'train');
mlClassifier.train();
mlClassifier.getModel();
```
Expand All @@ -345,13 +345,13 @@ The saved Tensorflow.js model.
```
import MLClassifier from 'ml-classifier';
const mlClassifier = new MLClassifier();
mlClassifier.addData(images, labels, DataType.TRAIN);
mlClassifier.clearData(DataType.TRAIN);
mlClassifier.addData(images, labels, 'train');
mlClassifier.clearData('train');
```

#### Parameters

* **dataType** (`DataType`) *Optional* - specifies which data to clear. If no argument is provided, all data will be cleared.
* **dataType** (`string`) *Optional* - specifies which data to clear. If no argument is provided, all data will be cleared.

#### Returns

Expand Down
6 changes: 6 additions & 0 deletions example/.babelrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"plugins": [

"@babel/plugin-transform-runtime"
]
}
Binary file added example/images/cat.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added example/images/dog.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions example/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
<script src="./index.js"></script>
76 changes: 76 additions & 0 deletions example/index.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/* globals Promise */
import * as tf from '@tensorflow/tfjs';
import MLClassifier from 'ml-classifier';
import dog from './images/dog.jpg';
import cat from './images/cat.png';

const mlClassifier = new MLClassifier();

const loadImage = (src) => new Promise((resolve, reject) => {
const image = new Image();
image.src = src;
image.crossOrigin = '';
image.onload = () => resolve(image);
image.onerror = (err) => reject(err);
});

function cropImage(img) {
const height = img.shape[0];
const width = img.shape[1];

// use the shorter side as the size to which we will crop
const shorterSide = Math.min(img.shape[0], img.shape[1]);

// calculate beginning and ending crop points
const startingWidth = Math.floor((width - shorterSide) / 2);
const startingHeight = Math.floor((height - shorterSide) / 2);
const endingWidth = Math.floor(startingWidth + shorterSide);
const endingHeight = Math.floor(startingHeight + shorterSide);

// return image data cropped to those points
return img.slice([startingHeight, startingWidth, 0], [endingHeight, endingWidth, 3]);
}
function resizeImage(image) {
return tf.image.resizeBilinear(image, [224, 224]);
}
function batchImage(image) {
// Expand our tensor to have an additional dimension, whose size is 1
const batchedImage = image.expandDims(0);

// Turn pixel data into a float between -1 and 1.
return batchedImage.toFloat().div(tf.scalar(127)).sub(tf.scalar(1));
}
function loadAndProcessImage(image) {
const croppedImage = cropImage(image);
const resizedImage = resizeImage(croppedImage);
const batchedImage = batchImage(resizedImage);
return batchedImage;
}

const parseImage = async (src) => {
const img = await loadImage(src);
const pixels = tf.browser.fromPixels(img);
const imageData = loadAndProcessImage(pixels);
return imageData;
};

(async function() {
// const dogPixels = await parseImage(dog);
// const catPixels = await parseImage(cat);
// const images = dogPixels.concat(catPixels);

const images = [dog, cat];
const labels = ['dog', 'cat'];
await mlClassifier.addData(images, labels, 'train');
mlClassifier.train({
callbacks: {
onTrainBegin: () => {
console.log('training begins');
},
onTrainEnd: () => {
console.log('training ends');
},
},
});
})();

30 changes: 30 additions & 0 deletions example/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"name": "ml-classifier-example",
"version": "0.1.0",
"description": "",
"main": "index.js",
"scripts": {
"develop": "parcel index.html"
},
"devDependencies": {
"@babel/core": "^7.0.0-beta.51",
"@babel/plugin-proposal-decorators": "^7.0.0-beta.51",
"@babel/plugin-transform-classes": "^7.0.0-beta.51",
"@babel/plugin-transform-regenerator": "^7.3.4",
"@babel/plugin-transform-runtime": "^7.3.4",
"@babel/preset-env": "^7.0.0-beta.51",
"@babel/preset-stage-0": "^7.0.0-beta.51",
"@babel/preset-typescript": "^7.0.0-beta.51",
"babel-core": "^7.0.0-0",
"babel-jest": "^23.2.0",
"babel-plugin-transform-class-properties": "^6.24.1",
"babel-plugin-transform-runtime": "^6.23.0",
"babel-polyfill": "^6.26.0",
"gh-pages": "^2.0.1",
"parcel": "^1.11.0"
},
"dependencies": {
"@babel/runtime-corejs2": "^7.3.4",
"@tensorflow/tfjs": "^1.0.0"
}
}
Loading