K最近邻(KNN)算法的java实现
1.急切学习与懒惰学习
急切学习:在给定训练元组之后、接收到测试元组之前就构造好泛化(即分类)模型。
属于急切学习的算法有:决策树、贝叶斯、基于规则的分类、后向传播分类、SVM和基于关联规则挖掘的分类等等。
懒惰学习:直至给定一个测试元组才开始构造泛化模型,也称为基于实例的学习法。
属于急切学习的算法有:KNN分类、基于案例的推理分类。
2.KNN的优缺点
优点:原理简单,实现起来比较方便。支持增量学习。能对超多边形的复杂决策空间建模。
缺点:计算开销大,需要有效的存储技术和并行硬件的支撑。
3.KNN算法原理
基于类比学习,通过比较训练元组和测试元组的相似度来学习。
将训练元组和测试元组看作是n维(若元组有n的属性)空间内的点,给定一条测试元组,搜索n维空间,找出与测试
元组最相近的k个点(即训练元组),最后取这k个点中的多数类作为测试元组的类别。
相近的度量方法:用空间内两个点的距离来度量。距离越大,表示两个点越不相似。
距离的选择:可采用欧几里得距离、曼哈顿距离或其它距离度量。多采用欧几里得距离,简单!
4.KNN算法中的细节处理
- 数值属性规范化:将数值属性规范到0-1区间以便于计算,也可防止大数值型属性对分类的主导作用。
可选的方法有:v' = (v - vmin)/ (vmax - vmin),当然也可以采用其它的规范化方法
- 比较的属性是分类类型而不是数值类型的:同则差为0,异则差为1.
有时候可以作更为精确的处理,比如黑色与白色的差肯定要大于灰色与白色的差。
- 缺失值的处理:取最大的可能差,对于分类属性,如果属性A的一个或两个对应值丢失,则取差值为1;
如果A是数值属性,若两个比较的元组A属性值均缺失,则取差值为1,若只有一个缺失,另一个值为v,
则取差值为|1-v|和|0-v|中的最大值
- 确定K的值:通过实验确定。进行若干次实验,取分类误差率最小的k值。
- 对噪声数据或不相关属性的处理:对属性赋予相关性权重w,w越大说明属性对分类的影响越相关。对噪声数据可以将所在
的元组直接cut掉。
5.KNN算法流程
- 准备数据,对数据进行预处理
- 选用合适的数据结构存储训练数据和测试元组
- 设定参数,如k
- 维护一个大小为k的的按距离由大到小的优先级队列,用于存储最近邻训练元组
- 随机从训练元组中选取k个元组作为初始的最近邻元组,分别计算测试元组到这k个元组的距离,将训练元组标号和距离存入优先级队列
- 遍历训练元组集,计算当前训练元组与测试元组的距离,将所得距离L与优先级队列中的最大距离Lmax进行比较。若L>=Lmax,则舍弃该元组,遍历下一个元组。若L < Lmax,删除优先级队列中最大距离的元组,将当前训练元组存入优先级队列。
- 遍历完毕,计算优先级队列中k个元组的多数类,并将其作为测试元组的类别。
- 测试元组集测试完毕后计算误差率,继续设定不同的k值重新进行训练,最后取误差率最小的k值。
6.KNN算法的改进策略
- 将存储的训练元组预先排序并安排在搜索树中(如何排序有待研究)
- 并行实现
- 部分距离计算,取n个属性的“子集”计算出部分距离,若超过设定的阈值则停止对当前元组作进一步计算。转向下一个元组。
- 剪枝或精简:删除证明是“无用的”元组。
7.KNN算法java实现
本算法只适合学习使用,可以大致了解一下KNN算法的原理。
算法作了如下的假定与简化处理:
1.小规模数据集
2.假设所有数据及类别都是数值类型的
3.直接根据数据规模设定了k值
4.对原训练集进行测试
KNN实现代码如下:
package KNN; /** * KNN结点类,用来存储最近邻的k个元组相关的信息 * @author Rowen * @qq 443773264 * @mail luowen3405@163.com * @blog blog.csdn.net/luowen3405 * @data 2011.03.25 */ public class KNNNode { private int index; // 元组标号 private double distance; // 与测试元组的距离 private String c; // 所属类别 public KNNNode(int index, double distance, String c) { super(); this.index = index; this.distance = distance; this.c = c; } public int getIndex() { return index; } public void setIndex(int index) { this.index = index; } public double getDistance() { return distance; } public void setDistance(double distance) { this.distance = distance; } public String getC() { return c; } public void setC(String c) { this.c = c; } }
package KNN; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.PriorityQueue; /** * KNN算法主体类 * @author Rowen * @qq 443773264 * @mail luowen3405@163.com * @blog blog.csdn.net/luowen3405 * @data 2011.03.25 */ public class KNN { /** * 设置优先级队列的比较函数,距离越大,优先级越高 */ private Comparator<KNNNode> comparator = new Comparator<KNNNode>() { public int compare(KNNNode o1, KNNNode o2) { if (o1.getDistance() >= o2.getDistance()) { return 1; } else { return 0; } } }; /** * 获取K个不同的随机数 * @param k 随机数的个数 * @param max 随机数最大的范围 * @return 生成的随机数数组 */ public List<Integer> getRandKNum(int k, int max) { List<Integer> rand = new ArrayList<Integer>(k); for (int i = 0; i < k; i++) { int temp = (int) (Math.random() * max); if (!rand.contains(temp)) { rand.add(temp); } else { i--; } } return rand; } /** * 计算测试元组与训练元组之前的距离 * @param d1 测试元组 * @param d2 训练元组 * @return 距离值 */ public double calDistance(List<Double> d1, List<Double> d2) { double distance = 0.00; for (int i = 0; i < d1.size(); i++) { distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i)); } return distance; } /** * 执行KNN算法,获取测试元组的类别 * @param datas 训练数据集 * @param testData 测试元组 * @param k 设定的K值 * @return 测试元组的类别 */ public String knn(List<List<Double>> datas, List<Double> testData, int k) { PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator); List<Integer> randNum = getRandKNum(k, datas.size()); for (int i = 0; i < k; i++) { int index = randNum.get(i); List<Double> currData = datas.get(index); String c = currData.get(currData.size() - 1).toString(); KNNNode node = new KNNNode(index, calDistance(testData, currData), c); pq.add(node); } for (int i = 0; i < datas.size(); i++) { List<Double> t = datas.get(i); double distance = calDistance(testData, t); KNNNode top = pq.peek(); if (top.getDistance() > distance) { pq.remove(); pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString())); } } return getMostClass(pq); } /** * 获取所得到的k个最近邻元组的多数类 * @param pq 存储k个最近近邻元组的优先级队列 * @return 多数类的名称 */ private String getMostClass(PriorityQueue<KNNNode> pq) { Map<String, Integer> classCount = new HashMap<String, Integer>(); for (int i = 0; i < pq.size(); i++) { KNNNode node = pq.remove(); String c = node.getC(); if (classCount.containsKey(c)) { classCount.put(c, classCount.get(c) + 1); } else { classCount.put(c, 1); } } int maxIndex = -1; int maxCount = 0; Object[] classes = classCount.keySet().toArray(); for (int i = 0; i < classes.length; i++) { if (classCount.get(classes[i]) > maxCount) { maxIndex = i; maxCount = classCount.get(classes[i]); } } return classes[maxIndex].toString(); } }
package KNN; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.util.ArrayList; import java.util.List; /** * KNN算法测试类 * @author Rowen * @qq 443773264 * @mail luowen3405@163.com * @blog blog.csdn.net/luowen3405 * @data 2011.03.25 */ public class TestKNN { /** * 从数据文件中读取数据 * @param datas 存储数据的集合对象 * @param path 数据文件的路径 */ public void read(List<List<Double>> datas, String path){ try { BufferedReader br = new BufferedReader(new FileReader(new File(path))); String data = br.readLine(); List<Double> l = null; while (data != null) { String t[] = data.split(" "); l = new ArrayList<Double>(); for (int i = 0; i < t.length; i++) { l.add(Double.parseDouble(t[i])); } datas.add(l); data = br.readLine(); } } catch (Exception e) { e.printStackTrace(); } } /** * 程序执行入口 * @param args */ public static void main(String[] args) { TestKNN t = new TestKNN(); String datafile = new File("").getAbsolutePath() + File.separator + "datafile"; String testfile = new File("").getAbsolutePath() + File.separator + "testfile"; try { List<List<Double>> datas = new ArrayList<List<Double>>(); List<List<Double>> testDatas = new ArrayList<List<Double>>(); t.read(datas, datafile); t.read(testDatas, testfile); KNN knn = new KNN(); for (int i = 0; i < testDatas.size(); i++) { List<Double> test = testDatas.get(i); System.out.print("测试元组: "); for (int j = 0; j < test.size(); j++) { System.out.print(test.get(j) + " "); } System.out.print("类别为: "); System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test, 3))))); } } catch (Exception e) { e.printStackTrace(); } } }
训练数据文件:
1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0
1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5
程序运行结果:
测试元组: 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 类别为: 1 测试元组: 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 类别为: 1 测试元组: 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 类别为: 1 测试元组: 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 类别为: 0 测试元组: 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 类别为: 1 测试元组: 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 类别为: 0
由结果可以看出,分类的测试结果是比较准确的!
转自:http://blog.csdn.net/luowen3405/article/details/6278764
参考:http://blog.csdn.net/xlm289348/article/details/8876353
http://coolshell.cn/articles/8052.html#more-8052