Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save ddhira123/b7d1aa07843823b0146b2459c74a278c to your computer and use it in GitHub Desktop.

Select an option

Save ddhira123/b7d1aa07843823b0146b2459c74a278c to your computer and use it in GitHub Desktop.
const tf = require('@tensorflow/tfjs-node');
const image = require('get-image-data');
const fs = require('fs');
var path = require('path');
const classes = ['rock', 'paper', 'scissors'];
exports.makePredictions = async (req, res, next) => {
const imagePath = `./public/images/${req && req['filename']}`;
try {
const loadModel = async (img) => {
const output = {};
// laod model
console.log('Loading.......')
const model = await tf.node.loadSavedModel(path.join(__dirname,'..', 'SavedModel'));
// classify
// output.predictions = await model.predict(img).data();
let predictions = await model.predict(img).data();
predictions = Array.from(predictions);
output.success = true;
output.message = `Success.`;
output.predictions = predictions;
res.statusCode = 200;
res.json(output);
};
await image(imagePath, async (err, imageData) => {
try {
const image = fs.readFileSync(imagePath);
let tensor = tf.node.decodeImage(image);
const resizedImage = tensor.resizeNearestNeighbor([150, 150]);
const batchedImage = resizedImage.expandDims(0);
const input = batchedImage.toFloat().div(tf.scalar(255));
await loadModel(input);
// delete image file
fs.unlinkSync(imagePath, (error) => {
if (error) {
console.error(error);
}
});
} catch (error) {
res.status(500).json({message: "Internal Server Error!"});
}
});
} catch (error) {
console.log(error)
}
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment