diff --git a/ddpo_pytorch/aesthetic_scorer.py b/ddpo_pytorch/aesthetic_scorer.py new file mode 100644 index 0000000..9f5263b --- /dev/null +++ b/ddpo_pytorch/aesthetic_scorer.py @@ -0,0 +1,48 @@ +# Based on https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/fe88a163f4661b4ddabba0751ff645e2e620746e/simple_inference.py + +from importlib import resources +import torch +import torch.nn as nn +import numpy as np +from transformers import CLIPModel, CLIPProcessor + +ASSETS_PATH = resources.files("ddpo_pytorch.assets") + + +class MLP(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(768, 1024), + nn.Identity(), + nn.Linear(1024, 128), + nn.Identity(), + nn.Linear(128, 64), + nn.Identity(), + nn.Linear(64, 16), + nn.Linear(16, 1), + ) + + state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth")) + self.load_state_dict(state_dict) + + @torch.no_grad() + def forward(self, embed): + return self.layers(embed) + + +class AestheticScorer(torch.nn.Module): + def __init__(self): + super().__init__() + self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") + self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") + self.mlp = MLP() + + @torch.no_grad() + def __call__(self, images): + inputs = self.processor(images=images, return_tensors="pt") + inputs = {k: v.cuda() for k, v in inputs.items()} + embed = self.clip.get_image_features(**inputs) + # normalize embedding + embed = embed / embed.norm(dim=-1, keepdim=True) + return self.mlp(embed) diff --git a/ddpo_pytorch/assets/sac+logos+ava1-l14-linearMSE.pth b/ddpo_pytorch/assets/sac+logos+ava1-l14-linearMSE.pth new file mode 100644 index 0000000..7c0d8aa Binary files /dev/null and b/ddpo_pytorch/assets/sac+logos+ava1-l14-linearMSE.pth differ diff --git a/ddpo_pytorch/assets/common_animals.txt b/ddpo_pytorch/assets/simple_animals.txt similarity index 100% rename from ddpo_pytorch/assets/common_animals.txt rename to ddpo_pytorch/assets/simple_animals.txt diff --git a/ddpo_pytorch/prompts.py b/ddpo_pytorch/prompts.py index fdf6e34..4e1294d 100644 --- a/ddpo_pytorch/prompts.py +++ b/ddpo_pytorch/prompts.py @@ -40,6 +40,10 @@ def imagenet_dogs(): return from_file("imagenet_classes.txt", 151, 269) +def simple_animals(): + return from_file("simple_animals.txt") + + def nouns_activities(nouns_file, activities_file): nouns = _load_lines(nouns_file) activities = _load_lines(activities_file) diff --git a/ddpo_pytorch/rewards.py b/ddpo_pytorch/rewards.py index 9ec6218..716a8aa 100644 --- a/ddpo_pytorch/rewards.py +++ b/ddpo_pytorch/rewards.py @@ -8,7 +8,7 @@ def jpeg_incompressibility(): def _fn(images, prompts, metadata): if isinstance(images, torch.Tensor): images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() - images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC + images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC images = [Image.fromarray(image) for image in images] buffers = [io.BytesIO() for _ in images] for image, buffer in zip(images, buffers): @@ -27,3 +27,17 @@ def jpeg_compressibility(): return -rew, meta return _fn + + +def aesthetic_score(): + from ddpo_pytorch.aesthetic_scorer import AestheticScorer + + scorer = AestheticScorer().cuda() + + def _fn(images, prompts, metadata): + if not isinstance(images, torch.Tensor): + images = torch.as_tensor(images) + scores = scorer(images) + return scores, {} + + return _fn diff --git a/scripts/train.py b/scripts/train.py index 78df021..ba474b9 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -42,7 +42,7 @@ def main(_): ) if accelerator.is_main_process: accelerator.init_trackers(project_name="ddpo-pytorch", config=config.to_dict()) - logger.info(config) + logger.info(f"\n{config}") # set seed set_seed(config.seed, device_specific=True)