计算loss和计算metric
计算loss时
# true: torch.Size([4, 512, 1024])
# pred: torch.Size([4, 19, 512, 1024])
1
2
2
计算 metric 时
# true: torch.Size([4, 512, 1024])
# pred: torch.Size([4, 512, 1024])
1
2
2
解决方法:
argmax_pred = torch.argmax(pred, dim=1).detach()
# true: torch.Size([4, 512, 1024])
# pred: torch.Size([4, 19, 512, 1024])
# argmax_pred: torch.Size([4, 512, 1024])
1
2
3
4
2
3
4
上次更新: 2021/10/03, 15:21:19
- 02
- README 美化05-20
- 03
- 常见 Tricks 代码片段05-12