add paddlepaddle
This commit is contained in:
parent
5eed8101a7
commit
4ffc2016b3
3
paddlepaddle/.gitignore
vendored
Normal file
3
paddlepaddle/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
.DS_Store
|
||||||
|
*.whl
|
||||||
|
snapshots
|
137
paddlepaddle/README.md
Normal file
137
paddlepaddle/README.md
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
# Image Classification based on NAS-Searched Models
|
||||||
|
|
||||||
|
This directory contains 10 image classification models.
|
||||||
|
Nine of them are automatically searched models from different Neural Architecture Search (NAS) algorithms. The other is the residual network.
|
||||||
|
We provide codes and scripts to train these models on both CIFAR-10 and CIFAR-100.
|
||||||
|
We use the standard data augmentation, i.e., random crop, random flip, and normalization.
|
||||||
|
|
||||||
|
---
|
||||||
|
## Table of Contents
|
||||||
|
- [Installation](#installation)
|
||||||
|
- [Data Preparation](#data-preparation)
|
||||||
|
- [Training Models](#training-models)
|
||||||
|
- [Project Structure](#project-structure)
|
||||||
|
- [Citation](#citation)
|
||||||
|
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
This project has the following requirements:
|
||||||
|
- Python = 3.6
|
||||||
|
- PadddlePaddle Fluid >= v0.15.0
|
||||||
|
|
||||||
|
|
||||||
|
### Data Preparation
|
||||||
|
Please download [CIFAR-10](https://dataset.bj.bcebos.com/cifar/cifar-10-python.tar.gz) and [CIFAR-100](https://dataset.bj.bcebos.com/cifar/cifar-100-python.tar.gz) before running the codes.
|
||||||
|
Note that the MD5 of CIFAR-10-Python compressed file is `c58f30108f718f92721af3b95e74349a` and the MD5 of CIFAR-100-Python compressed file is `eb9058c3a382ffc7106e4002c42a8d85`.
|
||||||
|
Please save the file into `${TORCH_HOME}/cifar.python`.
|
||||||
|
After data preparation, there should be two files `${TORCH_HOME}/cifar.python/cifar-10-python.tar.gz` and `${TORCH_HOME}/cifar.python/cifar-100-python.tar.gz`.
|
||||||
|
|
||||||
|
|
||||||
|
### Training Models
|
||||||
|
|
||||||
|
After setting up the environment and preparing the data, one can train the model. The main function entrance is `train_cifar.py`. We also provide some scripts for easy usage.
|
||||||
|
```
|
||||||
|
bash ./scripts/base-train.sh 0 cifar-10 ResNet110
|
||||||
|
bash ./scripts/train-nas.sh 0 cifar-10 GDAS_V1
|
||||||
|
bash ./scripts/train-nas.sh 0 cifar-10 GDAS_V2
|
||||||
|
bash ./scripts/train-nas.sh 0 cifar-10 SETN
|
||||||
|
bash ./scripts/train-nas.sh 0 cifar-100 SETN
|
||||||
|
```
|
||||||
|
The first argument is the GPU-ID to train your program, the second argument is the dataset name, and the last one is the model name.
|
||||||
|
Please use `./scripts/base-train.sh` for ResNet and use `./scripts/train-nas.sh` for NAS-searched models.
|
||||||
|
|
||||||
|
|
||||||
|
### Project Structure
|
||||||
|
```
|
||||||
|
.
|
||||||
|
├──train_cifar.py [Training CNN models]
|
||||||
|
├──lib [Library for dataset, models, and others]
|
||||||
|
│ └──models
|
||||||
|
│ ├──__init__.py [Import useful Classes and Functions in models]
|
||||||
|
│ ├──resnet.py [Define the ResNet models]
|
||||||
|
│ ├──operations.py [Define the atomic operation in NAS search space]
|
||||||
|
│ ├──genotypes.py [Define the topological structure of different NAS-searched models]
|
||||||
|
│ └──nas_net.py [Define the macro structure of NAS models]
|
||||||
|
│ └──utils
|
||||||
|
│ ├──__init__.py [Import useful Classes and Functions in utils]
|
||||||
|
│ ├──meter.py [Define the AverageMeter class to count the accuracy and loss]
|
||||||
|
│ ├──time_utils.py [Define some functions to print date or convert seconds into hours]
|
||||||
|
│ └──data_utils.py [Define data augmentation functions and dataset reader for CIFAR]
|
||||||
|
└──scripts [Scripts for running]
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Citation
|
||||||
|
If you find that this project helps your research, please consider citing these papers:
|
||||||
|
```
|
||||||
|
@inproceedings{dong2019one,
|
||||||
|
title = {One-Shot Neural Architecture Search via Self-Evaluated Template Network},
|
||||||
|
author = {Dong, Xuanyi and Yang, Yi},
|
||||||
|
booktitle = {Proceedings of the IEEE International Conference on Computer Vision (ICCV)},
|
||||||
|
year = {2019}
|
||||||
|
}
|
||||||
|
@inproceedings{dong2019search,
|
||||||
|
title = {Searching for A Robust Neural Architecture in Four GPU Hours},
|
||||||
|
author = {Dong, Xuanyi and Yang, Yi},
|
||||||
|
booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||||
|
pages = {1761--1770},
|
||||||
|
year = {2019}
|
||||||
|
}
|
||||||
|
@inproceedings{liu2018darts,
|
||||||
|
title = {Darts: Differentiable architecture search},
|
||||||
|
author = {Liu, Hanxiao and Simonyan, Karen and Yang, Yiming},
|
||||||
|
booktitle = {ICLR},
|
||||||
|
year = {2018}
|
||||||
|
}
|
||||||
|
@inproceedings{pham2018efficient,
|
||||||
|
title = {Efficient Neural Architecture Search via Parameter Sharing},
|
||||||
|
author = {Pham, Hieu and Guan, Melody and Zoph, Barret and Le, Quoc and Dean, Jeff},
|
||||||
|
booktitle = {International Conference on Machine Learning (ICML)},
|
||||||
|
pages = {4092--4101},
|
||||||
|
year = {2018}
|
||||||
|
}
|
||||||
|
@inproceedings{liu2018progressive,
|
||||||
|
title = {Progressive neural architecture search},
|
||||||
|
author = {Liu, Chenxi and Zoph, Barret and Neumann, Maxim and Shlens, Jonathon and Hua, Wei and Li, Li-Jia and Fei-Fei, Li and Yuille, Alan and Huang, Jonathan and Murphy, Kevin},
|
||||||
|
booktitle = {Proceedings of the European Conference on Computer Vision (ECCV)},
|
||||||
|
pages = {19--34},
|
||||||
|
year = {2018}
|
||||||
|
}
|
||||||
|
@inproceedings{zoph2018learning,
|
||||||
|
title = {Learning transferable architectures for scalable image recognition},
|
||||||
|
author = {Zoph, Barret and Vasudevan, Vijay and Shlens, Jonathon and Le, Quoc V},
|
||||||
|
booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||||
|
pages = {8697--8710},
|
||||||
|
year = {2018}
|
||||||
|
}
|
||||||
|
@inproceedings{real2019regularized,
|
||||||
|
title = {Regularized evolution for image classifier architecture search},
|
||||||
|
author = {Real, Esteban and Aggarwal, Alok and Huang, Yanping and Le, Quoc V},
|
||||||
|
booktitle = {Proceedings of the AAAI Conference on Artificial Intelligence},
|
||||||
|
pages = {4780--4789},
|
||||||
|
year = {2019}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
conda create -n PPD python=3.6 anaconda
|
||||||
|
pip3 install paddlepaddle-gpu==1.5.1.post97
|
||||||
|
pip3 install tb-paddle
|
||||||
|
```
|
||||||
|
|
||||||
|
## Active paddlepaddle environment
|
||||||
|
```
|
||||||
|
conda activate PPD
|
||||||
|
bash ./scripts/base-train.sh 0 cifar-10 ResNet110
|
||||||
|
bash ./scripts/train-nas.sh 0 cifar-10 GDAS_V1
|
||||||
|
bash ./scripts/train-nas.sh 0 cifar-10 GDAS_V2
|
||||||
|
bash ./scripts/train-nas.sh 0 cifar-10 SETN
|
||||||
|
bash ./scripts/train-nas.sh 0 cifar-100 SETN
|
||||||
|
```
|
||||||
|
|
||||||
|
use pytorch training
|
||||||
|
```
|
||||||
|
#CUDA_VISIBLE_DEVICES=0 bash ./scripts/com-paddle.sh cifar10 ResNet110 -1
|
||||||
|
CUDA_VISIBLE_DEVICES=0 bash ./scripts/com-paddle.sh cifar10
|
||||||
|
```
|
3
paddlepaddle/lib/models/__init__.py
Normal file
3
paddlepaddle/lib/models/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .genotypes import Networks
|
||||||
|
from .nas_net import NASCifarNet
|
||||||
|
from .resnet import resnet_cifar
|
175
paddlepaddle/lib/models/genotypes.py
Normal file
175
paddlepaddle/lib/models/genotypes.py
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
##################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
|
##################################################
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
|
||||||
|
|
||||||
|
|
||||||
|
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
|
||||||
|
NASNet = Genotype(
|
||||||
|
normal = [
|
||||||
|
(('sep_conv_5x5', 1), ('sep_conv_3x3', 0)),
|
||||||
|
(('sep_conv_5x5', 0), ('sep_conv_3x3', 0)),
|
||||||
|
(('avg_pool_3x3', 1), ('skip_connect', 0)),
|
||||||
|
(('avg_pool_3x3', 0), ('avg_pool_3x3', 0)),
|
||||||
|
(('sep_conv_3x3', 1), ('skip_connect', 1)),
|
||||||
|
],
|
||||||
|
normal_concat = [2, 3, 4, 5, 6],
|
||||||
|
reduce = [
|
||||||
|
(('sep_conv_5x5', 1), ('sep_conv_7x7', 0)),
|
||||||
|
(('max_pool_3x3', 1), ('sep_conv_7x7', 0)),
|
||||||
|
(('avg_pool_3x3', 1), ('sep_conv_5x5', 0)),
|
||||||
|
(('skip_connect', 3), ('avg_pool_3x3', 2)),
|
||||||
|
(('sep_conv_3x3', 2), ('max_pool_3x3', 1)),
|
||||||
|
],
|
||||||
|
reduce_concat = [4, 5, 6],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Progressive Neural Architecture Search, ECCV 2018
|
||||||
|
PNASNet = Genotype(
|
||||||
|
normal = [
|
||||||
|
(('sep_conv_5x5', 0), ('max_pool_3x3', 0)),
|
||||||
|
(('sep_conv_7x7', 1), ('max_pool_3x3', 1)),
|
||||||
|
(('sep_conv_5x5', 1), ('sep_conv_3x3', 1)),
|
||||||
|
(('sep_conv_3x3', 4), ('max_pool_3x3', 1)),
|
||||||
|
(('sep_conv_3x3', 0), ('skip_connect', 1)),
|
||||||
|
],
|
||||||
|
normal_concat = [2, 3, 4, 5, 6],
|
||||||
|
reduce = [
|
||||||
|
(('sep_conv_5x5', 0), ('max_pool_3x3', 0)),
|
||||||
|
(('sep_conv_7x7', 1), ('max_pool_3x3', 1)),
|
||||||
|
(('sep_conv_5x5', 1), ('sep_conv_3x3', 1)),
|
||||||
|
(('sep_conv_3x3', 4), ('max_pool_3x3', 1)),
|
||||||
|
(('sep_conv_3x3', 0), ('skip_connect', 1)),
|
||||||
|
],
|
||||||
|
reduce_concat = [2, 3, 4, 5, 6],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Regularized Evolution for Image Classifier Architecture Search, AAAI 2019
|
||||||
|
AmoebaNet = Genotype(
|
||||||
|
normal = [
|
||||||
|
(('avg_pool_3x3', 0), ('max_pool_3x3', 1)),
|
||||||
|
(('sep_conv_3x3', 0), ('sep_conv_5x5', 2)),
|
||||||
|
(('sep_conv_3x3', 0), ('avg_pool_3x3', 3)),
|
||||||
|
(('sep_conv_3x3', 1), ('skip_connect', 1)),
|
||||||
|
(('skip_connect', 0), ('avg_pool_3x3', 1)),
|
||||||
|
],
|
||||||
|
normal_concat = [4, 5, 6],
|
||||||
|
reduce = [
|
||||||
|
(('avg_pool_3x3', 0), ('sep_conv_3x3', 1)),
|
||||||
|
(('max_pool_3x3', 0), ('sep_conv_7x7', 2)),
|
||||||
|
(('sep_conv_7x7', 0), ('avg_pool_3x3', 1)),
|
||||||
|
(('max_pool_3x3', 0), ('max_pool_3x3', 1)),
|
||||||
|
(('conv_7x1_1x7', 0), ('sep_conv_3x3', 5)),
|
||||||
|
],
|
||||||
|
reduce_concat = [3, 4, 6]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Efficient Neural Architecture Search via Parameter Sharing, ICML 2018
|
||||||
|
ENASNet = Genotype(
|
||||||
|
normal = [
|
||||||
|
(('sep_conv_3x3', 1), ('skip_connect', 1)),
|
||||||
|
(('sep_conv_5x5', 1), ('skip_connect', 0)),
|
||||||
|
(('avg_pool_3x3', 0), ('sep_conv_3x3', 1)),
|
||||||
|
(('sep_conv_3x3', 0), ('avg_pool_3x3', 1)),
|
||||||
|
(('sep_conv_5x5', 1), ('avg_pool_3x3', 0)),
|
||||||
|
],
|
||||||
|
normal_concat = [2, 3, 4, 5, 6],
|
||||||
|
reduce = [
|
||||||
|
(('sep_conv_5x5', 0), ('sep_conv_3x3', 1)), # 2
|
||||||
|
(('sep_conv_3x3', 1), ('avg_pool_3x3', 1)), # 3
|
||||||
|
(('sep_conv_3x3', 1), ('avg_pool_3x3', 1)), # 4
|
||||||
|
(('avg_pool_3x3', 1), ('sep_conv_5x5', 4)), # 5
|
||||||
|
(('sep_conv_3x3', 5), ('sep_conv_5x5', 0)),
|
||||||
|
],
|
||||||
|
reduce_concat = [2, 3, 4, 5, 6],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# DARTS: Differentiable Architecture Search, ICLR 2019
|
||||||
|
DARTS_V1 = Genotype(
|
||||||
|
normal=[
|
||||||
|
(('sep_conv_3x3', 1), ('sep_conv_3x3', 0)), # step 1
|
||||||
|
(('skip_connect', 0), ('sep_conv_3x3', 1)), # step 2
|
||||||
|
(('skip_connect', 0), ('sep_conv_3x3', 1)), # step 3
|
||||||
|
(('sep_conv_3x3', 0), ('skip_connect', 2)) # step 4
|
||||||
|
],
|
||||||
|
normal_concat=[2, 3, 4, 5],
|
||||||
|
reduce=[
|
||||||
|
(('max_pool_3x3', 0), ('max_pool_3x3', 1)), # step 1
|
||||||
|
(('skip_connect', 2), ('max_pool_3x3', 0)), # step 2
|
||||||
|
(('max_pool_3x3', 0), ('skip_connect', 2)), # step 3
|
||||||
|
(('skip_connect', 2), ('avg_pool_3x3', 0)) # step 4
|
||||||
|
],
|
||||||
|
reduce_concat=[2, 3, 4, 5],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# DARTS: Differentiable Architecture Search, ICLR 2019
|
||||||
|
DARTS_V2 = Genotype(
|
||||||
|
normal=[
|
||||||
|
(('sep_conv_3x3', 0), ('sep_conv_3x3', 1)), # step 1
|
||||||
|
(('sep_conv_3x3', 0), ('sep_conv_3x3', 1)), # step 2
|
||||||
|
(('sep_conv_3x3', 1), ('skip_connect', 0)), # step 3
|
||||||
|
(('skip_connect', 0), ('dil_conv_3x3', 2)) # step 4
|
||||||
|
],
|
||||||
|
normal_concat=[2, 3, 4, 5],
|
||||||
|
reduce=[
|
||||||
|
(('max_pool_3x3', 0), ('max_pool_3x3', 1)), # step 1
|
||||||
|
(('skip_connect', 2), ('max_pool_3x3', 1)), # step 2
|
||||||
|
(('max_pool_3x3', 0), ('skip_connect', 2)), # step 3
|
||||||
|
(('skip_connect', 2), ('max_pool_3x3', 1)) # step 4
|
||||||
|
],
|
||||||
|
reduce_concat=[2, 3, 4, 5],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019
|
||||||
|
SETN = Genotype(
|
||||||
|
normal=[
|
||||||
|
(('skip_connect', 0), ('sep_conv_5x5', 1)),
|
||||||
|
(('sep_conv_5x5', 0), ('sep_conv_3x3', 1)),
|
||||||
|
(('sep_conv_5x5', 1), ('sep_conv_5x5', 3)),
|
||||||
|
(('max_pool_3x3', 1), ('conv_3x1_1x3', 4))],
|
||||||
|
normal_concat=[2, 3, 4, 5],
|
||||||
|
reduce=[
|
||||||
|
(('sep_conv_3x3', 0), ('sep_conv_5x5', 1)),
|
||||||
|
(('avg_pool_3x3', 0), ('sep_conv_5x5', 1)),
|
||||||
|
(('avg_pool_3x3', 0), ('sep_conv_5x5', 1)),
|
||||||
|
(('avg_pool_3x3', 0), ('skip_connect', 1))],
|
||||||
|
reduce_concat=[2, 3, 4, 5],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019
|
||||||
|
GDAS_V1 = Genotype(
|
||||||
|
normal=[
|
||||||
|
(('skip_connect', 0), ('skip_connect', 1)),
|
||||||
|
(('skip_connect', 0), ('sep_conv_5x5', 2)),
|
||||||
|
(('sep_conv_3x3', 3), ('skip_connect', 0)),
|
||||||
|
(('sep_conv_5x5', 4), ('sep_conv_3x3', 3))],
|
||||||
|
normal_concat=[2, 3, 4, 5],
|
||||||
|
reduce=[
|
||||||
|
(('sep_conv_5x5', 0), ('sep_conv_3x3', 1)),
|
||||||
|
(('sep_conv_5x5', 2), ('sep_conv_5x5', 1)),
|
||||||
|
(('dil_conv_5x5', 2), ('sep_conv_3x3', 1)),
|
||||||
|
(('sep_conv_5x5', 0), ('sep_conv_5x5', 1))],
|
||||||
|
reduce_concat=[2, 3, 4, 5],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Networks = {'DARTS_V1' : DARTS_V1,
|
||||||
|
'DARTS_V2' : DARTS_V2,
|
||||||
|
'DARTS' : DARTS_V2,
|
||||||
|
'NASNet' : NASNet,
|
||||||
|
'ENASNet' : ENASNet,
|
||||||
|
'AmoebaNet': AmoebaNet,
|
||||||
|
'GDAS_V1' : GDAS_V1,
|
||||||
|
'PNASNet' : PNASNet,
|
||||||
|
'SETN' : SETN,
|
||||||
|
}
|
79
paddlepaddle/lib/models/nas_net.py
Normal file
79
paddlepaddle/lib/models/nas_net.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
import paddle
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
from .operations import OPS
|
||||||
|
|
||||||
|
|
||||||
|
def AuxiliaryHeadCIFAR(inputs, C, class_num):
|
||||||
|
print ('AuxiliaryHeadCIFAR : inputs-shape : {:}'.format(inputs.shape))
|
||||||
|
temp = fluid.layers.relu(inputs)
|
||||||
|
temp = fluid.layers.pool2d(temp, pool_size=5, pool_stride=3, pool_padding=0, pool_type='avg')
|
||||||
|
temp = fluid.layers.conv2d(temp, filter_size=1, num_filters=128, stride=1, padding=0, act=None, bias_attr=False)
|
||||||
|
temp = fluid.layers.batch_norm(input=temp, act='relu', bias_attr=None)
|
||||||
|
temp = fluid.layers.conv2d(temp, filter_size=1, num_filters=768, stride=2, padding=0, act=None, bias_attr=False)
|
||||||
|
temp = fluid.layers.batch_norm(input=temp, act='relu', bias_attr=None)
|
||||||
|
print ('AuxiliaryHeadCIFAR : last---shape : {:}'.format(temp.shape))
|
||||||
|
predict = fluid.layers.fc(input=temp, size=class_num, act='softmax')
|
||||||
|
return predict
|
||||||
|
|
||||||
|
|
||||||
|
def InferCell(name, inputs_prev_prev, inputs_prev, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
|
||||||
|
print ('[{:}] C_prev_prev={:} C_prev={:}, C={:}, reduction_prev={:}, reduction={:}'.format(name, C_prev_prev, C_prev, C, reduction_prev, reduction))
|
||||||
|
print ('inputs_prev_prev : {:}'.format(inputs_prev_prev.shape))
|
||||||
|
print ('inputs_prev : {:}'.format(inputs_prev.shape))
|
||||||
|
inputs_prev_prev = OPS['skip_connect'](inputs_prev_prev, C_prev_prev, C, 2 if reduction_prev else 1)
|
||||||
|
inputs_prev = OPS['skip_connect'](inputs_prev, C_prev, C, 1)
|
||||||
|
print ('inputs_prev_prev : {:}'.format(inputs_prev_prev.shape))
|
||||||
|
print ('inputs_prev : {:}'.format(inputs_prev.shape))
|
||||||
|
if reduction: step_ops, concat = genotype.reduce, genotype.reduce_concat
|
||||||
|
else : step_ops, concat = genotype.normal, genotype.normal_concat
|
||||||
|
states = [inputs_prev_prev, inputs_prev]
|
||||||
|
for istep, operations in enumerate(step_ops):
|
||||||
|
op_a, op_b = operations
|
||||||
|
# the first operation
|
||||||
|
#print ('-->>[{:}/{:}] [{:}] + [{:}]'.format(istep, len(step_ops), op_a, op_b))
|
||||||
|
stride = 2 if reduction and op_a[1] < 2 else 1
|
||||||
|
tensor1 = OPS[ op_a[0] ](states[op_a[1]], C, C, stride)
|
||||||
|
stride = 2 if reduction and op_b[1] < 2 else 1
|
||||||
|
tensor2 = OPS[ op_b[0] ](states[op_b[1]], C, C, stride)
|
||||||
|
state = fluid.layers.elementwise_add(x=tensor1, y=tensor2, act=None)
|
||||||
|
assert tensor1.shape == tensor2.shape, 'invalid shape {:} vs. {:}'.format(tensor1.shape, tensor2.shape)
|
||||||
|
print ('-->>[{:}/{:}] tensor={:} from {:} + {:}'.format(istep, len(step_ops), state.shape, tensor1.shape, tensor2.shape))
|
||||||
|
states.append( state )
|
||||||
|
states_to_cat = [states[x] for x in concat]
|
||||||
|
outputs = fluid.layers.concat(states_to_cat, axis=1)
|
||||||
|
print ('-->> output-shape : {:} from concat={:}'.format(outputs.shape, concat))
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# NASCifarNet(inputs, 36, 6, 3, 10, 'xxx', True)
|
||||||
|
def NASCifarNet(ipt, C, N, stem_multiplier, class_num, genotype, auxiliary):
|
||||||
|
# cifar head module
|
||||||
|
C_curr = stem_multiplier * C
|
||||||
|
stem = fluid.layers.conv2d(ipt, filter_size=3, num_filters=C_curr, stride=1, padding=1, act=None, bias_attr=False)
|
||||||
|
stem = fluid.layers.batch_norm(input=stem, act=None, bias_attr=None)
|
||||||
|
print ('stem-shape : {:}'.format(stem.shape))
|
||||||
|
# N + 1 + N + 1 + N cells
|
||||||
|
layer_channels = [C ] * N + [C*2 ] + [C*2 ] * N + [C*4 ] + [C*4 ] * N
|
||||||
|
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
|
||||||
|
C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
|
||||||
|
reduction_prev = False
|
||||||
|
auxiliary_pred = None
|
||||||
|
|
||||||
|
cell_results = [stem, stem]
|
||||||
|
for index, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
|
||||||
|
xstr = '{:02d}/{:02d}'.format(index, len(layer_channels))
|
||||||
|
cell_result = InferCell(xstr, cell_results[-2], cell_results[-1], genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
|
||||||
|
reduction_prev = reduction
|
||||||
|
C_prev_prev, C_prev = C_prev, cell_result.shape[1]
|
||||||
|
cell_results.append( cell_result )
|
||||||
|
if auxiliary and reduction and C_curr == C*4:
|
||||||
|
auxiliary_pred = AuxiliaryHeadCIFAR(cell_result, C_prev, class_num)
|
||||||
|
|
||||||
|
global_P = fluid.layers.pool2d(input=cell_results[-1], pool_size=8, pool_type='avg', pool_stride=1)
|
||||||
|
predicts = fluid.layers.fc(input=global_P, size=class_num, act='softmax')
|
||||||
|
print ('predict-shape : {:}'.format(predicts.shape))
|
||||||
|
if auxiliary_pred is None:
|
||||||
|
return predicts
|
||||||
|
else:
|
||||||
|
return [predicts, auxiliary_pred]
|
91
paddlepaddle/lib/models/operations.py
Normal file
91
paddlepaddle/lib/models/operations.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import paddle
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
|
||||||
|
|
||||||
|
OPS = {
|
||||||
|
'none' : lambda inputs, C_in, C_out, stride: ZERO(inputs, stride),
|
||||||
|
'avg_pool_3x3' : lambda inputs, C_in, C_out, stride: POOL_3x3(inputs, C_in, C_out, stride, 'avg'),
|
||||||
|
'max_pool_3x3' : lambda inputs, C_in, C_out, stride: POOL_3x3(inputs, C_in, C_out, stride, 'max'),
|
||||||
|
'skip_connect' : lambda inputs, C_in, C_out, stride: Identity(inputs, C_in, C_out, stride),
|
||||||
|
'sep_conv_3x3' : lambda inputs, C_in, C_out, stride: SepConv(inputs, C_in, C_out, 3, stride, 1),
|
||||||
|
'sep_conv_5x5' : lambda inputs, C_in, C_out, stride: SepConv(inputs, C_in, C_out, 5, stride, 2),
|
||||||
|
'sep_conv_7x7' : lambda inputs, C_in, C_out, stride: SepConv(inputs, C_in, C_out, 7, stride, 3),
|
||||||
|
'dil_conv_3x3' : lambda inputs, C_in, C_out, stride: DilConv(inputs, C_in, C_out, 3, stride, 2, 2),
|
||||||
|
'dil_conv_5x5' : lambda inputs, C_in, C_out, stride: DilConv(inputs, C_in, C_out, 5, stride, 4, 2),
|
||||||
|
'conv_3x1_1x3' : lambda inputs, C_in, C_out, stride: Conv313(inputs, C_in, C_out, stride),
|
||||||
|
'conv_7x1_1x7' : lambda inputs, C_in, C_out, stride: Conv717(inputs, C_in, C_out, stride),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def ReLUConvBN(inputs, C_in, C_out, kernel, stride, padding):
|
||||||
|
temp = fluid.layers.relu(inputs)
|
||||||
|
temp = fluid.layers.conv2d(temp, filter_size=kernel, num_filters=C_out, stride=stride, padding=padding, act=None, bias_attr=False)
|
||||||
|
temp = fluid.layers.batch_norm(input=temp, act=None, bias_attr=None)
|
||||||
|
return temp
|
||||||
|
|
||||||
|
|
||||||
|
def ZERO(inputs, stride):
|
||||||
|
if stride == 1:
|
||||||
|
return inputs * 0
|
||||||
|
elif stride == 2:
|
||||||
|
return fluid.layers.pool2d(inputs, filter_size=2, pool_stride=2, pool_padding=0, pool_type='avg') * 0
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid stride of {:} not [1, 2]'.format(stride))
|
||||||
|
|
||||||
|
|
||||||
|
def Identity(inputs, C_in, C_out, stride):
|
||||||
|
if C_in == C_out and stride == 1:
|
||||||
|
return inputs
|
||||||
|
elif stride == 1:
|
||||||
|
return ReLUConvBN(inputs, C_in, C_out, 1, 1, 0)
|
||||||
|
else:
|
||||||
|
temp1 = fluid.layers.relu(inputs)
|
||||||
|
temp2 = fluid.layers.pad2d(input=temp1, paddings=[0, 1, 0, 1], mode='reflect')
|
||||||
|
temp2 = fluid.layers.slice(temp2, axes=[0, 1, 2, 3], starts=[0, 0, 1, 1], ends=[999, 999, 999, 999])
|
||||||
|
temp1 = fluid.layers.conv2d(temp1, filter_size=1, num_filters=C_out//2, stride=stride, padding=0, act=None, bias_attr=False)
|
||||||
|
temp2 = fluid.layers.conv2d(temp2, filter_size=1, num_filters=C_out-C_out//2, stride=stride, padding=0, act=None, bias_attr=False)
|
||||||
|
temp = fluid.layers.concat([temp1,temp2], axis=1)
|
||||||
|
return fluid.layers.batch_norm(input=temp, act=None, bias_attr=None)
|
||||||
|
|
||||||
|
|
||||||
|
def POOL_3x3(inputs, C_in, C_out, stride, mode):
|
||||||
|
if C_in == C_out:
|
||||||
|
xinputs = inputs
|
||||||
|
else:
|
||||||
|
xinputs = ReLUConvBN(inputs, C_in, C_out, 1, 1, 0)
|
||||||
|
return fluid.layers.pool2d(xinputs, pool_size=3, pool_stride=stride, pool_padding=1, pool_type=mode)
|
||||||
|
|
||||||
|
|
||||||
|
def SepConv(inputs, C_in, C_out, kernel, stride, padding):
|
||||||
|
temp = fluid.layers.relu(inputs)
|
||||||
|
temp = fluid.layers.conv2d(temp, filter_size=kernel, num_filters=C_in , stride=stride, padding=padding, act=None, bias_attr=False)
|
||||||
|
temp = fluid.layers.conv2d(temp, filter_size= 1, num_filters=C_in , stride= 1, padding= 0, act=None, bias_attr=False)
|
||||||
|
temp = fluid.layers.batch_norm(input=temp, act='relu', bias_attr=None)
|
||||||
|
temp = fluid.layers.conv2d(temp, filter_size=kernel, num_filters=C_in , stride= 1, padding=padding, act=None, bias_attr=False)
|
||||||
|
temp = fluid.layers.conv2d(temp, filter_size= 1, num_filters=C_out, stride= 1, padding= 0, act=None, bias_attr=False)
|
||||||
|
temp = fluid.layers.batch_norm(input=temp, act=None , bias_attr=None)
|
||||||
|
return temp
|
||||||
|
|
||||||
|
|
||||||
|
def DilConv(inputs, C_in, C_out, kernel, stride, padding, dilation):
|
||||||
|
temp = fluid.layers.relu(inputs)
|
||||||
|
temp = fluid.layers.conv2d(temp, filter_size=kernel, num_filters=C_in , stride=stride, padding=padding, dilation=dilation, act=None, bias_attr=False)
|
||||||
|
temp = fluid.layers.conv2d(temp, filter_size= 1, num_filters=C_out, stride= 1, padding= 0, act=None, bias_attr=False)
|
||||||
|
temp = fluid.layers.batch_norm(input=temp, act=None, bias_attr=None)
|
||||||
|
return temp
|
||||||
|
|
||||||
|
|
||||||
|
def Conv313(inputs, C_in, C_out, stride):
|
||||||
|
temp = fluid.layers.relu(inputs)
|
||||||
|
temp = fluid.layers.conv2d(temp, filter_size=(1,3), num_filters=C_out, stride=(1,stride), padding=(0,1), act=None, bias_attr=False)
|
||||||
|
temp = fluid.layers.conv2d(temp, filter_size=(3,1), num_filters=C_out, stride=(stride,1), padding=(1,0), act=None, bias_attr=False)
|
||||||
|
temp = fluid.layers.batch_norm(input=temp, act=None, bias_attr=None)
|
||||||
|
return temp
|
||||||
|
|
||||||
|
|
||||||
|
def Conv717(inputs, C_in, C_out, stride):
|
||||||
|
temp = fluid.layers.relu(inputs)
|
||||||
|
temp = fluid.layers.conv2d(temp, filter_size=(1,7), num_filters=C_out, stride=(1,stride), padding=(0,3), act=None, bias_attr=False)
|
||||||
|
temp = fluid.layers.conv2d(temp, filter_size=(7,1), num_filters=C_out, stride=(stride,1), padding=(3,0), act=None, bias_attr=False)
|
||||||
|
temp = fluid.layers.batch_norm(input=temp, act=None, bias_attr=None)
|
||||||
|
return temp
|
65
paddlepaddle/lib/models/resnet.py
Normal file
65
paddlepaddle/lib/models/resnet.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import paddle
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
|
||||||
|
|
||||||
|
def conv_bn_layer(input,
|
||||||
|
ch_out,
|
||||||
|
filter_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
act='relu',
|
||||||
|
bias_attr=False):
|
||||||
|
tmp = fluid.layers.conv2d(
|
||||||
|
input=input,
|
||||||
|
filter_size=filter_size,
|
||||||
|
num_filters=ch_out,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
act=None,
|
||||||
|
bias_attr=bias_attr)
|
||||||
|
return fluid.layers.batch_norm(input=tmp, act=act)
|
||||||
|
|
||||||
|
|
||||||
|
def shortcut(input, ch_in, ch_out, stride):
|
||||||
|
if stride == 2:
|
||||||
|
temp = fluid.layers.pool2d(input, pool_size=2, pool_type='avg', pool_stride=2)
|
||||||
|
temp = fluid.layers.conv2d(temp , filter_size=1, num_filters=ch_out, stride=1, padding=0, act=None, bias_attr=None)
|
||||||
|
return temp
|
||||||
|
elif ch_in != ch_out:
|
||||||
|
return conv_bn_layer(input, ch_out, 1, stride, 0, None, None)
|
||||||
|
else:
|
||||||
|
return input
|
||||||
|
|
||||||
|
|
||||||
|
def basicblock(input, ch_in, ch_out, stride):
|
||||||
|
tmp = conv_bn_layer(input, ch_out, 3, stride, 1)
|
||||||
|
tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None, bias_attr=True)
|
||||||
|
short = shortcut(input, ch_in, ch_out, stride)
|
||||||
|
return fluid.layers.elementwise_add(x=tmp, y=short, act='relu')
|
||||||
|
|
||||||
|
|
||||||
|
def layer_warp(block_func, input, ch_in, ch_out, count, stride):
|
||||||
|
tmp = block_func(input, ch_in, ch_out, stride)
|
||||||
|
for i in range(1, count):
|
||||||
|
tmp = block_func(tmp, ch_out, ch_out, 1)
|
||||||
|
return tmp
|
||||||
|
|
||||||
|
|
||||||
|
def resnet_cifar(ipt, depth, class_num):
|
||||||
|
# depth should be one of 20, 32, 44, 56, 110, 1202
|
||||||
|
assert (depth - 2) % 6 == 0
|
||||||
|
n = (depth - 2) // 6
|
||||||
|
print('[resnet] depth : {:}, class_num : {:}'.format(depth, class_num))
|
||||||
|
conv1 = conv_bn_layer(ipt, ch_out=16, filter_size=3, stride=1, padding=1)
|
||||||
|
print('conv-1 : shape = {:}'.format(conv1.shape))
|
||||||
|
res1 = layer_warp(basicblock, conv1, 16, 16, n, 1)
|
||||||
|
print('res--1 : shape = {:}'.format(res1.shape))
|
||||||
|
res2 = layer_warp(basicblock, res1 , 16, 32, n, 2)
|
||||||
|
print('res--2 : shape = {:}'.format(res2.shape))
|
||||||
|
res3 = layer_warp(basicblock, res2 , 32, 64, n, 2)
|
||||||
|
print('res--3 : shape = {:}'.format(res3.shape))
|
||||||
|
pool = fluid.layers.pool2d(input=res3, pool_size=8, pool_type='avg', pool_stride=1)
|
||||||
|
print('pool : shape = {:}'.format(pool.shape))
|
||||||
|
predict = fluid.layers.fc(input=pool, size=class_num, act='softmax')
|
||||||
|
print('predict: shape = {:}'.format(predict.shape))
|
||||||
|
return predict
|
6
paddlepaddle/lib/utils/__init__.py
Normal file
6
paddlepaddle/lib/utils/__init__.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
##################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
|
##################################################
|
||||||
|
from .meter import AverageMeter
|
||||||
|
from .time_utils import time_for_file, time_string, time_string_short, time_print, convert_size2str, convert_secs2time
|
||||||
|
from .data_utils import reader_creator
|
64
paddlepaddle/lib/utils/data_utils.py
Normal file
64
paddlepaddle/lib/utils/data_utils.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import random, tarfile
|
||||||
|
import numpy, six
|
||||||
|
from six.moves import cPickle as pickle
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
|
|
||||||
|
def train_cifar_augmentation(image):
|
||||||
|
# flip
|
||||||
|
if random.random() < 0.5: image1 = image.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
|
else: image1 = image
|
||||||
|
# random crop
|
||||||
|
image2 = ImageOps.expand(image1, border=4, fill=0)
|
||||||
|
i = random.randint(0, 40 - 32)
|
||||||
|
j = random.randint(0, 40 - 32)
|
||||||
|
image3 = image2.crop((j,i,j+32,i+32))
|
||||||
|
# to numpy
|
||||||
|
image3 = numpy.array(image3) / 255.0
|
||||||
|
mean = numpy.array([x / 255 for x in [125.3, 123.0, 113.9]]).reshape(1, 1, 3)
|
||||||
|
std = numpy.array([x / 255 for x in [63.0, 62.1, 66.7]]).reshape(1, 1, 3)
|
||||||
|
return (image3 - mean) / std
|
||||||
|
|
||||||
|
|
||||||
|
def valid_cifar_augmentation(image):
|
||||||
|
image3 = numpy.array(image) / 255.0
|
||||||
|
mean = numpy.array([x / 255 for x in [125.3, 123.0, 113.9]]).reshape(1, 1, 3)
|
||||||
|
std = numpy.array([x / 255 for x in [63.0, 62.1, 66.7]]).reshape(1, 1, 3)
|
||||||
|
return (image3 - mean) / std
|
||||||
|
|
||||||
|
|
||||||
|
def reader_creator(filename, sub_name, is_train, cycle=False):
|
||||||
|
def read_batch(batch):
|
||||||
|
data = batch[six.b('data')]
|
||||||
|
labels = batch.get(
|
||||||
|
six.b('labels'), batch.get(six.b('fine_labels'), None))
|
||||||
|
assert labels is not None
|
||||||
|
for sample, label in six.moves.zip(data, labels):
|
||||||
|
sample = sample.reshape(3, 32, 32)
|
||||||
|
sample = sample.transpose((1, 2, 0))
|
||||||
|
image = Image.fromarray(sample)
|
||||||
|
if is_train:
|
||||||
|
ximage = train_cifar_augmentation(image)
|
||||||
|
else:
|
||||||
|
ximage = valid_cifar_augmentation(image)
|
||||||
|
ximage = ximage.transpose((2, 0, 1))
|
||||||
|
yield ximage.astype(numpy.float32), int(label)
|
||||||
|
|
||||||
|
def reader():
|
||||||
|
with tarfile.open(filename, mode='r') as f:
|
||||||
|
names = (each_item.name for each_item in f
|
||||||
|
if sub_name in each_item.name)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
for name in names:
|
||||||
|
if six.PY2:
|
||||||
|
batch = pickle.load(f.extractfile(name))
|
||||||
|
else:
|
||||||
|
batch = pickle.load(
|
||||||
|
f.extractfile(name), encoding='bytes')
|
||||||
|
for item in read_batch(batch):
|
||||||
|
yield item
|
||||||
|
if not cycle:
|
||||||
|
break
|
||||||
|
|
||||||
|
return reader
|
26
paddlepaddle/lib/utils/meter.py
Normal file
26
paddlepaddle/lib/utils/meter.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
##################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
|
##################################################
|
||||||
|
import time, sys
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class AverageMeter(object):
|
||||||
|
"""Computes and stores the average and current value"""
|
||||||
|
def __init__(self):
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.val = 0.0
|
||||||
|
self.avg = 0.0
|
||||||
|
self.sum = 0.0
|
||||||
|
self.count = 0.0
|
||||||
|
|
||||||
|
def update(self, val, n=1):
|
||||||
|
self.val = val
|
||||||
|
self.sum += val * n
|
||||||
|
self.count += n
|
||||||
|
self.avg = self.sum / self.count
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ('{name}(val={val}, avg={avg}, count={count})'.format(name=self.__class__.__name__, **self.__dict__))
|
52
paddlepaddle/lib/utils/time_utils.py
Normal file
52
paddlepaddle/lib/utils/time_utils.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
#
|
||||||
|
import time, sys
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def time_for_file():
|
||||||
|
ISOTIMEFORMAT='%d-%h-at-%H-%M-%S'
|
||||||
|
return '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
|
||||||
|
|
||||||
|
def time_string():
|
||||||
|
ISOTIMEFORMAT='%Y-%m-%d %X'
|
||||||
|
string = '[{}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
|
||||||
|
return string
|
||||||
|
|
||||||
|
def time_string_short():
|
||||||
|
ISOTIMEFORMAT='%Y%m%d'
|
||||||
|
string = '{}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) ))
|
||||||
|
return string
|
||||||
|
|
||||||
|
def time_print(string, is_print=True):
|
||||||
|
if (is_print):
|
||||||
|
print('{} : {}'.format(time_string(), string))
|
||||||
|
|
||||||
|
def convert_size2str(torch_size):
|
||||||
|
dims = len(torch_size)
|
||||||
|
string = '['
|
||||||
|
for idim in range(dims):
|
||||||
|
string = string + ' {}'.format(torch_size[idim])
|
||||||
|
return string + ']'
|
||||||
|
|
||||||
|
def convert_secs2time(epoch_time, return_str=False):
|
||||||
|
need_hour = int(epoch_time / 3600)
|
||||||
|
need_mins = int((epoch_time - 3600*need_hour) / 60)
|
||||||
|
need_secs = int(epoch_time - 3600*need_hour - 60*need_mins)
|
||||||
|
if return_str:
|
||||||
|
str = '[{:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)
|
||||||
|
return str
|
||||||
|
else:
|
||||||
|
return need_hour, need_mins, need_secs
|
||||||
|
|
||||||
|
def print_log(print_string, log):
|
||||||
|
#if isinstance(log, Logger): log.log('{:}'.format(print_string))
|
||||||
|
if hasattr(log, 'log'): log.log('{:}'.format(print_string))
|
||||||
|
else:
|
||||||
|
print("{:}".format(print_string))
|
||||||
|
if log is not None:
|
||||||
|
log.write('{:}\n'.format(print_string))
|
||||||
|
log.flush()
|
31
paddlepaddle/scripts/base-train.sh
Normal file
31
paddlepaddle/scripts/base-train.sh
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# bash ./scripts/base-train.sh 0 cifar-10 ResNet110
|
||||||
|
echo script name: $0
|
||||||
|
echo $# arguments
|
||||||
|
if [ "$#" -ne 3 ] ;then
|
||||||
|
echo "Input illegal number of parameters " $#
|
||||||
|
echo "Need 3 parameters for GPU and dataset and the-model-name"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [ "$TORCH_HOME" = "" ]; then
|
||||||
|
echo "Must set TORCH_HOME envoriment variable for data dir saving"
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo "TORCH_HOME : $TORCH_HOME"
|
||||||
|
fi
|
||||||
|
|
||||||
|
GPU=$1
|
||||||
|
dataset=$2
|
||||||
|
model=$3
|
||||||
|
|
||||||
|
save_dir=snapshots/${dataset}-${model}
|
||||||
|
|
||||||
|
export FLAGS_fraction_of_gpu_memory_to_use="0.005"
|
||||||
|
export FLAGS_free_idle_memory=True
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=${GPU} python train_cifar.py \
|
||||||
|
--data_path $TORCH_HOME/cifar.python/${dataset}-python.tar.gz \
|
||||||
|
--log_dir ${save_dir} \
|
||||||
|
--dataset ${dataset} \
|
||||||
|
--model_name ${model} \
|
||||||
|
--lr 0.1 --epochs 300 --batch_size 256 --step_each_epoch 196
|
31
paddlepaddle/scripts/train-nas.sh
Normal file
31
paddlepaddle/scripts/train-nas.sh
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# bash ./scripts/base-train.sh 0 cifar-10 ResNet110
|
||||||
|
echo script name: $0
|
||||||
|
echo $# arguments
|
||||||
|
if [ "$#" -ne 3 ] ;then
|
||||||
|
echo "Input illegal number of parameters " $#
|
||||||
|
echo "Need 3 parameters for GPU and dataset and the-model-name"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [ "$TORCH_HOME" = "" ]; then
|
||||||
|
echo "Must set TORCH_HOME envoriment variable for data dir saving"
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo "TORCH_HOME : $TORCH_HOME"
|
||||||
|
fi
|
||||||
|
|
||||||
|
GPU=$1
|
||||||
|
dataset=$2
|
||||||
|
model=$3
|
||||||
|
|
||||||
|
save_dir=snapshots/${dataset}-${model}
|
||||||
|
|
||||||
|
export FLAGS_fraction_of_gpu_memory_to_use="0.005"
|
||||||
|
export FLAGS_free_idle_memory=True
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=${GPU} python train_cifar.py \
|
||||||
|
--data_path $TORCH_HOME/cifar.python/${dataset}-python.tar.gz \
|
||||||
|
--log_dir ${save_dir} \
|
||||||
|
--dataset ${dataset} \
|
||||||
|
--model_name ${model} \
|
||||||
|
--lr 0.025 --epochs 600 --batch_size 96 --step_each_epoch 521
|
189
paddlepaddle/train_cifar.py
Normal file
189
paddlepaddle/train_cifar.py
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
import os, sys, numpy as np, argparse
|
||||||
|
from pathlib import Path
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import math, time, paddle
|
||||||
|
import paddle.fluid.layers.ops as ops
|
||||||
|
#from tb_paddle import SummaryWriter
|
||||||
|
|
||||||
|
lib_dir = (Path(__file__).parent / 'lib').resolve()
|
||||||
|
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||||
|
from models import resnet_cifar, NASCifarNet, Networks
|
||||||
|
from utils import AverageMeter, time_for_file, time_string, convert_secs2time
|
||||||
|
from utils import reader_creator
|
||||||
|
|
||||||
|
|
||||||
|
def inference_program(model_name, num_class):
|
||||||
|
# The image is 32 * 32 with RGB representation.
|
||||||
|
data_shape = [3, 32, 32]
|
||||||
|
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
|
||||||
|
|
||||||
|
if model_name == 'ResNet20':
|
||||||
|
predict = resnet_cifar(images, 20, num_class)
|
||||||
|
elif model_name == 'ResNet32':
|
||||||
|
predict = resnet_cifar(images, 32, num_class)
|
||||||
|
elif model_name == 'ResNet110':
|
||||||
|
predict = resnet_cifar(images, 110, num_class)
|
||||||
|
else:
|
||||||
|
predict = NASCifarNet(images, 36, 6, 3, num_class, Networks[model_name], True)
|
||||||
|
return predict
|
||||||
|
|
||||||
|
|
||||||
|
def train_program(predict):
|
||||||
|
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
|
||||||
|
if isinstance(predict, (list, tuple)):
|
||||||
|
predict, aux_predict = predict
|
||||||
|
x_losses = fluid.layers.cross_entropy(input=predict, label=label)
|
||||||
|
aux_losses = fluid.layers.cross_entropy(input=aux_predict, label=label)
|
||||||
|
x_loss = fluid.layers.mean(x_losses)
|
||||||
|
aux_loss = fluid.layers.mean(aux_losses)
|
||||||
|
loss = x_loss + aux_loss * 0.4
|
||||||
|
accuracy = fluid.layers.accuracy(input=predict, label=label)
|
||||||
|
else:
|
||||||
|
losses = fluid.layers.cross_entropy(input=predict, label=label)
|
||||||
|
loss = fluid.layers.mean(losses)
|
||||||
|
accuracy = fluid.layers.accuracy(input=predict, label=label)
|
||||||
|
return [loss, accuracy]
|
||||||
|
|
||||||
|
|
||||||
|
# For training test cost
|
||||||
|
def evaluation(program, reader, fetch_list, place):
|
||||||
|
feed_var_list = [program.global_block().var('pixel'), program.global_block().var('label')]
|
||||||
|
feeder_test = fluid.DataFeeder(feed_list=feed_var_list, place=place)
|
||||||
|
test_exe = fluid.Executor(place)
|
||||||
|
losses, accuracies = AverageMeter(), AverageMeter()
|
||||||
|
for tid, test_data in enumerate(reader()):
|
||||||
|
loss, acc = test_exe.run(program=program, feed=feeder_test.feed(test_data), fetch_list=fetch_list)
|
||||||
|
losses.update(float(loss), len(test_data))
|
||||||
|
accuracies.update(float(acc)*100, len(test_data))
|
||||||
|
return losses.avg, accuracies.avg
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_decay_with_warmup(learning_rate, step_each_epoch, epochs=120):
|
||||||
|
"""Applies cosine decay to the learning rate.
|
||||||
|
lr = 0.05 * (math.cos(epoch * (math.pi / 120)) + 1)
|
||||||
|
decrease lr for every mini-batch and start with warmup.
|
||||||
|
"""
|
||||||
|
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
|
||||||
|
from paddle.fluid.initializer import init_on_cpu
|
||||||
|
global_step = _decay_step_counter()
|
||||||
|
lr = fluid.layers.tensor.create_global_var(
|
||||||
|
shape=[1],
|
||||||
|
value=0.0,
|
||||||
|
dtype='float32',
|
||||||
|
persistable=True,
|
||||||
|
name="learning_rate")
|
||||||
|
|
||||||
|
warmup_epoch = fluid.layers.fill_constant(
|
||||||
|
shape=[1], dtype='float32', value=float(5), force_cpu=True)
|
||||||
|
|
||||||
|
with init_on_cpu():
|
||||||
|
epoch = ops.floor(global_step / step_each_epoch)
|
||||||
|
with fluid.layers.control_flow.Switch() as switch:
|
||||||
|
with switch.case(epoch < warmup_epoch):
|
||||||
|
decayed_lr = learning_rate * (global_step / (step_each_epoch * warmup_epoch))
|
||||||
|
fluid.layers.tensor.assign(input=decayed_lr, output=lr)
|
||||||
|
with switch.default():
|
||||||
|
decayed_lr = learning_rate * \
|
||||||
|
(ops.cos((global_step - warmup_epoch * step_each_epoch) * (math.pi / (epochs * step_each_epoch))) + 1)/2
|
||||||
|
fluid.layers.tensor.assign(input=decayed_lr, output=lr)
|
||||||
|
return lr
|
||||||
|
|
||||||
|
|
||||||
|
def main(xargs):
|
||||||
|
|
||||||
|
save_dir = Path(xargs.log_dir) / time_for_file()
|
||||||
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print ('save dir : {:}'.format(save_dir))
|
||||||
|
print ('xargs : {:}'.format(xargs))
|
||||||
|
|
||||||
|
if xargs.dataset == 'cifar-10':
|
||||||
|
train_data = reader_creator(xargs.data_path, 'data_batch', True , False)
|
||||||
|
test__data = reader_creator(xargs.data_path, 'test_batch', False, False)
|
||||||
|
class_num = 10
|
||||||
|
print ('create cifar-10 dataset')
|
||||||
|
elif xargs.dataset == 'cifar-100':
|
||||||
|
train_data = reader_creator(xargs.data_path, 'train', True , False)
|
||||||
|
test__data = reader_creator(xargs.data_path, 'test' , False, False)
|
||||||
|
class_num = 100
|
||||||
|
print ('create cifar-100 dataset')
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid dataset : {:}'.format(xargs.dataset))
|
||||||
|
|
||||||
|
train_reader = paddle.batch(
|
||||||
|
paddle.reader.shuffle(train_data, buf_size=5000),
|
||||||
|
batch_size=xargs.batch_size)
|
||||||
|
|
||||||
|
# Reader for testing. A separated data set for testing.
|
||||||
|
test_reader = paddle.batch(test__data, batch_size=xargs.batch_size)
|
||||||
|
|
||||||
|
place = fluid.CUDAPlace(0)
|
||||||
|
|
||||||
|
main_program = fluid.default_main_program()
|
||||||
|
star_program = fluid.default_startup_program()
|
||||||
|
|
||||||
|
# programs
|
||||||
|
predict = inference_program(xargs.model_name, class_num)
|
||||||
|
[loss, accuracy] = train_program(predict)
|
||||||
|
print ('training program setup done')
|
||||||
|
test_program = main_program.clone(for_test=True)
|
||||||
|
print ('testing program setup done')
|
||||||
|
|
||||||
|
#infer_writer = SummaryWriter( str(save_dir / 'infer') )
|
||||||
|
#infer_writer.add_paddle_graph(fluid_program=fluid.default_main_program(), verbose=True)
|
||||||
|
#infer_writer.close()
|
||||||
|
#print(test_program.to_string(True))
|
||||||
|
|
||||||
|
#learning_rate = fluid.layers.cosine_decay(learning_rate=xargs.lr, step_each_epoch=xargs.step_each_epoch, epochs=xargs.epochs)
|
||||||
|
#learning_rate = fluid.layers.cosine_decay(learning_rate=0.1, step_each_epoch=196, epochs=300)
|
||||||
|
learning_rate = cosine_decay_with_warmup(xargs.lr, xargs.step_each_epoch, xargs.epochs)
|
||||||
|
optimizer = fluid.optimizer.Momentum(
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
momentum=0.9,
|
||||||
|
regularization=fluid.regularizer.L2Decay(0.0005),
|
||||||
|
use_nesterov=True)
|
||||||
|
optimizer.minimize( loss )
|
||||||
|
|
||||||
|
exe = fluid.Executor(place)
|
||||||
|
|
||||||
|
feed_var_list_loop = [main_program.global_block().var('pixel'), main_program.global_block().var('label')]
|
||||||
|
feeder = fluid.DataFeeder(feed_list=feed_var_list_loop, place=place)
|
||||||
|
exe.run(star_program)
|
||||||
|
|
||||||
|
start_time, epoch_time = time.time(), AverageMeter()
|
||||||
|
for iepoch in range(xargs.epochs):
|
||||||
|
losses, accuracies, steps = AverageMeter(), AverageMeter(), 0
|
||||||
|
for step_id, train_data in enumerate(train_reader()):
|
||||||
|
tloss, tacc, xlr = exe.run(main_program, feed=feeder.feed(train_data), fetch_list=[loss, accuracy, learning_rate])
|
||||||
|
tloss, tacc, xlr = float(tloss), float(tacc) * 100, float(xlr)
|
||||||
|
steps += 1
|
||||||
|
losses.update(tloss, len(train_data))
|
||||||
|
accuracies.update(tacc, len(train_data))
|
||||||
|
if step_id % 100 == 0:
|
||||||
|
print('{:} [{:03d}/{:03d}] [{:03d}] lr = {:.7f}, loss = {:.4f} ({:.4f}), accuracy = {:.2f} ({:.2f}), error={:.2f}'.format(time_string(), iepoch, xargs.epochs, step_id, xlr, tloss, losses.avg, tacc, accuracies.avg, 100-accuracies.avg))
|
||||||
|
test_loss, test_acc = evaluation(test_program, test_reader, [loss, accuracy], place)
|
||||||
|
need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (xargs.epochs-iepoch), True) )
|
||||||
|
print('{:}x[{:03d}/{:03d}] {:} train-loss = {:.4f}, train-accuracy = {:.2f}, test-loss = {:.4f}, test-accuracy = {:.2f} test-error = {:.2f} [{:} steps per epoch]\n'.format(time_string(), iepoch, xargs.epochs, need_time, losses.avg, accuracies.avg, test_loss, test_acc, 100-test_acc, steps))
|
||||||
|
if isinstance(predict, list):
|
||||||
|
fluid.io.save_inference_model(str(save_dir / 'inference_model'), ["pixel"], predict, exe)
|
||||||
|
else:
|
||||||
|
fluid.io.save_inference_model(str(save_dir / 'inference_model'), ["pixel"], [predict], exe)
|
||||||
|
# measure elapsed time
|
||||||
|
epoch_time.update(time.time() - start_time)
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
print('finish training and evaluation with {:} epochs in {:}'.format(xargs.epochs, convert_secs2time(epoch_time.sum, True)))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='Train.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument('--log_dir' , type=str, help='Save dir.')
|
||||||
|
parser.add_argument('--dataset', type=str, help='The dataset name.')
|
||||||
|
parser.add_argument('--data_path', type=str, help='The dataset path.')
|
||||||
|
parser.add_argument('--model_name', type=str, help='The model name.')
|
||||||
|
parser.add_argument('--lr', type=float, help='The learning rate.')
|
||||||
|
parser.add_argument('--batch_size', type=int, help='The batch size.')
|
||||||
|
parser.add_argument('--step_each_epoch',type=int, help='The batch size.')
|
||||||
|
parser.add_argument('--epochs' , type=int, help='The total training epochs.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
Loading…
Reference in New Issue
Block a user