diff --git a/ddpo_pytorch/aesthetic_scorer.py b/ddpo_pytorch/aesthetic_scorer.py index bb9af25..461fc2c 100644 --- a/ddpo_pytorch/aesthetic_scorer.py +++ b/ddpo_pytorch/aesthetic_scorer.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn import numpy as np from transformers import CLIPModel, CLIPProcessor +from PIL import Image ASSETS_PATH = resources.files("ddpo_pytorch.assets") @@ -40,6 +41,8 @@ class AestheticScorer(torch.nn.Module): @torch.no_grad() def __call__(self, images): + assert isinstance(images, list) + assert isinstance(images[0], Image.Image) inputs = self.processor(images=images, return_tensors="pt") inputs = {k: v.cuda() for k, v in inputs.items()} embed = self.clip.get_image_features(**inputs) diff --git a/ddpo_pytorch/rewards.py b/ddpo_pytorch/rewards.py index 1328ac6..0c58533 100644 --- a/ddpo_pytorch/rewards.py +++ b/ddpo_pytorch/rewards.py @@ -35,6 +35,10 @@ def aesthetic_score(): scorer = AestheticScorer().cuda() def _fn(images, prompts, metadata): + if isinstance(images, torch.Tensor): + images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC + images = [Image.fromarray(image) for image in images] scores = scorer(images) return scores, {}