now we add reward wait to test
This commit is contained in:
		| @@ -312,6 +312,7 @@ def test(cfg: DictConfig): | ||||
|         #     reward = 1.0 | ||||
|         #     rewards.append(reward) | ||||
|         return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device) | ||||
|     old_log_probs = None | ||||
|     while samples_left_to_generate > 0: | ||||
|         print(f'samples left to generate: {samples_left_to_generate}/' | ||||
|             f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) | ||||
| @@ -327,6 +328,15 @@ def test(cfg: DictConfig): | ||||
|         samples = samples + cur_sample | ||||
|         reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) | ||||
|         advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) | ||||
|         if old_log_probs is None: | ||||
|             old_log_probs = log_probs.clone() | ||||
|         ratio = torch.exp(log_probs - old_log_probs) | ||||
|         unclipped_loss = -advantages * ratio | ||||
|         clipped_loss = -advantages * torch.clamp(ratio, 1.0 - cfg.ppo.clip_param, 1.0 + cfg.ppo.clip_param) | ||||
|         loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) | ||||
|         accelerator.backward(loss) | ||||
|         optimizer.step() | ||||
|         optimizer.zero_grad() | ||||
|  | ||||
|  | ||||
|         samples_with_log_probs.append((cur_sample, log_probs, reward)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user