Fix aesthetic scorer

This commit is contained in:
Kevin Black
2023-06-28 10:42:30 -07:00
parent 28d2d8c40e
commit fe9ed8a25f
2 changed files with 7 additions and 0 deletions

View File

@@ -35,6 +35,10 @@ def aesthetic_score():
scorer = AestheticScorer().cuda()
def _fn(images, prompts, metadata):
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
images = [Image.fromarray(image) for image in images]
scores = scorer(images)
return scores, {}