update GDAS

This commit is contained in:
D-X-Y 2020-01-18 21:54:17 +11:00
parent fcb6007975
commit 28d354880c
4 changed files with 10 additions and 3 deletions

View File

@ -62,6 +62,8 @@ CUDA_VISIBLE_DEVICES=0 bash ./scripts-search/NAS-Bench-201/train-a-net.sh '|nor_
`|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|skip_connect~1|skip_connect~2|` represents the structure of a searched architecture. My codes will automatically print it during the searching procedure. `|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|skip_connect~0|skip_connect~1|skip_connect~2|` represents the structure of a searched architecture. My codes will automatically print it during the searching procedure.
**Tensorflow codes for GDAS are in experimental state**, which locates at `exps-tf`.
# Citation # Citation
If you find that this project helps your research, please consider citing the following paper: If you find that this project helps your research, please consider citing the following paper:

View File

@ -1,4 +1,9 @@
# [D-X-Y]
# Run GDAS
# CUDA_VISIBLE_DEVICES=0 python exps-tf/GDAS.py # CUDA_VISIBLE_DEVICES=0 python exps-tf/GDAS.py
# Run DARTS
# CUDA_VISIBLE_DEVICES=0 python exps-tf/GDAS.py --tau_max -1 --tau_min -1 --epochs 50
#
import os, sys, math, time, random, argparse import os, sys, math, time, random, argparse
import tensorflow as tf import tensorflow as tf
from pathlib import Path from pathlib import Path

View File

@ -7,10 +7,10 @@ from copy import deepcopy
from ..cell_operations import OPS from ..cell_operations import OPS
class SearchCell(tf.keras.layers.Layer): class NAS201SearchCell(tf.keras.layers.Layer):
def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False): def __init__(self, C_in, C_out, stride, max_nodes, op_names, affine=False):
super(SearchCell, self).__init__() super(NAS201SearchCell, self).__init__()
self.op_names = deepcopy(op_names) self.op_names = deepcopy(op_names)
self.max_nodes = max_nodes self.max_nodes = max_nodes

View File

@ -5,7 +5,7 @@ import tensorflow as tf
import numpy as np import numpy as np
from copy import deepcopy from copy import deepcopy
from ..cell_operations import ResNetBasicblock from ..cell_operations import ResNetBasicblock
from .search_cells import SearchCell from .search_cells import NAS201SearchCell as SearchCell
def sample_gumbel(shape, eps=1e-20): def sample_gumbel(shape, eps=1e-20):