使用Python训练SVM模型识别手写体数字

2457409219 8年前
   <p>支持向量机SVM(Support Vector Machine)是有监督的分类预测模型,本篇文章使用机器学习库scikit-learn中的手写数字数据集介绍使用Python对SVM模型进行训练并对手写数字进行识别的过程。</p>    <p><img src="https://simg.open-open.com/show/52ea23ffdeae4c6b1960199c9176bbd4.jpg"></p>    <h2>准备工作</h2>    <p>手写数字识别的原理是将数字的图片分割为8X8的灰度值矩阵,将这64个灰度值作为每个数字的训练集对模型进行训练。手写数字所对应的真实数字作为分类结果。在机器学习sklearn库中已经包含了不同数字的8X8灰度值矩阵,因此我们首先导入sklearn库自带的datasets数据集。然后是交叉验证库,SVM分类算法库,绘制图表库等。</p>    <pre>  #导入自带数据集  from sklearn import datasets  #导入交叉验证库  from sklearn import cross_validation  #导入SVM分类算法库  from sklearn import svm  #导入图表库  import matplotlib.pyplot as plt  #生成预测结果准确率的混淆矩阵  from sklearn import metrics</pre>    <h2>读取并查看数字矩阵</h2>    <p>从sklearn库自带的datasets数据集中读取数字的8X8矩阵信息并赋值给digits。</p>    <pre>  #读取自带数据集并赋值给digits  digits = datasets.load_digits()</pre>    <p>查看其中的数字9可以发现,手写的数字9以64个灰度值保存。从下面的8×8矩阵中很难看出这是数字9。</p>    <pre>  #查看数据集中数字9的矩阵  digits.data[9]</pre>    <p><img src="https://simg.open-open.com/show/2999bcb3b28c69f6f6b9c4e5e9542e61.jpg"></p>    <p>以灰度值的方式输出手写数字9的图像,可以看出个大概轮廓。这就是经过切割并以灰度保存的手写数字9。它所对应的64个灰度值就是模型的训练集,而真实的数字9是目标分类。我们的模型所要做的就是在已知64个灰度值与每个数字对应关系的情况下,通过对模型进行训练来对新的手写数字对应的真实数字进行分类。</p>    <pre>  #绘制图表查看数据集中数字9的图像  plt.imshow(digits.images[9], cmap=plt.cm.gray_r, interpolation='nearest')  plt.title('digits.target[9]')  plt.show()</pre>    <h2> </h2>    <p><img src="https://simg.open-open.com/show/f601b8764e695218e495e317a87569f5.png"></p>    <p>设置模型的特征X和预测目标Y</p>    <p>查看数据集中的分类目标,可以看到一共有10个分类,分布为0-9。我们将这个分类目标赋值给Y,作为模型的预测目标。</p>    <pre>  #数据集中的目标分类  digits.target</pre>    <p><img src="https://simg.open-open.com/show/de186cf518c4cfca468d371d4642e4f0.jpg"></p>    <pre>  #将数据集中的目标赋给Y  Y=digits.target</pre>    <p>手写数字的64个灰度值作为特征赋值给X,这里需要说明的是64个灰度值是以8×8矩阵的形式保持的,因此我们需要使用reshape函数重新调整矩阵的行列数。这里也就是将8×8的两维数据转换为64×1的一维数据。</p>    <pre>  #使用reshape函数对矩阵进行转换,并赋值给X  n_samples = len(digits.images)  X = digits.images.reshape((n_samples, 64))</pre>    <p>查看特征值X和预测目标Y的行数,共有1797行,也就是说数据集中共有1797个手写数字的图像,64列是经过我们转化后的灰度值。</p>    <pre>  #查看X和Y的行数  X.shape,Y.shape</pre>    <h2> </h2>    <p><img src="https://simg.open-open.com/show/311f62884e99b2b56ffc1d9fe09fbd69.jpg"></p>    <p>将数据分割为训练集和测试集</p>    <p>将1797个手写数字的灰度值采用随机抽样的方法分割为训练集和测试集,其中训练集为60%,测试集为40%。</p>    <pre>  #随机抽取生成训练集和测试集,其中训练集的比例为60%,测试集40%  X_train, X_test, y_train, y_test = cross_validation.train_test_split(X, Y, test_size=0.4, random_state=0)</pre>    <p>查看分割后的测试集数据,共有1078条数据。这些数据将用来训练SVM模型。</p>    <pre>  #查看训练集的行数  X_train.shape,y_train.shape</pre>    <h2> </h2>    <p><img src="https://simg.open-open.com/show/cac2427f63316c1a6bf23a0c1a2f18b2.jpg"></p>    <p>对SVM模型进行训练</p>    <p>将训练集数据X_train和y_train代入到SVM模型中,对模型进行训练。下面是具体的代码和结果。</p>    <pre>  #生成SVM分类模型  clf = svm.SVC(gamma=0.001)</pre>    <pre>  #使用训练集对svm分类模型进行训练  clf.fit(X_train, y_train)</pre>    <p><img src="https://simg.open-open.com/show/a61fdbe6b5feaf242711f320068008ee.jpg"></p>    <h2>使用测试集测对模型进行测试</h2>    <p>使用测试集数据X_test和y_test对训练后的SVM模型进行检验,模型对手写数字分类的准确率为99.3%。这是非常高的准确率。那么是否真的这么靠谱吗?下面我们来单独测试下。</p>    <pre>  #使用测试集衡量分类模型准确率  clf.score(X_test, y_test)</pre>    <p><img src="https://simg.open-open.com/show/a708fac9ef1f1fbf9befb311ac2bc124.jpg"></p>    <p>我们使用测试集的特征X,也就是每个手写数字的64个灰度值代入到模型中,让SVM模型进行分类。</p>    <pre>  #对测试集数据进行预测  predicted=clf.predict(X_test)</pre>    <p>然后查看前20个手写数字的分类结果,也就是手写数字所对应的真实数字。下面是具体的分类结果。</p>    <pre>  #查看前20个测试集的预测结果  predicted[:20]</pre>    <p><img src="https://simg.open-open.com/show/02ae1924ae147c88eb56eaf240a9d0bb.jpg"></p>    <p>再查看训练集中前20个分类结果,也就是真实数字的情况,并将之前的分类结果与测试集的真实结果进行对比。</p>    <pre>  #查看测试集中的真实结果  expected=y_test</pre>    <p>以下是测试集中前20个真实数字的结果,与前面SVM模型的分类结果对比,前20个结果是一致的。</p>    <pre>  #查看测试集中前20个真实结果  expected[:20]</pre>    <p><img src="https://simg.open-open.com/show/5e3c0f470878bbad405823e2b45623af.jpg"></p>    <p>使用混淆矩阵来看下SVM模型对所有测试集数据的预测与真实结果的准确率情况,下面是一个10X10的矩阵,左上角第一行第一个数字60表示实际为0,SVM模型也预测为0的个数,第一行第二个数字表示实际为0,SVM模型预测为1的数字。第二行第二个数字73表示实际为1,SVM模型也预测为1的个数。</p>    <pre>  #生成准确率的混淆矩阵(Confusion matrix)  metrics.confusion_matrix(expected, predicted)</pre>    <p><img src="https://simg.open-open.com/show/02bbf7165ee32fafb9b2ec8374b5a343.jpg"></p>    <p>从混淆矩阵中可以看到,大部分的数字SVM的分类和预测都是正确的,但也有个别的数字分类错误,例如真实的数字2,SVM模型有一次错误的分类为1,还有一次错误分类为7。</p>    <p>以下为scikit-learn官方的案例说明及代码。</p>    <p>http://scikit-learn.org/stable/auto_examples/classification/plot_digits_classification.html</p>    <p> </p>    <p>来自:http://bluewhale.cc/2016-09-06/python-svm-recognizing-hand-written-digits.html</p>    <p> </p>