This commit is contained in:
Kevin Black
2023-11-16 22:36:46 +00:00
parent 378dd18298
commit 1958463f02
5 changed files with 227 additions and 68 deletions

View File

@@ -35,7 +35,9 @@ class AestheticScorer(torch.nn.Module):
self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
self.mlp = MLP()
state_dict = torch.load(ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth"))
state_dict = torch.load(
ASSETS_PATH.joinpath("sac+logos+ava1-l14-linearMSE.pth")
)
self.mlp.load_state_dict(state_dict)
self.dtype = dtype
self.eval()

View File

@@ -20,9 +20,13 @@ def _left_broadcast(t, shape):
def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device)
alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(
timestep.device
)
alpha_prod_t_prev = torch.where(
prev_timestep.cpu() >= 0, self.alphas_cumprod.gather(0, prev_timestep.cpu()), self.final_alpha_cumprod
prev_timestep.cpu() >= 0,
self.alphas_cumprod.gather(0, prev_timestep.cpu()),
self.final_alpha_cumprod,
).to(timestep.device)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
@@ -86,31 +90,45 @@ def ddim_step_with_logprob(
# - pred_prev_sample -> "x_t-1"
# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
prev_timestep = (
timestep - self.config.num_train_timesteps // self.num_inference_steps
)
# to prevent OOB on gather
prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1)
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu())
alpha_prod_t_prev = torch.where(
prev_timestep.cpu() >= 0, self.alphas_cumprod.gather(0, prev_timestep.cpu()), self.final_alpha_cumprod
prev_timestep.cpu() >= 0,
self.alphas_cumprod.gather(0, prev_timestep.cpu()),
self.final_alpha_cumprod,
)
alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device)
alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(sample.device)
alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(
sample.device
)
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_original_sample = (
sample - beta_prod_t ** (0.5) * model_output
) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
pred_epsilon = (
sample - alpha_prod_t ** (0.5) * pred_original_sample
) / beta_prod_t ** (0.5)
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
pred_original_sample = (alpha_prod_t**0.5) * sample - (
beta_prod_t**0.5
) * model_output
pred_epsilon = (alpha_prod_t**0.5) * model_output + (
beta_prod_t**0.5
) * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
@@ -133,13 +151,19 @@ def ddim_step_with_logprob(
if use_clipped_model_output:
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
pred_epsilon = (
sample - alpha_prod_t ** (0.5) * pred_original_sample
) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (
0.5
) * pred_epsilon
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample_mean = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
prev_sample_mean = (
alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
)
if prev_sample is not None and generator is not None:
raise ValueError(
@@ -149,7 +173,10 @@ def ddim_step_with_logprob(
if prev_sample is None:
variance_noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
model_output.shape,
generator=generator,
device=model_output.device,
dtype=model_output.dtype,
)
prev_sample = prev_sample_mean + std_dev_t * variance_noise

View File

@@ -116,7 +116,15 @@ def pipeline_with_logprob(
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
self.check_inputs(
prompt,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -133,7 +141,11 @@ def pipeline_with_logprob(
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None)
if cross_attention_kwargs is not None
else None
)
prompt_embeds = self._encode_prompt(
prompt,
device,
@@ -172,7 +184,9 @@ def pipeline_with_logprob(
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
@@ -187,27 +201,39 @@ def pipeline_with_logprob(
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
noise_pred = rescale_noise_cfg(
noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
)
# compute the previous noisy sample x_t -> x_t-1
latents, log_prob = ddim_step_with_logprob(self.scheduler, noise_pred, t, latents, **extra_step_kwargs)
latents, log_prob = ddim_step_with_logprob(
self.scheduler, noise_pred, t, latents, **extra_step_kwargs
)
all_latents.append(latents)
all_log_probs.append(log_prob)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
image = self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False
)[0]
image, has_nsfw_concept = self.run_safety_checker(
image, device, prompt_embeds.dtype
)
else:
image = latents
has_nsfw_concept = None
@@ -217,7 +243,9 @@ def pipeline_with_logprob(
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
image = self.image_processor.postprocess(
image, output_type=output_type, do_denormalize=do_denormalize
)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:

View File

@@ -35,7 +35,11 @@ def aesthetic_score():
scorer = AestheticScorer(dtype=torch.float32).cuda()
def _fn(images, prompts, metadata):
images = (images * 255).round().clamp(0, 255).to(torch.uint8)
if isinstance(images, torch.Tensor):
images = (images * 255).round().clamp(0, 255).to(torch.uint8)
else:
images = images.transpose(0, 3, 1, 2) # NHWC -> NCHW
images = torch.tensor(images, dtype=torch.uint8)
scores = scorer(images)
return scores, {}
@@ -55,7 +59,9 @@ def llava_strict_satisfaction():
batch_size = 4
url = "http://127.0.0.1:8085"
sess = requests.Session()
retries = Retry(total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False)
retries = Retry(
total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
)
sess.mount("http://", HTTPAdapter(max_retries=retries))
def _fn(images, prompts, metadata):
@@ -121,7 +127,9 @@ def llava_bertscore():
batch_size = 16
url = "http://127.0.0.1:8085"
sess = requests.Session()
retries = Retry(total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False)
retries = Retry(
total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
)
sess.mount("http://", HTTPAdapter(max_retries=retries))
def _fn(images, prompts, metadata):
@@ -152,8 +160,11 @@ def llava_bertscore():
# format for LLaVA server
data = {
"images": jpeg_images,
"queries": [["Answer concisely: what is going on in this image?"]] * len(image_batch),
"answers": [[f"The image contains {prompt}"] for prompt in prompt_batch],
"queries": [["Answer concisely: what is going on in this image?"]]
* len(image_batch),
"answers": [
[f"The image contains {prompt}"] for prompt in prompt_batch
],
}
data_bytes = pickle.dumps(data)
@@ -167,7 +178,9 @@ def llava_bertscore():
all_scores += scores.tolist()
# save the precision and f1 scores for analysis
all_info["precision"] += np.array(response_data["precision"]).squeeze().tolist()
all_info["precision"] += (
np.array(response_data["precision"]).squeeze().tolist()
)
all_info["f1"] += np.array(response_data["f1"]).squeeze().tolist()
all_info["outputs"] += np.array(response_data["outputs"]).squeeze().tolist()