From a1a1bd60f876cae5f620383146a3263790bfb109 Mon Sep 17 00:00:00 2001 From: jaywcjlove <398188662@qq.com> Date: Mon, 13 May 2024 17:01:57 +0800 Subject: [PATCH] doc: update docs/pytorch.md #649 --- docs/pytorch.md | 114 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 82 insertions(+), 32 deletions(-) diff --git a/docs/pytorch.md b/docs/pytorch.md index 8e188bc..c39db61 100644 --- a/docs/pytorch.md +++ b/docs/pytorch.md @@ -13,6 +13,7 @@ Pytorch 是一种开源机器学习框架,可加速从研究原型设计到生 - [Pytorch 官方备忘清单](https://pytorch.org/tutorials/beginner/ptcheat.html) _(pytorch.org)_ ### 认识 Pytorch + ```python from __future__ import print_function @@ -32,6 +33,7 @@ tensor([ Tensors 张量: 张量的概念类似于Numpy中的ndarray数据结构, 最大的区别在于Tensor可以利用GPU的加速功能. ### 创建一个全零矩阵 + ```python x = torch.zeros(5, 3, dtype=torch.long) @@ -95,6 +97,7 @@ tensor([[ 1.6978, -1.6979, 0.3093], ``` ### 加法操作(4) + ```python y.add_(x) @@ -118,6 +121,7 @@ tensor([-2.0902, -0.4489, -0.1441, 0.8035, -0.8341]) ### 张量形状 + ```python x = torch.randn(4, 4) @@ -178,13 +182,21 @@ tensor([2., 2., 2., 2., 2.], dtype=torch.float64) ```python >>> x = torch.rand(1, 2, 1, 28, 1) ->>> x.squeeze().shape # squeeze不加参数,默认去除所有为1的维度 + +# squeeze不加参数,默认去除所有为1的维度 +>>> x.squeeze().shape torch.Size([2, 28]) ->>> x.squeeze(dim=0).shape # squeeze加参数,去除指定为1的维度 + +# squeeze加参数,去除指定为1的维度 +>>> x.squeeze(dim=0).shape torch.Size([2, 1, 28, 1]) ->>> x.squeeze(1).shape # squeeze加参数,如果不为1,则不变 + +# squeeze加参数,如果不为1,则不变 +>>> x.squeeze(1).shape torch.Size([1, 2, 1, 28, 1]) ->>> torch.squeeze(x,-1).shape # 既可以是函数,也可以是方法 + +# 既可以是函数,也可以是方法 +>>> torch.squeeze(x,-1).shape torch.Size([1, 2, 1, 28]) ``` @@ -192,47 +204,59 @@ torch.Size([1, 2, 1, 28]) ```python >>> x = torch.rand(2, 28) ->>> x.unsqueeze(0).shape # unsqueeze必须加参数, _ 2 _ 28 _ -torch.Size([1, 2, 28]) # 参数代表在哪里添加维度 0 1 2 ->>> torch.unsqueeze(x, -1).shape # 既可以是函数,也可以是方法 +# unsqueeze必须加参数, _ 2 _ 28 _ +>>> x.unsqueeze(0).shape +# 参数代表在哪里添加维度 0 1 2 +torch.Size([1, 2, 28]) +# 既可以是函数,也可以是方法 +>>> torch.unsqueeze(x, -1).shape torch.Size([2, 28, 1]) ``` Cuda 相关 --- + ### 检查 Cuda 是否可用 + ```python >>> import torch.cuda >>> torch.cuda.is_available() >>> True ``` + ### 列出 GPU 设备 + + ```python import torch + device_count = torch.cuda.device_count() print("CUDA 设备") + for i in range(device_count): device_name = torch.cuda.get_device_name(i) total_memory = torch.cuda.get_device_properties(i).total_memory / (1024 ** 3) print(f"├── 设备 {i}: {device_name}, 容量: {total_memory:.2f} GiB") + print("└── (结束)") ``` + ### 将模型、张量等数据在 GPU 和内存之间进行搬运 + ```python import torch -# Replace 0 to your GPU device index. or use "cuda" directly. +# 将 0 替换为您的 GPU 设备索引或者直接使用 "cuda" device = f"cuda:0" -# Move to GPU +# 移动到GPU tensor_m = torch.tensor([1, 2, 3]) tensor_g = tensor_m.to(device) model_m = torch.nn.Linear(1, 1) model_g = model_m.to(device) -# Move back. +# 向后移动 tensor_m = tensor_g.cpu() model_m = model_g.cpu() ``` - 导入 Imports --- @@ -241,28 +265,71 @@ model_m = model_g.cpu() ```python # 根包 import torch -# 数据集表示和加载 +``` + +数据集表示和加载 + +```python from torch.utils.data import Dataset, DataLoader ``` ### 神经网络 API + ```python # 计算图 import torch.autograd as autograd # 计算图中的张量节点 from torch import Tensor -# 神经网络 +``` + +神经网络 + +```python import torch.nn as nn + # 层、激活等 import torch.nn.functional as F # 优化器,例如 梯度下降、ADAM等 import torch.optim as optim -# 混合前端装饰器和跟踪 jit +``` + +混合前端装饰器和跟踪 jit + +```python from torch.jit import script, trace ``` +### ONNX + + +```python +torch.onnx.export(model, dummy data, xxxx.proto) +# 导出 ONNX 格式 +# 使用经过训练的模型模型,dummy +# 数据和所需的文件名 +``` + + +加载 ONNX 模型 + +```python +model = onnx.load("alexnet.proto") +``` + +检查模型,IT 是否结构良好 + +```python +onnx.checker.check_model(model) +``` + +打印一个人类可读的,图的表示 + +```python +onnx.helper.printable_graph(model.graph) +``` + ### Torchscript 和 JIT ```python @@ -277,25 +344,8 @@ torch.jit.trace() 装饰器用于指示被跟踪代码中的数据相关控制流 -### ONNX - -```python -torch.onnx.export(model, dummy data, xxxx.proto) -# 导出 ONNX 格式 -# 使用经过训练的模型模型,dummy -# 数据和所需的文件名 - -model = onnx.load("alexnet.proto") -# 加载 ONNX 模型 -onnx.checker.check_model(model) -# 检查模型,IT 是否结构良好 - -onnx.helper.printable_graph(model.graph) -# 打印一个人类可读的,图的表示 -``` - - ### Vision + ```python # 视觉数据集,架构 & 变换