TensorFlowサンプルコード
Revisão | 47c42744f7cdbefcb9dd7b1c09ee0e27076560c9 (tree) |
---|---|
Hora | 2018-01-18 19:22:38 |
Autor | hylom <hylom@hylo...> |
Commiter | hylom |
add tensorboard_test5.py
@@ -0,0 +1,167 @@ | ||
1 | +#!/usr/bin/env python | |
2 | +# -*- coding: utf-8 -*- | |
3 | + | |
4 | +import tensorflow as tf | |
5 | + | |
6 | +INPUT_SIZE = 15 | |
7 | +W1_SIZE = 15 | |
8 | +OUTPUT_SIZE = 10 | |
9 | + | |
10 | +with tf.variable_scope('model') as scope: | |
11 | + | |
12 | + # 入力 | |
13 | + x1 = tf.placeholder(dtype=tf.float32, name="x1") | |
14 | + y = tf.placeholder(dtype=tf.float32, name="y") | |
15 | + | |
16 | + # 第2層 | |
17 | + tf.set_random_seed(1234) | |
18 | + W1 = tf.get_variable("W1", | |
19 | + shape=[INPUT_SIZE, W1_SIZE], | |
20 | + dtype=tf.float32, | |
21 | + initializer=tf.random_normal_initializer(stddev=0.05)) | |
22 | + b1 = tf.get_variable("b1", | |
23 | + shape=[W1_SIZE], | |
24 | + dtype=tf.float32, | |
25 | + initializer=tf.random_normal_initializer(stddev=0.05)) | |
26 | + x2 = tf.sigmoid(tf.matmul(x1, W1) + b1, name="x2") | |
27 | + | |
28 | + # W1のヒストグラムを記録 | |
29 | + tf.summary.histogram('W1', W1) | |
30 | + | |
31 | + # 第3層 | |
32 | + W2 = tf.get_variable("W2", | |
33 | + shape=[W1_SIZE, OUTPUT_SIZE], | |
34 | + dtype=tf.float32, | |
35 | + initializer=tf.random_normal_initializer(stddev=0.05)) | |
36 | + b2 = tf.get_variable("b2", | |
37 | + shape=[OUTPUT_SIZE], | |
38 | + dtype=tf.float32, | |
39 | + initializer=tf.random_normal_initializer(stddev=0.05)) | |
40 | + x3 = tf.nn.softmax(tf.matmul(x2, W2) + b2, name="x3") | |
41 | + | |
42 | + # コスト関数 | |
43 | + cross_entropy = -tf.reduce_sum(y * tf.log(x3), name="cross_entropy") | |
44 | + tf.summary.scalar('cross_entropy', cross_entropy) | |
45 | + | |
46 | + # 正答率 | |
47 | + correct = tf.equal(tf.argmax(x3,1), tf.argmax(y,1), name="correct") | |
48 | + accuracy = tf.reduce_mean(tf.cast(correct, "float"), name="accuracy") | |
49 | + tf.summary.scalar('accuracy', accuracy) | |
50 | + | |
51 | + | |
52 | + # 最適化アルゴリズムを定義 | |
53 | + global_step = tf.Variable(0, name='global_step', trainable=False) | |
54 | + optimizer = tf.train.GradientDescentOptimizer(0.01, name="optimizer") | |
55 | + minimize = optimizer.minimize(cross_entropy, global_step=global_step, name="minimize") | |
56 | + | |
57 | + # 学習結果を保存するためのオブジェクトを用意 | |
58 | + saver = tf.train.Saver() | |
59 | + | |
60 | + | |
61 | +with tf.variable_scope('pipeline') as scope: | |
62 | + ## データセットを読み込むためのパイプラインを作成する | |
63 | + # リーダーオブジェクトを作成する | |
64 | + reader = tf.TextLineReader() | |
65 | + | |
66 | + # 読み込む対象のファイルを格納したキューを作成する | |
67 | + file_queue = tf.train.string_input_producer(["digits_data.csv", "test_data.csv"]) | |
68 | + | |
69 | + # キューからデータを読み込む | |
70 | + key, value = reader.read(file_queue) | |
71 | + | |
72 | + # 読み込んだCSV型式データをデコードする | |
73 | + # [[] for i in range(16)] は | |
74 | + # [[], [], [], [], [], [], [], [], | |
75 | + # [], [], [], [], [], [], [], []]に相当 | |
76 | + data = tf.decode_csv(value, record_defaults=[[] for i in range(16)]) | |
77 | + | |
78 | + # 10件のデータを読み出す | |
79 | + # 10件ずつデータを読み出す | |
80 | + # 第1カラム(data[0])はその文字が示す数だが、 | |
81 | + # ニューラルネットワークの出力は10要素の1次元テンソルとなる。 | |
82 | + # そのため、10×10の対角行列を作成し、そのdata[0]行目を取り出す操作を行うことで | |
83 | + # 1次元テンソルに変換する。dataは浮動小数点小数型なので、このとき | |
84 | + # int32型にキャストして使用する | |
85 | + data_x, data_y, y_value = tf.train.batch([ | |
86 | + tf.stack(data[1:]), | |
87 | + tf.reshape(tf.slice(tf.eye(10), [tf.cast(data[0], tf.int32), 0], [1, 10]), [10]), | |
88 | + tf.cast(data[0], tf.int64), | |
89 | + ], 10) | |
90 | + | |
91 | +# セッションの作成 | |
92 | +sess = tf.Session() | |
93 | + | |
94 | +# 変数の初期化を実行する | |
95 | +sess.run(tf.global_variables_initializer()) | |
96 | + | |
97 | +# 学習結果を保存したファイルが存在するかを確認し、 | |
98 | +# 存在していればそれを読み出す | |
99 | +latest_filename = tf.train.latest_checkpoint("./") | |
100 | +if latest_filename: | |
101 | + print("load saved model {}".format(latest_filename)) | |
102 | + saver.restore(sess, latest_filename) | |
103 | + | |
104 | +# サマリを取得するための処理 | |
105 | +summary_op = tf.summary.merge_all() | |
106 | +summary_writer = tf.summary.FileWriter('data', graph=sess.graph) | |
107 | + | |
108 | + | |
109 | +# コーディネータの作成 | |
110 | +coord = tf.train.Coordinator() | |
111 | + | |
112 | +# キューの開始 | |
113 | +threads = tf.train.start_queue_runners(sess=sess, coord=coord) | |
114 | + | |
115 | +# ファイルからのデータの読み出し | |
116 | +# 1回目のデータ読み込み。1つ目のファイルから10件のデータが読み込まれる | |
117 | +# 1つ目のファイルには10件のデータがあるので、これで全データが読み込まれる | |
118 | +dataset_x, dataset_y, values_y = sess.run([data_x, data_y, y_value]) | |
119 | + | |
120 | +# 2回目のデータ読み込み。1つ目のファイルのデータはすべて読み出したので、 | |
121 | +# 続けて2つ目のファイルから読み込みが行われる。 | |
122 | +testdata_x, testdata_y, testvalues_y = sess.run([data_x, data_y, y_value]) | |
123 | + | |
124 | +# 学習を開始 | |
125 | +for i in range(100): | |
126 | + for j in range(100): | |
127 | + _, summary = sess.run([minimize, summary_op], {x1: dataset_x, y: dataset_y}) | |
128 | + print("CROSS ENTROPY:", sess.run(cross_entropy, {x1: dataset_x, y: dataset_y})) | |
129 | + summary_writer.add_summary(summary, global_step=tf.train.global_step(sess, global_step)) | |
130 | + | |
131 | +# 結果を保存する | |
132 | +save_path = saver.save(sess, "./model", global_step=tf.train.global_step(sess, global_step)) | |
133 | +print("Model saved to {}".format(save_path)) | |
134 | + | |
135 | +## 結果の出力 | |
136 | +# 出力テンソルの中でもっとも値が大きいもののインデックスが | |
137 | +# 正答と等しいかどうかを計算する | |
138 | +y_value = tf.placeholder(dtype=tf.int64) | |
139 | +correct = tf.equal(tf.argmax(x3,1), y_value) | |
140 | +accuracy = tf.reduce_mean(tf.cast(correct, "float")) | |
141 | + | |
142 | +# 学習に使用したデータを入力した場合の | |
143 | +# ニューラルネットワークの出力を表示 | |
144 | +print("----result----") | |
145 | +print("raw output:") | |
146 | +print(sess.run(x3,feed_dict={x1: dataset_x})) | |
147 | +print("answers:", sess.run(tf.argmax(x3, 1), feed_dict={x1: dataset_x})) | |
148 | + | |
149 | +# このときの正答率を出力 | |
150 | +print("accuracy:", sess.run(accuracy, feed_dict={x1: dataset_x, y_value: values_y})) | |
151 | + | |
152 | + | |
153 | +# テスト用データを入力した場合の | |
154 | +# ニューラルネットワークの出力を表示 | |
155 | +print("----test----") | |
156 | +print("raw output:") | |
157 | +print(sess.run(x3,feed_dict={x1: testdata_x})) | |
158 | +print("answers:", sess.run(tf.argmax(x3, 1), feed_dict={x1: testdata_x})) | |
159 | + | |
160 | +# このときの正答率を出力 | |
161 | +print("accuracy:", sess.run(accuracy, feed_dict={x1: testdata_x, y_value: testvalues_y})) | |
162 | + | |
163 | + | |
164 | + | |
165 | +# キューの終了 | |
166 | +coord.request_stop() | |
167 | +coord.join(threads) |