Fix aesthetic score (again), add llava reward

This commit is contained in:
Kevin Black 2023-07-04 00:23:33 -07:00
parent c0bc708549
commit ec499edf84
3 changed files with 164 additions and 14 deletions

View File

@ -30,22 +30,22 @@ class MLP(nn.Module):
class AestheticScorer(torch.nn.Module): class AestheticScorer(torch.nn.Module):
def __init__(self): def __init__(self, dtype):
super().__init__() super().__init__()
self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
self.mlp = MLP() self.mlp = MLP()
state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth")) state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth"))
self.mlp.load_state_dict(state_dict) self.mlp.load_state_dict(state_dict)
self.dtype = dtype
self.eval() self.eval()
@torch.no_grad() @torch.no_grad()
def __call__(self, images): def __call__(self, images):
assert isinstance(images, list) device = next(self.parameters()).device
assert isinstance(images[0], Image.Image)
inputs = self.processor(images=images, return_tensors="pt") inputs = self.processor(images=images, return_tensors="pt")
inputs = {k: v.cuda() for k, v in inputs.items()} inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()}
embed = self.clip.get_image_features(**inputs) embed = self.clip.get_image_features(**inputs)
# normalize embedding # normalize embedding
embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
return self.mlp(embed) return self.mlp(embed).squeeze(1)

View File

@ -32,14 +32,145 @@ def jpeg_compressibility():
def aesthetic_score(): def aesthetic_score():
from ddpo_pytorch.aesthetic_scorer import AestheticScorer from ddpo_pytorch.aesthetic_scorer import AestheticScorer
scorer = AestheticScorer().cuda() scorer = AestheticScorer(dtype=torch.float32).cuda()
def _fn(images, prompts, metadata): def _fn(images, prompts, metadata):
if isinstance(images, torch.Tensor): images = (images * 255).round().clamp(0, 255).to(torch.uint8)
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images = [Image.fromarray(image) for image in images]
scores = scorer(images) scores = scorer(images)
return scores, {} return scores, {}
return _fn return _fn
def llava_strict_satisfaction():
"""Submits images to LLaVA and computes a reward by matching the responses to ground truth answers directly without
using BERTScore. Prompt metadata must have "questions" and "answers" keys. See
https://github.com/kvablack/LLaVA-server for server-side code.
"""
import requests
from requests.adapters import HTTPAdapter, Retry
from io import BytesIO
import pickle
batch_size = 4
url = "http://127.0.0.1:8085"
sess = requests.Session()
retries = Retry(total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False)
sess.mount("http://", HTTPAdapter(max_retries=retries))
def _fn(images, prompts, metadata):
del prompts
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
metadata_batched = np.array_split(metadata, np.ceil(len(metadata) / batch_size))
all_scores = []
all_info = {
"answers": [],
}
for image_batch, metadata_batch in zip(images_batched, metadata_batched):
jpeg_images = []
# Compress the images using JPEG
for image in image_batch:
img = Image.fromarray(image)
buffer = BytesIO()
img.save(buffer, format="JPEG", quality=80)
jpeg_images.append(buffer.getvalue())
# format for LLaVA server
data = {
"images": jpeg_images,
"queries": [m["questions"] for m in metadata_batch],
}
data_bytes = pickle.dumps(data)
# send a request to the llava server
response = sess.post(url, data=data_bytes, timeout=120)
response_data = pickle.loads(response.content)
correct = np.array(
[
[ans in resp for ans, resp in zip(m["answers"], responses)]
for m, responses in zip(metadata_batch, response_data["outputs"])
]
)
scores = correct.mean(axis=-1)
all_scores += scores.tolist()
all_info["answers"] += response_data["outputs"]
return np.array(all_scores), {k: np.array(v) for k, v in all_info.items()}
return _fn
def llava_bertscore():
"""Submits images to LLaVA and computes a reward by comparing the responses to the prompts using BERTScore. See
https://github.com/kvablack/LLaVA-server for server-side code.
"""
import requests
from requests.adapters import HTTPAdapter, Retry
from io import BytesIO
import pickle
batch_size = 16
url = "http://127.0.0.1:8085"
sess = requests.Session()
retries = Retry(total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False)
sess.mount("http://", HTTPAdapter(max_retries=retries))
def _fn(images, prompts, metadata):
del metadata
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
prompts_batched = np.array_split(prompts, np.ceil(len(prompts) / batch_size))
all_scores = []
all_info = {
"precision": [],
"f1": [],
"outputs": [],
}
for image_batch, prompt_batch in zip(images_batched, prompts_batched):
jpeg_images = []
# Compress the images using JPEG
for image in image_batch:
img = Image.fromarray(image)
buffer = BytesIO()
img.save(buffer, format="JPEG", quality=80)
jpeg_images.append(buffer.getvalue())
# format for LLaVA server
data = {
"images": jpeg_images,
"queries": [["Answer concisely: what is going on in this image?"]] * len(image_batch),
"answers": [[f"The image contains {prompt}"] for prompt in prompt_batch],
}
data_bytes = pickle.dumps(data)
# send a request to the llava server
response = sess.post(url, data=data_bytes, timeout=120)
response_data = pickle.loads(response.content)
# use the recall score as the reward
scores = np.array(response_data["recall"]).squeeze()
all_scores += scores.tolist()
# save the precision and f1 scores for analysis
all_info["precision"] += np.array(response_data["precision"]).squeeze().tolist()
all_info["f1"] += np.array(response_data["f1"]).squeeze().tolist()
all_info["outputs"] += np.array(response_data["outputs"]).squeeze().tolist()
return np.array(all_scores), {k: np.array(v) for k, v in all_info.items()}
return _fn

