try to deploy PPO policy

This commit is contained in:
mhz
2024-09-09 23:50:10 +02:00
parent 297261d666
commit 97fbdf91c7
2 changed files with 58 additions and 7 deletions

View File

@@ -152,6 +152,7 @@ from accelerate.utils import set_seed, ProjectConfiguration
version_base="1.1", config_path="../configs", config_name="config"
)
def test(cfg: DictConfig):
os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number
accelerator_config = ProjectConfiguration(
project_dir=os.path.join(cfg.general.log_dir, cfg.general.name),
automatic_checkpoint_naming=True,
@@ -162,6 +163,11 @@ def test(cfg: DictConfig):
project_config=accelerator_config,
gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.train.n_epochs,
)
# Debug: 确认可用设备
print(f"Available GPUs: {torch.cuda.device_count()}")
print(f"Using device: {accelerator.device}")
set_seed(cfg.train.seed, device_specific=True)
datamodule = dataset.DataModule(cfg)
@@ -185,13 +191,16 @@ def test(cfg: DictConfig):
"visualization_tools": visulization_tools,
}
# Debug: 确认可用设备
print(f"Available GPUs: {torch.cuda.device_count()}")
print(f"Using device: {accelerator.device}")
if cfg.general.test_only:
cfg, _ = get_resume(cfg, model_kwargs)
os.chdir(cfg.general.test_only.split("checkpoints")[0])
elif cfg.general.resume is not None:
cfg, _ = get_resume_adaptive(cfg, model_kwargs)
os.chdir(cfg.general.resume.split("checkpoints")[0])
# os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number
model = Graph_DiT(cfg=cfg, **model_kwargs)
graph_dit_model = model
@@ -201,6 +210,7 @@ def test(cfg: DictConfig):
# optional: freeze the model
# graph_dit_model.model.requires_grad_(True)
import torch.nn.functional as F
optimizer = graph_dit_model.configure_optimizers()
train_dataloader = accelerator.prepare(datamodule.train_dataloader())
@@ -256,13 +266,19 @@ def test(cfg: DictConfig):
chains_left_to_save = cfg.general.final_model_chains_to_save
samples, all_ys, batch_id = [], [], 0
samples_with_log_probs = []
test_y_collection = torch.cat(graph_dit_model.test_y_collection, dim=0)
num_examples = test_y_collection.size(0)
if cfg.general.final_model_samples_to_generate > num_examples:
ratio = cfg.general.final_model_samples_to_generate // num_examples
test_y_collection = test_y_collection.repeat(ratio+1, 1)
num_examples = test_y_collection.size(0)
def graph_reward_fn(graphs, true_graphs=None, device=None):
rewards = []
for graph in graphs:
reward = 1.0
rewards.append(reward)
return torch.tensor(rewards, dtype=torch.float32).unsqueeze(0).to(device)
while samples_left_to_generate > 0:
print(f'samples left to generate: {samples_left_to_generate}/'
f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
@@ -273,9 +289,12 @@ def test(cfg: DictConfig):
# batch_y = test_y_collection[batch_id : batch_id + to_generate]
batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device)
cur_sample = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)[0]
cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)
samples = samples + cur_sample
reward = graph_reward_fn(cur_sample, device=graph_dit_model.device)
samples_with_log_probs.append((cur_sample, log_probs, reward))
all_ys.append(batch_y)
batch_id += to_generate
@@ -294,6 +313,35 @@ def test(cfg: DictConfig):
print("Samples:")
print(samples)
perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device)
samples, log_probs, rewards = samples_with_log_probs[perm]
samples = list(samples)
log_probs = list(log_probs)
print(f'log_probs: {log_probs[:5]}')
print(f'log_probs: {log_probs[0].shape}') # torch.Size([1000])
rewards = list(rewards)
for inner_epoch in range(cfg.train.n_epochs):
# print(f'rewards: {rewards[0].shape}') # torch.Size([1000])
rewards = torch.cat(rewards, dim=0)
print(f'rewards: {rewards.shape}')
advantages = (rewards - torch.mean(rewards)) / (torch.std(rewards) + 1e-6)
old_log_probs = log_probs.copy()
with accelerator.accumulate(graph_dit_model):
ratio = torch.exp(log_probs - old_log_probs)
unclipped_loss = -advantages * ratio
clipped_loss = -advantages * torch.clamp(ratio,
1.0 - cfg.ppo.clip_param,
1.0 + cfg.ppo.clip_param)
loss = torch.mean(torch.max(unclipped_loss, clipped_loss))
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
accelerator.log({"loss": loss.item(), "epoch": inner_epoch})
print(f"loss: {loss.item()}, epoch: {inner_epoch}")
# trainer = Trainer(
# gradient_clip_val=cfg.train.clip_grad,
# # accelerator="cpu",