PyTorch torch.hub

2020-09-15 10:12 更新

原文:PyTorch torch.hub

Pytorch Hub 是經(jīng)過(guò)預(yù)先訓(xùn)練的模型資料庫(kù),旨在促進(jìn)研究的可重復(fù)性。

發(fā)布模型

Pytorch Hub 支持通過(guò)添加簡(jiǎn)單的hubconf.py文件將預(yù)訓(xùn)練的模型(模型定義和預(yù)訓(xùn)練的權(quán)重)發(fā)布到 github 存儲(chǔ)庫(kù);

hubconf.py可以有多個(gè)入口點(diǎn)。 每個(gè)入口點(diǎn)都定義為 python 函數(shù)(例如:您要發(fā)布的經(jīng)過(guò)預(yù)先訓(xùn)練的模型)。

  1. def entrypoint_name(*args, **kwargs):
  2. # args & kwargs are optional, for models which take positional/keyword arguments.
  3. ...

如何實(shí)現(xiàn)入口點(diǎn)?

如果我們擴(kuò)展pytorch/vision/hubconf.py中的實(shí)現(xiàn),則以下代碼段指定了resnet18模型的入口點(diǎn)。 在大多數(shù)情況下,在hubconf.py中導(dǎo)入正確的功能就足夠了。 在這里,我們僅以擴(kuò)展版本為例來(lái)說(shuō)明其工作原理。

  1. dependencies = ['torch']
  2. from torchvision.models.resnet import resnet18 as _resnet18
  3. ## resnet18 is the name of entrypoint
  4. def resnet18(pretrained=False, **kwargs):
  5. """ # This docstring shows up in hub.help()
  6. Resnet18 model
  7. pretrained (bool): kwargs, load pretrained weights into the model
  8. """
  9. # Call the model, load pretrained weights
  10. model = _resnet18(pretrained=pretrained, **kwargs)
  11. return model

  • dependencies變量是加載模型所需的軟件包名稱的列表。 請(qǐng)注意,這可能與訓(xùn)練模型所需的依賴項(xiàng)稍有不同。
  • argskwargs傳遞給實(shí)際的可調(diào)用函數(shù)。
  • 該函數(shù)的文檔字符串用作幫助消息。 它解釋了模型做什么以及允許的位置/關(guān)鍵字參數(shù)是什么。 強(qiáng)烈建議在此處添加一些示例。
  • Entrypoint 函數(shù)可以返回模型(nn.module),也可以返回輔助工具以使用戶工作流程更流暢,例如 標(biāo)記器。
  • 帶下劃線前綴的可調(diào)用項(xiàng)被視為輔助功能,不會(huì)在torch.hub.list()中顯示。
  • 預(yù)訓(xùn)練的權(quán)重既可以存儲(chǔ)在 github 存儲(chǔ)庫(kù)中,也可以由torch.hub.load_state_dict_from_url()加載。 如果少于 2GB,建議將其附加到項(xiàng)目版本,并使用該版本中的網(wǎng)址。 在上面的示例中,torchvision.models.resnet.resnet18處理pretrained,或者,您可以在入口點(diǎn)定義中添加以下邏輯。

  1. if pretrained:
  2. # For checkpoint saved in local github repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
  3. dirname = os.path.dirname(__file__)
  4. checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
  5. state_dict = torch.load(checkpoint)
  6. model.load_state_dict(state_dict)
  7. # For checkpoint saved elsewhere
  8. checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
  9. model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))

重要通知

  • 發(fā)布的模型應(yīng)至少在分支/標(biāo)簽中。 不能是隨機(jī)提交。

從集線器加載模型

Pytorch Hub 提供了便捷的 API,可通過(guò)torch.hub.list()瀏覽集線器中的所有可用模型,通過(guò)torch.hub.help()顯示文檔字符串和示例,并使用torch.hub.load()加載經(jīng)過(guò)預(yù)先訓(xùn)練的模型

  1. torch.hub.list(github, force_reload=False)?

