update scripts
This commit is contained in:
parent
7b977c08ec
commit
6814816d5f
@ -10,6 +10,7 @@ def select2withP(logits, tau, just_prob=False, num=2, eps=1e-7):
|
||||
while True: # a trick to avoid the gumbels bug
|
||||
gumbels = -torch.empty_like(logits).exponential_().log()
|
||||
new_logits = (logits + gumbels) / tau
|
||||
#new_logits = (logits.log_softmax(dim=1) + gumbels) / tau
|
||||
probs = nn.functional.softmax(new_logits, dim=1)
|
||||
if (not torch.isinf(gumbels).any()) and (not torch.isinf(probs).any()) and (not torch.isnan(probs).any()): break
|
||||
|
||||
|
@ -24,23 +24,11 @@ gumbel_max=5
|
||||
expected_FLOP_ratio=$4
|
||||
rseed=$5
|
||||
|
||||
PY_C="./env/bin/python"
|
||||
if [ ! -f ${PY_C} ]; then
|
||||
echo "Local Run with Python: "`which python`
|
||||
PY_C="python"
|
||||
SAVE_ROOT="./output"
|
||||
else
|
||||
echo "Cluster Run with Python: "${PY_C}
|
||||
SAVE_ROOT="./hadoop-data/TAS-checkpoints"
|
||||
mkdir -p $TORCH_HOME/TAS-checkpoints/
|
||||
cp -r ./hadoop-data/TAS-checkpoints/basemodels $TORCH_HOME/TAS-checkpoints/
|
||||
fi
|
||||
save_dir=./output/search-shape/${dataset}-${model}-${optim}-Gumbel_${gumbel_min}_${gumbel_max}-${expected_FLOP_ratio}
|
||||
|
||||
save_dir=${SAVE_ROOT}/search-shape/${dataset}-${model}-${optim}-Gumbel_${gumbel_min}_${gumbel_max}-${expected_FLOP_ratio}
|
||||
python --version
|
||||
|
||||
${PY_C} --version
|
||||
|
||||
${PY_C} ./exps/search-transformable.py --dataset ${dataset} \
|
||||
OMP_NUM_THREADS=4 python ./exps/search-transformable.py --dataset ${dataset} \
|
||||
--data_path $TORCH_HOME/cifar.python \
|
||||
--model_config ./configs/archs/CIFAR-${model}.config \
|
||||
--split_path ./.latent-data/splits/${dataset}-0.5.pth \
|
||||
@ -60,7 +48,7 @@ if [ "$rseed" = "-1" ]; then
|
||||
else
|
||||
# normal training
|
||||
xsave_dir=${save_dir}/seed-${rseed}-NMT
|
||||
${PY_C} ./exps/basic-main.py --dataset ${dataset} \
|
||||
OMP_NUM_THREADS=4 python ./exps/basic-main.py --dataset ${dataset} \
|
||||
--data_path $TORCH_HOME/cifar.python \
|
||||
--model_config ${save_dir}/seed-${rseed}-last.config \
|
||||
--optim_config ./configs/opts/CIFAR-E300-W5-L1-COS.config \
|
||||
@ -71,7 +59,7 @@ else
|
||||
--eval_frequency 1 --print_freq 100 --print_freq_eval 200
|
||||
# KD training
|
||||
xsave_dir=${save_dir}/seed-${rseed}-KDT
|
||||
${PY_C} ./exps/KD-main.py --dataset ${dataset} \
|
||||
OMP_NUM_THREADS=4 python ./exps/KD-main.py --dataset ${dataset} \
|
||||
--data_path $TORCH_HOME/cifar.python \
|
||||
--model_config ${save_dir}/seed-${rseed}-last.config \
|
||||
--optim_config ./configs/opts/CIFAR-E300-W5-L1-COS.config \
|
||||
|
@ -32,7 +32,7 @@ save_dir=${SAVE_ROOT}/search-depth/${dataset}-${model}-${optim}-Gumbel_${gumbel_
|
||||
|
||||
python --version
|
||||
|
||||
python ./exps/search-shape.py --dataset ${dataset} \
|
||||
OMP_NUM_THREADS=4 python ./exps/search-shape.py --dataset ${dataset} \
|
||||
--data_path $TORCH_HOME/cifar.python \
|
||||
--model_config ./configs/archs/CIFAR-${model}.config \
|
||||
--split_path ./.latent-data/splits/${dataset}-0.5.pth \
|
||||
@ -53,7 +53,7 @@ if [ "$rseed" = "-1" ]; then
|
||||
else
|
||||
# normal training
|
||||
xsave_dir=${save_dir}/seed-${rseed}-NMT
|
||||
python ./exps/basic-main.py --dataset ${dataset} \
|
||||
OMP_NUM_THREADS=4 python ./exps/basic-main.py --dataset ${dataset} \
|
||||
--data_path $TORCH_HOME/cifar.python \
|
||||
--model_config ${save_dir}/seed-${rseed}-last.config \
|
||||
--optim_config ./configs/opts/CIFAR-E300-W5-L1-COS.config \
|
||||
@ -64,7 +64,7 @@ else
|
||||
--eval_frequency 1 --print_freq 100 --print_freq_eval 200
|
||||
# KD training
|
||||
xsave_dir=${save_dir}/seed-${rseed}-KDT
|
||||
python ./exps/KD-main.py --dataset ${dataset} \
|
||||
OMP_NUM_THREADS=4 python ./exps/KD-main.py --dataset ${dataset} \
|
||||
--data_path $TORCH_HOME/cifar.python \
|
||||
--model_config ${save_dir}/seed-${rseed}-last.config \
|
||||
--optim_config ./configs/opts/CIFAR-E300-W5-L1-COS.config \
|
||||
|
@ -32,7 +32,7 @@ save_dir=${SAVE_ROOT}/search-width/${dataset}-${model}-${optim}-Gumbel_${gumbel_
|
||||
|
||||
python --version
|
||||
|
||||
python ./exps/search-shape.py --dataset ${dataset} \
|
||||
OMP_NUM_THREADS=4 python ./exps/search-shape.py --dataset ${dataset} \
|
||||
--data_path $TORCH_HOME/cifar.python \
|
||||
--model_config ./configs/archs/CIFAR-${model}.config \
|
||||
--split_path ./.latent-data/splits/${dataset}-0.5.pth \
|
||||
@ -53,7 +53,7 @@ if [ "$rseed" = "-1" ]; then
|
||||
else
|
||||
# normal training
|
||||
xsave_dir=${save_dir}/seed-${rseed}-NMT
|
||||
python ./exps/basic-main.py --dataset ${dataset} \
|
||||
OMP_NUM_THREADS=4 python ./exps/basic-main.py --dataset ${dataset} \
|
||||
--data_path $TORCH_HOME/cifar.python \
|
||||
--model_config ${save_dir}/seed-${rseed}-last.config \
|
||||
--optim_config ./configs/opts/CIFAR-E300-W5-L1-COS.config \
|
||||
@ -64,7 +64,7 @@ else
|
||||
--eval_frequency 1 --print_freq 100 --print_freq_eval 200
|
||||
# KD training
|
||||
xsave_dir=${save_dir}/seed-${rseed}-KDT
|
||||
python ./exps/KD-main.py --dataset ${dataset} \
|
||||
OMP_NUM_THREADS=4 python ./exps/KD-main.py --dataset ${dataset} \
|
||||
--data_path $TORCH_HOME/cifar.python \
|
||||
--model_config ${save_dir}/seed-${rseed}-last.config \
|
||||
--optim_config ./configs/opts/CIFAR-E300-W5-L1-COS.config \
|
||||
|
@ -28,7 +28,7 @@ save_dir=${SAVE_ROOT}/basic/${dataset}/${model}-${epoch}-${LR}-${batch}
|
||||
|
||||
python --version
|
||||
|
||||
python ./exps/basic-main.py --dataset ${dataset} \
|
||||
OMP_NUM_THREADS=4 python ./exps/basic-main.py --dataset ${dataset} \
|
||||
--data_path $TORCH_HOME/cifar.python \
|
||||
--model_config ./configs/archs/CIFAR-${model}.config \
|
||||
--optim_config ./configs/opts/CIFAR-${epoch}-W5-${LR}-COS.config \
|
||||
|
Loading…
Reference in New Issue
Block a user