Minor changes; add train_timestep_fraction

This commit is contained in:
Kevin Black 2023-06-27 22:17:32 -07:00
parent bae3f43f5f
commit 28d2d8c40e
5 changed files with 50 additions and 26 deletions

View File

@ -32,14 +32,15 @@ def get_config():
train.cfg = True train.cfg = True
train.adv_clip_max = 10 train.adv_clip_max = 10
train.clip_range = 1e-4 train.clip_range = 1e-4
train.timestep_fraction = 1.0
# sampling # sampling
config.sample = sample = ml_collections.ConfigDict() config.sample = sample = ml_collections.ConfigDict()
sample.num_steps = 5 sample.num_steps = 10
sample.eta = 1.0 sample.eta = 1.0
sample.guidance_scale = 5.0 sample.guidance_scale = 5.0
sample.batch_size = 1 sample.batch_size = 1
sample.num_batches_per_epoch = 1 sample.num_batches_per_epoch = 2
# prompting # prompting
config.prompt_fn = "imagenet_animals" config.prompt_fn = "imagenet_animals"
@ -49,7 +50,7 @@ def get_config():
config.reward_fn = "jpeg_compressibility" config.reward_fn = "jpeg_compressibility"
config.per_prompt_stat_tracking = ml_collections.ConfigDict() config.per_prompt_stat_tracking = ml_collections.ConfigDict()
config.per_prompt_stat_tracking.buffer_size = 64 config.per_prompt_stat_tracking.buffer_size = 16
config.per_prompt_stat_tracking.min_count = 16 config.per_prompt_stat_tracking.min_count = 16
return config return config

View File

@ -8,20 +8,25 @@ base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py"
def get_config(): def get_config():
config = base.get_config() config = base.get_config()
config.mixed_precision = "no" config.pretrained.model = "runwayml/stable-diffusion-v1-5"
config.mixed_precision = "fp16"
config.allow_tf32 = True config.allow_tf32 = True
config.use_lora = False config.use_lora = False
config.train.batch_size = 4 config.train.batch_size = 4
config.train.gradient_accumulation_steps = 8 config.train.gradient_accumulation_steps = 2
config.train.learning_rate = 1e-5 config.train.learning_rate = 3e-5
config.train.clip_range = 1.0 config.train.clip_range = 1e-4
# sampling # sampling
config.sample.num_steps = 50 config.sample.num_steps = 50
config.sample.batch_size = 16 config.sample.batch_size = 8
config.sample.num_batches_per_epoch = 2 config.sample.num_batches_per_epoch = 4
config.per_prompt_stat_tracking = None config.per_prompt_stat_tracking = {
"buffer_size": 16,
"min_count": 16,
}
return config return config

View File

@ -14,18 +14,15 @@ class MLP(nn.Module):
super().__init__() super().__init__()
self.layers = nn.Sequential( self.layers = nn.Sequential(
nn.Linear(768, 1024), nn.Linear(768, 1024),
nn.Identity(), nn.Dropout(0.2),
nn.Linear(1024, 128), nn.Linear(1024, 128),
nn.Identity(), nn.Dropout(0.2),
nn.Linear(128, 64), nn.Linear(128, 64),
nn.Identity(), nn.Dropout(0.1),
nn.Linear(64, 16), nn.Linear(64, 16),
nn.Linear(16, 1), nn.Linear(16, 1),
) )
state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth"))
self.load_state_dict(state_dict)
@torch.no_grad() @torch.no_grad()
def forward(self, embed): def forward(self, embed):
return self.layers(embed) return self.layers(embed)
@ -37,6 +34,9 @@ class AestheticScorer(torch.nn.Module):
self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
self.mlp = MLP() self.mlp = MLP()
state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth"))
self.mlp.load_state_dict(state_dict)
self.eval()
@torch.no_grad() @torch.no_grad()
def __call__(self, images): def __call__(self, images):
@ -44,5 +44,5 @@ class AestheticScorer(torch.nn.Module):
inputs = {k: v.cuda() for k, v in inputs.items()} inputs = {k: v.cuda() for k, v in inputs.items()}
embed = self.clip.get_image_features(**inputs) embed = self.clip.get_image_features(**inputs)
# normalize embedding # normalize embedding
embed = embed / embed.norm(dim=-1, keepdim=True) embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
return self.mlp(embed) return self.mlp(embed)

View File

@ -35,8 +35,6 @@ def aesthetic_score():
scorer = AestheticScorer().cuda() scorer = AestheticScorer().cuda()
def _fn(images, prompts, metadata): def _fn(images, prompts, metadata):
if not isinstance(images, torch.Tensor):
images = torch.as_tensor(images)
scores = scorer(images) scores = scorer(images)
return scores, {} return scores, {}

