Fix aesthetic scorer
This commit is contained in:
parent
28d2d8c40e
commit
fe9ed8a25f
@ -5,6 +5,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from transformers import CLIPModel, CLIPProcessor
|
from transformers import CLIPModel, CLIPProcessor
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
ASSETS_PATH = resources.files("ddpo_pytorch.assets")
|
ASSETS_PATH = resources.files("ddpo_pytorch.assets")
|
||||||
|
|
||||||
@ -40,6 +41,8 @@ class AestheticScorer(torch.nn.Module):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(self, images):
|
def __call__(self, images):
|
||||||
|
assert isinstance(images, list)
|
||||||
|
assert isinstance(images[0], Image.Image)
|
||||||
inputs = self.processor(images=images, return_tensors="pt")
|
inputs = self.processor(images=images, return_tensors="pt")
|
||||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||||
embed = self.clip.get_image_features(**inputs)
|
embed = self.clip.get_image_features(**inputs)
|
||||||
|
@ -35,6 +35,10 @@ def aesthetic_score():
|
|||||||
scorer = AestheticScorer().cuda()
|
scorer = AestheticScorer().cuda()
|
||||||
|
|
||||||
def _fn(images, prompts, metadata):
|
def _fn(images, prompts, metadata):
|
||||||
|
if isinstance(images, torch.Tensor):
|
||||||
|
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
||||||
|
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
|
||||||
|
images = [Image.fromarray(image) for image in images]
|
||||||
scores = scorer(images)
|
scores = scorer(images)
|
||||||
return scores, {}
|
return scores, {}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user