Fix aesthetic score (again), add llava reward
This commit is contained in:
		| @@ -30,22 +30,22 @@ class MLP(nn.Module): | |||||||
|  |  | ||||||
|  |  | ||||||
| class AestheticScorer(torch.nn.Module): | class AestheticScorer(torch.nn.Module): | ||||||
|     def __init__(self): |     def __init__(self, dtype): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") |         self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") | ||||||
|         self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") |         self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") | ||||||
|         self.mlp = MLP() |         self.mlp = MLP() | ||||||
|         state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth")) |         state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth")) | ||||||
|         self.mlp.load_state_dict(state_dict) |         self.mlp.load_state_dict(state_dict) | ||||||
|  |         self.dtype = dtype | ||||||
|         self.eval() |         self.eval() | ||||||
|  |  | ||||||
|     @torch.no_grad() |     @torch.no_grad() | ||||||
|     def __call__(self, images): |     def __call__(self, images): | ||||||
|         assert isinstance(images, list) |         device = next(self.parameters()).device | ||||||
|         assert isinstance(images[0], Image.Image) |  | ||||||
|         inputs = self.processor(images=images, return_tensors="pt") |         inputs = self.processor(images=images, return_tensors="pt") | ||||||
|         inputs = {k: v.cuda() for k, v in inputs.items()} |         inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()} | ||||||
|         embed = self.clip.get_image_features(**inputs) |         embed = self.clip.get_image_features(**inputs) | ||||||
|         # normalize embedding |         # normalize embedding | ||||||
|         embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) |         embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) | ||||||
|         return self.mlp(embed) |         return self.mlp(embed).squeeze(1) | ||||||
|   | |||||||
| @@ -32,14 +32,145 @@ def jpeg_compressibility(): | |||||||
| def aesthetic_score(): | def aesthetic_score(): | ||||||
|     from ddpo_pytorch.aesthetic_scorer import AestheticScorer |     from ddpo_pytorch.aesthetic_scorer import AestheticScorer | ||||||
|  |  | ||||||
|     scorer = AestheticScorer().cuda() |     scorer = AestheticScorer(dtype=torch.float32).cuda() | ||||||
|  |  | ||||||
|     def _fn(images, prompts, metadata): |     def _fn(images, prompts, metadata): | ||||||
|         if isinstance(images, torch.Tensor): |         images = (images * 255).round().clamp(0, 255).to(torch.uint8) | ||||||
|             images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() |  | ||||||
|             images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC |  | ||||||
|         images = [Image.fromarray(image) for image in images] |  | ||||||
|         scores = scorer(images) |         scores = scorer(images) | ||||||
|         return scores, {} |         return scores, {} | ||||||
|  |  | ||||||
|     return _fn |     return _fn | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def llava_strict_satisfaction(): | ||||||
|  |     """Submits images to LLaVA and computes a reward by matching the responses to ground truth answers directly without | ||||||
|  |     using BERTScore. Prompt metadata must have "questions" and "answers" keys. See | ||||||
|  |     https://github.com/kvablack/LLaVA-server for server-side code. | ||||||
|  |     """ | ||||||
|  |     import requests | ||||||
|  |     from requests.adapters import HTTPAdapter, Retry | ||||||
|  |     from io import BytesIO | ||||||
|  |     import pickle | ||||||
|  |  | ||||||
|  |     batch_size = 4 | ||||||
|  |     url = "http://127.0.0.1:8085" | ||||||
|  |     sess = requests.Session() | ||||||
|  |     retries = Retry(total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False) | ||||||
|  |     sess.mount("http://", HTTPAdapter(max_retries=retries)) | ||||||
|  |  | ||||||
|  |     def _fn(images, prompts, metadata): | ||||||
|  |         del prompts | ||||||
|  |         if isinstance(images, torch.Tensor): | ||||||
|  |             images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() | ||||||
|  |             images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC | ||||||
|  |  | ||||||
|  |         images_batched = np.array_split(images, np.ceil(len(images) / batch_size)) | ||||||
|  |         metadata_batched = np.array_split(metadata, np.ceil(len(metadata) / batch_size)) | ||||||
|  |  | ||||||
|  |         all_scores = [] | ||||||
|  |         all_info = { | ||||||
|  |             "answers": [], | ||||||
|  |         } | ||||||
|  |         for image_batch, metadata_batch in zip(images_batched, metadata_batched): | ||||||
|  |             jpeg_images = [] | ||||||
|  |  | ||||||
|  |             # Compress the images using JPEG | ||||||
|  |             for image in image_batch: | ||||||
|  |                 img = Image.fromarray(image) | ||||||
|  |                 buffer = BytesIO() | ||||||
|  |                 img.save(buffer, format="JPEG", quality=80) | ||||||
|  |                 jpeg_images.append(buffer.getvalue()) | ||||||
|  |  | ||||||
|  |             # format for LLaVA server | ||||||
|  |             data = { | ||||||
|  |                 "images": jpeg_images, | ||||||
|  |                 "queries": [m["questions"] for m in metadata_batch], | ||||||
|  |             } | ||||||
|  |             data_bytes = pickle.dumps(data) | ||||||
|  |  | ||||||
|  |             # send a request to the llava server | ||||||
|  |             response = sess.post(url, data=data_bytes, timeout=120) | ||||||
|  |  | ||||||
|  |             response_data = pickle.loads(response.content) | ||||||
|  |  | ||||||
|  |             correct = np.array( | ||||||
|  |                 [ | ||||||
|  |                     [ans in resp for ans, resp in zip(m["answers"], responses)] | ||||||
|  |                     for m, responses in zip(metadata_batch, response_data["outputs"]) | ||||||
|  |                 ] | ||||||
|  |             ) | ||||||
|  |             scores = correct.mean(axis=-1) | ||||||
|  |  | ||||||
|  |             all_scores += scores.tolist() | ||||||
|  |             all_info["answers"] += response_data["outputs"] | ||||||
|  |  | ||||||
|  |         return np.array(all_scores), {k: np.array(v) for k, v in all_info.items()} | ||||||
|  |  | ||||||
|  |     return _fn | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def llava_bertscore(): | ||||||
|  |     """Submits images to LLaVA and computes a reward by comparing the responses to the prompts using BERTScore. See | ||||||
|  |     https://github.com/kvablack/LLaVA-server for server-side code. | ||||||
|  |     """ | ||||||
|  |     import requests | ||||||
|  |     from requests.adapters import HTTPAdapter, Retry | ||||||
|  |     from io import BytesIO | ||||||
|  |     import pickle | ||||||
|  |  | ||||||
|  |     batch_size = 16 | ||||||
|  |     url = "http://127.0.0.1:8085" | ||||||
|  |     sess = requests.Session() | ||||||
|  |     retries = Retry(total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False) | ||||||
|  |     sess.mount("http://", HTTPAdapter(max_retries=retries)) | ||||||
|  |  | ||||||
|  |     def _fn(images, prompts, metadata): | ||||||
|  |         del metadata | ||||||
|  |         if isinstance(images, torch.Tensor): | ||||||
|  |             images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() | ||||||
|  |             images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC | ||||||
|  |  | ||||||
|  |         images_batched = np.array_split(images, np.ceil(len(images) / batch_size)) | ||||||
|  |         prompts_batched = np.array_split(prompts, np.ceil(len(prompts) / batch_size)) | ||||||
|  |  | ||||||
|  |         all_scores = [] | ||||||
|  |         all_info = { | ||||||
|  |             "precision": [], | ||||||
|  |             "f1": [], | ||||||
|  |             "outputs": [], | ||||||
|  |         } | ||||||
|  |         for image_batch, prompt_batch in zip(images_batched, prompts_batched): | ||||||
|  |             jpeg_images = [] | ||||||
|  |  | ||||||
|  |             # Compress the images using JPEG | ||||||
|  |             for image in image_batch: | ||||||
|  |                 img = Image.fromarray(image) | ||||||
|  |                 buffer = BytesIO() | ||||||
|  |                 img.save(buffer, format="JPEG", quality=80) | ||||||
|  |                 jpeg_images.append(buffer.getvalue()) | ||||||
|  |  | ||||||
|  |             # format for LLaVA server | ||||||
|  |             data = { | ||||||
|  |                 "images": jpeg_images, | ||||||
|  |                 "queries": [["Answer concisely: what is going on in this image?"]] * len(image_batch), | ||||||
|  |                 "answers": [[f"The image contains {prompt}"] for prompt in prompt_batch], | ||||||
|  |             } | ||||||
|  |             data_bytes = pickle.dumps(data) | ||||||
|  |  | ||||||
|  |             # send a request to the llava server | ||||||
|  |             response = sess.post(url, data=data_bytes, timeout=120) | ||||||
|  |  | ||||||
|  |             response_data = pickle.loads(response.content) | ||||||
|  |  | ||||||
|  |             # use the recall score as the reward | ||||||
|  |             scores = np.array(response_data["recall"]).squeeze() | ||||||
|  |             all_scores += scores.tolist() | ||||||
|  |  | ||||||
|  |             # save the precision and f1 scores for analysis | ||||||
|  |             all_info["precision"] += np.array(response_data["precision"]).squeeze().tolist() | ||||||
|  |             all_info["f1"] += np.array(response_data["f1"]).squeeze().tolist() | ||||||
|  |             all_info["outputs"] += np.array(response_data["outputs"]).squeeze().tolist() | ||||||
|  |  | ||||||
|  |         return np.array(all_scores), {k: np.array(v) for k, v in all_info.items()} | ||||||
|  |  | ||||||
|  |     return _fn | ||||||
|   | |||||||
| @@ -2,6 +2,8 @@ from collections import defaultdict | |||||||
| import contextlib | import contextlib | ||||||
| import os | import os | ||||||
| import datetime | import datetime | ||||||
|  | from concurrent import futures | ||||||
|  | import time | ||||||
| from absl import app, flags | from absl import app, flags | ||||||
| from ml_collections import config_flags | from ml_collections import config_flags | ||||||
| from accelerate import Accelerator | from accelerate import Accelerator | ||||||
| @@ -227,6 +229,10 @@ def main(_): | |||||||
|     # Prepare everything with our `accelerator`. |     # Prepare everything with our `accelerator`. | ||||||
|     trainable_layers, optimizer = accelerator.prepare(trainable_layers, optimizer) |     trainable_layers, optimizer = accelerator.prepare(trainable_layers, optimizer) | ||||||
|  |  | ||||||
|  |     # executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a | ||||||
|  |     # remote server running llava inference. | ||||||
|  |     executor = futures.ThreadPoolExecutor(max_workers=2) | ||||||
|  |  | ||||||
|     # Train! |     # Train! | ||||||
|     samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch |     samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch | ||||||
|     total_train_batch_size = ( |     total_train_batch_size = ( | ||||||
| @@ -298,8 +304,10 @@ def main(_): | |||||||
|             log_probs = torch.stack(log_probs, dim=1)  # (batch_size, num_steps, 1) |             log_probs = torch.stack(log_probs, dim=1)  # (batch_size, num_steps, 1) | ||||||
|             timesteps = pipeline.scheduler.timesteps.repeat(config.sample.batch_size, 1)  # (batch_size, num_steps) |             timesteps = pipeline.scheduler.timesteps.repeat(config.sample.batch_size, 1)  # (batch_size, num_steps) | ||||||
|  |  | ||||||
|             # compute rewards |             # compute rewards asynchronously | ||||||
|             rewards, reward_metadata = reward_fn(images, prompts, prompt_metadata) |             rewards = executor.submit(reward_fn, images, prompts, prompt_metadata) | ||||||
|  |             # yield to to make sure reward computation starts | ||||||
|  |             time.sleep(0) | ||||||
|  |  | ||||||
|             samples.append( |             samples.append( | ||||||
|                 { |                 { | ||||||
| @@ -309,10 +317,21 @@ def main(_): | |||||||
|                     "latents": latents[:, :-1],  # each entry is the latent before timestep t |                     "latents": latents[:, :-1],  # each entry is the latent before timestep t | ||||||
|                     "next_latents": latents[:, 1:],  # each entry is the latent after timestep t |                     "next_latents": latents[:, 1:],  # each entry is the latent after timestep t | ||||||
|                     "log_probs": log_probs, |                     "log_probs": log_probs, | ||||||
|                     "rewards": torch.as_tensor(rewards, device=accelerator.device), |                     "rewards": rewards, | ||||||
|                 } |                 } | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|  |         # wait for all rewards to be computed | ||||||
|  |         for sample in tqdm( | ||||||
|  |             samples, | ||||||
|  |             desc="Waiting for rewards", | ||||||
|  |             disable=not accelerator.is_local_main_process, | ||||||
|  |             position=0, | ||||||
|  |         ): | ||||||
|  |             rewards, reward_metadata = sample["rewards"].result() | ||||||
|  |             # accelerator.print(reward_metadata) | ||||||
|  |             sample["rewards"] = torch.as_tensor(rewards, device=accelerator.device) | ||||||
|  |  | ||||||
|         # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) |         # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) | ||||||
|         samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} |         samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} | ||||||
|  |  | ||||||
| @@ -472,7 +491,7 @@ def main(_): | |||||||
|             # make sure we did an optimization step at the end of the inner epoch |             # make sure we did an optimization step at the end of the inner epoch | ||||||
|             assert accelerator.sync_gradients |             assert accelerator.sync_gradients | ||||||
|  |  | ||||||
|         if epoch % config.save_freq == 0 and accelerator.is_main_process: |         if epoch != 0 and epoch % config.save_freq == 0 and accelerator.is_main_process: | ||||||
|             accelerator.save_state() |             accelerator.save_state() | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user