列出 <cite>github</cite> hubconf 中可用的所有入口點(diǎn)。

參數(shù)

  • github (字符串)–格式為“ repo_owner / repo_name [:tag_name]”的字符串,帶有可選的標(biāo)記/分支。 如果未指定,則默認(rèn)分支為<cite>主站</cite>。 示例:“ pytorch / vision [:hub]”
  • force_reload (bool , 可選)–是否放棄現(xiàn)有緩存并強(qiáng)制重新下載。 默認(rèn)值為<cite>否</cite>。

退貨

可用入口點(diǎn)名稱的列表

返回類型

入口點(diǎn)

  1. >>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)

  1. torch.hub.help(github, model, force_reload=False)?

顯示入口點(diǎn)<cite>模型</cite>的文檔字符串。

Parameters

  • github (字符串)–格式為< repo_owner / repo_name [:tag_name] [:HT_7]的字符串,帶有可選的標(biāo)記/分支。 如果未指定,則默認(rèn)分支為<cite>主站</cite>。 示例:“ pytorch / vision [:hub]”
  • 模型(字符串)–在存儲(chǔ)庫(kù)的 hubconf.py 中定義的入口點(diǎn)名稱字符串
  • force_reload (bool__, optional) – whether to discard the existing cache and force a fresh download. Default is <cite>False</cite>.

Example

  1. >>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))

  1. torch.hub.load(github, model, *args, **kwargs)?

使用預(yù)訓(xùn)練的權(quán)重從 github 存儲(chǔ)庫(kù)加載模型。

Parameters

  • github (string) – a string with format “repo_owner/repo_name[:tag_name]” with an optional tag/branch. The default branch is <cite>master</cite> if not specified. Example: 'pytorch/vision[:hub]'
  • model (string) – a string of entrypoint name defined in repo's hubconf.py
  • args (可選*)–可調(diào)用<cite>模型</cite>的相應(yīng) args。
  • force_reload (bool , 可選)–是否無(wú)條件強(qiáng)制重新下載 github 存儲(chǔ)庫(kù)。 默認(rèn)值為<cite>否</cite>。
  • 詳細(xì) (bool , 可選)–如果為 False,則忽略有關(guān)命中本地緩存的消息。 請(qǐng)注意,有關(guān)首次下載的消息不能被靜音。 默認(rèn)值為<cite>為真</cite>。
  • \ kwargs (可選)–可調(diào)用<cite>模型</cite>的相應(yīng) kwargs。

Returns

具有相應(yīng)預(yù)訓(xùn)練權(quán)重的單個(gè)模型。

Example

  1. >>> model = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)

  1. torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)?

將給定 URL 上的對(duì)象下載到本地路徑。

Parameters

  • url (字符串)–要下載的對(duì)象的 URL
  • dst (字符串)–保存對(duì)象的完整路徑,例如 <cite>/ tmp / temporary_file</cite>
  • hash_prefix (字符串 可選))–如果不是 None,則下載的 SHA256 文件應(yīng)以 <cite>hash_prefix</cite> 開頭。 默認(rèn)值:無(wú)
  • 進(jìn)度 (bool , 可選)–是否顯示 stderr 的進(jìn)度條默認(rèn)值:True

Example

  1. >>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')

  1. torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False)?

將 Torch 序列化對(duì)象加載到給定的 URL。

如果下載的文件是 zip 文件,它將被自動(dòng)解壓縮。

如果 <cite>model_dir</cite> 中已經(jīng)存在該對(duì)象,則將其反序列化并返回。 <cite>model_dir</cite> 的默認(rèn)值為$TORCH_HOME/checkpoints,其中環(huán)境變量$TORCH_HOME的默認(rèn)值為$XDG_CACHE_HOME/torch。 $XDG_CACHE_HOME遵循 Linux 文件系統(tǒng)布局的 X 設(shè)計(jì)組規(guī)范,如果未設(shè)置,則默認(rèn)值為~/.cache。

