# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint import argparse import os import os.path as osp import torch #=================# # UNet Conversion # #=================# unet_conversion_map = [ # (stable-diffusion, HF Diffusers) ('time_embed.0.weight', 'time_embedding.linear_1.weight'), ('time_embed.0.bias', 'time_embedding.linear_1.bias'), ('time_embed.2.weight', 'time_embedding.linear_2.weight'), ('time_embed.2.bias', 'time_embedding.linear_2.bias'), ('input_blocks.0.0.weight', 'conv_in.weight'), ('input_blocks.0.0.bias', 'conv_in.bias'), ('out.0.weight', 'conv_norm_out.weight'), ('out.0.bias', 'conv_norm_out.bias'), ('out.2.weight', 'conv_out.weight'), ('out.2.bias', 'conv_out.bias') ] unet_conversion_map_resnet = [ # (stable-diffusion, HF Diffusers) ("in_layers.0", "norm1"), ("in_layers.2", "conv1"), ("out_layers.0", "norm2"), ("out_layers.3", "conv2"), ("emb_layers.1", "time_emb_proj"), ("skip_connection", "conv_shortcut") ] unet_conversion_map_layer = [] # hardcoded number of downblocks and resnets/attentions... # would need smarter logic for other networks. for i in range(4): # loop over downblocks/upblocks for j in range(2): # loop over resnets/attentions for downblocks hf_down_res_prefix = f'down_blocks.{i}.resnets.{j}.' sd_down_res_prefix = f'input_blocks.{3*i + j + 1}.0.' unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) if i < 3: # no attention layers in down_blocks.3 hf_down_atn_prefix = f'down_blocks.{i}.attentions.{j}.' sd_down_atn_prefix = f'input_blocks.{3*i + j + 1}.1.' unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) for j in range(3): # loop over resnets/attentions for upblocks hf_up_res_prefix = f'up_blocks.{i}.resnets.{j}.' sd_up_res_prefix = f'output_blocks.{3*i + j}.0.' unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) if i > 0: # no attention layers in up_blocks.0 hf_up_atn_prefix = f'up_blocks.{i}.attentions.{j}.' sd_up_atn_prefix = f'output_blocks.{3*i + j}.1.' unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) if i < 3: # no downsample in down_blocks.3 hf_downsample_prefix = f'down_blocks.{i}.downsamplers.0.conv.' sd_downsample_prefix = f'input_blocks.{3*(i+1)}.0.op.' unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) # no upsample in up_blocks.3 hf_upsample_prefix = f'up_blocks.{i}.upsamplers.0.' sd_upsample_prefix = f'output_blocks.{3*i + 2}.{1 if i == 0 else 2}.' unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) hf_mid_atn_prefix = 'mid_block.attentions.0.' sd_mid_atn_prefix = 'middle_block.1.' unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) for j in range(2): hf_mid_res_prefix = f'mid_block.resnets.{j}.' sd_mid_res_prefix = f'middle_block.{2*j}.' unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) def convert_unet_state_dict(unet_state_dict): # buyer beware: this is a *brittle* function, # and correct output requires that all of these pieces interact in # the exact order in which I have arranged them. mapping = {k:k for k in unet_state_dict.keys()} for sd_name, hf_name in unet_conversion_map: mapping[hf_name] = sd_name for k,v in mapping.items(): if 'resnets' in k: for sd_part, hf_part in unet_conversion_map_resnet: v = v.replace(hf_part, sd_part) mapping[k] = v for k,v in mapping.items(): for sd_part, hf_part in unet_conversion_map_layer: v = v.replace(hf_part, sd_part) mapping[k] = v new_state_dict = {v:unet_state_dict[k] for k,v in mapping.items()} return new_state_dict #================# # VAE Conversion # #================# vae_conversion_map = [ # (stable-diffusion, HF Diffusers) ('nin_shortcut', 'conv_shortcut'), ('norm_out', 'conv_norm_out'), ('mid.attn_1.', 'mid_block.attentions.0.') ] for i in range(4): # down_blocks have two resnets for j in range(2): hf_down_prefix = f'encoder.down_blocks.{i}.resnets.{j}.' sd_down_prefix = f'encoder.down.{i}.block.{j}.' vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) if i < 3: hf_downsample_prefix = f'down_blocks.{i}.downsamplers.0.' sd_downsample_prefix = f'down.{i}.downsample.' vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) hf_upsample_prefix = f'up_blocks.{i}.upsamplers.0.' sd_upsample_prefix = f'up.{3-i}.upsample.' vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) # up_blocks have three resnets # also, up blocks in hf are numbered in reverse from sd for j in range(3): hf_up_prefix = f'decoder.up_blocks.{i}.resnets.{j}.' sd_up_prefix = f'decoder.up.{3-i}.block.{j}.' vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) # this part accounts for mid blocks in both the encoder and the decoder for i in range(2): hf_mid_res_prefix = f'mid_block.resnets.{i}.' sd_mid_res_prefix = f'mid.block_{i+1}.' vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) vae_conversion_map_attn = [ # (stable-diffusion, HF Diffusers) ('norm.', 'group_norm.'), ('q.','query.'), ('k.','key.'), ('v.','value.'), ('proj_out.','proj_attn.') ] def reshape_weight_for_sd(w): # convert HF linear weights to SD conv2d weights return w.reshape(*w.shape, 1, 1) def convert_vae_state_dict(vae_state_dict): mapping = {k:k for k in vae_state_dict.keys()} for k,v in mapping.items(): for sd_part, hf_part in vae_conversion_map: v = v.replace(hf_part, sd_part) mapping[k] = v for k,v in mapping.items(): if 'attentions' in k: for sd_part, hf_part in vae_conversion_map_attn: v = v.replace(hf_part, sd_part) mapping[k] = v new_state_dict = {v:vae_state_dict[k] for k,v in mapping.items()} weights_to_convert = ['q', 'k', 'v', 'proj_out'] for k,v in new_state_dict.items(): for weight_name in weights_to_convert: if f'mid.attn_1.{weight_name}.weight' in k: print(f'Reshaping {k} for SD format') new_state_dict[k] = reshape_weight_for_sd(v) return new_state_dict #=========================# # Text Encoder Conversion # #=========================# # pretty much a no-op def convert_text_enc_state_dict(text_enc_dict): return text_enc_dict if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") args = parser.parse_args() assert args.model_path is not None, "Must provide a model path!" assert args.checkpoint_path is not None, "Must provide a checkpoint path!" unet_path = osp.join(args.model_path, 'unet', 'diffusion_pytorch_model.bin') vae_path = osp.join(args.model_path, 'vae', 'diffusion_pytorch_model.bin') text_enc_path = osp.join(args.model_path, 'text_encoder', 'pytorch_model.bin') # Convert the UNet model unet_state_dict = torch.load(unet_path, map_location=torch.device("cpu")) unet_state_dict = convert_unet_state_dict(unet_state_dict) unet_state_dict = {"model.diffusion_model."+k:v for k,v in unet_state_dict.items()} # Convert the VAE model vae_state_dict = torch.load(vae_path, map_location=torch.device("cpu")) vae_state_dict = convert_vae_state_dict(vae_state_dict) vae_state_dict = {"first_stage_model."+k:v for k,v in vae_state_dict.items()} # Convert the text encoder model text_enc_dict = torch.load(text_enc_path, map_location=torch.device("cpu")) text_enc_dict = convert_text_enc_state_dict(text_enc_dict) text_enc_dict = {"cond_stage_model.transformer."+k:v for k,v in text_enc_dict.items()} # Put together new checkpoint state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} state_dict = {"state_dict": state_dict} torch.save(state_dict, args.checkpoint_path)