Add aesthetic scorer reward function

This commit is contained in:
Kevin Black
2023-06-27 10:40:36 -07:00
parent 8cab96dea4
commit bae3f43f5f
6 changed files with 68 additions and 2 deletions

View File

@@ -8,7 +8,7 @@ def jpeg_incompressibility():
def _fn(images, prompts, metadata):
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images = [Image.fromarray(image) for image in images]
buffers = [io.BytesIO() for _ in images]
for image, buffer in zip(images, buffers):
@@ -27,3 +27,17 @@ def jpeg_compressibility():
return -rew, meta
return _fn
def aesthetic_score():
from ddpo_pytorch.aesthetic_scorer import AestheticScorer
scorer = AestheticScorer().cuda()
def _fn(images, prompts, metadata):
if not isinstance(images, torch.Tensor):
images = torch.as_tensor(images)
scores = scorer(images)
return scores, {}
return _fn