PyTorch C ++ API

2025-06-25 10:47 更新

一、PyTorch C++ API 概覽

PyTorch 的 C++ API 提供了一個強大的工具集,用于在 C++ 環(huán)境中進行張量計算和深度學(xué)習(xí)模型開發(fā)。它主要包括以下幾個部分:

1.1 ATen 庫

ATen 是 PyTorch 的基礎(chǔ)張量庫,提供了豐富的張量操作和數(shù)學(xué)運算功能。

  1. #include <ATen/ATen.h>
  2. at::Tensor a = at::ones({2, 2}, at::kInt);
  3. at::Tensor b = at::randn({2, 2});
  4. auto c = a + b.to(at::kInt);

1.2 Autograd 自動求導(dǎo)

Autograd 是 PyTorch C++ API 的自動微分組件,擴展了 ATen 的功能,使其支持自動求導(dǎo)。

  1. #include <torch/torch.h>
  2. torch::Tensor a = torch::ones({2, 2}, torch::requires_grad());
  3. torch::Tensor b = torch::randn({2, 2});
  4. auto c = a + b;
  5. c.backward();

1.3 C++ 前端

C++ 前端提供了高層接口,用于構(gòu)建和訓(xùn)練神經(jīng)網(wǎng)絡(luò)模型。

  1. #include <torch/torch.h>
  2. class SimpleModel : public torch::nn::Module {
  3. public:
  4. SimpleModel() {
  5. linear = register_module("linear", torch::nn::Linear(10, 2));
  6. }
  7. torch::Tensor forward(torch::Tensor x) {
  8. return linear(x);
  9. }
  10. private:
  11. torch::nn::Linear linear;
  12. };
  13. int main() {
  14. SimpleModel model;
  15. torch::Tensor input = torch::randn({1, 10});
  16. torch::Tensor output = model(input);
  17. return 0;
  18. }

1.4 TorchScript 支持

TorchScript 是 PyTorch 的 JIT 編譯器和解釋器,支持模型的序列化和優(yōu)化。

  1. #include <torch/torch.h>
  2. int main() {
  3. // 加載 TorchScript 模型
  4. torch::jit::script::Module model;
  5. model.load("model.pt");
  6. // 執(zhí)行模型推理
  7. torch::Tensor input = torch::randn({1, 3, 224, 224});
  8. torch::Tensor output = model.forward({input});
  9. return 0;
  10. }

1.5 C++ 擴展

C++ 擴展允許開發(fā)者通過自定義 C++ 和 CUDA 代碼擴展 PyTorch 的功能。

  1. #include <torch/torch.h>
  2. torch::Tensor custom_add(torch::Tensor x, torch::Tensor y) {
  3. return x + y;
  4. }
  5. PYBIND11_MODULE(custom_ops, m) {
  6. m.def("custom_add", &custom_add, "A custom add operation");
  7. }

二、開發(fā)環(huán)境搭建

2.1 安裝 PyTorch C++ API

可以從 PyTorch 官方網(wǎng)站獲取安裝包,或通過源代碼編譯。

  1. ## 使用 conda 安裝
  2. conda install pytorch torchvision torchaudio pytorch-cpp -c pytorch

2.2 配置開發(fā)工具

推薦使用支持 C++17 的編譯器,如 GCC 9 或更高版本。同時,可以使用 CMake 來管理項目構(gòu)建。

  1. cmake_minimum_required(VERSION 3.18)
  2. project(MyPyTorchProject)
  3. find_package(Torch REQUIRED)
  4. add_executable(my_app main.cpp)
  5. target_link_libraries(my_app PRIVATE ${TORCH_LIBRARIES})

三、模型開發(fā)與訓(xùn)練

3.1 定義模型

使用 torch::nn::Module 定義神經(jīng)網(wǎng)絡(luò)模型。

  1. #include <torch/torch.h>
  2. class MyModel : public torch::nn::Module {
  3. public:
  4. MyModel() {
  5. conv1 = register_module("conv1", torch::nn::Conv2d(
  6. torch::nn::Conv2dOptions(1, 20, 5)
  7. ));
  8. fc1 = register_module("fc1", torch::nn::Linear(20 * 20 * 20, 500));
  9. fc2 = register_module("fc2", torch::nn::Linear(500, 10));
  10. }
  11. torch::Tensor forward(torch::Tensor x) {
  12. x = torch::functional::ReLU(conv1(x));
  13. x = torch::max_pool2d(x, 2);
  14. x = x.view({-1, 20 * 20 * 20});
  15. x = torch::functional::ReLU(fc1(x));
  16. x = fc2(x);
  17. return x;
  18. }
  19. private:
  20. torch::nn::Conv2d conv1;
  21. torch::nn::Linear fc1, fc2;
  22. };

