Transfer learning with mobilenet and KNN
Transfer learning is a popular approach used in machine learning where a model trained on a task is re-purpose a model designed for a different task.
Therefore the definition of Transfer Learning is the following.
Given a source domain D(s)
and learning task T(s)
, a target domain D(t)
and learning task T(t)
, transfer learning aims to help improve the learning of the target predictive function f(t)
in D(t)
using the knowledge in D(s)
and T(s)
, where D(s) = D(t)
or T(s) = T(t)
Requirements
The requirements for this tutorial are:
- NodeJS v>=10
- Parcel bundler (npm i -g parcel-bundler)
step 1
Create a project folder and init an npm proj.
Install the required libraries.
mkdir mobilenet-knn &&
npm init &&
npm i -s @tensorflow/tfjs @tensorflow-models/mobilenet @tensorflow-models/knn-classifier
step 2
Create an index.html file with the following content
<html>
<head>
<title>transfer learning with tfjs mobilenet and knn-classifier</title>
</head>
<body>
<div id="console"></div>
<video autoplay playsinline muted id="webcam" width="224" height="224"></video>
<div>
<button id="class-a">Add A</button>
<button id="class-b">Add B</button>
<button id="class-c">Add C</button>
</div>
<script src="index.js"></script>
</body>
</html>
step 3
Create an index.js file with the following content
import * as tf from "@tensorflow/tfjs";
import * as mobilenet from "@tensorflow-models/mobilenet";
import * as knnClassifier from "@tensorflow-models/knn-classifier";
let net;
const webcamElement = document.getElementById("webcam");
const classifier = knnClassifier.create();
async function setupWebcam() {
return new Promise((resolve, reject) => {
const navigatorAny = navigator;
navigator.getUserMedia =
navigator.getUserMedia ||
navigatorAny.webkitGetUserMedia ||
navigatorAny.mozGetUserMedia ||
navigatorAny.msGetUserMedia;
if (navigator.getUserMedia) {
navigator.getUserMedia(
{ video: true },
stream => {
webcamElement.srcObject = stream;
webcamElement.addEventListener("loadeddata", () => resolve(), false);
},
error => reject()
);
} else {
reject();
}
});
}
async function app() {
console.log("app");
console.log("Loading mobilenet..");
// Load the model.
net = await mobilenet.load();
console.log("Sucessfully loaded model");
await setupWebcam();
// Reads an image from the webcam and associates it with a specific class
// index.
const addExample = classId => {
// Get the intermediate activation of MobileNet 'conv_preds' and pass that
// to the KNN classifier.
const activation = net.infer(webcamElement, "conv_preds");
// Pass the intermediate activation to the classifier.
classifier.addExample(activation, classId);
};
// When clicking a button, add an example for that class.
document
.getElementById("class-a")
.addEventListener("click", () => addExample(0));
document
.getElementById("class-b")
.addEventListener("click", () => addExample(1));
document
.getElementById("class-c")
.addEventListener("click", () => addExample(2));
setInterval(async () => {
if (classifier.getNumClasses() > 0) {
// Get the activation from mobilenet from the webcam.
const activation = net.infer(webcamElement, "conv_preds");
// Get the most likely class and confidences from the classifier module.
const result = await classifier.predictClass(activation);
const classes = ["A", "B", "C"];
document.getElementById("console").innerText = `
prediction: ${classes[result.classIndex]}\n
probability: ${result.confidences[result.classIndex]}
`;
}
tf.nextFrame();
}, 1000);
}
app();
step 4
Run parcel index.html to serve the application and navigate to the url provided by parcel.
The following is the application that you should be presented to.
The app will cask you to enable the webcam, after that you will be able to train the different classes with the provided buttons (A,B,C).
<div id="console"></div>
<video autoplay playsinline muted id="webcam" width="400" height="400"></video>
<div>
<button id="class-a">Add A</button>
<button id="class-b">Add B</button>
<button id="class-c">Add C</button>
</div>