K最近邻(KNN)算法的java实现

jopen 10年前

1.急切学习与懒惰学习

急切学习:在给定训练元组之后、接收到测试元组之前就构造好泛化(即分类)模型。

属于急切学习的算法有:决策树、贝叶斯、基于规则的分类、后向传播分类、SVM和基于关联规则挖掘的分类等等。

 

懒惰学习:直至给定一个测试元组才开始构造泛化模型,也称为基于实例的学习法。

属于急切学习的算法有:KNN分类、基于案例的推理分类。

 

2.KNN的优缺点

优点:原理简单,实现起来比较方便。支持增量学习。能对超多边形的复杂决策空间建模。

缺点:计算开销大,需要有效的存储技术和并行硬件的支撑。

 

3.KNN算法原理

基于类比学习,通过比较训练元组和测试元组的相似度来学习。

将训练元组和测试元组看作是n维(若元组有n的属性)空间内的点,给定一条测试元组,搜索n维空间,找出与测试

元组最相近的k个点(即训练元组),最后取这k个点中的多数类作为测试元组的类别。

 

相近的度量方法:用空间内两个点的距离来度量。距离越大,表示两个点越不相似。

 

距离的选择:可采用欧几里得距离、曼哈顿距离或其它距离度量。多采用欧几里得距离,简单!

 

4.KNN算法中的细节处理

  • 数值属性规范化:将数值属性规范到0-1区间以便于计算,也可防止大数值型属性对分类的主导作用。

可选的方法有:v' = (v - vmin)/ (vmax - vmin),当然也可以采用其它的规范化方法

  • 比较的属性是分类类型而不是数值类型的:同则差为0,异则差为1.

有时候可以作更为精确的处理,比如黑色与白色的差肯定要大于灰色与白色的差。

  • 缺失值的处理:取最大的可能差,对于分类属性,如果属性A的一个或两个对应值丢失,则取差值为1;

如果A是数值属性,若两个比较的元组A属性值均缺失,则取差值为1,若只有一个缺失,另一个值为v,

则取差值为|1-v|和|0-v|中的最大值

  • 确定K的值:通过实验确定。进行若干次实验,取分类误差率最小的k值。
  • 对噪声数据或不相关属性的处理:对属性赋予相关性权重w,w越大说明属性对分类的影响越相关。对噪声数据可以将所在

的元组直接cut掉。

 

5.KNN算法流程

  • 准备数据,对数据进行预处理
  • 选用合适的数据结构存储训练数据和测试元组
  • 设定参数,如k
  • 维护一个大小为k的的按距离由大到小的优先级队列,用于存储最近邻训练元组
  • 随机从训练元组中选取k个元组作为初始的最近邻元组,分别计算测试元组到这k个元组的距离,将训练元组标号和距离存入优先级队列
  • 遍历训练元组集,计算当前训练元组与测试元组的距离,将所得距离L与优先级队列中的最大距离Lmax进行比较。若L>=Lmax,则舍弃该元组,遍历下一个元组。若L < Lmax,删除优先级队列中最大距离的元组,将当前训练元组存入优先级队列。
  • 遍历完毕,计算优先级队列中k个元组的多数类,并将其作为测试元组的类别。
  • 测试元组集测试完毕后计算误差率,继续设定不同的k值重新进行训练,最后取误差率最小的k值。

6.KNN算法的改进策略

  • 将存储的训练元组预先排序并安排在搜索树中(如何排序有待研究)
  • 并行实现
  • 部分距离计算,取n个属性的“子集”计算出部分距离,若超过设定的阈值则停止对当前元组作进一步计算。转向下一个元组。
  • 剪枝或精简:删除证明是“无用的”元组。

7.KNN算法java实现

 

本算法只适合学习使用,可以大致了解一下KNN算法的原理。

 

算法作了如下的假定与简化处理:

1.小规模数据集

2.假设所有数据及类别都是数值类型的

3.直接根据数据规模设定了k值

4.对原训练集进行测试

 