3.2 數(shù)據(jù)加載與處理

使用 torch::data 加載和處理數(shù)據(jù)。

  1. #include <torch/data.h>
  2. #include <torch/datasets.h>
  3. using namespace torch::data;
  4. class MyDataset : public torch::data::Dataset<MyDataset> {
  5. public:
  6. MyDataset(std::string path) : path_(std::move(path)) {}
  7. torch::data::Example<> get(size_t index) override {
  8. // 實現(xiàn)數(shù)據(jù)加載邏輯
  9. torch::Tensor data = ...;
  10. torch::Tensor target = ...;
  11. return {data, target};
  12. }
  13. torch::optional<size_t> size() const override {
  14. return 1000; // 數(shù)據(jù)集大小
  15. }
  16. private:
  17. std::string path_;
  18. };
  19. int main() {
  20. auto dataset = MyDataset("data");
  21. auto dataloader = make_data_loader(
  22. dataset,
  23. DataLoaderOptions().batch_size(32).workers(4)
  24. );
  25. for (auto& batch : *dataloader) {
  26. auto data = batch.data;
  27. auto target = batch.target;
  28. // 訓(xùn)練邏輯
  29. }
  30. return 0;
  31. }

3.3 模型訓(xùn)練與優(yōu)化

使用優(yōu)化器進行模型訓(xùn)練。

  1. #include <torch/optim.h>
  2. int main() {
  3. MyModel model;
  4. auto optimizer = torch::optim::SGD(
  5. model.parameters(),
  6. torch::optim::SGDOptions(0.01).momentum(0.9)
  7. );
  8. for (auto& batch : *dataloader) {
  9. auto data = batch.data;
  10. auto target = batch.target;
  11. optimizer.zero_grad();
  12. auto output = model(data);
  13. auto loss = torch::nn::functional::nll_loss(output, target);
  14. loss.backward();
  15. optimizer.step();
  16. }
  17. return 0;
  18. }

四、模型推理與部署

4.1 模型保存與加載

保存和加載模型參數(shù)或整個模型。

  1. #include <torch/serialize.h>
  2. int main() {
  3. MyModel model;
  4. // 保存模型
  5. torch::save(model, "model.pth");
  6. // 加載模型
  7. MyModel loaded_model;
  8. torch::load(loaded_model, "model.pth");
  9. return 0;
  10. }

4.2 TorchScript 模型推理

加載并運行 TorchScript 模型。

  1. #include <torch/jit.h>
  2. int main() {
  3. // 加載 TorchScript 模型
  4. torch::jit::script::Module model;
  5. model.load("model.pt");
  6. // 模型推理
  7. torch::Tensor input = torch::randn({1, 3, 224, 224});
  8. std::vector<torch::jit::IValue> inputs;
  9. inputs.push_back(input);
  10. torch::Tensor output = model.forward(inputs).toTensor();
  11. return 0;
  12. }

五、性能優(yōu)化技巧

5.1 使用混合精度訓(xùn)練

在訓(xùn)練過程中使用混合精度加速計算。

  1. #include <torchcuda.h>
  2. int main() {
  3. MyModel model;
  4. model.cuda();
  5. auto scaler = torch::cuda::amp::GradScaler();
  6. for (auto& batch : *dataloader) {
  7. auto data = batch.data.cuda();
  8. auto target = batch.target.cuda();
  9. scaler.scale(loss).backward();
  10. scaler.step(optimizer);
  11. scaler.update();
  12. }
  13. return 0;
  14. }

5.2 多 GPU 并行訓(xùn)練

使用多 GPU 進行模型并行訓(xùn)練。

  1. #include <torch/distributed.h>
  2. #include <torch/data/distributed.h>
  3. int main() {
  4. // 初始化分布式環(huán)境
  5. torch::distributed::init_process_group(torch::distributed::Backend::NCCL, std::string("env://"));
  6. int rank = torch::distributed::get_rank();
  7. int world_size = torch::distributed::get_world_size();
  8. MyModel model;
  9. model.cuda(rank);
  10. // 數(shù)據(jù)并行
  11. auto model_ddp = torch::nn::DataParallel(model);
  12. return 0;
  13. }

六、總結(jié)與展望

PyTorch C++ API 提供了強大的功能,使開發(fā)者能夠在 C++ 環(huán)境中高效地進行深度學(xué)習(xí)模型的開發(fā)和部署。通過合理利用 ATen、Autograd、C++ 前端、TorchScript 和 C++ 擴展,可以構(gòu)建高性能的機器學(xué)習(xí)應(yīng)用。

關(guān)注編程獅(W3Cschool)平臺,獲取更多 PyTorch C++ API 開發(fā)相關(guān)的教程和案例。

以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號