机器学习--k-近邻(kNN)算法

jopen 10年前

  一、基本原理
        存在一个样本数据集合(也称训练样本集),并且样本集中每个数据都存在标签。输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。
我们一般只选择样本集中前k(k通常是不大于20的整数)个最相似的数据,最后选择k个最相似数据中出现次数最多的分类,作为新数据的分类。
二、算法流程
        1)计算已知类别数据集中的点与当前点之间的距离;
        2)按照距离递增次序排序;
</div>
        3)选取与当前点距离最小的k个点;
</div>
        4)确定前k个点所在类别的出现频率;
</div>
        5)返回前k个点出现频率最高的类别作为当前点的预测分类。
</div>
三、算法的特点
        优点:精度高、对异常值不敏感、无数据输入假定。
        缺点:计算复杂度高、空间复杂度高。
        适用数据范围:数值型和标称型。
四、python代码实现
1、创建数据集
def create_data_set():
    group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
    labels = ['A', 'A', 'B', 'B']
    return group, labels
2、实施KNN算法

##############################
#功能:将每组数据划分到某个类中
#输入变量:inx, data_set,labels,k
# 分类的向量,样本数据,标签,k个近邻的样本
#输出变量:sorted_class_count[0][0] 选择最近的类别标签
##############################

def classify0(inx, data_set, labels, k):
    data_set_size = data_set.shape[0]  # 获得数组的行数

    # 利用tile(inx, (data_set_size, 1)),在原来的基础上再构造data_set_size*1的inx
    # 每行数据相当于某个矢量点的坐标
    # 对每行数据进行求和,得到一个data_set_size*1的矩阵
    # 最后计算欧式距离
    diff_mat = tile(inx, (data_set_size, 1))-data_set
    sq_diff_mat = diff_mat**2
    sq_distances = sq_diff_mat.sum(axis=1)
    distances = sq_distances**0.5

     # argsort函数返回的是数组值从小到大的索引值
    sorted_dist_indicies = distances.argsort()

    class_count = {}
    for i in xrange(k):
        vote_label = labels[sorted_dist_indicies[i]]

        # get相当于一条if...else...语句
        # 如果参数vote_label不在字典中则返回参数0,如果vote_label在字典中则返回vote_label对应的value值
        class_count[vote_label] = class_count.get(vote_label, 0) + 1

    # items以列表方式返回字典中的键值对,iteritems以迭代器对象返回键值对,而键值对以元组方式存储,即这种方式[(), ()]
    # operator.itemgetter(0)获取对象的第0个域的值,即返回的是key值
    # operator.itemgetter(1)获取对象的第1个域的值,即返回的是value值
    # operator.itemgetter定义了一个函数,通过该函数作用到对象上才能获取值
    # reverse=True是按降序排序
    sorted_class_count = sorted(class_count.iteritems(), key=operator.itemgetter(1), reverse=True)

    return sorted_class_count[0][0]

3、代码测试
def main():
    group, labels = create_data_set()
    sorted_class_labels = classify0([0, 0], group, labels, 3)
    print 'sorted_class_labels=', sorted_class_labels
if __name__ == '__main__':
    main()

来自:http://blog.csdn.net/zyl1042635242/article/details/45081065