View File

@ -34,17 +34,23 @@ logger = get_logger(__name__)
def main(_): def main(_):
# basic Accelerate and logging setup # basic Accelerate and logging setup
config = FLAGS.config config = FLAGS.config
# number of timesteps within each trajectory to train on
num_train_timesteps = int(config.sample.num_steps * config.train.timestep_fraction)
accelerator = Accelerator( accelerator = Accelerator(
log_with="wandb", log_with="wandb",
mixed_precision=config.mixed_precision, mixed_precision=config.mixed_precision,
project_dir=config.logdir, project_dir=config.logdir,
gradient_accumulation_steps=config.train.gradient_accumulation_steps * config.sample.num_steps, # we always accumulate gradients across timesteps; config.train.gradient_accumulation_steps is the number of
# _samples_ to accumulate across
gradient_accumulation_steps=config.train.gradient_accumulation_steps * num_train_timesteps,
) )
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers(project_name="ddpo-pytorch", config=config.to_dict()) accelerator.init_trackers(project_name="ddpo-pytorch", config=config.to_dict())
logger.info(f"\n{config}") logger.info(f"\n{config}")
# set seed # set seed (device_specific is very important to get different prompts on different devices)
set_seed(config.seed, device_specific=True) set_seed(config.seed, device_specific=True)
# load scheduler, tokenizer and models. # load scheduler, tokenizer and models.
@ -152,7 +158,8 @@ def main(_):
config.per_prompt_stat_tracking.min_count, config.per_prompt_stat_tracking.min_count,
) )
# for some reason, autocast is necessary for non-lora training but for lora training it uses more memory # for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
# more memory
autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast autocast = contextlib.nullcontext if config.use_lora else accelerator.autocast
# Prepare everything with our `accelerator`. # Prepare everything with our `accelerator`.
@ -289,8 +296,15 @@ def main(_):
#################### TRAINING #################### #################### TRAINING ####################
for inner_epoch in range(config.train.num_inner_epochs): for inner_epoch in range(config.train.num_inner_epochs):
# shuffle samples along batch dimension # shuffle samples along batch dimension
indices = torch.randperm(total_batch_size, device=accelerator.device) perm = torch.randperm(total_batch_size, device=accelerator.device)
samples = {k: v[indices] for k, v in samples.items()} samples = {k: v[perm] for k, v in samples.items()}
# shuffle along time dimension independently for each sample
perms = torch.stack(
[torch.randperm(num_timesteps, device=accelerator.device) for _ in range(total_batch_size)]
)
for key in ["timesteps", "latents", "next_latents", "log_probs"]:
samples[key] = samples[key][torch.arange(total_batch_size, device=accelerator.device)[:, None], perms]
# rebatch for training # rebatch for training
samples_batched = {k: v.reshape(-1, config.train.batch_size, *v.shape[1:]) for k, v in samples.items()} samples_batched = {k: v.reshape(-1, config.train.batch_size, *v.shape[1:]) for k, v in samples.items()}
@ -300,6 +314,7 @@ def main(_):
# train # train
pipeline.unet.train() pipeline.unet.train()
info = defaultdict(list)
for i, sample in tqdm( for i, sample in tqdm(
list(enumerate(samples_batched)), list(enumerate(samples_batched)),
desc=f"Epoch {epoch}.{inner_epoch}: training", desc=f"Epoch {epoch}.{inner_epoch}: training",
@ -312,9 +327,8 @@ def main(_):
else: else:
embeds = sample["prompt_embeds"] embeds = sample["prompt_embeds"]
info = defaultdict(list)
for j in tqdm( for j in tqdm(
range(num_timesteps), range(num_train_timesteps),
desc="Timestep", desc="Timestep",
position=1, position=1,
leave=False, leave=False,
@ -371,14 +385,20 @@ def main(_):
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: if accelerator.sync_gradients:
assert (j == num_train_timesteps - 1) and (i + 1) % config.train.gradient_accumulation_steps == 0
# log training-related stuff # log training-related stuff
info = {k: torch.mean(torch.stack(v)) for k, v in info.items()} info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
info = accelerator.reduce(info, reduction="mean")
info.update({"epoch": epoch, "inner_epoch": inner_epoch}) info.update({"epoch": epoch, "inner_epoch": inner_epoch})
accelerator.log(info, step=global_step) accelerator.log(info, step=global_step)
global_step += 1 global_step += 1
info = defaultdict(list) info = defaultdict(list)
# make sure we did an optimization step at the end of the inner epoch
assert accelerator.sync_gradients
if __name__ == "__main__": if __name__ == "__main__":
app.run(main) app.run(main)