PyTorch torch.utils.checkpoint

2020-09-15 11:40 更新

原文: PyTorch torch.utils.checkpoint

注意

通過在反向過程中為每個檢查點段重新運行一個正向通過段來實現(xiàn)檢查點。 這可能會導致像 RNG 狀態(tài)這樣的持久狀態(tài)比沒有檢查點的狀態(tài)更先進。 默認情況下,檢查點包括處理 RNG 狀態(tài)的邏輯,以便與非檢查點通過相比,使用 RNG(例如,通過丟棄)的檢查點通過具有確定的輸出。 根據(jù)檢查點操作的運行時間,存儲和恢復 RNG 狀態(tài)的邏輯可能會導致性能下降。 如果不需要與非檢查點通過相比確定的輸出,則在每個檢查點期間向checkpointcheckpoint_sequential提供preserve_rng_state=False,以忽略存儲和恢復 RNG 狀態(tài)。

隱藏邏輯將當前設備以及所有 cuda Tensor 參數(shù)的設備的 RNG 狀態(tài)保存并恢復到run_fn。 但是,該邏輯無法預料用戶是否將張量移動到run_fn本身內的新設備。 因此,如果在run_fn中將張量移動到新設備(“新”表示不屬于[當前設備+張量參數(shù)的設備的集合]),則與非檢查點傳遞相比,確定性輸出將永遠無法保證。

  1. torch.utils.checkpoint.checkpoint(function, *args, **kwargs)?

檢查點模型或模型的一部分

檢查點通過將計算交換為內存來工作。 檢查點部分沒有存儲整個計算圖的所有中間激活以進行向后計算,而是由而不是保存中間激活,而是在向后傳遞時重新計算它們。 它可以應用于模型的任何部分。

具體而言,在前向傳遞中,function將以torch.no_grad()方式運行,即不存儲中間激活。 相反,前向傳遞保存輸入元組和function參數(shù)。 在向后遍歷中,檢索保存的輸入和function,并再次在function上計算正向遍歷,現(xiàn)在跟蹤中間激活,然后使用這些激活值計算梯度。

警告

檢查點不適用于 torch.autograd.grad() ,而僅適用于 torch.autograd.backward() 。

Warning

如果后退期間的function調用與前退期間的調用有任何不同,例如,由于某些全局變量,則檢查點版本將不相等,很遺憾,無法檢測到該版本。

參數(shù)

  • 函數(shù) –描述在模型的正向傳遞中或模型的一部分中運行的內容。 它還應該知道如何處理作為元組傳遞的輸入。 例如,在 LSTM 中,如果用戶通過(activation, hidden),則function應正確使用第一個輸入作為activation,第二個輸入作為hidden
  • reserve_rng_state (bool , 可選 默認= True 在每個檢查點期間恢復 RNG 狀態(tài)。
  • args –包含function輸入的元組

退貨

*args上運行function的輸出

  1. torch.utils.checkpoint.checkpoint_sequential(functions, segments, *inputs, **kwargs)?

用于檢查點順序模型的輔助功能。

順序模型按順序(依次)執(zhí)行模塊/功能列表。 因此,我們可以將這樣的模型劃分為不同的段,并在每個段上檢查點。 除最后一個段外,所有段都將以torch.no_grad()方式運行,即不存儲中間激活。 將保存每個檢查點線段的輸入,以便在后向傳遞中重新運行該線段。

有關檢查點的工作方式,請參見 checkpoint() 。

Warning

Checkpointing doesn't work with torch.autograd.grad(), but only with torch.autograd.backward().

Parameters

  • 功能 –一個 torch.nn.Sequential 或要順序運行的模塊或功能列表(包含模型)。
  • –在模型中創(chuàng)建的塊數(shù)
  • 輸入 –張量元組,它們是functions的輸入
  • preserve_rng_state (bool__, optional__, default=True) – Omit stashing and restoring the RNG state during each checkpoint.

Returns

*inputs上順序運行functions的輸出

  1. >>> model = nn.Sequential(...)
  2. >>> input_var = checkpoint_sequential(model, chunks, input_var)
以上內容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號