View File

@ -2,6 +2,8 @@ from collections import defaultdict
import contextlib import contextlib
import os import os
import datetime import datetime
from concurrent import futures
import time
from absl import app, flags from absl import app, flags
from ml_collections import config_flags from ml_collections import config_flags
from accelerate import Accelerator from accelerate import Accelerator
@ -227,6 +229,10 @@ def main(_):
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
trainable_layers, optimizer = accelerator.prepare(trainable_layers, optimizer) trainable_layers, optimizer = accelerator.prepare(trainable_layers, optimizer)
# executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a
# remote server running llava inference.
executor = futures.ThreadPoolExecutor(max_workers=2)
# Train! # Train!
samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch
total_train_batch_size = ( total_train_batch_size = (
@ -298,8 +304,10 @@ def main(_):
log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1) log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
timesteps = pipeline.scheduler.timesteps.repeat(config.sample.batch_size, 1) # (batch_size, num_steps) timesteps = pipeline.scheduler.timesteps.repeat(config.sample.batch_size, 1) # (batch_size, num_steps)
# compute rewards # compute rewards asynchronously
rewards, reward_metadata = reward_fn(images, prompts, prompt_metadata) rewards = executor.submit(reward_fn, images, prompts, prompt_metadata)
# yield to to make sure reward computation starts
time.sleep(0)
samples.append( samples.append(
{ {
@ -309,10 +317,21 @@ def main(_):
"latents": latents[:, :-1], # each entry is the latent before timestep t "latents": latents[:, :-1], # each entry is the latent before timestep t
"next_latents": latents[:, 1:], # each entry is the latent after timestep t "next_latents": latents[:, 1:], # each entry is the latent after timestep t
"log_probs": log_probs, "log_probs": log_probs,
"rewards": torch.as_tensor(rewards, device=accelerator.device), "rewards": rewards,
} }
) )
# wait for all rewards to be computed
for sample in tqdm(
samples,
desc="Waiting for rewards",
disable=not accelerator.is_local_main_process,
position=0,
):
rewards, reward_metadata = sample["rewards"].result()
# accelerator.print(reward_metadata)
sample["rewards"] = torch.as_tensor(rewards, device=accelerator.device)
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
@ -472,7 +491,7 @@ def main(_):
# make sure we did an optimization step at the end of the inner epoch # make sure we did an optimization step at the end of the inner epoch
assert accelerator.sync_gradients assert accelerator.sync_gradients
if epoch % config.save_freq == 0 and accelerator.is_main_process: if epoch != 0 and epoch % config.save_freq == 0 and accelerator.is_main_process:
accelerator.save_state() accelerator.save_state()