對 Tensorflow 的架構以及 Session 有了基本概念,接下來要了解 Tensorflow 是怎麼利用 Variables 來 maintain state。

當訓練一個 model 的時候,variables 是用來保存和更新參數的。接下來的這個例子示範了用 variable 來做一個 counter。初始值為 0,每次往上加 1。

定義 variable

首先,定義一個 variable 叫做 state,它的初始值為 0,名字叫做 “counter”。以及一個 op 來把 1 加到 state 上。

1
2
3
4
5
6
7
8
# Create a Variable, that will be initialized to the scalar value 0.
state = tf.Variable(0, name='counter')
# Create an Op to add one to `state`.
one = tf.constant(1)
new_value = tf.add(state, one) # not adding directly by this line
update = tf.assign(state, new_value) # assign new_value to state

初始化 variable

在 Tensorflow 中,如果定義了一些 Variables,那麼一定要對它們做初始化。

1
2
3
# Variables must be initialized by running an `init` Op after having
# launched the graph. We first have to add the `init` Op to the graph.
init_op = tf.global_variables_initializer()

一直到這步,其實這些 variables 都還沒有被 activate,必須一直到 sess.run(init_op) 這一步,才算真正初始化。而要真正讓 state 往上加 1,則需要透過 sess.run(update) 這一步。

1
2
3
4
5
6
7
8
9
with tf.Session() as sess:
# Run the 'init' op
sess.run(init_op)
# Print the initial value of 'state'
print('initial value: ', sess.run(state))
# Run the op that updates 'state' and print 'state'.
for _ in range(3):
sess.run(update)
print('step', _, ': ', sess.run(state))

完整程式碼執行結果如下:

initial value:  0
step 0 :  1
step 1 :  2
step 2 :  3