Fix aesthetic scorer

This commit is contained in:
Kevin Black
2023-06-28 10:42:30 -07:00
parent 28d2d8c40e
commit fe9ed8a25f
2 changed files with 7 additions and 0 deletions

View File

@@ -5,6 +5,7 @@ import torch
import torch.nn as nn
import numpy as np
from transformers import CLIPModel, CLIPProcessor
from PIL import Image
ASSETS_PATH = resources.files("ddpo_pytorch.assets")
@@ -40,6 +41,8 @@ class AestheticScorer(torch.nn.Module):
@torch.no_grad()
def __call__(self, images):
assert isinstance(images, list)
assert isinstance(images[0], Image.Image)
inputs = self.processor(images=images, return_tensors="pt")
inputs = {k: v.cuda() for k, v in inputs.items()}
embed = self.clip.get_image_features(**inputs)