JavaScript机器学习之KNN算法
Leo09L
7年前
<p style="text-align:center"><img src="https://simg.open-open.com/show/d575a0346f2034f12b20c9daab561852.png"></p> <p style="text-align:center">上图使用 <a href="/misc/goto?guid=4958839605157594927" rel="nofollow,noindex">plot.ly</a> 所画。</p> <p>上次我们用JavaScript实现了 <a href="/misc/goto?guid=4959750688648678659" rel="nofollow,noindex">线性规划</a> ,这次我们来聊聊KNN算法。</p> <p>KNN是 <strong>k-Nearest-Neighbours</strong> 的缩写,它是一种监督学习算法。KNN算法可以用来做分类,也可以用来解决回归问题。</p> <p>GitHub仓库: <a href="/misc/goto?guid=4959750221901345822" rel="nofollow,noindex">machine-learning-with-js</a></p> <h2>KNN算法简介</h2> <p>简单地说, KNN算法由那离自己最近的K个点来投票决定待分类数据归为哪一类 。</p> <p>如果待分类的数据有这些邻近数据, <em>NY</em> : <strong>7</strong> , <em>NJ</em> : <strong>0</strong> , <em>IN</em> : <strong>4</strong> ,即它有7个 <strong>NY</strong> 邻居,0个 <strong>NJ</strong> 邻居,4个 <strong>IN</strong> 邻居,则这个数据应该归类为 <strong>NY</strong> 。</p> <p>假设你在邮局工作,你的任务是为邮递员分配信件,目标是最小化到各个社区的投递旅程。不妨假设一共有7个街区。这就是一个实际的分类问题。你需要将这些信件分类,决定它属于哪个社区,比如 <strong>上东城</strong> 、 <strong>曼哈顿下城</strong> 等。</p> <p>最坏的方案是随意分配信件分配给邮递员,这样每个邮递员会拿到各个社区的信件。</p> <p>最佳的方案是根据信件地址进行分类,这样每个邮递员只需要负责邻近社区的信件。</p> <p>也许你是这样想的:”将邻近3个街区的信件分配给同一个邮递员”。这时,邻近街区的个数就是 <strong>k</strong> 。你可以不断增加 <strong>k</strong> ,直到获得最佳的分配方案。这个 <strong>k</strong> 就是分类问题的最佳值。</p> <h2>KNN代码实现</h2> <p>像 <a href="/misc/goto?guid=4959750688648678659" rel="nofollow,noindex">上次</a> 一样,我们将使用 <a href="/misc/goto?guid=4959750221572178382" rel="nofollow,noindex">mljs</a> 的 <strong>KNN</strong> 模块 <a href="/misc/goto?guid=4959750688812642893" rel="nofollow,noindex">ml-knn</a> 来实现。</p> <p>每一个机器学习算法都需要数据,这次我将使用 <strong>IRIS数据集</strong> 。其数据集包含了150个样本,都属于 <a href="/misc/goto?guid=4959750688896168448" rel="nofollow,noindex">鸢尾属</a> 下的三个亚属,分别是 <a href="/misc/goto?guid=4959750688984671333" rel="nofollow,noindex">山鸢尾</a> 、 <a href="/misc/goto?guid=4959750689066410907" rel="nofollow,noindex">变色鸢尾</a> 和 <a href="https://zh.wikipedia.org/w/index.php?title=%E7%BB%B4%E5%90%89%E5%B0%BC%E4%BA%9A%E9%B8%A2%E5%B0%BE&action=edit&redlink=1" rel="nofollow,noindex">维吉尼亚鸢尾</a> 。四个特征被用作样本的定量分析,它们分别是 <a href="/misc/goto?guid=4959750689233533920" rel="nofollow,noindex">花萼</a> 和 <a href="/misc/goto?guid=4959750689335486003" rel="nofollow,noindex">花瓣</a> 的长度和宽度。</p> <h3>1. 安装模块</h3> <pre> <code class="language-javascript">$npm install ml-knn@2.0.0 csvtojson prompt </code></pre> <p><a href="/misc/goto?guid=4959750688812642893" rel="nofollow,noindex">ml-knn</a> : <strong>k-Nearest-Neighbours</strong> 模块,不同版本的接口可能不同,这篇博客使用了2.0.0</p> <p><a href="/misc/goto?guid=4959750222072386551" rel="nofollow,noindex">csvtojson</a> : 用于将CSV数据转换为JSON</p> <p><a href="/misc/goto?guid=4959750689455181683" rel="nofollow,noindex">prompt</a> : 在控制台输入输出数据</p> <h3>2. 初始化并导入数据</h3> <p><a href="/misc/goto?guid=4959750689535289418" rel="nofollow,noindex">IRIS数据集</a> 由加州大学欧文分校提供。</p> <pre> <code class="language-javascript">curl https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data > iris.csv </code></pre> <p>假设你已经初始化了一个NPM项目,请在 <strong>index.js</strong> 中输入以下内容:</p> <pre> <code class="language-javascript">const KNN = require('ml-knn'); const csv = require('csvtojson'); const prompt = require('prompt'); var knn; const csvFilePath = 'iris.csv'; // 数据集 const names = ['sepalLength', 'sepalWidth', 'petalLength', 'petalWidth', 'type']; let seperationSize; // 分割训练和测试数据 let data = [], X = [], y = []; let trainingSetX = [], trainingSetY = [], testSetX = [], testSetY = []; </code></pre> <ul> <li><strong>seperationSize</strong> 用于分割数据和测试数据</li> </ul> <p>使用csvtojson模块的fromFile方法加载数据:</p> <pre> <code class="language-javascript">csv( { noheader: true, headers: names }) .fromFile(csvFilePath) .on('json', (jsonObj) => { data.push(jsonObj); // 将数据集转换为JS对象数组 }) .on('done', (error) => { seperationSize = 0.7 * data.length; data = shuffleArray(data); dressData(); }); </code></pre> <p>我们将 <strong>seperationSize</strong> 设为样本数目的0.7倍。注意,如果训练数据集太小的话,分类效果将变差。</p> <p>由于数据集是根据种类排序的,所以需要使用 <strong>shuffleArray</strong> 函数对数据进行混淆,这样才能方便分割出训练数据。这个函数的定义请参考StackOverflow的提问 <a href="/misc/goto?guid=4959750689616574243" rel="nofollow,noindex">How to randomize (shuffle) a JavaScript array?</a> :</p> <pre> <code class="language-javascript">function shuffleArray(array) { for (var i = array.length - 1; i > 0; i--) { var j = Math.floor(Math.random() * (i + 1)); var temp = array[i]; array[i] = array[j]; array[j] = temp; } return array; } </code></pre> <h3>3. 转换数据</h3> <p>数据集中每一条数据可以转换为一个JS对象:</p> <pre> <code class="language-javascript">{ sepalLength: ‘5.1’, sepalWidth: ‘3.5’, petalLength: ‘1.4’, petalWidth: ‘0.2’, type: ‘Iris-setosa’ } </code></pre> <p>在使用 <strong>KNN</strong> 算法训练数据之前,需要对数据进行这些处理:</p> <ol> <li>将属性(sepalLength, sepalWidth,petalLength,petalWidth)由字符串转换为浮点数. ( <strong>parseFloat</strong> )</li> <li>将分类 (type)用数字表示</li> </ol> <pre> <code class="language-javascript">function dressData() { let types = new Set(); data.forEach((row) => { types.add(row.type); }); let typesArray = [...types]; data.forEach((row) => { let rowArray, typeNumber; rowArray = Object.keys(row).map(key => parseFloat(row[key])).slice(0, 4); typeNumber = typesArray.indexOf(row.type); // Convert type(String) to type(Number) X.push(rowArray); y.push(typeNumber); }); trainingSetX = X.slice(0, seperationSize); trainingSetY = y.slice(0, seperationSize); testSetX = X.slice(seperationSize); testSetY = y.slice(seperationSize); train(); } </code></pre> <h3>4. 训练数据并测试</h3> <pre> <code class="language-javascript">function train() { knn = new KNN(trainingSetX, trainingSetY, { k: 7 }); test(); } </code></pre> <p>train方法需要2个必须的参数: 输入数据,即 <a href="/misc/goto?guid=4959750689233533920" rel="nofollow,noindex">花萼</a> 和 <a href="/misc/goto?guid=4959750689335486003" rel="nofollow,noindex">花瓣</a> 的长度和宽度;实际分类,即 <a href="/misc/goto?guid=4959750688984671333" rel="nofollow,noindex">山鸢尾</a> 、 <a href="/misc/goto?guid=4959750689066410907" rel="nofollow,noindex">变色鸢尾</a> 和 <a href="https://zh.wikipedia.org/w/index.php?title=%E7%BB%B4%E5%90%89%E5%B0%BC%E4%BA%9A%E9%B8%A2%E5%B0%BE&action=edit&redlink=1" rel="nofollow,noindex">维吉尼亚鸢尾</a> 。另外,第三个参数是可选的,用于提供调整 <strong>KNN</strong> 算法的内部参数。我将 <strong>k</strong> 参数设为7,其默认值为5。</p> <p>训练好模型之后,就可以使用测试数据来检查准确性了。我们主要对预测出错的个数比较感兴趣。</p> <pre> <code class="language-javascript">function test() { const result = knn.predict(testSetX); const testSetLength = testSetX.length; const predictionError = error(result, testSetY); console.log(`Test Set Size = ${testSetLength} and number of Misclassifications = ${predictionError}`); predict(); } </code></pre> <p>比较预测值与真实值,就可以得到出错个数:</p> <pre> <code class="language-javascript">function error(predicted, expected) { let misclassifications = 0; for (var index = 0; index < predicted.length; index++) { if (predicted[index] !== expected[index]) { misclassifications++; } } return misclassifications; } </code></pre> <h3>5. 进行预测(可选)</h3> <p>任意输入属性值,就可以得到预测值</p> <pre> <code class="language-javascript">function predict() { let temp = []; prompt.start(); prompt.get(['Sepal Length', 'Sepal Width', 'Petal Length', 'Petal Width'], function(err, result) { if (!err) { for (var key in result) { temp.push(parseFloat(result[key])); } console.log(`With ${temp} -- type = ${knn.predict(temp)}`); } }); } </code></pre> <h3>6. 完整程序</h3> <p>完整的程序 <strong>index.js</strong> 是这样的:</p> <pre> <code class="language-javascript">const KNN = require('ml-knn'); const csv = require('csvtojson'); const prompt = require('prompt'); var knn; const csvFilePath = 'iris.csv'; // 数据集 const names = ['sepalLength', 'sepalWidth', 'petalLength', 'petalWidth', 'type']; let seperationSize; // 分割训练和测试数据 let data = [], X = [], y = []; let trainingSetX = [], trainingSetY = [], testSetX = [], testSetY = []; csv( { noheader: true, headers: names }) .fromFile(csvFilePath) .on('json', (jsonObj) => { data.push(jsonObj); // 将数据集转换为JS对象数组 }) .on('done', (error) => { seperationSize = 0.7 * data.length; data = shuffleArray(data); dressData(); }); function dressData() { let types = new Set(); data.forEach((row) => { types.add(row.type); }); let typesArray = [...types]; data.forEach((row) => { let rowArray, typeNumber; rowArray = Object.keys(row).map(key => parseFloat(row[key])).slice(0, 4); typeNumber = typesArray.indexOf(row.type); // Convert type(String) to type(Number) X.push(rowArray); y.push(typeNumber); }); trainingSetX = X.slice(0, seperationSize); trainingSetY = y.slice(0, seperationSize); testSetX = X.slice(seperationSize); testSetY = y.slice(seperationSize); train(); } // 使用KNN算法训练数据 function train() { knn = new KNN(trainingSetX, trainingSetY, { k: 7 }); test(); } // 测试训练的模型 function test() { const result = knn.predict(testSetX); const testSetLength = testSetX.length; const predictionError = error(result, testSetY); console.log(`Test Set Size = ${testSetLength} and number of Misclassifications = ${predictionError}`); predict(); } // 计算出错个数 function error(predicted, expected) { let misclassifications = 0; for (var index = 0; index < predicted.length; index++) { if (predicted[index] !== expected[index]) { misclassifications++; } } return misclassifications; } // 根据输入预测结果 function predict() { let temp = []; prompt.start(); prompt.get(['Sepal Length', 'Sepal Width', 'Petal Length', 'Petal Width'], function(err, result) { if (!err) { for (var key in result) { temp.push(parseFloat(result[key])); } console.log(`With ${temp} -- type = ${knn.predict(temp)}`); } }); } // 混淆数据集的顺序 function shuffleArray(array) { for (var i = array.length - 1; i > 0; i--) { var j = Math.floor(Math.random() * (i + 1)); var temp = array[i]; array[i] = array[j]; array[j] = temp; } return array; } </code></pre> <p>在控制台执行 <strong>node index.js</strong></p> <pre> <code class="language-javascript">$ node index.js </code></pre> <p>输出如下:</p> <pre> <code class="language-javascript">Test Set Size = 45 and number of Misclassifications = 2 prompt: Sepal Length: 1.7 prompt: Sepal Width: 2.5 prompt: Petal Length: 0.5 prompt: Petal Width: 3.4 With 1.7,2.5,0.5,3.4 -- type = 2 </code></pre> <h3>参考链接</h3> <ul> <li><a href="/misc/goto?guid=4958522355535028428" rel="nofollow,noindex">K NEAREST NEIGHBOR 算法</a></li> <li><a href="/misc/goto?guid=4959750689780782702" rel="nofollow,noindex">安德森鸢尾花卉数据集</a></li> </ul> <p>欢迎加入 <a href="/misc/goto?guid=4959746568903379993" rel="nofollow,noindex">我们Fundebug</a> 的 <strong>全栈BUG监控交流群: 622902485</strong> 。</p> <p><img src="https://simg.open-open.com/show/37aab51a0b9583cf0cb9426da1e81cdc.jpg"></p> <p>来自:https://kiwenlau.com/2017/07/10/javascript-machine-learning-knn/</p> <p> </p>