Fix aesthetic score (again), add llava reward

This commit is contained in:
Kevin Black
2023-07-04 00:23:33 -07:00
parent c0bc708549
commit ec499edf84
3 changed files with 164 additions and 14 deletions

View File

@@ -2,6 +2,8 @@ from collections import defaultdict
import contextlib
import os
import datetime
from concurrent import futures
import time
from absl import app, flags
from ml_collections import config_flags
from accelerate import Accelerator
@@ -227,6 +229,10 @@ def main(_):
# Prepare everything with our `accelerator`.
trainable_layers, optimizer = accelerator.prepare(trainable_layers, optimizer)
# executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a
# remote server running llava inference.
executor = futures.ThreadPoolExecutor(max_workers=2)
# Train!
samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch
total_train_batch_size = (
@@ -298,8 +304,10 @@ def main(_):
log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
timesteps = pipeline.scheduler.timesteps.repeat(config.sample.batch_size, 1) # (batch_size, num_steps)
# compute rewards
rewards, reward_metadata = reward_fn(images, prompts, prompt_metadata)
# compute rewards asynchronously
rewards = executor.submit(reward_fn, images, prompts, prompt_metadata)
# yield to to make sure reward computation starts
time.sleep(0)
samples.append(
{
@@ -309,10 +317,21 @@ def main(_):
"latents": latents[:, :-1], # each entry is the latent before timestep t
"next_latents": latents[:, 1:], # each entry is the latent after timestep t
"log_probs": log_probs,
"rewards": torch.as_tensor(rewards, device=accelerator.device),
"rewards": rewards,
}
)
# wait for all rewards to be computed
for sample in tqdm(
samples,
desc="Waiting for rewards",
disable=not accelerator.is_local_main_process,
position=0,
):
rewards, reward_metadata = sample["rewards"].result()
# accelerator.print(reward_metadata)
sample["rewards"] = torch.as_tensor(rewards, device=accelerator.device)
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
@@ -472,7 +491,7 @@ def main(_):
# make sure we did an optimization step at the end of the inner epoch
assert accelerator.sync_gradients
if epoch % config.save_freq == 0 and accelerator.is_main_process:
if epoch != 0 and epoch % config.save_freq == 0 and accelerator.is_main_process:
accelerator.save_state()