75 lines
2.8 KiB
Python
75 lines
2.8 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
def copy_conv(module, init):
|
|
assert isinstance(module, nn.Conv2d), "invalid module : {:}".format(module)
|
|
assert isinstance(init, nn.Conv2d), "invalid module : {:}".format(init)
|
|
new_i, new_o = module.in_channels, module.out_channels
|
|
module.weight.copy_(init.weight.detach()[:new_o, :new_i])
|
|
if module.bias is not None:
|
|
module.bias.copy_(init.bias.detach()[:new_o])
|
|
|
|
|
|
def copy_bn(module, init):
|
|
assert isinstance(module, nn.BatchNorm2d), "invalid module : {:}".format(module)
|
|
assert isinstance(init, nn.BatchNorm2d), "invalid module : {:}".format(init)
|
|
num_features = module.num_features
|
|
if module.weight is not None:
|
|
module.weight.copy_(init.weight.detach()[:num_features])
|
|
if module.bias is not None:
|
|
module.bias.copy_(init.bias.detach()[:num_features])
|
|
if module.running_mean is not None:
|
|
module.running_mean.copy_(init.running_mean.detach()[:num_features])
|
|
if module.running_var is not None:
|
|
module.running_var.copy_(init.running_var.detach()[:num_features])
|
|
|
|
|
|
def copy_fc(module, init):
|
|
assert isinstance(module, nn.Linear), "invalid module : {:}".format(module)
|
|
assert isinstance(init, nn.Linear), "invalid module : {:}".format(init)
|
|
new_i, new_o = module.in_features, module.out_features
|
|
module.weight.copy_(init.weight.detach()[:new_o, :new_i])
|
|
if module.bias is not None:
|
|
module.bias.copy_(init.bias.detach()[:new_o])
|
|
|
|
|
|
def copy_base(module, init):
|
|
assert type(module).__name__ in [
|
|
"ConvBNReLU",
|
|
"Downsample",
|
|
], "invalid module : {:}".format(module)
|
|
assert type(init).__name__ in [
|
|
"ConvBNReLU",
|
|
"Downsample",
|
|
], "invalid module : {:}".format(init)
|
|
if module.conv is not None:
|
|
copy_conv(module.conv, init.conv)
|
|
if module.bn is not None:
|
|
copy_bn(module.bn, init.bn)
|
|
|
|
|
|
def copy_basic(module, init):
|
|
copy_base(module.conv_a, init.conv_a)
|
|
copy_base(module.conv_b, init.conv_b)
|
|
if module.downsample is not None:
|
|
if init.downsample is not None:
|
|
copy_base(module.downsample, init.downsample)
|
|
# else:
|
|
# import pdb; pdb.set_trace()
|
|
|
|
|
|
def init_from_model(network, init_model):
|
|
with torch.no_grad():
|
|
copy_fc(network.classifier, init_model.classifier)
|
|
for base, target in zip(init_model.layers, network.layers):
|
|
assert (
|
|
type(base).__name__ == type(target).__name__
|
|
), "invalid type : {:} vs {:}".format(base, target)
|
|
if type(base).__name__ == "ConvBNReLU":
|
|
copy_base(target, base)
|
|
elif type(base).__name__ == "ResNetBasicblock":
|
|
copy_basic(target, base)
|
|
else:
|
|
raise ValueError("unknown type name : {:}".format(type(base).__name__))
|