Update xlayers
This commit is contained in:
		| @@ -21,6 +21,57 @@ from procedures.advanced_main import basic_train_fn, basic_eval_fn | ||||
| from procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||
| from datasets.synthetic_core import get_synthetic_env | ||||
| from models.xcore import get_model | ||||
| from xlayers import super_core | ||||
|  | ||||
|  | ||||
| class LFNAmlp: | ||||
|     """A LFNA meta-model that uses the MLP as delta-net.""" | ||||
|  | ||||
|     def __init__(self, obs_dim, hidden_sizes, act_name): | ||||
|         self.delta_net = super_core.SuperSequential( | ||||
|             super_core.SuperLinear(obs_dim, hidden_sizes[0]), | ||||
|             super_core.super_name2activation[act_name](), | ||||
|             super_core.SuperLinear(hidden_sizes[0], hidden_sizes[1]), | ||||
|             super_core.super_name2activation[act_name](), | ||||
|             super_core.SuperLinear(hidden_sizes[1], 1), | ||||
|         ) | ||||
|         self.meta_optimizer = torch.optim.Adam( | ||||
|             self.delta_net.parameters(), lr=0.01, amsgrad=True | ||||
|         ) | ||||
|  | ||||
|     def adapt(self, model, criterion, w_container, xs, ys): | ||||
|         containers = [w_container] | ||||
|         for idx, (x, y) in enumerate(zip(xs, ys)): | ||||
|             y_hat = model.forward_with_container(x, containers[-1]) | ||||
|             loss = criterion(y_hat, y) | ||||
|             gradients = torch.autograd.grad(loss, containers[-1].tensors) | ||||
|             with torch.no_grad(): | ||||
|                 flatten_w = containers[-1].flatten().view(-1, 1) | ||||
|                 flatten_g = containers[-1].flatten(gradients).view(-1, 1) | ||||
|                 input_statistics = torch.tensor([x.mean(), x.std()]).view(1, 2) | ||||
|                 input_statistics = input_statistics.expand(flatten_w.numel(), -1) | ||||
|             delta_inputs = torch.cat((flatten_w, flatten_g, input_statistics), dim=-1) | ||||
|             delta = self.delta_net(delta_inputs).view(-1) | ||||
|             # delta = torch.clamp(delta, -0.5, 0.5) | ||||
|             unflatten_delta = containers[-1].unflatten(delta) | ||||
|             future_container = containers[-1].additive(unflatten_delta) | ||||
|             containers.append(future_container) | ||||
|         # containers = containers[1:] | ||||
|         meta_loss = [] | ||||
|         for idx, (x, y) in enumerate(zip(xs, ys)): | ||||
|             if idx == 0: | ||||
|                 continue | ||||
|             current_container = containers[idx] | ||||
|             y_hat = model.forward_with_container(x, current_container) | ||||
|             loss = criterion(y_hat, y) | ||||
|             meta_loss.append(loss) | ||||
|         meta_loss = sum(meta_loss) | ||||
|         meta_loss.backward() | ||||
|         self.meta_optimizer.step() | ||||
|  | ||||
|     def zero_grad(self): | ||||
|         self.meta_optimizer.zero_grad() | ||||
|         self.delta_net.zero_grad() | ||||
|  | ||||
|  | ||||
| class Population: | ||||
| @@ -28,11 +79,23 @@ class Population: | ||||
|  | ||||
|     def __init__(self): | ||||
|         self._time2model = dict() | ||||
|         self._time2score = dict()  # higher is better | ||||
|  | ||||
|     def append(self, timestamp, model): | ||||
|     def append(self, timestamp, model, score): | ||||
|         if timestamp in self._time2model: | ||||
|             raise ValueError("This timestamp has been added.") | ||||
|         self._time2model[timestamp] = model | ||||
|         self._time2score[timestamp] = score | ||||
|  | ||||
|     def query(self, timestamp): | ||||
|         closet_timestamp = None | ||||
|         for xtime, model in self._time2model.items(): | ||||
|             if ( | ||||
|                 closet_timestamp is None | ||||
|                 or timestamp - closet_timestamp >= timestamp - xtime | ||||
|             ): | ||||
|                 closet_timestamp = xtime | ||||
|         return self._time2model[closet_timestamp], closet_timestamp | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
| @@ -70,100 +133,39 @@ def main(args): | ||||
|     ) | ||||
|  | ||||
|     w_container = base_model.named_parameters_buffers() | ||||
|     criterion = torch.nn.MSELoss() | ||||
|     print("There are {:} weights.".format(w_container.numel())) | ||||
|  | ||||
|     adaptor = LFNAmlp(4, (50, 20), "leaky_relu") | ||||
|  | ||||
|     pool = Population() | ||||
|     pool.append(0, w_container) | ||||
|  | ||||
|     # LFNA meta-training | ||||
|     per_epoch_time, start_time = AverageMeter(), time.time() | ||||
|     for iepoch in range(args.epochs): | ||||
|  | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs) | ||||
|             + need_time | ||||
|         ) | ||||
|  | ||||
|         for ibatch in range(args.meta_batch): | ||||
|             sampled_timestamp = random.randint(0, train_time_bar) | ||||
|             query_w_container, query_timestamp = pool.query(sampled_timestamp) | ||||
|             # def adapt(self, model, w_container, xs, ys): | ||||
|             xs, ys = [], [] | ||||
|             for it in range(sampled_timestamp, sampled_timestamp + args.max_seq): | ||||
|                 xs.append(env_info["{:}-x".format(it)]) | ||||
|                 ys.append(env_info["{:}-y".format(it)]) | ||||
|             adaptor.adapt(base_model, criterion, query_w_container, xs, ys) | ||||
|             import pdb | ||||
|  | ||||
|             pdb.set_trace() | ||||
|         print("-") | ||||
|  | ||||
|     for i, idx in enumerate(to_evaluate_indexes): | ||||
|  | ||||
|         need_time = "Time Left: {:}".format( | ||||
|             convert_secs2time( | ||||
|                 per_timestamp_time.avg * (len(to_evaluate_indexes) - i), True | ||||
|             ) | ||||
|         ) | ||||
|         logger.log( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}][{:04d}]".format(i, len(to_evaluate_indexes), idx) | ||||
|             + " " | ||||
|             + need_time | ||||
|         ) | ||||
|         # train the same data | ||||
|         assert idx != 0 | ||||
|         historical_x = env_info["{:}-x".format(idx)] | ||||
|         historical_y = env_info["{:}-y".format(idx)] | ||||
|         # build model | ||||
|         mean, std = historical_x.mean().item(), historical_x.std().item() | ||||
|         model_kwargs = dict(input_dim=1, output_dim=1, mean=mean, std=std) | ||||
|         model = get_model(dict(model_type="simple_mlp"), **model_kwargs) | ||||
|         # build optimizer | ||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True) | ||||
|         criterion = torch.nn.MSELoss() | ||||
|         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | ||||
|             optimizer, | ||||
|             milestones=[ | ||||
|                 int(args.epochs * 0.25), | ||||
|                 int(args.epochs * 0.5), | ||||
|                 int(args.epochs * 0.75), | ||||
|             ], | ||||
|             gamma=0.3, | ||||
|         ) | ||||
|         train_metric = MSEMetric() | ||||
|         best_loss, best_param = None, None | ||||
|         for _iepoch in range(args.epochs): | ||||
|             preds = model(historical_x) | ||||
|             optimizer.zero_grad() | ||||
|             loss = criterion(preds, historical_y) | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|             lr_scheduler.step() | ||||
|             # save best | ||||
|             if best_loss is None or best_loss > loss.item(): | ||||
|                 best_loss = loss.item() | ||||
|                 best_param = copy.deepcopy(model.state_dict()) | ||||
|         model.load_state_dict(best_param) | ||||
|         with torch.no_grad(): | ||||
|             train_metric(preds, historical_y) | ||||
|         train_results = train_metric.get_info() | ||||
|  | ||||
|         metric = ComposeMetric(MSEMetric(), SaveMetric()) | ||||
|         eval_dataset = torch.utils.data.TensorDataset( | ||||
|             env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)] | ||||
|         ) | ||||
|         eval_loader = torch.utils.data.DataLoader( | ||||
|             eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 | ||||
|         ) | ||||
|         results = basic_eval_fn(eval_loader, model, metric, logger) | ||||
|         log_str = ( | ||||
|             "[{:}]".format(time_string()) | ||||
|             + " [{:04d}/{:04d}]".format(idx, env_info["total"]) | ||||
|             + " train-mse: {:.5f}, eval-mse: {:.5f}".format( | ||||
|                 train_results["mse"], results["mse"] | ||||
|             ) | ||||
|         ) | ||||
|         logger.log(log_str) | ||||
|  | ||||
|         save_path = logger.path(None) / "{:04d}-{:04d}.pth".format( | ||||
|             idx, env_info["total"] | ||||
|         ) | ||||
|         save_checkpoint( | ||||
|             { | ||||
|                 "model_state_dict": model.state_dict(), | ||||
|                 "model": model, | ||||
|                 "index": idx, | ||||
|                 "timestamp": env_info["{:}-timestamp".format(idx)], | ||||
|             }, | ||||
|             save_path, | ||||
|             logger, | ||||
|         ) | ||||
|         logger.log("") | ||||
|  | ||||
|         per_timestamp_time.update(time.time() - start_time) | ||||
| @@ -188,10 +190,10 @@ if __name__ == "__main__": | ||||
|         help="The initial learning rate for the optimizer (default is Adam)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--batch_size", | ||||
|         "--meta_batch", | ||||
|         type=int, | ||||
|         default=512, | ||||
|         help="The batch size", | ||||
|         default=2, | ||||
|         help="The batch size for the meta-model", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--epochs", | ||||
| @@ -199,6 +201,12 @@ if __name__ == "__main__": | ||||
|         default=1000, | ||||
|         help="The total number of epochs.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--max_seq", | ||||
|         type=int, | ||||
|         default=5, | ||||
|         help="The maximum length of the sequence.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--workers", | ||||
|         type=int, | ||||
|   | ||||
| @@ -34,3 +34,4 @@ def get_model(config: Dict[Text, Any], **kwargs): | ||||
|     else: | ||||
|         raise TypeError("Unkonwn model type: {:}".format(model_type)) | ||||
|     return model | ||||
|  | ||||
|   | ||||
| @@ -31,6 +31,9 @@ class SuperReLU(SuperModule): | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         return F.relu(input, inplace=self._inplace) | ||||
|  | ||||
|     def forward_with_container(self, input, container, prefix=[]): | ||||
|         return self.forward_raw(input) | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "inplace=True" if self._inplace else "" | ||||
|  | ||||
| @@ -53,6 +56,29 @@ class SuperLeakyReLU(SuperModule): | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         return F.leaky_relu(input, self._negative_slope, self._inplace) | ||||
|  | ||||
|     def forward_with_container(self, input, container, prefix=[]): | ||||
|         return self.forward_raw(input) | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         inplace_str = "inplace=True" if self._inplace else "" | ||||
|         return "negative_slope={}{}".format(self._negative_slope, inplace_str) | ||||
|  | ||||
|  | ||||
| class SuperTanh(SuperModule): | ||||
|     """Applies a the Tanh function element-wise.""" | ||||
|  | ||||
|     def __init__(self) -> None: | ||||
|         super(SuperTanh, self).__init__() | ||||
|  | ||||
|     @property | ||||
|     def abstract_search_space(self): | ||||
|         return spaces.VirtualNode(id(self)) | ||||
|  | ||||
|     def forward_candidate(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         return self.forward_raw(input) | ||||
|  | ||||
|     def forward_raw(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         return torch.tanh(input) | ||||
|  | ||||
|     def forward_with_container(self, input, container, prefix=[]): | ||||
|         return self.forward_raw(input) | ||||
|   | ||||
| @@ -111,3 +111,10 @@ class SuperSequential(SuperModule): | ||||
|         for module in self: | ||||
|             input = module(input) | ||||
|         return input | ||||
|  | ||||
|     def forward_with_container(self, input, container, prefix=[]): | ||||
|         for index, module in enumerate(self): | ||||
|             input = module.forward_with_container( | ||||
|                 input, container, prefix + [str(index)] | ||||
|             ) | ||||
|         return input | ||||
|   | ||||
| @@ -27,8 +27,13 @@ from .super_transformer import SuperTransformerEncoderLayer | ||||
|  | ||||
| from .super_activations import SuperReLU | ||||
| from .super_activations import SuperLeakyReLU | ||||
| from .super_activations import SuperTanh | ||||
|  | ||||
| super_name2activation = {"relu": SuperReLU, "leaky_relu": SuperLeakyReLU} | ||||
| super_name2activation = { | ||||
|     "relu": SuperReLU, | ||||
|     "leaky_relu": SuperLeakyReLU, | ||||
|     "tanh": SuperTanh, | ||||
| } | ||||
|  | ||||
|  | ||||
| from .super_trade_stem import SuperAlphaEBDv1 | ||||
|   | ||||
| @@ -115,6 +115,16 @@ class SuperLinear(SuperModule): | ||||
|             self._in_features, self._out_features, self._bias | ||||
|         ) | ||||
|  | ||||
|     def forward_with_container(self, input, container, prefix=[]): | ||||
|         super_weight_name = ".".join(prefix + ["_super_weight"]) | ||||
|         super_weight = container.query(super_weight_name) | ||||
|         super_bias_name = ".".join(prefix + ["_super_bias"]) | ||||
|         if container.has(super_bias_name): | ||||
|             super_bias = container.query(super_bias_name) | ||||
|         else: | ||||
|             super_bias = None | ||||
|         return F.linear(input, super_weight, super_bias) | ||||
|  | ||||
|  | ||||
| class SuperMLPv1(SuperModule): | ||||
|     """An MLP layer: FC -> Activation -> Drop -> FC -> Drop.""" | ||||
|   | ||||
| @@ -39,6 +39,41 @@ class TensorContainer: | ||||
|         self._param_or_buffers = [] | ||||
|         self._name2index = dict() | ||||
|  | ||||
|     def additive(self, tensors): | ||||
|         result = TensorContainer() | ||||
|         for index, name in enumerate(self._names): | ||||
|             new_tensor = self._tensors[index] + tensors[index] | ||||
|             result.append(name, new_tensor, self._param_or_buffers[index]) | ||||
|         return result | ||||
|  | ||||
|     def no_grad_clone(self): | ||||
|         result = TensorContainer() | ||||
|         with torch.no_grad(): | ||||
|             for index, name in enumerate(self._names): | ||||
|                 result.append( | ||||
|                     name, self._tensors[index].clone(), self._param_or_buffers[index] | ||||
|                 ) | ||||
|         return result | ||||
|  | ||||
|     @property | ||||
|     def tensors(self): | ||||
|         return self._tensors | ||||
|  | ||||
|     def flatten(self, tensors=None): | ||||
|         if tensors is None: | ||||
|             tensors = self._tensors | ||||
|         tensors = [tensor.view(-1) for tensor in tensors] | ||||
|         return torch.cat(tensors) | ||||
|  | ||||
|     def unflatten(self, tensor): | ||||
|         tensors, s = [], 0 | ||||
|         for raw_tensor in self._tensors: | ||||
|             length = raw_tensor.numel() | ||||
|             x = torch.reshape(tensor[s : s + length], shape=raw_tensor.shape) | ||||
|             tensors.append(x) | ||||
|             s += length | ||||
|         return tensors | ||||
|  | ||||
|     def append(self, name, tensor, param_or_buffer): | ||||
|         if not isinstance(tensor, torch.Tensor): | ||||
|             raise TypeError( | ||||
| @@ -54,6 +89,23 @@ class TensorContainer: | ||||
|         ) | ||||
|         self._name2index[name] = len(self._names) - 1 | ||||
|  | ||||
|     def query(self, name): | ||||
|         if not self.has(name): | ||||
|             raise ValueError( | ||||
|                 "The {:} is not in {:}".format(name, list(self._name2index.keys())) | ||||
|             ) | ||||
|         index = self._name2index[name] | ||||
|         return self._tensors[index] | ||||
|  | ||||
|     def has(self, name): | ||||
|         return name in self._name2index | ||||
|  | ||||
|     def has_prefix(self, prefix): | ||||
|         for name, idx in self._name2index.items(): | ||||
|             if name.startswith(prefix): | ||||
|                 return name | ||||
|         return False | ||||
|  | ||||
|     def numel(self): | ||||
|         total = 0 | ||||
|         for tensor in self._tensors: | ||||
| @@ -181,3 +233,6 @@ class SuperModule(abc.ABC, nn.Module): | ||||
|                 ) | ||||
|             ) | ||||
|         return outputs | ||||
|  | ||||
|     def forward_with_container(self, inputs, container, prefix=[]): | ||||
|         raise NotImplementedError | ||||
|   | ||||
| @@ -161,6 +161,21 @@ class SuperSimpleLearnableNorm(SuperModule): | ||||
|             mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0) | ||||
|         return tensor.sub_(mean).div_(std) | ||||
|  | ||||
|     def forward_with_container(self, input, container, prefix=[]): | ||||
|         if not self._inplace: | ||||
|             tensor = input.clone() | ||||
|         else: | ||||
|             tensor = input | ||||
|         mean_name = ".".join(prefix + ["_mean"]) | ||||
|         std_name = ".".join(prefix + ["_std"]) | ||||
|         mean, std = ( | ||||
|             container.query(mean_name).to(tensor.device), | ||||
|             torch.abs(container.query(std_name).to(tensor.device)) + self._eps, | ||||
|         ) | ||||
|         while mean.ndim < tensor.ndim: | ||||
|             mean, std = torch.unsqueeze(mean, dim=0), torch.unsqueeze(std, dim=0) | ||||
|         return tensor.sub_(mean).div_(std) | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "mean={mean}, std={std}, inplace={inplace}".format( | ||||
|             mean=self._mean.item(), std=self._std.item(), inplace=self._inplace | ||||
| @@ -191,3 +206,6 @@ class SuperIdentity(SuperModule): | ||||
|  | ||||
|     def extra_repr(self) -> str: | ||||
|         return "inplace={inplace}".format(inplace=self._inplace) | ||||
|  | ||||
|     def forward_with_container(self, input, container, prefix=[]): | ||||
|         return self.forward_raw(input) | ||||
|   | ||||
							
								
								
									
										120
									
								
								lib/xlayers/super_rl_actor.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										120
									
								
								lib/xlayers/super_rl_actor.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,120 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| # DISABLED / NOT-FINISHED | ||||
| ##################################################### | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
|  | ||||
| import math | ||||
| from typing import Optional, Callable | ||||
|  | ||||
| import spaces | ||||
| from .super_container import SuperSequential | ||||
| from .super_linear import SuperLinear | ||||
|  | ||||
|  | ||||
| class SuperActor(SuperModule): | ||||
|     """A Actor in RL.""" | ||||
|  | ||||
|     def _distribution(self, obs): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def _log_prob_from_distribution(self, pi, act): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def forward_candidate(self, **kwargs): | ||||
|         return self.forward_raw(**kwargs) | ||||
|  | ||||
|     def forward_raw(self, obs, act=None): | ||||
|         # Produce action distributions for given observations, and | ||||
|         # optionally compute the log likelihood of given actions under | ||||
|         # those distributions. | ||||
|         pi = self._distribution(obs) | ||||
|         logp_a = None | ||||
|         if act is not None: | ||||
|             logp_a = self._log_prob_from_distribution(pi, act) | ||||
|         return pi, logp_a | ||||
|  | ||||
|  | ||||
| class SuperLfnaMetaMLP(SuperModule): | ||||
|     def __init__(self, obs_dim, hidden_sizes, act_cls): | ||||
|         super(SuperLfnaMetaMLP).__init__() | ||||
|         self.delta_net = SuperSequential( | ||||
|             SuperLinear(obs_dim, hidden_sizes[0]), | ||||
|             act_cls(), | ||||
|             SuperLinear(hidden_sizes[0], hidden_sizes[1]), | ||||
|             act_cls(), | ||||
|             SuperLinear(hidden_sizes[1], 1), | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class SuperLfnaMetaMLP(SuperModule): | ||||
|     def __init__(self, obs_dim, act_dim, hidden_sizes, act_cls): | ||||
|         super(SuperLfnaMetaMLP).__init__() | ||||
|         log_std = -0.5 * np.ones(act_dim, dtype=np.float32) | ||||
|         self.log_std = torch.nn.Parameter(torch.as_tensor(log_std)) | ||||
|         self.mu_net = SuperSequential( | ||||
|             SuperLinear(obs_dim, hidden_sizes[0]), | ||||
|             act_cls(), | ||||
|             SuperLinear(hidden_sizes[0], hidden_sizes[1]), | ||||
|             act_cls(), | ||||
|             SuperLinear(hidden_sizes[1], act_dim), | ||||
|         ) | ||||
|  | ||||
|     def _distribution(self, obs): | ||||
|         mu = self.mu_net(obs) | ||||
|         std = torch.exp(self.log_std) | ||||
|         return Normal(mu, std) | ||||
|  | ||||
|     def _log_prob_from_distribution(self, pi, act): | ||||
|         return pi.log_prob(act).sum(axis=-1) | ||||
|  | ||||
|     def forward_candidate(self, **kwargs): | ||||
|         return self.forward_raw(**kwargs) | ||||
|  | ||||
|     def forward_raw(self, obs, act=None): | ||||
|         # Produce action distributions for given observations, and | ||||
|         # optionally compute the log likelihood of given actions under | ||||
|         # those distributions. | ||||
|         pi = self._distribution(obs) | ||||
|         logp_a = None | ||||
|         if act is not None: | ||||
|             logp_a = self._log_prob_from_distribution(pi, act) | ||||
|         return pi, logp_a | ||||
|  | ||||
|  | ||||
| class SuperMLPGaussianActor(SuperModule): | ||||
|     def __init__(self, obs_dim, act_dim, hidden_sizes, act_cls): | ||||
|         super(SuperMLPGaussianActor).__init__() | ||||
|         log_std = -0.5 * np.ones(act_dim, dtype=np.float32) | ||||
|         self.log_std = torch.nn.Parameter(torch.as_tensor(log_std)) | ||||
|         self.mu_net = SuperSequential( | ||||
|             SuperLinear(obs_dim, hidden_sizes[0]), | ||||
|             act_cls(), | ||||
|             SuperLinear(hidden_sizes[0], hidden_sizes[1]), | ||||
|             act_cls(), | ||||
|             SuperLinear(hidden_sizes[1], act_dim), | ||||
|         ) | ||||
|  | ||||
|     def _distribution(self, obs): | ||||
|         mu = self.mu_net(obs) | ||||
|         std = torch.exp(self.log_std) | ||||
|         return Normal(mu, std) | ||||
|  | ||||
|     def _log_prob_from_distribution(self, pi, act): | ||||
|         return pi.log_prob(act).sum(axis=-1) | ||||
|  | ||||
|     def forward_candidate(self, **kwargs): | ||||
|         return self.forward_raw(**kwargs) | ||||
|  | ||||
|     def forward_raw(self, obs, act=None): | ||||
|         # Produce action distributions for given observations, and | ||||
|         # optionally compute the log likelihood of given actions under | ||||
|         # those distributions. | ||||
|         pi = self._distribution(obs) | ||||
|         logp_a = None | ||||
|         if act is not None: | ||||
|             logp_a = self._log_prob_from_distribution(pi, act) | ||||
|         return pi, logp_a | ||||
		Reference in New Issue
	
	Block a user