到目前為止,似乎我們在建立網(wǎng)絡(luò)時草率地逃脫了懲罰。具體來說,我們做了以下不符合直覺的事情,這些事情可能看起來不應(yīng)該起作用:
-
我們在沒有指定輸入維度的情況下定義了網(wǎng)絡(luò)架構(gòu)。
-
我們添加層時沒有指定前一層的輸出維度。
您可能會對我們的代碼完全運行感到驚訝。畢竟,深度學(xué)習(xí)框架無法判斷網(wǎng)絡(luò)的輸入維數(shù)。這里的技巧是框架推遲初始化,等到我們第一次通過模型傳遞數(shù)據(jù)時,動態(tài)推斷每一層的大小。
稍后,當使用卷積神經(jīng)網(wǎng)絡(luò)時,該技術(shù)將變得更加方便,因為輸入維度(即圖像的分辨率)將影響每個后續(xù)層的維度。因此,在編寫代碼時無需知道維度是多少就可以設(shè)置參數(shù)的能力可以極大地簡化指定和隨后修改模型的任務(wù)。接下來,我們將更深入地研究初始化機制。
import tensorflow as tf
首先,讓我們實例化一個 MLP。
net = nn.Sequential()
net.add(nn.Dense(256, activation='relu'))
net.add(nn.Dense(10))
net = tf.keras.models.Sequential([
tf.keras.layers.Dense(256, activation=tf.nn.relu),
tf.keras.layers.Dense(10),
])
此時,網(wǎng)絡(luò)不可能知道輸入層權(quán)重的維度,因為輸入維度仍然未知。
Consequently the framework has not yet initialized any parameters. We confirm by attempting to access the parameters below.
<bound method Block.collect_params of Sequential(
(0): Dense(-1 -> 256, Activation(relu))
(1): Dense(-1 -> 10, linear)
)>
sequential0_ (
Parameter dense0_weight (shape=(256, -1), dtype=float32)
Parameter dense0_bias (shape=(256,), dtype=float32)
Parameter dense1_weight (shape=(10, -1), dtype=float32)
Parameter dense1_bias (shape=(10,), dtype=float32)
)
Note that while the parameter objects exist, the input dimension to each layer is listed as -1. MXNet uses the special value -1 to indicate that the parameter dimension remains unknown. At this point, attempts to access net[0].weight.data()
would trigger a runtime error stating that the network must be initialized before the parameters can be accessed. Now let’s see what happens when we attempt to initialize parameters via the initialize
method.
sequential0_ (
Parameter dense0_weight (shape=(256, -1), dtype=float32)
Parameter dense0_bias (shape=(256,), dtype=float32)
Parameter dense1_weight (shape=(10, -1), dtype=float32)
Parameter dense1_bias (shape=(10,), dtype=float32)
)
As we can see, nothing has changed. When input dimensions are unknown, calls to initialize do not truly initialize the parameters. Instead, this call registers to MXNet that we wish (and optionally, according to which distribution) to initialize the parameters.
As mentioned in Section 6.2.1, parameters and the network definition are decoupled in Jax and Flax, and the user handles both manually. Flax models are stateless hence there is no parameters
attribute.
Consequently the framework has not yet initialized any parameters. We confirm by attempting to access the parameters below.
[[], []]
Note that each layer objects exist but the weights are empty. Using net.get_weights()
would throw an error since the weights have not been initialized yet.
接下來讓我們通過網(wǎng)絡(luò)傳遞數(shù)據(jù),讓框架最終初始化參數(shù)。
sequential0_ (
Parameter dense0_weight (shape=(256, 20), dtype=float32)
Parameter dense0_bias (shape=(256,), dtype=float32)
Parameter dense1_weight (shape=(10, 256), dtype=float32)
Parameter dense1_bias (shape=(10,), dtype=float32)
)
params = net.init(d2l.get_key(), jnp.zeros((2, 20)))
jax.tree_util.tree_map(lambda x: x.shape, params).tree_flatten()
(({'params': {'layers_0': {'bias': (256,), 'kernel': (20, 256)},
'layers_2': {'bias': (10,), 'kernel': (256, 10)}}},),
())
一旦我們知道輸入維度 20,框架就可以通過插入值 20 來識別第一層權(quán)重矩陣的形狀。識別出第一層的形狀后,框架進入第二層,依此類推計算圖,直到所有形狀都已知。請注意,在這種情況下,只有第一層需要延遲初始化,但框架會按順序進行初始化。一旦知道所有參數(shù)形狀,框架就可以最終初始化參數(shù)。
以下方法通過網(wǎng)絡(luò)傳入虛擬輸入以進行試運行,以推斷所有參數(shù)形狀并隨后初始化參數(shù)。當不需要默認隨機初始化時,稍后將使用它。
Parameter initialization in Flax is always done manually and handled by the user. The following method takes a dummy input and a key dictionary as argument. This key dictionary has the rngs for initializing the model parameters and dropout rng for generating the dropout mask for the models with dropout layers. More about dropout will be covered later in Section 5.6. Ultimately the method initializes the model returning the parameters. We have been using it under the hood in the previous sections as well.
6.4.1. 概括
延遲初始化可能很方便,允許框架自動推斷參數(shù)形狀,從而輕松修改架構(gòu)并消除一種常見的錯誤來源。我們可以通過模型傳遞數(shù)據(jù),讓框架最終初始化參數(shù)。
6.4.2. 練習(xí)
-
如果您將輸入維度指定給第一層而不是后續(xù)層,會發(fā)生什么情況?你會立即初始化嗎?
-
如果您指定不匹配的尺寸會發(fā)生什么?
-
如果你有不同維度的輸入,你需要做什么?提示:查看參數(shù)綁定。
評論
查看更多