From a0558220560ba773235d37189c96586e14b10b0b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 19 Aug 2022 15:31:23 +0000 Subject: [PATCH] up --- .../pipeline_stable_diffusion.py | 19 +++- .../pipelines/stable_diffusion/safety.py | 94 +++++++++++++++++++ .../stable_diffusion/safety_checker.py | 2 +- 3 files changed, 113 insertions(+), 2 deletions(-) create mode 100755 src/diffusers/pipelines/stable_diffusion/safety.py diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index baff1db97092..8dd546a596f1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -10,6 +10,9 @@ from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from .safety_checker import StableDiffusionSafetyChecker +from .safety import SafetyChecker + +original_checker = SafetyChecker() class StableDiffusionPipeline(DiffusionPipeline): @@ -149,9 +152,23 @@ def __call__( image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() + original_result = original_checker(self.numpy_to_pil(image)) + # run safety checker safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(torch_device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + image, has_nsfw_concept, result = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + + def check_values_the_same(dict_1, dict_2, name): + dict_1_values = dict_1[name].values() + dict_2_values = dict_2[name].values() + the_same = torch.allclose(torch.tensor(list(dict_1_values)), torch.tensor(list(dict_2_values)), atol=1e-3) + if not the_same: + print("Original", dict_1[name]) + print("Diffusers", dict_2[name]) + + for dict_1, dict_2 in zip(original_result, result): + for name in ['special_scores', 'concept_scores']: + check_values_the_same(dict_1, dict_2, name) if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/safety.py b/src/diffusers/pipelines/stable_diffusion/safety.py new file mode 100755 index 000000000000..ea333ec362f2 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/safety.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +from cgitb import reset +from typing import OrderedDict +import torch, torch.nn as nn +import open_clip +import numpy as np +import yaml + +from open_clip import create_model_and_transforms + +model, _, preprocess = create_model_and_transforms("ViT-L-14", "openai") + +def normalized(a, axis=-1, order=2): + """Normalize the given array along the specified axis in order to""" + l2 = np.atleast_1d(np.linalg.norm(a, order, axis)) + l2[l2 == 0] = 1 + return a / np.expand_dims(l2, axis) + +def pw_cosine_distance(input_a, input_b): + normalized_input_a = torch.nn.functional.normalize(input_a) + normalized_input_b = torch.nn.functional.normalize(input_b) + return torch.mm(normalized_input_a, normalized_input_b.T) + +class SafetyChecker(nn.Module): + def __init__(self, device = 'cuda') -> None: + super().__init__() + self.clip_model = model.to(device) + self.preprocess = preprocess + self.device = device + safety_settings = yaml.safe_load(open("/home/patrick/safety_settings.yml", "r")) + self.concepts_dict = dict(safety_settings["nsfw"]["concepts"]) + self.special_care_dict = dict(safety_settings["special"]["concepts"]) + self.concept_embeds = self.get_text_embeds( + list(self.concepts_dict.keys())) + self.special_care_embeds = self.get_text_embeds( + list(self.special_care_dict.keys())) + + def get_image_embeds(self, input): + """Get embeddings for images or tensor""" + with torch.cuda.amp.autocast(): + with torch.no_grad(): + # Preprocess if input is a list of PIL images + if isinstance(input, list): + l = [] + for image in input: + l.append(self.preprocess(image)) + img_tensor = torch.stack(l) + # input is a tensor + elif isinstance(input, torch.Tensor): + img_tensor = input + return self.clip_model.encode_image(img_tensor.half().to(self.device)) + + def get_text_embeds(self, input): + """Get text embeddings for a list of text""" + with torch.cuda.amp.autocast(): + with torch.no_grad(): + input = open_clip.tokenize(input).to(self.device) + return(self.clip_model.encode_text(input)) + + def forward(self, images): + """Get embeddings for images and output nsfw and concept scores""" + image_embeds = self.get_image_embeds(images) + concept_list = list(self.concepts_dict.keys()) + special_list = list(self.special_care_dict.keys()) + special_cos_dist = pw_cosine_distance(image_embeds, + self.special_care_embeds).cpu().numpy() + cos_dist = pw_cosine_distance(image_embeds, + self.concept_embeds).cpu().numpy() + result = [] + for i in range(image_embeds.shape[0]): + result_img = { + "special_scores":{}, + "special_care":[], + "concept_scores":{}, + "bad_concepts":[]} + adjustment = 0.05 + for j in range(len(special_cos_dist[0])): + concept_name = special_list[j] + concept_cos = special_cos_dist[i][j] + concept_threshold = self.special_care_dict[concept_name] + result_img["special_scores"][concept_name] = round( + concept_cos - concept_threshold + adjustment,3) + if result_img["special_scores"][concept_name] > 0: + result_img["special_care"].append({concept_name,result_img["special_scores"][concept_name]}) + adjustment = 0.01 + for j in range(len(cos_dist[0])): + concept_name = concept_list[j] + concept_cos = cos_dist[i][j] + concept_threshold = self.concepts_dict[concept_name] + result_img["concept_scores"][concept_name] = round(concept_cos - concept_threshold + adjustment,3) + if result_img["concept_scores"][concept_name]> 0: + result_img["bad_concepts"].append(concept_name) + result.append(result_img) + return result diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 3c43d7ffd988..5a53361b6d9a 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -74,4 +74,4 @@ def forward(self, clip_input, images): " Try again with a different prompt and/or seed." ) - return images, has_nsfw_concepts + return images, has_nsfw_concepts, result