import torch import torch.nn as nn def copy_conv(module, init): assert isinstance(module, nn.Conv2d), "invalid module : {:}".format(module) assert isinstance(init, nn.Conv2d), "invalid module : {:}".format(init) new_i, new_o = module.in_channels, module.out_channels module.weight.copy_(init.weight.detach()[:new_o, :new_i]) if module.bias is not None: module.bias.copy_(init.bias.detach()[:new_o]) def copy_bn(module, init): assert isinstance(module, nn.BatchNorm2d), "invalid module : {:}".format(module) assert isinstance(init, nn.BatchNorm2d), "invalid module : {:}".format(init) num_features = module.num_features if module.weight is not None: module.weight.copy_(init.weight.detach()[:num_features]) if module.bias is not None: module.bias.copy_(init.bias.detach()[:num_features]) if module.running_mean is not None: module.running_mean.copy_(init.running_mean.detach()[:num_features]) if module.running_var is not None: module.running_var.copy_(init.running_var.detach()[:num_features]) def copy_fc(module, init): assert isinstance(module, nn.Linear), "invalid module : {:}".format(module) assert isinstance(init, nn.Linear), "invalid module : {:}".format(init) new_i, new_o = module.in_features, module.out_features module.weight.copy_(init.weight.detach()[:new_o, :new_i]) if module.bias is not None: module.bias.copy_(init.bias.detach()[:new_o]) def copy_base(module, init): assert type(module).__name__ in [ "ConvBNReLU", "Downsample", ], "invalid module : {:}".format(module) assert type(init).__name__ in [ "ConvBNReLU", "Downsample", ], "invalid module : {:}".format(init) if module.conv is not None: copy_conv(module.conv, init.conv) if module.bn is not None: copy_bn(module.bn, init.bn) def copy_basic(module, init): copy_base(module.conv_a, init.conv_a) copy_base(module.conv_b, init.conv_b) if module.downsample is not None: if init.downsample is not None: copy_base(module.downsample, init.downsample) # else: # import pdb; pdb.set_trace() def init_from_model(network, init_model): with torch.no_grad(): copy_fc(network.classifier, init_model.classifier) for base, target in zip(init_model.layers, network.layers): assert ( type(base).__name__ == type(target).__name__ ), "invalid type : {:} vs {:}".format(base, target) if type(base).__name__ == "ConvBNReLU": copy_base(target, base) elif type(base).__name__ == "ResNetBasicblock": copy_basic(target, base) else: raise ValueError("unknown type name : {:}".format(type(base).__name__))