書籍転載:TensorFlowはじめました ― 実践!最新Googleマシンラーニング(4)
TensorFlowでデータの読み込み ― 画像を分類するCIFAR-10の基礎
転載4回目。今回から「畳み込みニューラルネットワーク」のモデルを構築して、CIFAR-10のデータセットを使った学習と評価を行う。今回はデータの読み込みを説明。
書籍『TensorFlowはじめました ― 実践!最新Googleマシンラーニング』から全7本の記事を転載します。本稿はその4回目です。今回から「第2章 CIFAR-10の学習と評価」に入ります。今回はデータの読み込みについて説明します。
書籍転載について
本コーナーは、インプレスR&D[Next Publishing]発行の書籍『TensorFlowはじめました ― 実践!最新Googleマシンラーニング』の中から、特にBuild Insiderの読者に有用だと考えられる項目を編集部が選び、同社の許可を得て転載したものです。
『TensorFlowはじめました ― 実践!最新Googleマシンラーニング』(Kindle電子書籍もしくはオンデマンドペーパーバック)の詳細や購入はAmazon.co.jpのページをご覧ください。書籍全体の目次は連載INDEXページに掲載しています。プログラムのダウンロードは、「TensorFlowはじめました」のサポート用フォームから行えます。
ご注意
本記事は、書籍の内容を改変することなく、そのまま転載したものです。このため用字用語の統一ルールなどはBuild Insiderのそれとは一致しません。あらかじめご了承ください。
■
第2章 CIFAR-10の学習と評価
CIFAR-10*1は、10種類の画像を分類する「多クラス分類」と呼ばれるタスクの画像セットです。
- *1 CIFAR: Canadian Institute For Advanced Research。
本章では、「畳み込みニューラルネットワーク」のモデルを構築して、CIFAR-10のデータセットを使った学習と評価を行います。
2.1 データの読み込み
データの入手
CIFAR-10のデータセットは、トロント大学のAlex Krizhevsky氏が配布しています。
- The CIFAR-10 dataset : https://www.cs.toronto.edu/~kriz/cifar.html
Pythonで直接読み込める形式(Pickle形式)が用意されてますが、今回はバイナリ形式の「CIFAR-10 binary version (suitable for C programs)」をダウンロードします。
ダウンロードしたZIPファイルには、6つのファイルが含まれています。
data_batch_[1-5].bin
が訓練データ、test_batch.bin
がテストデータです。以後、これらのファイルをディレクトリ./data
に配置したとして解説を進めます。
CIFAR-10のデータ構造
CIFAR-10のデータは、JPEGやPNGなど一般的な画像形式ではありません。そのため、データ構造に合わせてプログラムの中で読み込み、さらにTensorFlowで取り扱うことができる形式へ変換する必要があります。
1つのファイル(データセット)には、10,000個のレコードが含まれています(図2.2)。
レコードは固定長の3,073バイト。先頭の1バイトがラベルで、残り3,072バイトは縦横32pxの画像データを直列化したものです。一般的なBitmap形式と違い、RGB各チャンネルのデータが1,024バイトずつ並ぶ構造になっています。
画像ラベルと種類(クラス)の対応は図2.1の通りです。
読み込みと構造変更
リスト2.1は、CIFAR-10形式のデータセットを読み込むプログラムです。
まず、Cifar10Reader
のコンストラクタに、読み込むデータセット(ファイル)の名前を指定します。次に、read
メソッドに「レコード番号」を指定すると対応するレコード(Cifar10Record
)が得られます。
# coding: UTF-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
class Cifar10Record(object):
width = 32
height = 32
depth = 3
def set_label(self, label_byte):
self.label = np.frombuffer(label_byte, dtype=np.uint8)
def set_image(self, image_bytes):
byte_buffer = np.frombuffer(image_bytes, dtype=np.int8)
reshaped_array = np.reshape(byte_buffer,
[self.depth, self.width, self.height])
self.byte_array = np.transpose(reshaped_array, [1, 2, 0])
self.byte_array = self.byte_array.astype(np.float32)
class Cifar10Reader(object):
def __init__(self, filename):
if not os.path.exists(filename):
print(filename + ' is not exist')
return
self.bytestream = open(filename, mode="rb")
def close(self):
if not self.bytestream:
self.bytestream.close()
def read(self, index):
result = Cifar10Record()
label_bytes = 1
image_bytes = result.height * result.width * result.depth
record_bytes = label_bytes + image_bytes
self.bytestream.seek(record_bytes * index, 0)
result.set_label(self.bytestream.read(label_bytes))
result.set_image(self.bytestream.read(image_bytes))
return result
|
read
メソッド内で取得する画像データは、最初、直列化された1次元配列です。TensorFlowで取り扱うには、一般的なBitmap形式の「幅・高さ・チャンネル」の構造に変換する必要があります。
そこで、リスト2.1ではまず、numpyのreshape
で一次元配列から「チャンネル・幅・高さ」の3次元配列に変換しています。次にnumpyのtranspose
関数を用いて「幅・高さ・チャンネル」の形式に配列を転置して、最後にfloat32
に型を変換しています。
PNG形式で書き出し
CIFAR-10形式から画像データを取り出すことができたので、取得した画像データを一般的な画像形式で保存してみましょう(リスト2.2)。
この行程はTensorFlowを学ぶ上では必要ありませんが、画像データの操作の練習として試してみてください。
# coding: UTF-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
from PIL import Image
from reader import Cifar10Reader
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('file', None, "処理するファイルのパス")
tf.app.flags.DEFINE_integer('offset', 0, "読み飛ばすレコード数")
tf.app.flags.DEFINE_integer('length', 16, "読み込んで変換するレコード数")
basename = os.path.basename(FLAGS.file)
path = os.path.dirname(FLAGS.file)
reader = Cifar10Reader(FLAGS.file)
stop = FLAGS.offset + FLAGS.length
for index in range(FLAGS.offset, stop):
image = reader.read(index)
print('label: %d' % image.label)
imageshow = Image.fromarray(image.byte_array.astype(np.uint8))
file_name = '%s-%02d-%d.png' % (basename, index, image.label)
file = os.path.join(path, file_name)
with open(file, mode='wb') as out:
imageshow.save(out, format='png')
reader.close()
|
処理するファイルをコマンドラインの引数に指定して実行すると、図2.3のような画像が、データセットと同じディレクトリに書き出されます。
$ python3 convert_cifar10_png.py --file ./data/data_batch_1.bin
|
Note: tf.app.flags
リスト2.2冒頭にあるtf.app.flagsは、コマンドラインの引数を簡単に設定する機能を提供します。
これはGoogleがオープンソースで公開している「gflags」のPython実装「python-gflags」と同等のものです。
■
今回は「画像データの読み込み」を行いました。次回は、「推論」(=画像の種類・クラスを判別)について説明します。
※以下では、本稿の前後を合わせて5回分(第1回~第5回)のみ表示しています。
連載の全タイトルを参照するには、[この記事の連載 INDEX]を参照してください。
1. TensorFlowとは? データフローグラフを構築・実行してみよう
技術書オンリー即売会「技術書典」で頒布された同名出版物をベースとして制作されたTensorFlowの入門書籍を転載開始。その1回目として、データフローグラフや定数といったTensorFlowの基礎を説明する。
3. TensorFlowの“テンソル(Tensor)”とは? TensorBoardの使い方
転載3回目。テンソル(Tensor)とTensorBoardによるグラフの可視化を解説する。「第1章 TensorFlowの基礎」は今回で完結。
4. 【現在、表示中】≫ TensorFlowでデータの読み込み ― 画像を分類するCIFAR-10の基礎
転載4回目。今回から「畳み込みニューラルネットワーク」のモデルを構築して、CIFAR-10のデータセットを使った学習と評価を行う。今回はデータの読み込みを説明。
5. TensorFlowによる推論 ― 画像を分類するCIFAR-10の基礎
転載5回目。CIFAR-10データセットを使った学習と評価を行う。画像データの読み込みが終わったので、今回は画像の種類(クラス)を判別、つまり「推論」について説明する。