目录
1 目的
2 方法
3 源代码
4 结果
1 目的
①熟悉 Python 的输入输出流;
②学会使用 matplotlib进行图像可视化;
③掌握神经网络的基本原理,学会使用 sklearn 库中的 MLPClassifier 函数构建基础的多层感知机神经网络分类器;
④学会使用网格查找进行超参数优化。
2 方法
①读取并解压 mnist.gz文件,并区分好训练集与测试集;
②查看数据结构,对手写字符进行可视化展示;
③构建多层感知机神经网络模型,并使用网格查找出最优参数;
④输出模型的最优参数以及模型的预测精度。
3 源代码
①启动 Spyder,新建.py 文件,加载试验所需模块
python"># 导入相关模块
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import GridSearchCV
import numpy as np
import pickle
import gzip
import matplotlib. pyplot as plt
②加载数据,数据文件保存在 mnist.gz 安装包中,因此需要对文件进行解压后对文件进行读取,且区分训练集、测试集与验证集
python">#解压数据并进行读取
with gzip.open(r"D:\大二下\数据挖掘\神经网络\mnist.gz") as fp:
training_data, valid_data, test_data= pickle.load(fp,encoding='bytes')
#区分训练集与测试集
X_training_data,y_training_data= training_data
X_valid_data,y_valid_data= valid_data
X_test_data, y_test_data= test_data
③查看数据的结构,为后续建模做准备:
python">#定义函数show_data_struct 用于展示数据的结构
def show_data_struct():
print(X_training_data.shape,y_training_data.shape)
print(X_valid_data.shape,y_valid_data.shape)
print(X_test_data.shape,y_test_data.shape)
print(X_training_data[0])
print(y_training_data[0])
#使用show_data_struct 函数进行数据展示
show_data_struct()
④为了更好地了解数据的形态,对手写字符进行可视化展示
python">#定义函数用于可视化字符的原有图像
def show_image():
plt.figure(1)
for i in range(10):
image=X_training_data[i]
pixels=image.reshape((28,28))
plt.subplot(5,2,i+1)
plt.imshow(pixels,cmap='gray')
plt.title(y_training_data[i])
plt.axis('off')
plt.subplots_adjust(top=0.92,bottom=0.08,left=0.10,right=0.95, hspace=0.45,wspace=0.85)
plt.show()
#使用show_image函数进行图像展示
show_image()
⑤构建参数字典,用于后续使用网格查找进行超参数优化
python">#字典中用于存放的 MLPClassifier 函数的参数列表
mlp_clf__tuned_parameters= {"hidden_layer_sizes":[(100,),(100,30)],
"solver":[' adam', 'sgd', 'bfgs'],
"max_iter":[20],
"verbose":[True]
}
⑥使用MLPClassifier 丽数构建多层感知机神经网络,并使用GridSearchCV 网格查找进行超参数优化,找出最合适的参数
python">#构建多层感知机分类器
mlp=MLPClassifier()
#通过网格查找出最优参数
estimator= GridSearchCV(mlp,mlp_clf__tuned_parameters,n_jobs=6)
#拟合模型
estimator.fit(X_training_data, y_training_data)
#输出最优参数
print(estimator.best_params_)
#输出模型的预测精度
print(estimator.score(X_test_data, y_test_data))
4 结果
(50000, 784) (50000,)
(10000, 784) (10000,)
(10000, 784) (10000,)
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0.01171875 0.0703125 0.0703125 0.0703125
0.4921875 0.53125 0.68359375 0.1015625 0.6484375 0.99609375
0.96484375 0.49609375 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0.1171875 0.140625 0.3671875 0.6015625
0.6640625 0.98828125 0.98828125 0.98828125 0.98828125 0.98828125
0.87890625 0.671875 0.98828125 0.9453125 0.76171875 0.25
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.19140625
0.9296875 0.98828125 0.98828125 0.98828125 0.98828125 0.98828125
0.98828125 0.98828125 0.98828125 0.98046875 0.36328125 0.3203125
0.3203125 0.21875 0.15234375 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0.0703125 0.85546875 0.98828125
0.98828125 0.98828125 0.98828125 0.98828125 0.7734375 0.7109375
0.96484375 0.94140625 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0.3125 0.609375 0.41796875 0.98828125
0.98828125 0.80078125 0.04296875 0. 0.16796875 0.6015625
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0.0546875 0.00390625 0.6015625 0.98828125 0.3515625
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0.54296875 0.98828125 0.7421875 0.0078125 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.04296875
0.7421875 0.98828125 0.2734375 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.13671875 0.94140625
0.87890625 0.625 0.421875 0.00390625 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0.31640625 0.9375 0.98828125
0.98828125 0.46484375 0.09765625 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0.17578125 0.7265625 0.98828125 0.98828125
0.5859375 0.10546875 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0.0625 0.36328125 0.984375 0.98828125 0.73046875
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0.97265625 0.98828125 0.97265625 0.25 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0.1796875 0.5078125 0.71484375 0.98828125
0.98828125 0.80859375 0.0078125 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.15234375 0.578125
0.89453125 0.98828125 0.98828125 0.98828125 0.9765625 0.7109375
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0.09375 0.4453125 0.86328125 0.98828125 0.98828125 0.98828125
0.98828125 0.78515625 0.3046875 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0.08984375 0.2578125 0.83203125 0.98828125
0.98828125 0.98828125 0.98828125 0.7734375 0.31640625 0.0078125
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.0703125 0.66796875
0.85546875 0.98828125 0.98828125 0.98828125 0.98828125 0.76171875
0.3125 0.03515625 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0.21484375 0.671875 0.8828125 0.98828125 0.98828125 0.98828125
0.98828125 0.953125 0.51953125 0.04296875 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.53125 0.98828125
0.98828125 0.98828125 0.828125 0.52734375 0.515625 0.0625
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
5
通过观察数据结构可知,数据由 10000个样本组成,其中每一个样本是由784(28*28)个像素组成的图像,像素黑白用 0/1 进行表示,对应的label目标变量的每个字符图像的真实标签。
由图可知,MNIST数据由手写字符图像和标签组成。
通过对 MLP分类器的学习可见,模型经过 20次迭代,loss 不断减少0.2320587后达到拟合状态。
由输出结果可见,通过 GridSearchCV 网格查到的最优参数为:隐藏层数为(100,30),最大池化层为 20,激活函数为sgd;且此时多层感知机神经网络MNIST手写字符识别的准确率达到了 0.9347。