Fix aesthetic score (again), add llava reward
This commit is contained in:
		| @@ -30,22 +30,22 @@ class MLP(nn.Module): | ||||
|  | ||||
|  | ||||
| class AestheticScorer(torch.nn.Module): | ||||
|     def __init__(self): | ||||
|     def __init__(self, dtype): | ||||
|         super().__init__() | ||||
|         self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") | ||||
|         self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") | ||||
|         self.mlp = MLP() | ||||
|         state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth")) | ||||
|         self.mlp.load_state_dict(state_dict) | ||||
|         self.dtype = dtype | ||||
|         self.eval() | ||||
|  | ||||
|     @torch.no_grad() | ||||
|     def __call__(self, images): | ||||
|         assert isinstance(images, list) | ||||
|         assert isinstance(images[0], Image.Image) | ||||
|         device = next(self.parameters()).device | ||||
|         inputs = self.processor(images=images, return_tensors="pt") | ||||
|         inputs = {k: v.cuda() for k, v in inputs.items()} | ||||
|         inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()} | ||||
|         embed = self.clip.get_image_features(**inputs) | ||||
|         # normalize embedding | ||||
|         embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) | ||||
|         return self.mlp(embed) | ||||
|         return self.mlp(embed).squeeze(1) | ||||
|   | ||||
| @@ -32,14 +32,145 @@ def jpeg_compressibility(): | ||||
| def aesthetic_score(): | ||||
|     from ddpo_pytorch.aesthetic_scorer import AestheticScorer | ||||
|  | ||||
|     scorer = AestheticScorer().cuda() | ||||
|     scorer = AestheticScorer(dtype=torch.float32).cuda() | ||||
|  | ||||
|     def _fn(images, prompts, metadata): | ||||
|         if isinstance(images, torch.Tensor): | ||||
|             images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() | ||||
|             images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC | ||||
|         images = [Image.fromarray(image) for image in images] | ||||
|         images = (images * 255).round().clamp(0, 255).to(torch.uint8) | ||||
|         scores = scorer(images) | ||||
|         return scores, {} | ||||
|  | ||||
|     return _fn | ||||
|  | ||||
|  | ||||
| def llava_strict_satisfaction(): | ||||
|     """Submits images to LLaVA and computes a reward by matching the responses to ground truth answers directly without | ||||
|     using BERTScore. Prompt metadata must have "questions" and "answers" keys. See | ||||
|     https://github.com/kvablack/LLaVA-server for server-side code. | ||||
|     """ | ||||
|     import requests | ||||
|     from requests.adapters import HTTPAdapter, Retry | ||||
|     from io import BytesIO | ||||
|     import pickle | ||||
|  | ||||
|     batch_size = 4 | ||||
|     url = "http://127.0.0.1:8085" | ||||
|     sess = requests.Session() | ||||
|     retries = Retry(total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False) | ||||
|     sess.mount("http://", HTTPAdapter(max_retries=retries)) | ||||
|  | ||||
|     def _fn(images, prompts, metadata): | ||||
|         del prompts | ||||
|         if isinstance(images, torch.Tensor): | ||||
|             images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() | ||||
|             images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC | ||||
|  | ||||
|         images_batched = np.array_split(images, np.ceil(len(images) / batch_size)) | ||||
|         metadata_batched = np.array_split(metadata, np.ceil(len(metadata) / batch_size)) | ||||
|  | ||||
|         all_scores = [] | ||||
|         all_info = { | ||||
|             "answers": [], | ||||
|         } | ||||
|         for image_batch, metadata_batch in zip(images_batched, metadata_batched): | ||||
|             jpeg_images = [] | ||||
|  | ||||
|             # Compress the images using JPEG | ||||
|             for image in image_batch: | ||||
|                 img = Image.fromarray(image) | ||||
|                 buffer = BytesIO() | ||||
|                 img.save(buffer, format="JPEG", quality=80) | ||||
|                 jpeg_images.append(buffer.getvalue()) | ||||
|  | ||||
|             # format for LLaVA server | ||||
|             data = { | ||||
|                 "images": jpeg_images, | ||||
|                 "queries": [m["questions"] for m in metadata_batch], | ||||
|             } | ||||
|             data_bytes = pickle.dumps(data) | ||||
|  | ||||
|             # send a request to the llava server | ||||
|             response = sess.post(url, data=data_bytes, timeout=120) | ||||
|  | ||||
|             response_data = pickle.loads(response.content) | ||||
|  | ||||
|             correct = np.array( | ||||
|                 [ | ||||
|                     [ans in resp for ans, resp in zip(m["answers"], responses)] | ||||
|                     for m, responses in zip(metadata_batch, response_data["outputs"]) | ||||
|                 ] | ||||
|             ) | ||||
|             scores = correct.mean(axis=-1) | ||||
|  | ||||
|             all_scores += scores.tolist() | ||||
|             all_info["answers"] += response_data["outputs"] | ||||
|  | ||||
|         return np.array(all_scores), {k: np.array(v) for k, v in all_info.items()} | ||||
|  | ||||
|     return _fn | ||||
|  | ||||
|  | ||||
| def llava_bertscore(): | ||||
|     """Submits images to LLaVA and computes a reward by comparing the responses to the prompts using BERTScore. See | ||||
|     https://github.com/kvablack/LLaVA-server for server-side code. | ||||
|     """ | ||||
|     import requests | ||||
|     from requests.adapters import HTTPAdapter, Retry | ||||
|     from io import BytesIO | ||||
|     import pickle | ||||
|  | ||||
|     batch_size = 16 | ||||
|     url = "http://127.0.0.1:8085" | ||||
|     sess = requests.Session() | ||||
|     retries = Retry(total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False) | ||||
|     sess.mount("http://", HTTPAdapter(max_retries=retries)) | ||||
|  | ||||
|     def _fn(images, prompts, metadata): | ||||
|         del metadata | ||||
|         if isinstance(images, torch.Tensor): | ||||
|             images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() | ||||
|             images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC | ||||
|  | ||||
|         images_batched = np.array_split(images, np.ceil(len(images) / batch_size)) | ||||
|         prompts_batched = np.array_split(prompts, np.ceil(len(prompts) / batch_size)) | ||||
|  | ||||
|         all_scores = [] | ||||
|         all_info = { | ||||
|             "precision": [], | ||||
|             "f1": [], | ||||
|             "outputs": [], | ||||
|         } | ||||
|         for image_batch, prompt_batch in zip(images_batched, prompts_batched): | ||||
|             jpeg_images = [] | ||||
|  | ||||
|             # Compress the images using JPEG | ||||
|             for image in image_batch: | ||||
|                 img = Image.fromarray(image) | ||||
|                 buffer = BytesIO() | ||||
|                 img.save(buffer, format="JPEG", quality=80) | ||||
|                 jpeg_images.append(buffer.getvalue()) | ||||
|  | ||||
|             # format for LLaVA server | ||||
|             data = { | ||||
|                 "images": jpeg_images, | ||||
|                 "queries": [["Answer concisely: what is going on in this image?"]] * len(image_batch), | ||||
|                 "answers": [[f"The image contains {prompt}"] for prompt in prompt_batch], | ||||
|             } | ||||
|             data_bytes = pickle.dumps(data) | ||||
|  | ||||
|             # send a request to the llava server | ||||
|             response = sess.post(url, data=data_bytes, timeout=120) | ||||
|  | ||||
|             response_data = pickle.loads(response.content) | ||||
|  | ||||
|             # use the recall score as the reward | ||||
|             scores = np.array(response_data["recall"]).squeeze() | ||||
|             all_scores += scores.tolist() | ||||
|  | ||||
|             # save the precision and f1 scores for analysis | ||||
|             all_info["precision"] += np.array(response_data["precision"]).squeeze().tolist() | ||||
|             all_info["f1"] += np.array(response_data["f1"]).squeeze().tolist() | ||||
|             all_info["outputs"] += np.array(response_data["outputs"]).squeeze().tolist() | ||||
|  | ||||
|         return np.array(all_scores), {k: np.array(v) for k, v in all_info.items()} | ||||
|  | ||||
|     return _fn | ||||
|   | ||||
| @@ -2,6 +2,8 @@ from collections import defaultdict | ||||
| import contextlib | ||||
| import os | ||||
| import datetime | ||||
| from concurrent import futures | ||||
| import time | ||||
| from absl import app, flags | ||||
| from ml_collections import config_flags | ||||
| from accelerate import Accelerator | ||||
| @@ -227,6 +229,10 @@ def main(_): | ||||
|     # Prepare everything with our `accelerator`. | ||||
|     trainable_layers, optimizer = accelerator.prepare(trainable_layers, optimizer) | ||||
|  | ||||
|     # executor to perform callbacks asynchronously. this is beneficial for the llava callbacks which makes a request to a | ||||
|     # remote server running llava inference. | ||||
|     executor = futures.ThreadPoolExecutor(max_workers=2) | ||||
|  | ||||
|     # Train! | ||||
|     samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch | ||||
|     total_train_batch_size = ( | ||||
| @@ -298,8 +304,10 @@ def main(_): | ||||
|             log_probs = torch.stack(log_probs, dim=1)  # (batch_size, num_steps, 1) | ||||
|             timesteps = pipeline.scheduler.timesteps.repeat(config.sample.batch_size, 1)  # (batch_size, num_steps) | ||||
|  | ||||
|             # compute rewards | ||||
|             rewards, reward_metadata = reward_fn(images, prompts, prompt_metadata) | ||||
|             # compute rewards asynchronously | ||||
|             rewards = executor.submit(reward_fn, images, prompts, prompt_metadata) | ||||
|             # yield to to make sure reward computation starts | ||||
|             time.sleep(0) | ||||
|  | ||||
|             samples.append( | ||||
|                 { | ||||
| @@ -309,10 +317,21 @@ def main(_): | ||||
|                     "latents": latents[:, :-1],  # each entry is the latent before timestep t | ||||
|                     "next_latents": latents[:, 1:],  # each entry is the latent after timestep t | ||||
|                     "log_probs": log_probs, | ||||
|                     "rewards": torch.as_tensor(rewards, device=accelerator.device), | ||||
|                     "rewards": rewards, | ||||
|                 } | ||||
|             ) | ||||
|  | ||||
|         # wait for all rewards to be computed | ||||
|         for sample in tqdm( | ||||
|             samples, | ||||
|             desc="Waiting for rewards", | ||||
|             disable=not accelerator.is_local_main_process, | ||||
|             position=0, | ||||
|         ): | ||||
|             rewards, reward_metadata = sample["rewards"].result() | ||||
|             # accelerator.print(reward_metadata) | ||||
|             sample["rewards"] = torch.as_tensor(rewards, device=accelerator.device) | ||||
|  | ||||
|         # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) | ||||
|         samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} | ||||
|  | ||||
| @@ -472,7 +491,7 @@ def main(_): | ||||
|             # make sure we did an optimization step at the end of the inner epoch | ||||
|             assert accelerator.sync_gradients | ||||
|  | ||||
|         if epoch % config.save_freq == 0 and accelerator.is_main_process: | ||||
|         if epoch != 0 and epoch % config.save_freq == 0 and accelerator.is_main_process: | ||||
|             accelerator.save_state() | ||||
|  | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user