KNN实现代码如下:

    package KNN;        /**        * KNN结点类,用来存储最近邻的k个元组相关的信息        * @author Rowen        * @qq 443773264        * @mail luowen3405@163.com        * @blog blog.csdn.net/luowen3405        * @data 2011.03.25        */        public class KNNNode {            private int index; // 元组标号            private double distance; // 与测试元组的距离            private String c; // 所属类别            public KNNNode(int index, double distance, String c) {                super();                this.index = index;                this.distance = distance;                this.c = c;            }                                    public int getIndex() {                return index;            }            public void setIndex(int index) {                this.index = index;            }            public double getDistance() {                return distance;            }            public void setDistance(double distance) {                this.distance = distance;            }            public String getC() {                return c;            }            public void setC(String c) {                this.c = c;            }        }  

 

    package KNN;        import java.util.ArrayList;        import java.util.Comparator;        import java.util.HashMap;        import java.util.List;        import java.util.Map;        import java.util.PriorityQueue;                /**        * KNN算法主体类        * @author Rowen        * @qq 443773264        * @mail luowen3405@163.com        * @blog blog.csdn.net/luowen3405        * @data 2011.03.25        */        public class KNN {            /**            * 设置优先级队列的比较函数,距离越大,优先级越高            */            private Comparator<KNNNode> comparator = new Comparator<KNNNode>() {                public int compare(KNNNode o1, KNNNode o2) {                    if (o1.getDistance() >= o2.getDistance()) {                        return 1;                    } else {                        return 0;                    }                }            };            /**            * 获取K个不同的随机数            * @param k 随机数的个数            * @param max 随机数最大的范围            * @return 生成的随机数数组            */            public List<Integer> getRandKNum(int k, int max) {                List<Integer> rand = new ArrayList<Integer>(k);                for (int i = 0; i < k; i++) {                    int temp = (int) (Math.random() * max);                    if (!rand.contains(temp)) {                        rand.add(temp);                    } else {                        i--;                    }                }                return rand;            }            /**            * 计算测试元组与训练元组之前的距离            * @param d1 测试元组            * @param d2 训练元组            * @return 距离值            */            public double calDistance(List<Double> d1, List<Double> d2) {                double distance = 0.00;                for (int i = 0; i < d1.size(); i++) {                    distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i));                }                return distance;            }            /**            * 执行KNN算法,获取测试元组的类别            * @param datas 训练数据集            * @param testData 测试元组            * @param k 设定的K值            * @return 测试元组的类别            */            public String knn(List<List<Double>> datas, List<Double> testData, int k) {                PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator);                List<Integer> randNum = getRandKNum(k, datas.size());                for (int i = 0; i < k; i++) {                    int index = randNum.get(i);                    List<Double> currData = datas.get(index);                    String c = currData.get(currData.size() - 1).toString();                    KNNNode node = new KNNNode(index, calDistance(testData, currData), c);                    pq.add(node);                }                for (int i = 0; i < datas.size(); i++) {                    List<Double> t = datas.get(i);                    double distance = calDistance(testData, t);                    KNNNode top = pq.peek();                    if (top.getDistance() > distance) {                        pq.remove();                        pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString()));                    }                }                                return getMostClass(pq);            }            /**            * 获取所得到的k个最近邻元组的多数类            * @param pq 存储k个最近近邻元组的优先级队列            * @return 多数类的名称            */            private String getMostClass(PriorityQueue<KNNNode> pq) {                Map<String, Integer> classCount = new HashMap<String, Integer>();                for (int i = 0; i < pq.size(); i++) {                    KNNNode node = pq.remove();                    String c = node.getC();                    if (classCount.containsKey(c)) {                        classCount.put(c, classCount.get(c) + 1);                    } else {                        classCount.put(c, 1);                    }                }                int maxIndex = -1;                int maxCount = 0;                Object[] classes = classCount.keySet().toArray();                for (int i = 0; i < classes.length; i++) {                    if (classCount.get(classes[i]) > maxCount) {                        maxIndex = i;                        maxCount = classCount.get(classes[i]);                    }                }                return classes[maxIndex].toString();            }        }  

    package KNN;        import java.io.BufferedReader;        import java.io.File;        import java.io.FileReader;        import java.util.ArrayList;        import java.util.List;        /**        * KNN算法测试类        * @author Rowen        * @qq 443773264        * @mail luowen3405@163.com        * @blog blog.csdn.net/luowen3405        * @data 2011.03.25        */        public class TestKNN {                        /**            * 从数据文件中读取数据            * @param datas 存储数据的集合对象            * @param path 数据文件的路径            */            public void read(List<List<Double>> datas, String path){                try {                    BufferedReader br = new BufferedReader(new FileReader(new File(path)));                    String data = br.readLine();                    List<Double> l = null;                    while (data != null) {                        String t[] = data.split(" ");                        l = new ArrayList<Double>();                        for (int i = 0; i < t.length; i++) {                            l.add(Double.parseDouble(t[i]));                        }                        datas.add(l);                        data = br.readLine();                    }                } catch (Exception e) {                    e.printStackTrace();                }            }                        /**            * 程序执行入口            * @param args            */            public static void main(String[] args) {                TestKNN t = new TestKNN();                String datafile = new File("").getAbsolutePath() + File.separator + "datafile";                String testfile = new File("").getAbsolutePath() + File.separator + "testfile";                try {                    List<List<Double>> datas = new ArrayList<List<Double>>();                    List<List<Double>> testDatas = new ArrayList<List<Double>>();                    t.read(datas, datafile);                    t.read(testDatas, testfile);                    KNN knn = new KNN();                    for (int i = 0; i < testDatas.size(); i++) {                        List<Double> test = testDatas.get(i);                        System.out.print("测试元组: ");                        for (int j = 0; j < test.size(); j++) {                            System.out.print(test.get(j) + " ");                        }                        System.out.print("类别为: ");                        System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test, 3)))));                    }                } catch (Exception e) {                    e.printStackTrace();                }            }        }  

 

训练数据文件:

    1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1        1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1        1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1        1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0        1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1        1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0  

 

    1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5        1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8        1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2        1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5        1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5        1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5  

 

程序运行结果:

    测试元组: 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 类别为: 1        测试元组: 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 类别为: 1        测试元组: 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 类别为: 1        测试元组: 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 类别为: 0        测试元组: 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 类别为: 1        测试元组: 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 类别为: 0  

 

由结果可以看出,分类的测试结果是比较准确的!

 

 

转自:http://blog.csdn.net/luowen3405/article/details/6278764

 

参考:http://blog.csdn.net/xlm289348/article/details/8876353

http://coolshell.cn/articles/8052.html#more-8052