Pytorch Hub 是經(jīng)過(guò)預(yù)先訓(xùn)練的模型資料庫(kù),旨在促進(jìn)研究的可重復(fù)性。
Pytorch Hub 支持通過(guò)添加簡(jiǎn)單的hubconf.py
文件將預(yù)訓(xùn)練的模型(模型定義和預(yù)訓(xùn)練的權(quán)重)發(fā)布到 github 存儲(chǔ)庫(kù);
hubconf.py
可以有多個(gè)入口點(diǎn)。 每個(gè)入口點(diǎn)都定義為 python 函數(shù)(例如:您要發(fā)布的經(jīng)過(guò)預(yù)先訓(xùn)練的模型)。
def entrypoint_name(*args, **kwargs):
# args & kwargs are optional, for models which take positional/keyword arguments.
...
如果我們擴(kuò)展pytorch/vision/hubconf.py
中的實(shí)現(xiàn),則以下代碼段指定了resnet18
模型的入口點(diǎn)。 在大多數(shù)情況下,在hubconf.py
中導(dǎo)入正確的功能就足夠了。 在這里,我們僅以擴(kuò)展版本為例來(lái)說(shuō)明其工作原理。
dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18
## resnet18 is the name of entrypoint
def resnet18(pretrained=False, **kwargs):
""" # This docstring shows up in hub.help()
Resnet18 model
pretrained (bool): kwargs, load pretrained weights into the model
"""
# Call the model, load pretrained weights
model = _resnet18(pretrained=pretrained, **kwargs)
return model
dependencies
變量是加載模型所需的軟件包名稱的列表。 請(qǐng)注意,這可能與訓(xùn)練模型所需的依賴項(xiàng)稍有不同。args
和kwargs
傳遞給實(shí)際的可調(diào)用函數(shù)。torch.hub.list()
中顯示。torch.hub.load_state_dict_from_url()
加載。 如果少于 2GB,建議將其附加到項(xiàng)目版本,并使用該版本中的網(wǎng)址。 在上面的示例中,torchvision.models.resnet.resnet18
處理pretrained
,或者,您可以在入口點(diǎn)定義中添加以下邏輯。if pretrained:
# For checkpoint saved in local github repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
dirname = os.path.dirname(__file__)
checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
state_dict = torch.load(checkpoint)
model.load_state_dict(state_dict)
# For checkpoint saved elsewhere
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
Pytorch Hub 提供了便捷的 API,可通過(guò)torch.hub.list()
瀏覽集線器中的所有可用模型,通過(guò)torch.hub.help()
顯示文檔字符串和示例,并使用torch.hub.load()
加載經(jīng)過(guò)預(yù)先訓(xùn)練的模型
torch.hub.list(github, force_reload=False)?
列出 <cite>github</cite> hubconf 中可用的所有入口點(diǎn)。
參數(shù)
退貨
可用入口點(diǎn)名稱的列表
返回類型
入口點(diǎn)
例
>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
torch.hub.help(github, model, force_reload=False)?
顯示入口點(diǎn)<cite>模型</cite>的文檔字符串。
Parameters
Example
>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
torch.hub.load(github, model, *args, **kwargs)?
使用預(yù)訓(xùn)練的權(quán)重從 github 存儲(chǔ)庫(kù)加載模型。
Parameters
Returns
具有相應(yīng)預(yù)訓(xùn)練權(quán)重的單個(gè)模型。
Example
>>> model = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)
torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)?
將給定 URL 上的對(duì)象下載到本地路徑。
Parameters
Example
>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False)?
將 Torch 序列化對(duì)象加載到給定的 URL。
如果下載的文件是 zip 文件,它將被自動(dòng)解壓縮。
如果 <cite>model_dir</cite> 中已經(jīng)存在該對(duì)象,則將其反序列化并返回。 <cite>model_dir</cite> 的默認(rèn)值為$TORCH_HOME/checkpoints
,其中環(huán)境變量$TORCH_HOME
的默認(rèn)值為$XDG_CACHE_HOME/torch
。 $XDG_CACHE_HOME
遵循 Linux 文件系統(tǒng)布局的 X 設(shè)計(jì)組規(guī)范,如果未設(shè)置,則默認(rèn)值為~/.cache
。
Parameters
filename-<sha256>.ext
,其中[ <sha256>
是文件內(nèi)容的 SHA256 哈希值的前 8 位或更多位。 哈希用于確保唯一的名稱并驗(yàn)證文件的內(nèi)容。 默認(rèn)值:FalseExample
>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
注意,torch.load()
中的*args, **kwargs
用于實(shí)例化模型。 加載模型后,如何找到可以使用該模型的功能? 建議的工作流程是
dir(model)
查看模型的所有可用方法。help(model.foo)
檢查model.foo
需要執(zhí)行哪些參數(shù)為了幫助用戶探索而又不來(lái)回參考文檔,我們強(qiáng)烈建議回購(gòu)所有者使功能幫助消息清晰明了。 包含一個(gè)最小的工作示例也很有幫助。
這些位置按以下順序使用
hub.set_dir(<PATH_TO_HUB_DIR>)
$TORCH_HOME/hub
,如果設(shè)置了環(huán)境變量TORCH_HOME
。$XDG_CACHE_HOME/torch/hub
,如果設(shè)置了環(huán)境變量XDG_CACHE_HOME
。~/.cache/torch/hub
torch.hub.set_dir(d)?
(可選)將 hub_dir 設(shè)置為本地目錄,以保存下載的模型&權(quán)重。
如果未調(diào)用set_dir
,則默認(rèn)路徑為$TORCH_HOME/hub
,其中環(huán)境變量$TORCH_HOME
默認(rèn)為$XDG_CACHE_HOME/torch
。 $XDG_CACHE_HOME
遵循 Linux 文件系統(tǒng)布局的 X 設(shè)計(jì)組規(guī)范,如果未設(shè)置環(huán)境變量,則默認(rèn)值為~/.cache
。
Parameters
d (字符串)–本地文件夾的路徑,用于保存下載的模型&權(quán)重。
默認(rèn)情況下,加載文件后我們不會(huì)清理文件。 如果hub_dir
中已經(jīng)存在,則集線器默認(rèn)使用緩存。
用戶可以通過(guò)調(diào)用hub.load(..., force_reload=True)
來(lái)強(qiáng)制重新加載。 這將刪除現(xiàn)有的 github 文件夾和下載的權(quán)重,重新初始化新的下載。 當(dāng)更新發(fā)布到同一分支時(shí),此功能很有用,用戶可以跟上最新版本。
Torch 集線器通過(guò)導(dǎo)入軟件包來(lái)進(jìn)行工作,就像安裝軟件包一樣。 在 Python 中導(dǎo)入會(huì)帶來(lái)一些副作用。 例如,您可以在 Python 緩存sys.modules
和sys.path_importer_cache
中看到新項(xiàng)目,這是正常的 Python 行為。
在這里值得一提的已知限制是用戶無(wú)法在相同的 python 進(jìn)程中加載同一存儲(chǔ)庫(kù)的兩個(gè)不同分支。 就像在 Python 中安裝兩個(gè)具有相同名稱的軟件包一樣,這是不好的。 快取可能會(huì)加入聚會(huì),如果您實(shí)際嘗試的話會(huì)給您帶來(lái)驚喜。 當(dāng)然,將它們分別加載是完全可以的。
更多建議: