Fix aesthetic score (again), add llava reward
This commit is contained in:
@@ -2,6 +2,8 @@ from collections import defaultdict
|
||||
import contextlib
|
||||
import os
|
||||
import datetime
|
||||
from concurrent import futures
|
||||
import time
|
||||
from absl import app, flags
|
||||
from ml_collections import config_flags
|
||||
from accelerate import Accelerator
|
||||
@@ -227,6 +229,10 @@ def main(_):
|
||||
# Prepare everything with our `accelerator`.
|
||||
trainable_layers, optimizer = accelerator.prepare(trainable_layers, optimizer)
|
||||
|
||||
# executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a
|
||||
# remote server running llava inference.
|
||||
executor = futures.ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
# Train!
|
||||
samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch
|
||||
total_train_batch_size = (
|
||||
@@ -298,8 +304,10 @@ def main(_):
|
||||
log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
|
||||
timesteps = pipeline.scheduler.timesteps.repeat(config.sample.batch_size, 1) # (batch_size, num_steps)
|
||||
|
||||
# compute rewards
|
||||
rewards, reward_metadata = reward_fn(images, prompts, prompt_metadata)
|
||||
# compute rewards asynchronously
|
||||
rewards = executor.submit(reward_fn, images, prompts, prompt_metadata)
|
||||
# yield to to make sure reward computation starts
|
||||
time.sleep(0)
|
||||
|
||||
samples.append(
|
||||
{
|
||||
@@ -309,10 +317,21 @@ def main(_):
|
||||
"latents": latents[:, :-1], # each entry is the latent before timestep t
|
||||
"next_latents": latents[:, 1:], # each entry is the latent after timestep t
|
||||
"log_probs": log_probs,
|
||||
"rewards": torch.as_tensor(rewards, device=accelerator.device),
|
||||
"rewards": rewards,
|
||||
}
|
||||
)
|
||||
|
||||
# wait for all rewards to be computed
|
||||
for sample in tqdm(
|
||||
samples,
|
||||
desc="Waiting for rewards",
|
||||
disable=not accelerator.is_local_main_process,
|
||||
position=0,
|
||||
):
|
||||
rewards, reward_metadata = sample["rewards"].result()
|
||||
# accelerator.print(reward_metadata)
|
||||
sample["rewards"] = torch.as_tensor(rewards, device=accelerator.device)
|
||||
|
||||
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
|
||||
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
|
||||
|
||||
@@ -472,7 +491,7 @@ def main(_):
|
||||
# make sure we did an optimization step at the end of the inner epoch
|
||||
assert accelerator.sync_gradients
|
||||
|
||||
if epoch % config.save_freq == 0 and accelerator.is_main_process:
|
||||
if epoch != 0 and epoch % config.save_freq == 0 and accelerator.is_main_process:
|
||||
accelerator.save_state()
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user