import paddle import paddle.fluid as fluid def conv_bn_layer(input, ch_out, filter_size, stride, padding, act='relu', bias_attr=False): tmp = fluid.layers.conv2d( input=input, filter_size=filter_size, num_filters=ch_out, stride=stride, padding=padding, act=None, bias_attr=bias_attr) return fluid.layers.batch_norm(input=tmp, act=act) def shortcut(input, ch_in, ch_out, stride): if stride == 2: temp = fluid.layers.pool2d(input, pool_size=2, pool_type='avg', pool_stride=2) temp = fluid.layers.conv2d(temp , filter_size=1, num_filters=ch_out, stride=1, padding=0, act=None, bias_attr=None) return temp elif ch_in != ch_out: return conv_bn_layer(input, ch_out, 1, stride, 0, None, None) else: return input def basicblock(input, ch_in, ch_out, stride): tmp = conv_bn_layer(input, ch_out, 3, stride, 1) tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None, bias_attr=True) short = shortcut(input, ch_in, ch_out, stride) return fluid.layers.elementwise_add(x=tmp, y=short, act='relu') def layer_warp(block_func, input, ch_in, ch_out, count, stride): tmp = block_func(input, ch_in, ch_out, stride) for i in range(1, count): tmp = block_func(tmp, ch_out, ch_out, 1) return tmp def resnet_cifar(ipt, depth, class_num): # depth should be one of 20, 32, 44, 56, 110, 1202 assert (depth - 2) % 6 == 0 n = (depth - 2) // 6 print('[resnet] depth : {:}, class_num : {:}'.format(depth, class_num)) conv1 = conv_bn_layer(ipt, ch_out=16, filter_size=3, stride=1, padding=1) print('conv-1 : shape = {:}'.format(conv1.shape)) res1 = layer_warp(basicblock, conv1, 16, 16, n, 1) print('res--1 : shape = {:}'.format(res1.shape)) res2 = layer_warp(basicblock, res1 , 16, 32, n, 2) print('res--2 : shape = {:}'.format(res2.shape)) res3 = layer_warp(basicblock, res2 , 32, 64, n, 2) print('res--3 : shape = {:}'.format(res3.shape)) pool = fluid.layers.pool2d(input=res3, pool_size=8, pool_type='avg', pool_stride=1) print('pool : shape = {:}'.format(pool.shape)) predict = fluid.layers.fc(input=pool, size=class_num, act='softmax') print('predict: shape = {:}'.format(predict.shape)) return predict