# Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import torch from transformers import UnivNetConfig, UnivNetModel, logging logging.set_verbosity_info() logger = logging.get_logger("transformers.models.univnet") def get_kernel_predictor_key_mapping(config: UnivNetConfig, old_prefix: str = "", new_prefix: str = ""): mapping = {} # Initial conv layer mapping[f"{old_prefix}.input_conv.0.weight_g"] = f"{new_prefix}.input_conv.weight_g" mapping[f"{old_prefix}.input_conv.0.weight_v"] = f"{new_prefix}.input_conv.weight_v" mapping[f"{old_prefix}.input_conv.0.bias"] = f"{new_prefix}.input_conv.bias" # Kernel predictor resnet blocks for i in range(config.kernel_predictor_num_blocks): mapping[f"{old_prefix}.residual_convs.{i}.1.weight_g"] = f"{new_prefix}.resblocks.{i}.conv1.weight_g" mapping[f"{old_prefix}.residual_convs.{i}.1.weight_v"] = f"{new_prefix}.resblocks.{i}.conv1.weight_v" mapping[f"{old_prefix}.residual_convs.{i}.1.bias"] = f"{new_prefix}.resblocks.{i}.conv1.bias" mapping[f"{old_prefix}.residual_convs.{i}.3.weight_g"] = f"{new_prefix}.resblocks.{i}.conv2.weight_g" mapping[f"{old_prefix}.residual_convs.{i}.3.weight_v"] = f"{new_prefix}.resblocks.{i}.conv2.weight_v" mapping[f"{old_prefix}.residual_convs.{i}.3.bias"] = f"{new_prefix}.resblocks.{i}.conv2.bias" # Kernel output conv mapping[f"{old_prefix}.kernel_conv.weight_g"] = f"{new_prefix}.kernel_conv.weight_g" mapping[f"{old_prefix}.kernel_conv.weight_v"] = f"{new_prefix}.kernel_conv.weight_v" mapping[f"{old_prefix}.kernel_conv.bias"] = f"{new_prefix}.kernel_conv.bias" # Bias output conv mapping[f"{old_prefix}.bias_conv.weight_g"] = f"{new_prefix}.bias_conv.weight_g" mapping[f"{old_prefix}.bias_conv.weight_v"] = f"{new_prefix}.bias_conv.weight_v" mapping[f"{old_prefix}.bias_conv.bias"] = f"{new_prefix}.bias_conv.bias" return mapping def get_key_mapping(config: UnivNetConfig): mapping = {} # NOTE: inital conv layer keys are the same # LVC Residual blocks for i in range(len(config.resblock_stride_sizes)): # LVCBlock initial convt layer mapping[f"res_stack.{i}.convt_pre.1.weight_g"] = f"resblocks.{i}.convt_pre.weight_g" mapping[f"res_stack.{i}.convt_pre.1.weight_v"] = f"resblocks.{i}.convt_pre.weight_v" mapping[f"res_stack.{i}.convt_pre.1.bias"] = f"resblocks.{i}.convt_pre.bias" # Kernel predictor kernel_predictor_mapping = get_kernel_predictor_key_mapping( config, old_prefix=f"res_stack.{i}.kernel_predictor", new_prefix=f"resblocks.{i}.kernel_predictor" ) mapping.update(kernel_predictor_mapping) # LVC Residual blocks for j in range(len(config.resblock_dilation_sizes[i])): mapping[f"res_stack.{i}.conv_blocks.{j}.1.weight_g"] = f"resblocks.{i}.resblocks.{j}.conv.weight_g" mapping[f"res_stack.{i}.conv_blocks.{j}.1.weight_v"] = f"resblocks.{i}.resblocks.{j}.conv.weight_v" mapping[f"res_stack.{i}.conv_blocks.{j}.1.bias"] = f"resblocks.{i}.resblocks.{j}.conv.bias" # Output conv layer mapping["conv_post.1.weight_g"] = "conv_post.weight_g" mapping["conv_post.1.weight_v"] = "conv_post.weight_v" mapping["conv_post.1.bias"] = "conv_post.bias" return mapping def rename_state_dict(state_dict, keys_to_modify, keys_to_remove): model_state_dict = {} for key, value in state_dict.items(): if key in keys_to_remove: continue if key in keys_to_modify: new_key = keys_to_modify[key] model_state_dict[new_key] = value else: model_state_dict[key] = value return model_state_dict def convert_univnet_checkpoint( checkpoint_path, pytorch_dump_folder_path, config_path=None, repo_id=None, safe_serialization=False, ): model_state_dict_base = torch.load(checkpoint_path, map_location="cpu") # Get the generator's state dict state_dict = model_state_dict_base["model_g"] if config_path is not None: config = UnivNetConfig.from_pretrained(config_path) else: config = UnivNetConfig() keys_to_modify = get_key_mapping(config) keys_to_remove = set() hf_state_dict = rename_state_dict(state_dict, keys_to_modify, keys_to_remove) model = UnivNetModel(config) # Apply weight norm since the original checkpoint has weight norm applied model.apply_weight_norm() model.load_state_dict(hf_state_dict) # Remove weight norm in preparation for inference model.remove_weight_norm() model.save_pretrained(pytorch_dump_folder_path, safe_serialization=safe_serialization) if repo_id: print("Pushing to the hub...") model.push_to_hub(repo_id) def main(): parser = argparse.ArgumentParser() parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint") parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") parser.add_argument( "--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model." ) parser.add_argument( "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub." ) parser.add_argument( "--safe_serialization", action="store_true", help="Whether to save the model using `safetensors`." ) args = parser.parse_args() convert_univnet_checkpoint( args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.push_to_hub, args.safe_serialization, ) if __name__ == "__main__": main()