Parameters

  • url (string) – URL of the object to download
  • model_dir (字符串 , 可選)–保存對(duì)象的目錄
  • map_location (可選)–指定如何重新映射存儲(chǔ)位置的函數(shù)或命令(請(qǐng)參見 torch.load)
  • 進(jìn)度 (bool , 可選)–是否顯示 stderr 進(jìn)度條。 默認(rèn)值:True
  • check_hash (bool 可選)–如果為 True,則 URL 的文件名部分應(yīng)遵循命名約定filename-<sha256>.ext,其中[ <sha256>是文件內(nèi)容的 SHA256 哈希值的前 8 位或更多位。 哈希用于確保唯一的名稱并驗(yàn)證文件的內(nèi)容。 默認(rèn)值:False

Example

  1. >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')

運(yùn)行加載的模型:

注意,torch.load()中的*args, **kwargs用于實(shí)例化模型。 加載模型后,如何找到可以使用該模型的功能? 建議的工作流程是

  • dir(model)查看模型的所有可用方法。
  • help(model.foo)檢查model.foo需要執(zhí)行哪些參數(shù)

為了幫助用戶探索而又不來(lái)回參考文檔,我們強(qiáng)烈建議回購(gòu)所有者使功能幫助消息清晰明了。 包含一個(gè)最小的工作示例也很有幫助。

我下載的模型保存在哪里?

這些位置按以下順序使用

  • 呼叫hub.set_dir(<PATH_TO_HUB_DIR>)
  • $TORCH_HOME/hub,如果設(shè)置了環(huán)境變量TORCH_HOME。
  • $XDG_CACHE_HOME/torch/hub,如果設(shè)置了環(huán)境變量XDG_CACHE_HOME。
  • ~/.cache/torch/hub

  1. torch.hub.set_dir(d)?

(可選)將 hub_dir 設(shè)置為本地目錄,以保存下載的模型&權(quán)重。

如果未調(diào)用set_dir,則默認(rèn)路徑為$TORCH_HOME/hub,其中環(huán)境變量$TORCH_HOME默認(rèn)為$XDG_CACHE_HOME/torch。 $XDG_CACHE_HOME遵循 Linux 文件系統(tǒng)布局的 X 設(shè)計(jì)組規(guī)范,如果未設(shè)置環(huán)境變量,則默認(rèn)值為~/.cache。

Parameters

d (字符串)–本地文件夾的路徑,用于保存下載的模型&權(quán)重。

緩存邏輯

默認(rèn)情況下,加載文件后我們不會(huì)清理文件。 如果hub_dir中已經(jīng)存在,則集線器默認(rèn)使用緩存。

用戶可以通過(guò)調(diào)用hub.load(..., force_reload=True)來(lái)強(qiáng)制重新加載。 這將刪除現(xiàn)有的 github 文件夾和下載的權(quán)重,重新初始化新的下載。 當(dāng)更新發(fā)布到同一分支時(shí),此功能很有用,用戶可以跟上最新版本。

已知限制:

Torch 集線器通過(guò)導(dǎo)入軟件包來(lái)進(jìn)行工作,就像安裝軟件包一樣。 在 Python 中導(dǎo)入會(huì)帶來(lái)一些副作用。 例如,您可以在 Python 緩存sys.modulessys.path_importer_cache中看到新項(xiàng)目,這是正常的 Python 行為。

在這里值得一提的已知限制是用戶無(wú)法相同的 python 進(jìn)程中加載同一存儲(chǔ)庫(kù)的兩個(gè)不同分支。 就像在 Python 中安裝兩個(gè)具有相同名稱的軟件包一樣,這是不好的。 快取可能會(huì)加入聚會(huì),如果您實(shí)際嘗試的話會(huì)給您帶來(lái)驚喜。 當(dāng)然,將它們分別加載是完全可以的。

以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)