JavaScript机器学习之KNN算法

Leo09L 7年前
   <p style="text-align:center"><img src="https://simg.open-open.com/show/d575a0346f2034f12b20c9daab561852.png"></p>    <p style="text-align:center">上图使用 <a href="/misc/goto?guid=4958839605157594927" rel="nofollow,noindex">plot.ly</a> 所画。</p>    <p>上次我们用JavaScript实现了 <a href="/misc/goto?guid=4959750688648678659" rel="nofollow,noindex">线性规划</a> ,这次我们来聊聊KNN算法。</p>    <p>KNN是 <strong>k-Nearest-Neighbours</strong> 的缩写,它是一种监督学习算法。KNN算法可以用来做分类,也可以用来解决回归问题。</p>    <p>GitHub仓库: <a href="/misc/goto?guid=4959750221901345822" rel="nofollow,noindex">machine-learning-with-js</a></p>    <h2>KNN算法简介</h2>    <p>简单地说, KNN算法由那离自己最近的K个点来投票决定待分类数据归为哪一类 。</p>    <p>如果待分类的数据有这些邻近数据, <em>NY</em> : <strong>7</strong> , <em>NJ</em> : <strong>0</strong> , <em>IN</em> : <strong>4</strong> ,即它有7个 <strong>NY</strong> 邻居,0个 <strong>NJ</strong> 邻居,4个 <strong>IN</strong> 邻居,则这个数据应该归类为 <strong>NY</strong> 。</p>    <p>假设你在邮局工作,你的任务是为邮递员分配信件,目标是最小化到各个社区的投递旅程。不妨假设一共有7个街区。这就是一个实际的分类问题。你需要将这些信件分类,决定它属于哪个社区,比如 <strong>上东城</strong> 、 <strong>曼哈顿下城</strong> 等。</p>    <p>最坏的方案是随意分配信件分配给邮递员,这样每个邮递员会拿到各个社区的信件。</p>    <p>最佳的方案是根据信件地址进行分类,这样每个邮递员只需要负责邻近社区的信件。</p>    <p>也许你是这样想的:”将邻近3个街区的信件分配给同一个邮递员”。这时,邻近街区的个数就是 <strong>k</strong> 。你可以不断增加 <strong>k</strong> ,直到获得最佳的分配方案。这个 <strong>k</strong> 就是分类问题的最佳值。</p>    <h2>KNN代码实现</h2>    <p>像 <a href="/misc/goto?guid=4959750688648678659" rel="nofollow,noindex">上次</a> 一样,我们将使用 <a href="/misc/goto?guid=4959750221572178382" rel="nofollow,noindex">mljs</a> 的 <strong>KNN</strong> 模块 <a href="/misc/goto?guid=4959750688812642893" rel="nofollow,noindex">ml-knn</a> 来实现。</p>    <p>每一个机器学习算法都需要数据,这次我将使用 <strong>IRIS数据集</strong> 。其数据集包含了150个样本,都属于 <a href="/misc/goto?guid=4959750688896168448" rel="nofollow,noindex">鸢尾属</a> 下的三个亚属,分别是 <a href="/misc/goto?guid=4959750688984671333" rel="nofollow,noindex">山鸢尾</a> 、 <a href="/misc/goto?guid=4959750689066410907" rel="nofollow,noindex">变色鸢尾</a> 和 <a href="https://zh.wikipedia.org/w/index.php?title=%E7%BB%B4%E5%90%89%E5%B0%BC%E4%BA%9A%E9%B8%A2%E5%B0%BE&action=edit&redlink=1" rel="nofollow,noindex">维吉尼亚鸢尾</a> 。四个特征被用作样本的定量分析,它们分别是 <a href="/misc/goto?guid=4959750689233533920" rel="nofollow,noindex">花萼</a> 和 <a href="/misc/goto?guid=4959750689335486003" rel="nofollow,noindex">花瓣</a> 的长度和宽度。</p>    <h3>1. 安装模块</h3>    <pre>  <code class="language-javascript">$npm install ml-knn@2.0.0 csvtojson prompt  </code></pre>    <p><a href="/misc/goto?guid=4959750688812642893" rel="nofollow,noindex">ml-knn</a> : <strong>k-Nearest-Neighbours</strong> 模块,不同版本的接口可能不同,这篇博客使用了2.0.0</p>    <p><a href="/misc/goto?guid=4959750222072386551" rel="nofollow,noindex">csvtojson</a> : 用于将CSV数据转换为JSON</p>    <p><a href="/misc/goto?guid=4959750689455181683" rel="nofollow,noindex">prompt</a> : 在控制台输入输出数据</p>    <h3>2. 初始化并导入数据</h3>    <p><a href="/misc/goto?guid=4959750689535289418" rel="nofollow,noindex">IRIS数据集</a> 由加州大学欧文分校提供。</p>    <pre>  <code class="language-javascript">curl https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data > iris.csv  </code></pre>    <p>假设你已经初始化了一个NPM项目,请在 <strong>index.js</strong> 中输入以下内容:</p>    <pre>  <code class="language-javascript">const KNN = require('ml-knn');  const csv = require('csvtojson');  const prompt = require('prompt');    var knn;    const csvFilePath = 'iris.csv'; // 数据集  const names = ['sepalLength', 'sepalWidth', 'petalLength', 'petalWidth', 'type'];    let seperationSize; // 分割训练和测试数据    let data = [],   X = [],   y = [];    let trainingSetX = [],   trainingSetY = [],   testSetX = [],   testSetY = [];  </code></pre>    <ul>     <li><strong>seperationSize</strong> 用于分割数据和测试数据</li>    </ul>    <p>使用csvtojson模块的fromFile方法加载数据:</p>    <pre>  <code class="language-javascript">csv(   {   noheader: true,   headers: names   })   .fromFile(csvFilePath)   .on('json', (jsonObj) =>   {   data.push(jsonObj); // 将数据集转换为JS对象数组   })   .on('done', (error) =>   {   seperationSize = 0.7 * data.length;   data = shuffleArray(data);   dressData();   });  </code></pre>    <p>我们将 <strong>seperationSize</strong> 设为样本数目的0.7倍。注意,如果训练数据集太小的话,分类效果将变差。</p>    <p>由于数据集是根据种类排序的,所以需要使用 <strong>shuffleArray</strong> 函数对数据进行混淆,这样才能方便分割出训练数据。这个函数的定义请参考StackOverflow的提问 <a href="/misc/goto?guid=4959750689616574243" rel="nofollow,noindex">How to randomize (shuffle) a JavaScript array?</a> :</p>    <pre>  <code class="language-javascript">function shuffleArray(array)  {   for (var i = array.length - 1; i > 0; i--)   {   var j = Math.floor(Math.random() * (i + 1));   var temp = array[i];   array[i] = array[j];   array[j] = temp;   }   return array;  }  </code></pre>    <h3>3. 转换数据</h3>    <p>数据集中每一条数据可以转换为一个JS对象:</p>    <pre>  <code class="language-javascript">{  sepalLength: ‘5.1’,  sepalWidth: ‘3.5’,  petalLength: ‘1.4’,  petalWidth: ‘0.2’,  type: ‘Iris-setosa’   }  </code></pre>    <p>在使用 <strong>KNN</strong> 算法训练数据之前,需要对数据进行这些处理:</p>    <ol>     <li>将属性(sepalLength, sepalWidth,petalLength,petalWidth)由字符串转换为浮点数. ( <strong>parseFloat</strong> )</li>     <li>将分类 (type)用数字表示</li>    </ol>    <pre>  <code class="language-javascript">function dressData()  {   let types = new Set();    data.forEach((row) =>   {   types.add(row.type);   });   let typesArray = [...types];      data.forEach((row) =>   {   let rowArray, typeNumber;   rowArray = Object.keys(row).map(key => parseFloat(row[key])).slice(0, 4);   typeNumber = typesArray.indexOf(row.type); // Convert type(String) to type(Number)     X.push(rowArray);   y.push(typeNumber);   });     trainingSetX = X.slice(0, seperationSize);   trainingSetY = y.slice(0, seperationSize);   testSetX = X.slice(seperationSize);   testSetY = y.slice(seperationSize);     train();  }  </code></pre>    <h3>4. 训练数据并测试</h3>    <pre>  <code class="language-javascript">function train()  {   knn = new KNN(trainingSetX, trainingSetY,   {   k: 7   });   test();  }  </code></pre>    <p>train方法需要2个必须的参数: 输入数据,即 <a href="/misc/goto?guid=4959750689233533920" rel="nofollow,noindex">花萼</a> 和 <a href="/misc/goto?guid=4959750689335486003" rel="nofollow,noindex">花瓣</a> 的长度和宽度;实际分类,即 <a href="/misc/goto?guid=4959750688984671333" rel="nofollow,noindex">山鸢尾</a> 、 <a href="/misc/goto?guid=4959750689066410907" rel="nofollow,noindex">变色鸢尾</a> 和 <a href="https://zh.wikipedia.org/w/index.php?title=%E7%BB%B4%E5%90%89%E5%B0%BC%E4%BA%9A%E9%B8%A2%E5%B0%BE&action=edit&redlink=1" rel="nofollow,noindex">维吉尼亚鸢尾</a> 。另外,第三个参数是可选的,用于提供调整 <strong>KNN</strong> 算法的内部参数。我将 <strong>k</strong> 参数设为7,其默认值为5。</p>    <p>训练好模型之后,就可以使用测试数据来检查准确性了。我们主要对预测出错的个数比较感兴趣。</p>    <pre>  <code class="language-javascript">function test()  {   const result = knn.predict(testSetX);   const testSetLength = testSetX.length;   const predictionError = error(result, testSetY);   console.log(`Test Set Size = ${testSetLength} and number of Misclassifications = ${predictionError}`);   predict();  }  </code></pre>    <p>比较预测值与真实值,就可以得到出错个数:</p>    <pre>  <code class="language-javascript">function error(predicted, expected)  {   let misclassifications = 0;   for (var index = 0; index < predicted.length; index++)   {   if (predicted[index] !== expected[index])   {   misclassifications++;   }   }   return misclassifications;  }  </code></pre>    <h3>5. 进行预测(可选)</h3>    <p>任意输入属性值,就可以得到预测值</p>    <pre>  <code class="language-javascript">function predict()  {   let temp = [];   prompt.start();   prompt.get(['Sepal Length', 'Sepal Width', 'Petal Length', 'Petal Width'], function(err, result)   {   if (!err)   {   for (var key in result)   {   temp.push(parseFloat(result[key]));   }   console.log(`With ${temp} -- type = ${knn.predict(temp)}`);   }   });  }  </code></pre>    <h3>6. 完整程序</h3>    <p>完整的程序 <strong>index.js</strong> 是这样的:</p>    <pre>  <code class="language-javascript">const KNN = require('ml-knn');  const csv = require('csvtojson');  const prompt = require('prompt');    var knn;    const csvFilePath = 'iris.csv'; // 数据集  const names = ['sepalLength', 'sepalWidth', 'petalLength', 'petalWidth', 'type'];    let seperationSize; // 分割训练和测试数据    let data = [],   X = [],   y = [];    let trainingSetX = [],   trainingSetY = [],   testSetX = [],   testSetY = [];      csv(   {   noheader: true,   headers: names   })   .fromFile(csvFilePath)   .on('json', (jsonObj) =>   {   data.push(jsonObj); // 将数据集转换为JS对象数组   })   .on('done', (error) =>   {   seperationSize = 0.7 * data.length;   data = shuffleArray(data);   dressData();   });    function dressData()  {   let types = new Set();    data.forEach((row) =>   {   types.add(row.type);   });   let typesArray = [...types];      data.forEach((row) =>   {   let rowArray, typeNumber;   rowArray = Object.keys(row).map(key => parseFloat(row[key])).slice(0, 4);   typeNumber = typesArray.indexOf(row.type); // Convert type(String) to type(Number)     X.push(rowArray);   y.push(typeNumber);   });     trainingSetX = X.slice(0, seperationSize);   trainingSetY = y.slice(0, seperationSize);   testSetX = X.slice(seperationSize);   testSetY = y.slice(seperationSize);     train();  }      // 使用KNN算法训练数据  function train()  {   knn = new KNN(trainingSetX, trainingSetY,   {   k: 7   });   test();  }      // 测试训练的模型  function test()  {   const result = knn.predict(testSetX);   const testSetLength = testSetX.length;   const predictionError = error(result, testSetY);   console.log(`Test Set Size = ${testSetLength} and number of Misclassifications = ${predictionError}`);   predict();  }      // 计算出错个数  function error(predicted, expected)  {   let misclassifications = 0;   for (var index = 0; index < predicted.length; index++)   {   if (predicted[index] !== expected[index])   {   misclassifications++;   }   }   return misclassifications;  }      // 根据输入预测结果  function predict()  {   let temp = [];   prompt.start();   prompt.get(['Sepal Length', 'Sepal Width', 'Petal Length', 'Petal Width'], function(err, result)   {   if (!err)   {   for (var key in result)   {   temp.push(parseFloat(result[key]));   }   console.log(`With ${temp} -- type = ${knn.predict(temp)}`);   }   });  }      // 混淆数据集的顺序  function shuffleArray(array)  {   for (var i = array.length - 1; i > 0; i--)   {   var j = Math.floor(Math.random() * (i + 1));   var temp = array[i];   array[i] = array[j];   array[j] = temp;   }   return array;  }  </code></pre>    <p>在控制台执行 <strong>node index.js</strong></p>    <pre>  <code class="language-javascript">$ node index.js  </code></pre>    <p>输出如下:</p>    <pre>  <code class="language-javascript">Test Set Size = 45 and number of Misclassifications = 2  prompt: Sepal Length: 1.7  prompt: Sepal Width: 2.5  prompt: Petal Length: 0.5  prompt: Petal Width: 3.4  With 1.7,2.5,0.5,3.4 -- type = 2  </code></pre>    <h3>参考链接</h3>    <ul>     <li><a href="/misc/goto?guid=4958522355535028428" rel="nofollow,noindex">K NEAREST NEIGHBOR 算法</a></li>     <li><a href="/misc/goto?guid=4959750689780782702" rel="nofollow,noindex">安德森鸢尾花卉数据集</a></li>    </ul>    <p>欢迎加入 <a href="/misc/goto?guid=4959746568903379993" rel="nofollow,noindex">我们Fundebug</a> 的 <strong>全栈BUG监控交流群: 622902485</strong> 。</p>    <p><img src="https://simg.open-open.com/show/37aab51a0b9583cf0cb9426da1e81cdc.jpg"></p>    <p>来自:https://kiwenlau.com/2017/07/10/javascript-machine-learning-knn/</p>    <p> </p>