함께 배워가는 학생개발자
MNIST 숫자 인식 본문
MNIST
0 ~ 9 손글씨 데이터를 이용하여 인식률을 확인하고 그래프를 나타내는 프로그램
주석으로 코드 설명했으니 참고하시기 바랍니다.
import tensorflow as tf
import matplotlib.pyplot as plt
import random
from tensorflow.examples.tutorials.mnist import input_data
# data를 mnist 변수에 대입
# one_hot
mnist = input_data.read_data_sets("MNIST_data/", one_hot = True)
# classes가 10 -> 인식 할 숫자가 0 ~ 9
nb_classes = 10
# X, Y 생성 (?, 784) (?, 10)
# 784는 픽셀 개수 (28 by 28)
X = tf.placeholder(tf.float32, [None, 784])
Y = tf.placeholder(tf.float32, [None, nb_classes])
# 변수 W,b 생성 (784, 10) (10)
W = tf.Variable(tf.random_normal([784, nb_classes]))
b = tf.Variable(tf.random_normal([nb_classes]))
# Softmax
# Cross entropy
# is_correct using arg_max
# Calculate accuracy
# softmax를 이용한 hypothesis 생성
hypothesis = tf.nn.softmax(tf.matmul(X,W) + b)
# Cross entropy
# 0 < value < 1
cost = tf.reduce_mean(-tf.reduce_sum(Y * tf.log(hypothesis) ,axis = 1))
optimizer = tf.train.GradientDescentOptimizer(learning_rate = 0.1).minimize(cost)
# Test model (array)
# hypothesis, Y의 arg_max 값이 같은지 True/False 로 확인
is_correct = tf.equal(tf.arg_max(hypothesis, 1), tf.arg_max(Y, 1))
# accuracy 계산
# is_correct를 float 형으로 바꿔 평균 계산
accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))
# parameters
# 1epochs - 전체 데이터 셋을 한번 트레이닝 시킨 것
# batch_size - 몇 개씩 잘라서 학습할 것인지 (메모리 효율)
training_epochs = 20
batch_size = 100
# 세션 열기
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# Training cycle
# 15번 학습
for epoch in range(training_epochs):
avg_cost = 0
# 전체 Size 를 batch_size로 잘라서 학습
# 한번에 몇 개씩 학습할 것인지 설정
total_batch = int(mnist.train.num_examples / batch_size)
for i in range(total_batch):
# x_data, y_data
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
# cost 계산
c, _ = sess.run([cost, optimizer], feed_dict={X: batch_xs, Y: batch_ys})
# batch_size만큼 나눠서 avg_cost와 합 계산
avg_cost += c / total_batch
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.9f}'.format(avg_cost))
print('Training Finish')
print("Accuracy:", accuracy.eval(session = sess, feed_dict={X:mnist.test.images, Y:mnist.test.labels}))
# Random 숫자 읽기
r = random.randint(0, mnist.test.num_examples - 1)
# test할 data 숫자 읽기
print("Label:", sess.run(tf.argmax(mnist.test.labels[r:r+1], 1)))
# hypothesis 대입
print("Prediction:", sess.run(tf.argmax(hypothesis, 1), feed_dict = {X : mnist.test.images[r:r+1]}))
plt.imshow(mnist.test.images[r:r+1].reshape(28, 28), cmap = 'Greys', interpolation = 'nearest')
plt.show()
출처 : 모두를 위한 딥러닝
'머신러닝' 카테고리의 다른 글
Faster R-CNN (0) | 2019.06.17 |
---|---|
Normalization (0) | 2017.05.15 |
Softmax function (0) | 2017.05.15 |
Comments