Tensorflow

Tensorflow keras MNIST image of Model Optimization & Training (Expert)

Tensorflow 공식 예제 사이트 이미지 분석 초보자 예제

https://www.tensorflow.org/tutorials/quickstart/advanced

라이브러리 Import

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import datasets
import matplotlib.pyplot as plt
%matplotlib inline 

학습 과정 돌아보기

Build Model

input_shape = (28, 28, 1)
num_classes = 10

inputs = layers.Input(input_shape, dtype=tf.float32)
net = layers.Conv2D(32, (3, 3), padding='SAME')(inputs)
net = layers.Activation('relu')(net)
net = layers.Conv2D(32, (3, 3), padding='SAME')(net)
net = layers.Activation('relu')(net)
net = layers.MaxPooling2D(pool_size=(2, 2))(net)
net = layers.Dropout(0.5)(net)

net = layers.Conv2D(64, (3, 3), padding='SAME')(net)
net = layers.Activation('relu')(net)
net = layers.Conv2D(64, (3, 3), padding='SAME')(net)
net = layers.Activation('relu')(net)
net = layers.MaxPooling2D(pool_size=(2, 2))(net)
net = layers.Dropout(0.5)(net)

net = layers.Flatten()(net)
net = layers.Dense(512)(net)
net = layers.Activation('relu')(net)
net = layers.Dropout(0.5)(net)
net = layers.Dense(num_classes)(net)
net = layers.Activation('softmax')(net)

model = tf.keras.Model(inputs=inputs, outputs=net, name='Basic_CNN')

mnist = tf.keras.datasets.mnist

# Load Data from MNIST
(train_x, train_y), (test_x, test_y) = mnist.load_data()

# 차원 추가
train_x = train_x[...,tf.newaxis]
test_x = test_x[...,tf.newaxis]

# Data Normalization
train_x, test_x = train_x / 255 , test_x /255

tf.print(train_x.dtype)
tf.print(train_x.shape)
print 값

이미지 확인

plt.imshow(train_x[0, :, : , 0])
plt.show()
train_x.shape

tf.data.Dataset.from_tensor_slices

train_ds = tf.data.Dataset.from_tensor_slices((train_x, train_y))
train_ds = train_ds.shuffle(1000)
train_ds = train_ds.batch(32)

test_ds = tf.data.Dataset.from_tensor_slices((test_x, test_y))
test_ds = test_ds.batch(32)
for image, label in train_ds.take(2):
    tf.print(image.shape)
    tf.print(label[0])
    plt.imshow(image[0, :, :, 0])
    plt.colorbar()
    plt.show()
image ,label = next(iter(train_ds))
image.shape, image.dtype, label.shape, label.dtype

for image, label in train_ds.take(2):
    plt.title(str(label[0]))
    plt.imshow(image[0,:, :, 0], 'gray')
    plt.colorbar()
    plt.show()

Keras 초보자용 부분에서 사용했던 fit부분에 train_x, train_y를 넣어주지 않고 바로 train_ds을 넣어줄 수 있다 그 이유는 train_ds는 generator Type이기 떄문이다.

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
model.fit(train_ds, epochs= 1)

Optimization

loss_object = tf.keras.losses.SparseCategoricalCrossentropy()

opt = tf.keras.optimizers.Adam()

train_loss = tf.keras.metrics.Mean(name='train_loss') # 평균으로 계산
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

Training

@tf.function : 기존 session 열었던 것처럼 바로 작동 안 하고, 그래프만 만들고 학습이 시작되면 돌아가도록 함

@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        predictions = model(images)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    opt.apply_gradients(zip(gradients, model.trainable_variables))
    
    train_loss(loss) #평균 값을 낼 수 있도록
    train_accuracy(labels, predictions)

@tf.function
def test_step(images, labels):
    predictions = model(images)
    t_loss = loss_object(labels, predictions)
    
    test_loss(t_loss)
    test_accuracy(labels, predictions)

학습 시작

for epoch in range(1):
    for image, label in train_ds:
        train_step(image, label)
        
    for test_image, test_label in test_ds:
        test_step(test_image, test_label)
        
    template = 'Epoch {}, Loss : {}, Accuracy : {}, Test Loss {}, Test Accuracy : {}'
    tf.print(template.format(epoch + 1,
                            train_loss.result(),
                            train_accuracy.result() * 100,
                            test_loss.result(),
                            test_accuracy.result() * 100))

댓글 남기기

이메일은 공개되지 않습니다. 필수 입력창은 * 로 표시되어 있습니다