原创

手写数字识别-小数据集


1.手写数字数据集

  • from sklearn.datasets import load_digits
  • digits = load_digits()

2.图片数据预处理

  • x:归一化MinMaxScaler()
  • y:独热编码OneHotEncoder()或to_categorical
  • 训练集测试集划分
  • 张量结构

3.设计卷积神经网络结构

  • 绘制模型结构图,并说明设计依据。

4.模型训练

5.模型评价

  • model.evaluate()
  • 交叉表与交叉矩阵
  • pandas.crosstab
  • seaborn.heatmap

实现代码

# # author:陌攻
import numpy
from tensorflow.keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense
from sklearn.externals import joblib
from keras.utils import np_utils
import numpy as np
import pandas as pd
import seaborn as sns
import struct

seed = 7
numpy.random.seed(seed)

# 加载数据
(X_tarin, y_train), (X_test, y_test) = mnist.load_data()
y_test1=y_test

# 数据处理
# # 数据降维与转码
num_pixels = X_tarin.shape[1] * X_tarin.shape[2]
X_tarin = X_tarin.reshape(X_tarin.shape[0], num_pixels).astype('float32')
X_test = X_test.reshape(X_test.shape[0], num_pixels).astype('float32')

# # 像素255*255*255
X_tarin = X_tarin / 255
X_test = X_test / 255

# # 对输出进行one hot编码
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
num_classes = y_test.shape[1]

# MLP模型
def baseline_model():
    model = Sequential()
    model.add(Dense(num_pixels, input_dim=num_pixels, init='normal', activation='relu'))
    model.add(Dense(num_classes, init='normal', activation='softmax'))
    model.summary()
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model

# 建立模型
model = baseline_model()

# 训练模型
model.fit(X_tarin, y_train, validation_data=(X_test, y_test), nb_epoch=10, batch_size=200, verbose=2)

# 保存模型
joblib.dump(model, 'NumberModel.pkl')

# 读取模型
# model = joblib.load('NumberModel.pkl')

# 模型评估
scores = model.evaluate(X_test, y_test, verbose=0)
print("正确率: %.2f%%" % (scores[1]*100))  # 输出正确率

# 交叉表与交叉矩阵
# # 识别test数据
y_pred=model.predict(X_test)

# # 将识别出来的数组(10000,10)还原成数字(10000,)
y_pred=np.argmax(y_pred,axis=1).reshape(-1)

a=pd.crosstab(np.array(y_test1),y_pred)

# # 属性转换dataframe
df=pd.DataFrame(a)

# # 打印交叉矩阵
print(df)

# # 绘制交叉表
from matplotlib import pyplot as plt
sns.heatmap(df,annot=True,cmap="YlGnBu",linewidths=0.2,linecolor='G')
plt.show()

运行结果图:

交叉矩阵

交叉矩阵

交叉表

交叉表

python
算法
  • 作者:陌攻(联系作者)
  • 发表时间:2021-04-28 14:02
  • 版权声明:自由转载-非商用-非衍生-保持署名(创意共享3.0许可证)
  • 公众号转载:请在文末添加作者公众号二维码
  • 评论