Fix bugs in new env-v1
This commit is contained in:
		| @@ -28,7 +28,7 @@ from xautodl.utils import split_str2indexes | |||||||
|  |  | ||||||
| from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn | from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn | ||||||
| from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric | ||||||
| from xautodl.datasets.synthetic_core import get_synthetic_env, EnvSampler | from xautodl.datasets.synthetic_core import get_synthetic_env | ||||||
| from xautodl.models.xcore import get_model | from xautodl.models.xcore import get_model | ||||||
| from xautodl.xlayers import super_core, trunc_normal_ | from xautodl.xlayers import super_core, trunc_normal_ | ||||||
|  |  | ||||||
| @@ -244,7 +244,7 @@ def main(args): | |||||||
|         args.time_dim, |         args.time_dim, | ||||||
|         timestamps, |         timestamps, | ||||||
|         seq_length=args.seq_length, |         seq_length=args.seq_length, | ||||||
|         interval=train_env.timestamp_interval, |         interval=train_env.time_interval, | ||||||
|     ) |     ) | ||||||
|     meta_model = meta_model.to(args.device) |     meta_model = meta_model.to(args.device) | ||||||
|  |  | ||||||
| @@ -253,7 +253,7 @@ def main(args): | |||||||
|     logger.log("The base-model is\n{:}".format(base_model)) |     logger.log("The base-model is\n{:}".format(base_model)) | ||||||
|     logger.log("The meta-model is\n{:}".format(meta_model)) |     logger.log("The meta-model is\n{:}".format(meta_model)) | ||||||
|  |  | ||||||
|     batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) |     # batch_sampler = EnvSampler(train_env, args.meta_batch, args.sampler_enlarge) | ||||||
|     pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) |     pretrain_v2(base_model, meta_model, criterion, train_env, args, logger) | ||||||
|  |  | ||||||
|     # try to evaluate once |     # try to evaluate once | ||||||
| @@ -387,7 +387,7 @@ def main(args): | |||||||
|         future_time = env_info["{:}-timestamp".format(idx)].item() |         future_time = env_info["{:}-timestamp".format(idx)].item() | ||||||
|         time_seqs = [] |         time_seqs = [] | ||||||
|         for iseq in range(args.seq_length): |         for iseq in range(args.seq_length): | ||||||
|             time_seqs.append(future_time - iseq * eval_env.timestamp_interval) |             time_seqs.append(future_time - iseq * eval_env.time_interval) | ||||||
|         time_seqs.reverse() |         time_seqs.reverse() | ||||||
|         with torch.no_grad(): |         with torch.no_grad(): | ||||||
|             meta_model.eval() |             meta_model.eval() | ||||||
| @@ -409,7 +409,7 @@ def main(args): | |||||||
|  |  | ||||||
|         # creating the new meta-time-embedding |         # creating the new meta-time-embedding | ||||||
|         distance = meta_model.get_closest_meta_distance(future_time) |         distance = meta_model.get_closest_meta_distance(future_time) | ||||||
|         if distance < eval_env.timestamp_interval: |         if distance < eval_env.time_interval: | ||||||
|             continue |             continue | ||||||
|         # |         # | ||||||
|         new_param = meta_model.create_meta_embed() |         new_param = meta_model.create_meta_embed() | ||||||
|   | |||||||
| @@ -16,8 +16,8 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         shape_container, |         shape_container, | ||||||
|         layer_embedding, |         layer_dim, | ||||||
|         time_embedding, |         time_dim, | ||||||
|         meta_timestamps, |         meta_timestamps, | ||||||
|         mha_depth: int = 2, |         mha_depth: int = 2, | ||||||
|         dropout: float = 0.1, |         dropout: float = 0.1, | ||||||
| @@ -39,53 +39,41 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|  |  | ||||||
|         self.register_parameter( |         self.register_parameter( | ||||||
|             "_super_layer_embed", |             "_super_layer_embed", | ||||||
|             torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embedding)), |             torch.nn.Parameter(torch.Tensor(self._num_layers, layer_dim)), | ||||||
|         ) |         ) | ||||||
|         self.register_parameter( |         self.register_parameter( | ||||||
|             "_super_meta_embed", |             "_super_meta_embed", | ||||||
|             torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_embedding)), |             torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_dim)), | ||||||
|         ) |         ) | ||||||
|         self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps)) |         self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps)) | ||||||
|         # register a time difference buffer |         # register a time difference buffer | ||||||
|         time_interval = [-i * self._interval for i in range(self._seq_length)] |         time_interval = [-i * self._interval for i in range(self._seq_length)] | ||||||
|         time_interval.reverse() |         time_interval.reverse() | ||||||
|         self.register_buffer("_time_interval", torch.Tensor(time_interval)) |         self.register_buffer("_time_interval", torch.Tensor(time_interval)) | ||||||
|         self._time_embed_dim = time_embedding |         self._time_embed_dim = time_dim | ||||||
|         self._append_meta_embed = dict(fixed=None, learnt=None) |         self._append_meta_embed = dict(fixed=None, learnt=None) | ||||||
|         self._append_meta_timestamps = dict(fixed=None, learnt=None) |         self._append_meta_timestamps = dict(fixed=None, learnt=None) | ||||||
|  |  | ||||||
|         self._tscalar_embed = super_core.SuperDynamicPositionE( |         self._tscalar_embed = super_core.SuperDynamicPositionE( | ||||||
|             time_embedding, scale=500 |             time_dim, scale=1 / interval | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         # build transformer |         # build transformer | ||||||
|         self._trans_att = super_core.SuperQKVAttentionV2( |         self._trans_att = super_core.SuperQKVAttentionV2( | ||||||
|             qk_att_dim=time_embedding, |             qk_att_dim=time_dim, | ||||||
|             in_v_dim=time_embedding, |             in_v_dim=time_dim, | ||||||
|             hidden_dim=time_embedding, |             hidden_dim=time_dim, | ||||||
|             num_heads=4, |             num_heads=4, | ||||||
|             proj_dim=time_embedding, |             proj_dim=time_dim, | ||||||
|             qkv_bias=True, |             qkv_bias=True, | ||||||
|             attn_drop=None, |             attn_drop=None, | ||||||
|             proj_drop=dropout, |             proj_drop=dropout, | ||||||
|         ) |         ) | ||||||
|         """ |  | ||||||
|         self._trans_att = super_core.SuperQKVAttention( |  | ||||||
|             time_embedding, |  | ||||||
|             time_embedding, |  | ||||||
|             time_embedding, |  | ||||||
|             time_embedding, |  | ||||||
|             num_heads=4, |  | ||||||
|             qkv_bias=True, |  | ||||||
|             attn_drop=None, |  | ||||||
|             proj_drop=dropout, |  | ||||||
|         ) |  | ||||||
|         """ |  | ||||||
|         layers = [] |         layers = [] | ||||||
|         for ilayer in range(mha_depth): |         for ilayer in range(mha_depth): | ||||||
|             layers.append( |             layers.append( | ||||||
|                 super_core.SuperTransformerEncoderLayer( |                 super_core.SuperTransformerEncoderLayer( | ||||||
|                     time_embedding * 2, |                     time_dim * 2, | ||||||
|                     4, |                     4, | ||||||
|                     True, |                     True, | ||||||
|                     4, |                     4, | ||||||
| @@ -95,14 +83,14 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|                     use_mask=True, |                     use_mask=True, | ||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|         layers.append(super_core.SuperLinear(time_embedding * 2, time_embedding)) |         layers.append(super_core.SuperLinear(time_dim * 2, time_dim)) | ||||||
|         self._meta_corrector = super_core.SuperSequential(*layers) |         self._meta_corrector = super_core.SuperSequential(*layers) | ||||||
|  |  | ||||||
|         model_kwargs = dict( |         model_kwargs = dict( | ||||||
|             config=dict(model_type="dual_norm_mlp"), |             config=dict(model_type="dual_norm_mlp"), | ||||||
|             input_dim=layer_embedding + time_embedding, |             input_dim=layer_dim + time_dim, | ||||||
|             output_dim=max(self._numel_per_layer), |             output_dim=max(self._numel_per_layer), | ||||||
|             hidden_dims=[(layer_embedding + time_embedding) * 2] * 3, |             hidden_dims=[(layer_dim + time_dim) * 2] * 3, | ||||||
|             act_cls="gelu", |             act_cls="gelu", | ||||||
|             norm_cls="layer_norm_1d", |             norm_cls="layer_norm_1d", | ||||||
|             dropout=dropout, |             dropout=dropout, | ||||||
| @@ -193,11 +181,6 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|         # timestamps is a batch of sequence of timestamps |         # timestamps is a batch of sequence of timestamps | ||||||
|         batch, seq = timestamps.shape |         batch, seq = timestamps.shape | ||||||
|         meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed |         meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed | ||||||
|         """ |  | ||||||
|         timestamp_q_embed = self._tscalar_embed(timestamps) |  | ||||||
|         timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1)) |  | ||||||
|         timestamp_v_embed = meta_embeds.unsqueeze(dim=0) |  | ||||||
|         """ |  | ||||||
|         timestamp_v_embed = meta_embeds.unsqueeze(dim=0) |         timestamp_v_embed = meta_embeds.unsqueeze(dim=0) | ||||||
|         timestamp_qk_att_embed = self._tscalar_embed( |         timestamp_qk_att_embed = self._tscalar_embed( | ||||||
|             torch.unsqueeze(timestamps, dim=-1) - meta_timestamps |             torch.unsqueeze(timestamps, dim=-1) - meta_timestamps | ||||||
| @@ -212,7 +195,6 @@ class LFNA_Meta(super_core.SuperModule): | |||||||
|             > self._thresh |             > self._thresh | ||||||
|         ) |         ) | ||||||
|         timestamp_embeds = self._trans_att( |         timestamp_embeds = self._trans_att( | ||||||
|             # timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask |  | ||||||
|             timestamp_qk_att_embed, |             timestamp_qk_att_embed, | ||||||
|             timestamp_v_embed, |             timestamp_v_embed, | ||||||
|             mask, |             mask, | ||||||
|   | |||||||
| @@ -21,6 +21,8 @@ class DynamicGenerator(abc.ABC): | |||||||
|  |  | ||||||
|  |  | ||||||
| class GaussianDGenerator(DynamicGenerator): | class GaussianDGenerator(DynamicGenerator): | ||||||
|  |     """Generate data from Gaussian distribution.""" | ||||||
|  |  | ||||||
|     def __init__(self, mean_functors, cov_functors, trunc=(-1, 1)): |     def __init__(self, mean_functors, cov_functors, trunc=(-1, 1)): | ||||||
|         super(GaussianDGenerator, self).__init__() |         super(GaussianDGenerator, self).__init__() | ||||||
|         self._ndim = assert_list_tuple(mean_functors) |         self._ndim = assert_list_tuple(mean_functors) | ||||||
| @@ -41,6 +43,10 @@ class GaussianDGenerator(DynamicGenerator): | |||||||
|             assert assert_list_tuple(trunc) == 2 and trunc[0] < trunc[1] |             assert assert_list_tuple(trunc) == 2 and trunc[0] < trunc[1] | ||||||
|         self._trunc = trunc |         self._trunc = trunc | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def ndim(self): | ||||||
|  |         return self._ndim | ||||||
|  |  | ||||||
|     def __call__(self, time, num): |     def __call__(self, time, num): | ||||||
|         mean_list = [functor(time) for functor in self._mean_functors] |         mean_list = [functor(time) for functor in self._mean_functors] | ||||||
|         cov_matrix = [ |         cov_matrix = [ | ||||||
|   | |||||||
| @@ -115,7 +115,7 @@ class SyntheticDEnv(data.Dataset): | |||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             cur_num=len(self), |             cur_num=len(self), | ||||||
|             total=len(self._time_generator), |             total=len(self._time_generator), | ||||||
|             ndim=self._ndim, |             ndim=self._data_generator.ndim, | ||||||
|             num_per_task=self._num_per_task, |             num_per_task=self._num_per_task, | ||||||
|             xrange_min=self.min_timestamp, |             xrange_min=self.min_timestamp, | ||||||
|             xrange_max=self.max_timestamp, |             xrange_max=self.max_timestamp, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user