diff --git a/scripts/train.py b/scripts/train.py index 58395d9..6e7214c 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -263,8 +263,8 @@ def main(_): accelerator.log( { "images": [ - wandb.Image(os.path.join(tmpdir, f"{i}.jpg"), caption=prompt) - for i, prompt in enumerate(prompts) + wandb.Image(os.path.join(tmpdir, f"{i}.jpg"), caption=f"{prompt:.25} | {reward:.2f}") + for i, (prompt, reward) in enumerate(zip(prompts, rewards)) ], }, step=global_step,