| 
									
										
										
										
											2019-11-15 17:15:07 +11:00
										 |  |  | ################################################## | 
					
						
							|  |  |  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | 
					
						
							| 
									
										
										
										
											2019-11-09 01:36:31 +11:00
										 |  |  | ########################################################################## | 
					
						
							|  |  |  | # Efficient Neural Architecture Search via Parameters Sharing, ICML 2018 # | 
					
						
							|  |  |  | ########################################################################## | 
					
						
							|  |  |  | import torch | 
					
						
							|  |  |  | import torch.nn as nn | 
					
						
							|  |  |  | from torch.distributions.categorical import Categorical | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Controller(nn.Module): | 
					
						
							|  |  |  |   # we refer to https://github.com/TDeVries/enas_pytorch/blob/master/models/controller.py | 
					
						
							|  |  |  |   def __init__(self, num_edge, num_ops, lstm_size=32, lstm_num_layers=2, tanh_constant=2.5, temperature=5.0): | 
					
						
							|  |  |  |     super(Controller, self).__init__() | 
					
						
							|  |  |  |     # assign the attributes | 
					
						
							|  |  |  |     self.num_edge  = num_edge | 
					
						
							|  |  |  |     self.num_ops   = num_ops | 
					
						
							|  |  |  |     self.lstm_size = lstm_size | 
					
						
							|  |  |  |     self.lstm_N    = lstm_num_layers | 
					
						
							|  |  |  |     self.tanh_constant = tanh_constant | 
					
						
							|  |  |  |     self.temperature   = temperature | 
					
						
							|  |  |  |     # create parameters | 
					
						
							|  |  |  |     self.register_parameter('input_vars', nn.Parameter(torch.Tensor(1, 1, lstm_size))) | 
					
						
							|  |  |  |     self.w_lstm = nn.LSTM(input_size=self.lstm_size, hidden_size=self.lstm_size, num_layers=self.lstm_N) | 
					
						
							|  |  |  |     self.w_embd = nn.Embedding(self.num_ops, self.lstm_size) | 
					
						
							|  |  |  |     self.w_pred = nn.Linear(self.lstm_size, self.num_ops) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     nn.init.uniform_(self.input_vars         , -0.1, 0.1) | 
					
						
							|  |  |  |     nn.init.uniform_(self.w_lstm.weight_hh_l0, -0.1, 0.1) | 
					
						
							|  |  |  |     nn.init.uniform_(self.w_lstm.weight_ih_l0, -0.1, 0.1) | 
					
						
							|  |  |  |     nn.init.uniform_(self.w_embd.weight      , -0.1, 0.1) | 
					
						
							|  |  |  |     nn.init.uniform_(self.w_pred.weight      , -0.1, 0.1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   def forward(self): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     inputs, h0 = self.input_vars, None | 
					
						
							|  |  |  |     log_probs, entropys, sampled_arch = [], [], [] | 
					
						
							|  |  |  |     for iedge in range(self.num_edge): | 
					
						
							|  |  |  |       outputs, h0 = self.w_lstm(inputs, h0) | 
					
						
							|  |  |  |        | 
					
						
							|  |  |  |       logits = self.w_pred(outputs) | 
					
						
							|  |  |  |       logits = logits / self.temperature | 
					
						
							|  |  |  |       logits = self.tanh_constant * torch.tanh(logits) | 
					
						
							|  |  |  |       # distribution | 
					
						
							|  |  |  |       op_distribution = Categorical(logits=logits) | 
					
						
							|  |  |  |       op_index    = op_distribution.sample() | 
					
						
							|  |  |  |       sampled_arch.append( op_index.item() ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       op_log_prob = op_distribution.log_prob(op_index) | 
					
						
							|  |  |  |       log_probs.append( op_log_prob.view(-1) ) | 
					
						
							|  |  |  |       op_entropy  = op_distribution.entropy() | 
					
						
							|  |  |  |       entropys.append( op_entropy.view(-1) ) | 
					
						
							|  |  |  |        | 
					
						
							|  |  |  |       # obtain the input embedding for the next step | 
					
						
							|  |  |  |       inputs = self.w_embd(op_index) | 
					
						
							|  |  |  |     return torch.sum(torch.cat(log_probs)), torch.sum(torch.cat(entropys)), sampled_arch |