Fix aesthetic scorer
This commit is contained in:
@@ -5,6 +5,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
from PIL import Image
|
||||
|
||||
ASSETS_PATH = resources.files("ddpo_pytorch.assets")
|
||||
|
||||
@@ -40,6 +41,8 @@ class AestheticScorer(torch.nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, images):
|
||||
assert isinstance(images, list)
|
||||
assert isinstance(images[0], Image.Image)
|
||||
inputs = self.processor(images=images, return_tensors="pt")
|
||||
inputs = {k: v.cuda() for k, v in inputs.items()}
|
||||
embed = self.clip.get_image_features(**inputs)
|
||||
|
Reference in New Issue
Block a user