From fe9ed8a25f982a95f280538c3178670e0041f1a9 Mon Sep 17 00:00:00 2001 From: Kevin Black Date: Wed, 28 Jun 2023 10:42:30 -0700 Subject: [PATCH] Fix aesthetic scorer --- ddpo_pytorch/aesthetic_scorer.py | 3 +++ ddpo_pytorch/rewards.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/ddpo_pytorch/aesthetic_scorer.py b/ddpo_pytorch/aesthetic_scorer.py index bb9af25..461fc2c 100644 --- a/ddpo_pytorch/aesthetic_scorer.py +++ b/ddpo_pytorch/aesthetic_scorer.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn import numpy as np from transformers import CLIPModel, CLIPProcessor +from PIL import Image ASSETS_PATH = resources.files("ddpo_pytorch.assets") @@ -40,6 +41,8 @@ class AestheticScorer(torch.nn.Module): @torch.no_grad() def __call__(self, images): + assert isinstance(images, list) + assert isinstance(images[0], Image.Image) inputs = self.processor(images=images, return_tensors="pt") inputs = {k: v.cuda() for k, v in inputs.items()} embed = self.clip.get_image_features(**inputs) diff --git a/ddpo_pytorch/rewards.py b/ddpo_pytorch/rewards.py index 1328ac6..0c58533 100644 --- a/ddpo_pytorch/rewards.py +++ b/ddpo_pytorch/rewards.py @@ -35,6 +35,10 @@ def aesthetic_score(): scorer = AestheticScorer().cuda() def _fn(images, prompts, metadata): + if isinstance(images, torch.Tensor): + images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC + images = [Image.fromarray(image) for image in images] scores = scorer(images) return scores, {}