2019-04-10 13:13:29 +02:00
# Modified from https://github.com/quark0/darts
2019-01-31 15:27:38 +01:00
import os , gc , sys , time , math
import numpy as np
from copy import deepcopy
import torch
import torch . nn as nn
from utils import print_log , obtain_accuracy , AverageMeter
from utils import time_string , convert_secs2time
from utils import count_parameters_in_MB
from datasets import Corpus
from nas_rnn import batchify , get_batch , repackage_hidden
from nas_rnn import DARTSCell , RNNModel
def obtain_best ( accuracies ) :
if len ( accuracies ) == 0 : return ( 0 , 0 )
tops = [ value for key , value in accuracies . items ( ) ]
s2b = sorted ( tops )
return s2b [ - 1 ]
def main_procedure ( config , genotype , save_dir , print_freq , log ) :
print_log ( ' - ' * 90 , log )
print_log ( ' save-dir : {:} ' . format ( save_dir ) , log )
print_log ( ' genotype : {:} ' . format ( genotype ) , log )
print_log ( ' config : {:} ' . format ( config ) , log )
corpus = Corpus ( config . data_path )
train_data = batchify ( corpus . train , config . train_batch , True )
valid_data = batchify ( corpus . valid , config . eval_batch , True )
test_data = batchify ( corpus . test , config . test_batch , True )
ntokens = len ( corpus . dictionary )
print_log ( " Train--Data Size : {:} " . format ( train_data . size ( ) ) , log )
print_log ( " Valid--Data Size : {:} " . format ( valid_data . size ( ) ) , log )
print_log ( " Test---Data Size : {:} " . format ( test_data . size ( ) ) , log )
print_log ( " ntokens = {:} " . format ( ntokens ) , log )
model = RNNModel ( ntokens , config . emsize , config . nhid , config . nhidlast ,
config . dropout , config . dropouth , config . dropoutx , config . dropouti , config . dropoute ,
cell_cls = DARTSCell , genotype = genotype )
model = model . cuda ( )
print_log ( ' Network => \n {:} ' . format ( model ) , log )
print_log ( ' Genotype : {:} ' . format ( genotype ) , log )
print_log ( ' Parameters : {:.3f} MB ' . format ( count_parameters_in_MB ( model ) ) , log )
checkpoint_path = os . path . join ( save_dir , ' checkpoint- {:} .pth ' . format ( config . data_name ) )
Soptimizer = torch . optim . SGD ( model . parameters ( ) , lr = config . LR , weight_decay = config . wdecay )
Aoptimizer = torch . optim . ASGD ( model . parameters ( ) , lr = config . LR , t0 = 0 , lambd = 0. , weight_decay = config . wdecay )
if os . path . isfile ( checkpoint_path ) :
checkpoint = torch . load ( checkpoint_path )
model . load_state_dict ( checkpoint [ ' state_dict ' ] )
Soptimizer . load_state_dict ( checkpoint [ ' SGD_optimizer ' ] )
Aoptimizer . load_state_dict ( checkpoint [ ' ASGD_optimizer ' ] )
epoch = checkpoint [ ' epoch ' ]
use_asgd = checkpoint [ ' use_asgd ' ]
print_log ( ' load checkpoint from {:} and start train from {:} ' . format ( checkpoint_path , epoch ) , log )
else :
epoch , use_asgd = 0 , False
start_time , epoch_time = time . time ( ) , AverageMeter ( )
valid_loss_from_sgd , losses = [ ] , { - 1 : 1e9 }
while epoch < config . epochs :
need_time = convert_secs2time ( epoch_time . val * ( config . epochs - epoch ) , True )
print_log ( " \n ==>> {:s} [Epoch= {:04d} / {:04d} ] {:} " . format ( time_string ( ) , epoch , config . epochs , need_time ) , log )
if use_asgd : optimizer = Aoptimizer
else : optimizer = Soptimizer
try :
Dtime , Btime = train ( model , optimizer , corpus , train_data , config , epoch , print_freq , log )
except :
torch . cuda . empty_cache ( )
checkpoint = torch . load ( checkpoint_path )
model . load_state_dict ( checkpoint [ ' state_dict ' ] )
Soptimizer . load_state_dict ( checkpoint [ ' SGD_optimizer ' ] )
Aoptimizer . load_state_dict ( checkpoint [ ' ASGD_optimizer ' ] )
epoch = checkpoint [ ' epoch ' ]
use_asgd = checkpoint [ ' use_asgd ' ]
valid_loss_from_sgd = checkpoint [ ' valid_loss_from_sgd ' ]
continue
if use_asgd :
tmp = { }
for prm in model . parameters ( ) :
tmp [ prm ] = prm . data . clone ( )
prm . data = Aoptimizer . state [ prm ] [ ' ax ' ] . clone ( )
val_loss = evaluate ( model , corpus , valid_data , config . eval_batch , config . bptt )
for prm in model . parameters ( ) :
prm . data = tmp [ prm ] . clone ( )
else :
val_loss = evaluate ( model , corpus , valid_data , config . eval_batch , config . bptt )
if len ( valid_loss_from_sgd ) > config . nonmono and val_loss > min ( valid_loss_from_sgd ) :
use_asgd = True
valid_loss_from_sgd . append ( val_loss )
print_log ( ' {:} end of epoch {:3d} with {:} | valid loss {:5.2f} | valid ppl {:8.2f} ' . format ( time_string ( ) , epoch , ' ASGD ' if use_asgd else ' SGD ' , val_loss , math . exp ( val_loss ) ) , log )
if val_loss < min ( losses . values ( ) ) :
if use_asgd :
tmp = { }
for prm in model . parameters ( ) :
tmp [ prm ] = prm . data . clone ( )
prm . data = Aoptimizer . state [ prm ] [ ' ax ' ] . clone ( )
torch . save ( { ' epoch ' : epoch ,
' use_asgd ' : use_asgd ,
' valid_loss_from_sgd ' : valid_loss_from_sgd ,
' state_dict ' : model . state_dict ( ) ,
' SGD_optimizer ' : Soptimizer . state_dict ( ) ,
' ASGD_optimizer ' : Aoptimizer . state_dict ( ) } ,
checkpoint_path )
if use_asgd :
for prm in model . parameters ( ) :
prm . data = tmp [ prm ] . clone ( )
print_log ( ' save into {:} ' . format ( checkpoint_path ) , log )
if use_asgd :
tmp = { }
for prm in model . parameters ( ) :
tmp [ prm ] = prm . data . clone ( )
prm . data = Aoptimizer . state [ prm ] [ ' ax ' ] . clone ( )
test_loss = evaluate ( model , corpus , test_data , config . test_batch , config . bptt )
if use_asgd :
for prm in model . parameters ( ) :
prm . data = tmp [ prm ] . clone ( )
print_log ( ' | epoch= {:03d} | test loss {:5.2f} | test ppl {:8.2f} ' . format ( epoch , test_loss , math . exp ( test_loss ) ) , log )
losses [ epoch ] = val_loss
epoch = epoch + 1
# measure elapsed time
epoch_time . update ( time . time ( ) - start_time )
start_time = time . time ( )
print_log ( ' --------------------- Finish Training ---------------- ' , log )
checkpoint = torch . load ( checkpoint_path )
model . load_state_dict ( checkpoint [ ' state_dict ' ] )
test_loss = evaluate ( model , corpus , test_data , config . test_batch , config . bptt )
print_log ( ' | End of training | test loss {:5.2f} | test ppl {:8.2f} ' . format ( test_loss , math . exp ( test_loss ) ) , log )
vali_loss = evaluate ( model , corpus , valid_data , config . eval_batch , config . bptt )
print_log ( ' | End of training | valid loss {:5.2f} | valid ppl {:8.2f} ' . format ( vali_loss , math . exp ( vali_loss ) ) , log )
def evaluate ( model , corpus , data_source , batch_size , bptt ) :
# Turn on evaluation mode which disables dropout.
model . eval ( )
total_loss , total_length = 0.0 , 0.0
with torch . no_grad ( ) :
ntokens = len ( corpus . dictionary )
hidden = model . init_hidden ( batch_size )
for i in range ( 0 , data_source . size ( 0 ) - 1 , bptt ) :
data , targets = get_batch ( data_source , i , bptt )
targets = targets . view ( - 1 )
log_prob , hidden = model ( data , hidden )
loss = nn . functional . nll_loss ( log_prob . view ( - 1 , log_prob . size ( 2 ) ) , targets )
total_loss + = loss . item ( ) * len ( data )
total_length + = len ( data )
hidden = repackage_hidden ( hidden )
return total_loss / total_length
def train ( model , optimizer , corpus , train_data , config , epoch , print_freq , log ) :
# Turn on training mode which enables dropout.
total_loss , data_time , batch_time = 0 , AverageMeter ( ) , AverageMeter ( )
start_time = time . time ( )
ntokens = len ( corpus . dictionary )
hidden_train = model . init_hidden ( config . train_batch )
batch , i = 0 , 0
while i < train_data . size ( 0 ) - 1 - 1 :
bptt = config . bptt if np . random . random ( ) < 0.95 else config . bptt / 2.
# Prevent excessively small or negative sequence lengths
seq_len = max ( 5 , int ( np . random . normal ( bptt , 5 ) ) )
# There's a very small chance that it could select a very long sequence length resulting in OOM
seq_len = min ( seq_len , config . bptt + config . max_seq_len_delta )
lr2 = optimizer . param_groups [ 0 ] [ ' lr ' ]
optimizer . param_groups [ 0 ] [ ' lr ' ] = lr2 * seq_len / config . bptt
model . train ( )
data , targets = get_batch ( train_data , i , seq_len )
targets = targets . contiguous ( ) . view ( - 1 )
# count data preparation time
data_time . update ( time . time ( ) - start_time )
optimizer . zero_grad ( )
hidden_train = repackage_hidden ( hidden_train )
log_prob , hidden_train , rnn_hs , dropped_rnn_hs = model ( data , hidden_train , return_h = True )
raw_loss = nn . functional . nll_loss ( log_prob . view ( - 1 , log_prob . size ( 2 ) ) , targets )
loss = raw_loss
# Activiation Regularization
if config . alpha > 0 :
loss = loss + sum ( config . alpha * dropped_rnn_h . pow ( 2 ) . mean ( ) for dropped_rnn_h in dropped_rnn_hs [ - 1 : ] )
# Temporal Activation Regularization (slowness)
loss = loss + sum ( config . beta * ( rnn_h [ 1 : ] - rnn_h [ : - 1 ] ) . pow ( 2 ) . mean ( ) for rnn_h in rnn_hs [ - 1 : ] )
loss . backward ( )
torch . nn . utils . clip_grad_norm_ ( model . parameters ( ) , config . clip )
optimizer . step ( )
gc . collect ( )
optimizer . param_groups [ 0 ] [ ' lr ' ] = lr2
total_loss + = raw_loss . item ( )
assert torch . isnan ( loss ) == False , ' --- Epoch= {:04d} :: {:03d} / {:03d} Get Loss = Nan ' . format ( epoch , batch , len ( train_data ) / / config . bptt )
batch_time . update ( time . time ( ) - start_time )
start_time = time . time ( )
batch , i = batch + 1 , i + seq_len
if batch % print_freq == 0 :
cur_loss = total_loss / print_freq
print_log ( ' >> Epoch: {:04d} :: {:03d} / {:03d} || loss = {:5.2f} , ppl = {:8.2f} ' . format ( epoch , batch , len ( train_data ) / / config . bptt , cur_loss , math . exp ( cur_loss ) ) , log )
total_loss = 0
return data_time . sum , batch_time . sum