Architectures
Utility tool
check weights or bias
- 通过print
model.state_dict()
或model.named_parameters()
查看所有参数,信息包括name
和params
,然后使用点运算符.
查看相应的参数; 通过官网查看相应网络模块具有的
variables
,然后使用点运算符.
查看相应的参数;1
2
3
4model.fc.weight
model.fc.bias
model.rnn.weight_ih_l0
model.rnn.bias_hh_l1
Running with GPU
查看GPU是否可用,然后使用
.to()
使用相应device:1
2device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)常用GPU命令
1
2
3
4
5
6
7
8
9
10
11torch.cuda.is_available()
cuda是否可用;
torch.cuda.device_count()
返回gpu数量;
torch.cuda.get_device_name(0)
返回gpu名字,设备索引默认从0开始;
torch.cuda.current_device()
返回当前设备索引;
保存或加载模型
快速保存模型:
1
torch.save(model.state_dict(), 'path/file_name.pth')
详细保存:包括其他一下信息
1
2
3
4
5checkpoints={'label': model.n_label,
'input_size': model.input_size,
'state_dict': model.state_dict()}
with open(os.path.join(args.save+fold,model_name), 'wb') as f:
torch.save(checkpoint, f)加载:
1
model.load_state_dict((torch.load(path/file_name.pth)))
迁移学习:torchvision.models
直接从
torchvision
模块里加载模型:1
2import torchvision.models as models/ from torchvision import models
model = models.vgg16(pretrained=True/ not args.caffe_pretrain)freeze parameters1
1
2
3
4
5for param in model.features.parameters():
param.require_grad = False
```
- freeze specified parameter2delete the last MaxPooling layer
freeze the first 4 convs
features = list(model.features)[:30] for layer in features[:10]:
for p in layer.parameters(): p.requires_grad = False
features = nn.Sequential(*features)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23> 最好检查一下是否“冻结”了参数,必要时在`optimizer`使用`filter`过滤一遍,[传送门][filter]:
> `optimizer.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)`
- 删除或增加layer:转成`list`方便操作,后面再使用`nn.Sequential(*list)`构建模型,其中的`*`是将列表或元组(可能适合其他类型)数据拆分;
```
# delete the last Linear layer
classifier = list(model.classifier)
in_feature = classifier[6].in_feature
del classifier[6]
if not args.use_drop:
del classifier[2]
del classifier[5]
classifier = nn.Sequential(*classifier)
classifier.add_module(str(6), nn.Linear(in_feature, 10))
```
# python #
## 使用[os模块][os]创建文件夹 ##
> 一般先使用`os.path.exists(path)`来检查相应文件夹是否存在
1. 使用`os.mkdir(path)`创建目录;
2. 使用`os.makedirs(path)`递归创建目录;if not os.path.exists(path):
os.makedirs(path)
`