W3Cschool
恭喜您成為首批注冊用戶
獲得88經(jīng)驗值獎勵
注意
通過在反向過程中為每個檢查點段重新運行一個正向通過段來實現(xiàn)檢查點。 這可能會導致像 RNG 狀態(tài)這樣的持久狀態(tài)比沒有檢查點的狀態(tài)更先進。 默認情況下,檢查點包括處理 RNG 狀態(tài)的邏輯,以便與非檢查點通過相比,使用 RNG(例如,通過丟棄)的檢查點通過具有確定的輸出。 根據(jù)檢查點操作的運行時間,存儲和恢復 RNG 狀態(tài)的邏輯可能會導致性能下降。 如果不需要與非檢查點通過相比確定的輸出,則在每個檢查點期間向checkpoint
或checkpoint_sequential
提供preserve_rng_state=False
,以忽略存儲和恢復 RNG 狀態(tài)。
隱藏邏輯將當前設備以及所有 cuda Tensor 參數(shù)的設備的 RNG 狀態(tài)保存并恢復到run_fn
。 但是,該邏輯無法預料用戶是否將張量移動到run_fn
本身內的新設備。 因此,如果在run_fn
中將張量移動到新設備(“新”表示不屬于[當前設備+張量參數(shù)的設備的集合]),則與非檢查點傳遞相比,確定性輸出將永遠無法保證。
torch.utils.checkpoint.checkpoint(function, *args, **kwargs)?
檢查點模型或模型的一部分
檢查點通過將計算交換為內存來工作。 檢查點部分沒有存儲整個計算圖的所有中間激活以進行向后計算,而是由而不是保存中間激活,而是在向后傳遞時重新計算它們。 它可以應用于模型的任何部分。
具體而言,在前向傳遞中,function
將以torch.no_grad()
方式運行,即不存儲中間激活。 相反,前向傳遞保存輸入元組和function
參數(shù)。 在向后遍歷中,檢索保存的輸入和function
,并再次在function
上計算正向遍歷,現(xiàn)在跟蹤中間激活,然后使用這些激活值計算梯度。
警告
檢查點不適用于 torch.autograd.grad()
,而僅適用于 torch.autograd.backward()
。
Warning
如果后退期間的function
調用與前退期間的調用有任何不同,例如,由于某些全局變量,則檢查點版本將不相等,很遺憾,無法檢測到該版本。
參數(shù)
(activation, hidden)
,則function
應正確使用第一個輸入作為activation
,第二個輸入作為hidden
function
輸入的元組退貨
在*args
上運行function
的輸出
torch.utils.checkpoint.checkpoint_sequential(functions, segments, *inputs, **kwargs)?
用于檢查點順序模型的輔助功能。
順序模型按順序(依次)執(zhí)行模塊/功能列表。 因此,我們可以將這樣的模型劃分為不同的段,并在每個段上檢查點。 除最后一個段外,所有段都將以torch.no_grad()
方式運行,即不存儲中間激活。 將保存每個檢查點線段的輸入,以便在后向傳遞中重新運行該線段。
有關檢查點的工作方式,請參見 checkpoint()
。
Warning
Checkpointing doesn't work with torch.autograd.grad()
, but only with torch.autograd.backward()
.
Parameters
torch.nn.Sequential
或要順序運行的模塊或功能列表(包含模型)。functions
的輸入Returns
在*inputs
上順序運行functions
的輸出
例
>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)
Copyright©2021 w3cschool編程獅|閩ICP備15016281號-3|閩公網(wǎng)安備35020302033924號
違法和不良信息舉報電話:173-0602-2364|舉報郵箱:jubao@eeedong.com
掃描二維碼
下載編程獅App
編程獅公眾號
聯(lián)系方式:
更多建議: