검색하기귀찮아서만든블로그

[ML] MNIST 모델 학습 본문

개발

[ML] MNIST 모델 학습

hellworld 2022. 11. 28. 20:18

인공지능 기초에 대해서 공부를 시작해보려고 한다.
CNN 알고리즘을 사용한 인공지능 모델 중 가장 기초로 많은 예제가 있는 MNIST 샘플에 대하여 정리한 내용이다.
 
CNN 은 이미지 학습에 가장 많이 사용되는 알고리즘인데 그중 글씨를 인식하기 위한 모델은 MNIST라고 한다.

출처 : Adventures in machine learning

모델을 공부하기 위해서는 기초 지식이 많이 필요하다.
상기 모델은 CNN 알고리즘  Max Pooling 방식으로 압축하고 1차원으로 만든 후에 ReLU 와 Softmax 활성화 함수를 사용한 모델 예시이다.
 

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from keras.models import load_model

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# 각종 파라메터의 영향을 보기 위해 랜덤값 고정
tf.random.set_seed(1234)

# Normalizing data
x_train, x_test = x_train / 255.0, x_test / 255.0

# (60000, 28, 28) => (60000, 28, 28, 1)로 reshape
x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.reshape(-1, 28, 28, 1)

# One-hot 인코딩
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(kernel_size=(3,3), filters=64, input_shape=(28,28,1), padding='same', activation='relu'),
    tf.keras.layers.Conv2D(kernel_size=(3,3), filters=64, padding='same', activation='relu'),
    tf.keras.layers.MaxPool2D(pool_size=(2,2)),

    tf.keras.layers.Conv2D(kernel_size=(3,3), filters=128, padding='same', activation='relu'),
    tf.keras.layers.Conv2D(kernel_size=(3,3), filters=256, padding='valid', activation='relu'),
    tf.keras.layers.MaxPool2D(pool_size=(2,2)),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(units=512, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(units=256, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(units=10, activation='softmax')
])

model.compile(loss='categorical_crossentropy', optimizer=tf.optimizers.Adam(lr=0.001), metrics=['accuracy'])
model.summary()

model.fit(x_train, y_train, batch_size=100, epochs=3, validation_data=(x_test, y_test))

result = model.evaluate(x_test, y_test)
print("최종 예측 성공률(%): ", result[1]*100)

# 6. 모델 저장하기
model.save('mnist_model2.h5')

검색한 자료중 인식률이 높은 모델 샘플을 가져왔다. 샘플 모델은 CNN > CNN > Maxpolling + CNN > CNN > Maxpolling > 1차원 변환 > relu active func > dropout 50% > relu active func > dropout 50% > softmax active func 구조로 생성되고
숫자 이미지를 학습하고, 트레이닝 이미지로 예측률을 출력한다.
모델은 파일로 저장한다.
 

 Layer (type)                Output Shape              Param #
=================================================================
 conv2d (Conv2D)             (None, 28, 28, 64)        640

 conv2d_1 (Conv2D)           (None, 28, 28, 64)        36928

 max_pooling2d (MaxPooling2D  (None, 14, 14, 64)       0
 )

 conv2d_2 (Conv2D)           (None, 14, 14, 128)       73856

 conv2d_3 (Conv2D)           (None, 12, 12, 256)       295168

 max_pooling2d_1 (MaxPooling  (None, 6, 6, 256)        0
 2D)

 flatten (Flatten)           (None, 9216)              0

 dense (Dense)               (None, 512)               4719104

 dropout (Dropout)           (None, 512)               0

 dense_1 (Dense)             (None, 256)               131328

 dropout_1 (Dropout)         (None, 256)               0

 dense_2 (Dense)             (None, 10)                2570

=================================================================
Total params: 5,259,594
Trainable params: 5,259,594
Non-trainable params: 0
_________________________________________________________________
Epoch 1/10
600/600 [==============================] - 147s 243ms/step - loss: 0.1932 - accuracy: 0.9409 - val_loss: 0.0358 - val_accuracy: 0.9880
Epoch 2/10
600/600 [==============================] - 138s 230ms/step - loss: 0.0551 - accuracy: 0.9845 - val_loss: 0.0336 - val_accuracy: 0.9884
Epoch 3/10
600/600 [==============================] - 135s 225ms/step - loss: 0.0418 - accuracy: 0.9887 - val_loss: 0.0289 - val_accuracy: 0.9908
Epoch 4/10
229/600 [==========>...................] - ETA: 1:20 - loss: 0.0327 - accuracy: 0.9911

학습은 3번을 반복하도록 했으나.. 4번까지 수행하다가 중지되었다..?
모델의 테스트 결과는 손실률 3.2% 정확도 99.1%가 나왔다.
너무 높은 정확도는 오버피팅될 수 있기 때문에 적당한 결과물이 나온 것으로 예상된다.
 
 
코드 참조 : https://github.com/Laboputer/LearnML/blob/master/02.%20%5BPOST%5D/01.%20MNIST%2099.5%25%20with%20CNN.ipynb