> 文章列表 > LeNet网络简介

LeNet网络简介

LeNet网络简介

1.  背景

主要介绍LeNet网络预测在CIFAR-10图像数据集上的训练及预测。

2. CIFAR-10图像数据集简介

        CIFAR-10是一个包含了6W张32*32像素的三通道彩色图像数据集,图像划分为10大类,每个类别包含了6K张图像。其中训练集5W张,测试集1W张。

数据加载及预处理:

def load_and_proc_data():(X_train, y_train), (X_test, y_test) = cifar10.load_data()print('X_train shape', X_train.shape)# X_train shape (50000, 32, 32, 3)print(X_train.shape[0], 'train samples')print(X_test.shape[0], 'test samples')X_train = X_train.astype('float32')X_test = X_test.astype('float32')X_train /= 255X_test /= 255# 将类向量转换成二值类别矩阵y_train = np_utils.to_categorical(y_train, NB_CLASSES)y_test = np_utils.to_categorical(y_test, NB_CLASSES)return X_train, X_test, y_train, y_test

3. LeNet网络模型定义

3.1 单层卷积网络

from keras.models import Sequential
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.core import Activation, Flatten, Dense, Dropout
from keras.datasets import cifar10
from keras.utils import np_utils
from keras.optimizers import RMSpropclass LeNet:@staticmethoddef build(input_shape, classes):model = Sequential()model.add(Conv2D(32, kernel_size=3, padding='same', input_shape=input_shape))model.add(Activation('relu'))model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))model.add(Dropout(0.25))model.add(Flatten())model.add(Dense(512))model.add(Activation('relu'))model.add(Dropout(0.5))model.add(Dense(classes))model.add(Activation('softmax'))model.summary()  # 概要汇总网络return model

3.2 模型结构及相关参数

  3.3 增加模型深度(多层卷积)

class LeNet:@staticmethoddef build(input_shape, classes):model = Sequential()model.add(Conv2D(32, kernel_size=3, padding='same', input_shape=input_shape))model.add(Activation('relu'))model.add(Conv2D(32, kernel_size=3, padding='same'))model.add(Activation('relu'))model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))model.add(Dropout(0.25))model.add(Conv2D(64, kernel_size=3, padding='same'))model.add(Activation('relu'))model.add(Conv2D(64, kernel_size=3, padding='same'))model.add(Activation('relu'))model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))model.add(Dropout(0.25))model.add(Flatten())model.add(Dense(512))model.add(Activation('relu'))model.add(Dropout(0.5))model.add(Dense(classes))model.add(Activation('softmax'))model.summary()  # 概要汇总网络return model

4. 模型训练及预测

def model_train(X_train, y_train):OPTIMIZER = RMSprop()model = LeNet.build(input_shape=INPUT_SHAPE, classes=NB_CLASSES)model.compile(loss='categorical_crossentropy', optimizer=OPTIMIZER, metrics=['accuracy'])history = model.fit(X_train, y_train, batch_size=BATCH_SIZE, epochs=NB_EPOCH, verbose=1, validation_split=VALIDATION_SPLIT)# plot_picture(history)return modeldef model_evaluate(model, X_test, y_test):score = model.evaluate(X_test, y_test, batch_size=BATCH_SIZE, verbose=1)print('Test score: ', score[0])print('Test acc: ', score[1])

5. 打印准确率和损失函数

def plot_picture(history):print(history.history.keys())# -----------acc---------------plt.plot(history.history['accuracy'])plt.plot(history.history['val_accuracy'])plt.title('model acc')plt.ylabel('acc')plt.xlabel('epoch')plt.legend(['train', 'test'], loc='upper left')plt.show()# -----------loss---------------plt.plot(history.history['loss'])plt.plot(history.history['val_loss'])plt.title('model loss')plt.ylabel('loss')plt.xlabel('epoch')plt.legend(['train', 'test'], loc='upper left')plt.show()

 6. 模型保存

def model_save(model):# 保存网络结构model_json = model.to_json()with open('cifar10_architecture.json', 'w') as f:f.write(model_json)# 保存网络权重model.save_weights('cifar10_weights.h5', overwrite=True)

7. 模型加载及预测

8. 主函数

NB_EPOCH = 50
BATCH_SIZE = 128
VALIDATION_SPLIT = 0.2
IMG_ROWS, IMG_COLS = 32, 32
IMG_CHANNELS = 3
INPUT_SHAPE = (IMG_ROWS, IMG_COLS, IMG_CHANNELS)  # 注意顺序
NB_CLASSES = 10if __name__ == '__main__':X_train, X_test, y_train, y_test = load_and_proc_data()model = model_train(X_train, y_train)# model_save(model)model_evaluate(model, X_test, y_test)

模型输出

Test score:  1.3542113304138184
Test acc:  0.6733999848365784

9.