PyTorch 常见代码片段
# 1、基本配置
# 1.1 导入包和版本查询
import torch
import torch.nn as nn
import torchvision
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.version.cuda)
print(torch.backends.cudnn.version())
print(torch.cuda.get_device_name(0))
1
2
3
4
5
6
7
8
2
3
4
5
6
7
8
# 1.2 可复现性以及设置随机种子
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
1
2
3
4
5
6
2
3
4
5
6
# 1.3 显卡设置以及清除显存
# 利用 PyTorch 的接口指定单张显卡
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 如果需要指定多张显卡,比如0,1号显卡。
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
# 也可以在命令行运行代码时设置显卡:
CUDA_VISIBLE_DEVICES=0,1 python train.py
# 清除显存
torch.cuda.empty_cache()
# 也可以使用在命令行重置GPU的指令
nvidia-smi --gpu-reset -i [gpu_id]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 2、张量处理
# 2.1 查看张量的基本信息
tensor = torch.randn(3,4,5)
print(tensor.type()) # 数据类型
print(tensor.size()) # 张量的shape,是个元组
print(tensor.shape) # 张量的shape,是个元组
print(tensor.dim()) # 维度的数量
1
2
3
4
5
2
3
4
5
# 2.2 torch.Tensor 和 np.ndarray 互相转换
# np.ndarray -> torch.Tensor
ndarray = tensor.cpu().numpy()
tensor = torch.from_numpy(ndarray).float()
tensor = torch.from_numpy(ndarray.copy()).float() # If ndarray has negative stride.
# torch.Tensor -> np.ndarray
tensor = torch.randn(3,4,5)
array = tensor.cpu().detach().numpy()
1
2
3
4
5
6
7
8
2
3
4
5
6
7
8
# 2.3 torch.tensor 与 PIL.Image 转换
# pytorch中的张量默认采用[N, C, H, W]的顺序,并且数据范围在[0,1],需要进行转置和规范化
# torch.Tensor -> PIL.Image
image = PIL.Image.fromarray(
torch.clamp(tensor*255, min=0, max=255).byte().permute(1,2,0).cpu().numpy())
image = torchvision.transforms.functional.to_pil_image(tensor) # Equivalently way
# PIL.Image -> torch.Tensor
path = 'figure.jpg'
tensor = torch.from_numpy(
np.asarray(PIL.Image.open(path))
).permute(2,0,1).float() / 255
# PIL.Image -> torch.Tensor
tensor = torchvision.transforms.functional.to_tensor(PIL.Image.open(path))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
2
3
4
5
6
7
8
9
10
11
12
13
14
# 2.4 张量的一些操作
# 1、从只包含一个元素的张量中提取值
value = torch.rand(1).item()
# 2、张量形变
tensor = torch.rand(2,3,4)
shape = (6, 4)
tensor = torch.reshape(tensor, shape)
# 3、张量顺序重排
tensor = torch.rand(2,3,4)
tensor = tensor.permute(2,0,1)
# 4、复制张量
# Operation | New/Shared memory | Still in computation graph |
tensor.clone() # | New | Yes |
tensor.detach() # | Shared | No |
tensor.detach.clone()() # | New | No |
# 5、张量拼接
# 注意 torch.cat 和 torch.stack 的区别在于 torch.cat 沿着给定的维度拼接,
# 而 torch.stack 会新增一维。例如当参数是 3 个 10x5 的张量,torch.cat 的结果是 30x5 的张量,
# 而 torch.stack 的结果是3x10x5的张量。
tensor = torch.cat(list_of_tensors, dim=0)
tensor = torch.stack(list_of_tensors, dim=0)
# 6、得到非零元素
torch.nonzero(tensor) # index of non-zero elements
torch.nonzero(tensor==0) # index of zero elements
torch.nonzero(tensor).size(0) # number of non-zero elements
torch.nonzero(tensor == 0).size(0) # number of zero elements
# 7、判断两个张量相等
torch.allclose(tensor1, tensor2) # float tensor
torch.equal(tensor1, tensor2) # int tensor
# 8、张量扩展
# Expand tensor of shape 64*512 to shape 64*512*7*7.
tensor = torch.rand(64,512)
torch.reshape(tensor, (64, 512, 1, 1)).expand(64, 512, 7, 7)
# 9、矩阵乘法
# Matrix multiplcation: (m*n) * (n*p) * -> (m*p).
result = torch.mm(tensor1, tensor2)
# Batch matrix multiplication: (b*m*n) * (b*n*p) -> (b*m*p)
result = torch.bmm(tensor1, tensor2)
# Element-wise multiplication.
result = tensor1 * tensor2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# 2.5 将整数标签转为 one-hot 编码
# pytorch的标记默认从0开始
tensor = torch.tensor([0, 2, 1, 3])
N = tensor.size(0)
num_classes = 4
one_hot = torch.zeros(N, num_classes).long()
one_hot.scatter_(dim=1, index=torch.unsqueeze(tensor, dim=1), src=torch.ones(N, num_classes).long())
1
2
3
4
5
6
2
3
4
5
6
# 3、timm 模型提取特征流程
# 3.1 timm 提取特征
def forward_features(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.act1(x)
x = self.maxpool(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq([self.layer1, self.layer2, self.layer3, self.layer4], x, flatten=True)
else:
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
def forward_head(self, x, pre_logits: bool = False):
x = self.global_pool(x)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
return x if pre_logits else self.fc(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# 3.2 timm 展示出所有的模型
参考资料
- 部分内容转载自:https://zhuanlan.zhihu.com/p/104019160
上次更新: 2023/03/25, 19:58:09
- 02
- README 美化05-20
- 03
- 常见 Tricks 代码片段05-12