diff --git a/Hallo2/README.md b/Hallo2/README.md new file mode 100644 index 00000000..343e0e96 --- /dev/null +++ b/Hallo2/README.md @@ -0,0 +1,76 @@ +# Hallo2项目配置文档 + +### 小组成员 + +黄松毅 吴京桥 熊康慈 + +### 镜像文件下载和导入 + +百度云盘:https://pan.baidu.com/s/1D7w-AarTui4qsTPO_wiNfg +提取码:eigy + +该云盘中有镜像压缩文件:hallo2.tar,先从云盘中下载完整docker镜像文件hallo2.tar,然后将其导入到服务器中,使用其加载对应镜像,命令为:`docker load -i hallo2.tar`,然后可以查看现有镜像:`docker images`,应该会出现一个名为hallo2,版本为v5的镜像。 +接着需要基于hallo2:v5镜像构建容器,命令为: +```bash +docker run -it --rm \ +--gpus all \ +-v /path/to/your/input_image:/app/input.jpg \ +-v /path/to/your/input_audio_text:/app/input.wav \ +-v /path/to/your/output_dir:/app/output \ +hallo2:v5 +``` + +其中,我们需要指定--gpus为all,否则哪怕容器装有cuda driver和cuda都没办法调用到主机上的gpu。其中我们建议输入的图片格式为jpg格式,音频格式为wav格式,并将图片和音频导入到容器中时候名字都为input。 + +### 运行项目生成视频 + +在输入上述指令之后,我们就成功进入hallo2容器中。 + +该容器的工作目录为/app,因此我们需要先进入/app目录:`cd /app`,随后我们输入生成视频的指令: +```bash +python scripts/inference_long.py \ +--config configs/inference/long.yaml \ +--source_image ./input.jpg \ +--driving_audio ./input.wav \ +--pose_weight 1.0 \ +--face_weight 1.0 \ +--lip_weight 1.0 \ +--face_expand_ratio 1.0 \ +&& python scripts/video_sr.py \ +--input_path output_long/debug/input/merge_video.mp4 \ +--output_path output/ +--bg_upsampler realesrgan --face_upsample -w 1 -s 4 +``` + +即可以在容器中的/app/output下找到生成的merge_video.mp4。 + +注:运行7s的视频,需要运行20min左右。 + +### 对视频进行评估 + +之后我们在app目录下找到一个evaluation文件夹,进入该文件夹:`cd evaluation`,之后可以找到该目录下的evaluation.py文件,该文件用于评估指标: + +![alt text](evaluation_help.png) + +其中,original_video_path需要指出原视频路径,generated_video_path需要指出生成视频路径,output_dir需要指出输出数据的保存路径,另外几个参数为是否要计算该评估指标,值为1表示需要计算。 + +之后,我们使用之前就保存在该容器中的两个视频分别作为原始视频和生成视频来测试评估代码,输入指令如下: +```bash +Python evaluation.py \ +--original_video_path ./examples/merge_video.mp4 \ +--generated_video_path ../output_long/debug/1/merge_video.mp4 \ +--output_dir ./output --psnr 1 --fid 1 --lse 1 +``` + +该指令计算psnr、fid和lse(LSE-C和LSE-D),并且输出在目前文件夹(/app/evaluation)的/output文件夹下的evaluation.txt。该txt文件结尾输出数据结构为: + +原始视频路径: +PSNR:... +FID:... +...(其他指标) + +注:如果不在容器环境下运行evaluation.py,需要将pytorch_fid的fid_score.py中的get_activations()函数中的dataloader中的num_workers设置为0 + +---------- + +注:**如有任何问题**,请联系2223915400@qq.com或者vx:hsy190613 \ No newline at end of file diff --git a/Hallo2/evaluation/FID/FID.py b/Hallo2/evaluation/FID/FID.py new file mode 100644 index 00000000..0f382f0d --- /dev/null +++ b/Hallo2/evaluation/FID/FID.py @@ -0,0 +1,43 @@ +import torch +from pytorch_fid import fid_score +# import logging + +# # 配置日志记录 +# logging.basicConfig( +# filename='fid_score.log', # 日志文件名 +# filemode='a', # 追加模式 +# level=logging.INFO, # 日志记录级别 +# format='%(asctime)s - %(levelname)s - %(message)s' +# ) + +def calculate_fid(real_images_folder='./original_frames', generated_images_folder='./generated_frames'): + # 设置真实数据和生成数据文件夹路径 + # real_images_folder = 'raw_results' + # generated_images_folder = 'final_results' + + # 设置参数 + new_batch_size = 16 + device = 'cuda' if torch.cuda.is_available() else 'cpu' + dims = 2048 # 使用 Inception 模型的默认特征维度 + + # try: + # 计算 FID 值 + fid_value = fid_score.calculate_fid_given_paths( + [real_images_folder, generated_images_folder], + batch_size=new_batch_size, + device=device, + dims=dims + ) + # logging.info(f'FID value: {fid_value}') + # print(f'FID value: {fid_value}') + + print(__file__) + + return fid_value + # except Exception as e: + # logging.error(f'Error occurred while calculating FID: {str(e)}') + + + +if __name__ == '__main__': + calculate_fid() diff --git a/Hallo2/evaluation/FID/__pycache__/FID.cpython-310.pyc b/Hallo2/evaluation/FID/__pycache__/FID.cpython-310.pyc new file mode 100644 index 00000000..3bfe1522 Binary files /dev/null and b/Hallo2/evaluation/FID/__pycache__/FID.cpython-310.pyc differ diff --git a/Hallo2/evaluation/LSE/SyncNetInstance_calc_scores.py b/Hallo2/evaluation/LSE/SyncNetInstance_calc_scores.py new file mode 100644 index 00000000..e3354da0 --- /dev/null +++ b/Hallo2/evaluation/LSE/SyncNetInstance_calc_scores.py @@ -0,0 +1,245 @@ +#!/usr/bin/python +#-*- coding: utf-8 -*- +# Video 25 FPS, Audio 16000HZ + +import torch +import numpy +import time, pdb, argparse, subprocess, os, math, glob +import cv2 +import python_speech_features + +from scipy import signal +from scipy.io import wavfile +from LSE.SyncNetModel import * +from shutil import rmtree +from moviepy.editor import VideoFileClip + +# ==================== Get OFFSET ==================== + +def calc_pdist(feat1, feat2, vshift=10): + + win_size = vshift*2+1 + + feat2p = torch.nn.functional.pad(feat2,(0,0,vshift,vshift)) + + dists = [] + + for i in range(0,len(feat1)): + + dists.append(torch.nn.functional.pairwise_distance(feat1[[i],:].repeat(win_size, 1), feat2p[i:i+win_size,:])) + + return dists + +# ==================== MAIN DEF ==================== + +class SyncNetInstance(torch.nn.Module): + + def __init__(self, dropout = 0, num_layers_in_fc_layers = 1024): + super(SyncNetInstance, self).__init__() + + self.__S__ = S(num_layers_in_fc_layers = num_layers_in_fc_layers).cuda() + + def evaluate(self, opt, videofile): + + self.__S__.eval() + + # ========== ========== + # Convert files + # ========== ========== + + if os.path.exists(os.path.join(opt.tmp_dir,opt.reference)): + rmtree(os.path.join(opt.tmp_dir,opt.reference)) + + os.makedirs(os.path.join(opt.tmp_dir,opt.reference)) + + # ========== ========== + # Save jpg & wav + # ========== ========== + + # 一直有问题 + # command = ("ffmpeg -loglevel error -y -i %s -threads 1 -f image2 %s" % (videofile,os.path.join(opt.tmp_dir,opt.reference,'%06d.jpg'))) + # output = subprocess.call(command, shell=True, stdout=None) + + # command = ("ffmpeg -loglevel error -y -i %s -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 %s" % (videofile,os.path.join(opt.tmp_dir,opt.reference,'audio.wav'))) + # output = subprocess.call(command, shell=True, stdout=None) + + # 提取音频并保存为 wav 格式(使用 moviepy) + video = VideoFileClip(videofile) + audio = video.audio + audio_path = os.path.join(opt.tmp_dir,opt.reference, "audio.wav") + audio.write_audiofile(audio_path, codec='pcm_s16le') # 保存为 .wav 格式 + + # 使用 OpenCV 提取每一帧并保存为 jpg + frame_dir = os.path.join(opt.tmp_dir,opt.reference) + # if not os.path.exists(frame_dir): + # os.makedirs(frame_dir) + + # 打开视频文件 + cap = cv2.VideoCapture(videofile) + + # 读取视频帧并保存 + frame_count = 0 + while True: + ret, frame = cap.read() + if not ret: + break # 如果没有更多帧,跳出循环 + + # 保存每帧图像 + frame_path = os.path.join(frame_dir, f"{frame_count:06d}.jpg") + cv2.imwrite(frame_path, frame) + # print(f"Saved frame {frame_count:06d}.jpg") + frame_count += 1 + + # 释放资源 + cap.release() + + # ========== ========== + # Load video + # ========== ========== + + images = [] + + flist = glob.glob(os.path.join(opt.tmp_dir,opt.reference,'*.jpg')) + flist.sort() + + for fname in flist: + img_input = cv2.imread(fname) + img_input = cv2.resize(img_input, (224,224)) #HARD CODED, CHANGE BEFORE RELEASE + images.append(img_input) + + im = numpy.stack(images,axis=3) + im = numpy.expand_dims(im,axis=0) + im = numpy.transpose(im,(0,3,4,1,2)) + + imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float()) + + # ========== ========== + # Load audio + # ========== ========== + + sample_rate, audio = wavfile.read(os.path.join(opt.tmp_dir,opt.reference,'audio.wav')) + mfcc = zip(*python_speech_features.mfcc(audio,sample_rate)) + mfcc = numpy.stack([numpy.array(i) for i in mfcc]) + + cc = numpy.expand_dims(numpy.expand_dims(mfcc,axis=0),axis=0) + cct = torch.autograd.Variable(torch.from_numpy(cc.astype(float)).float()) + + # ========== ========== + # Check audio and video input length + # ========== ========== + + #if (float(len(audio))/16000) != (float(len(images))/25) : + # print("WARNING: Audio (%.4fs) and video (%.4fs) lengths are different."%(float(len(audio))/16000,float(len(images))/25)) + + min_length = min(len(images),math.floor(len(audio)/640)) + + # ========== ========== + # Generate video and audio feats + # ========== ========== + + lastframe = min_length-5 + im_feat = [] + cc_feat = [] + + tS = time.time() + for i in range(0,lastframe,opt.batch_size): + + im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe,i+opt.batch_size)) ] + im_in = torch.cat(im_batch,0) + im_out = self.__S__.forward_lip(im_in.cuda()) + im_feat.append(im_out.data.cpu()) + + cc_batch = [ cct[:,:,:,vframe*4:vframe*4+20] for vframe in range(i,min(lastframe,i+opt.batch_size)) ] + cc_in = torch.cat(cc_batch,0) + cc_out = self.__S__.forward_aud(cc_in.cuda()) + cc_feat.append(cc_out.data.cpu()) + + im_feat = torch.cat(im_feat,0) + cc_feat = torch.cat(cc_feat,0) + + # ========== ========== + # Compute offset + # ========== ========== + + #print('Compute time %.3f sec.' % (time.time()-tS)) + + dists = calc_pdist(im_feat,cc_feat,vshift=opt.vshift) + mdist = torch.mean(torch.stack(dists,1),1) + + minval, minidx = torch.min(mdist,0) + + offset = opt.vshift-minidx + conf = torch.median(mdist) - minval + + fdist = numpy.stack([dist[minidx].numpy() for dist in dists]) + # fdist = numpy.pad(fdist, (3,3), 'constant', constant_values=15) + fconf = torch.median(mdist).numpy() - fdist + fconfm = signal.medfilt(fconf,kernel_size=9) + + numpy.set_printoptions(formatter={'float': '{: 0.3f}'.format}) + #print('Framewise conf: ') + #print(fconfm) + #print('AV offset: \t%d \nMin dist: \t%.3f\nConfidence: \t%.3f' % (offset,minval,conf)) + + dists_npy = numpy.array([ dist.numpy() for dist in dists ]) + return offset.numpy(), conf.numpy(), minval.numpy() + + def extract_feature(self, opt, videofile): + + self.__S__.eval() + + # ========== ========== + # Load video + # ========== ========== + cap = cv2.VideoCapture(videofile) + + frame_num = 1 + images = [] + while frame_num: + frame_num += 1 + ret, image = cap.read() + if ret == 0: + break + + images.append(image) + + im = numpy.stack(images,axis=3) + im = numpy.expand_dims(im,axis=0) + im = numpy.transpose(im,(0,3,4,1,2)) + + imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float()) + + # ========== ========== + # Generate video feats + # ========== ========== + + lastframe = len(images)-4 + im_feat = [] + + tS = time.time() + for i in range(0,lastframe,opt.batch_size): + + im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe,i+opt.batch_size)) ] + im_in = torch.cat(im_batch,0) + im_out = self.__S__.forward_lipfeat(im_in.cuda()) + im_feat.append(im_out.data.cpu()) + + im_feat = torch.cat(im_feat,0) + + # ========== ========== + # Compute offset + # ========== ========== + + print('Compute time %.3f sec.' % (time.time()-tS)) + + return im_feat + + + def loadParameters(self, path): + loaded_state = torch.load(path, map_location=lambda storage, loc: storage) + + self_state = self.__S__.state_dict() + + for name, param in loaded_state.items(): + + self_state[name].copy_(param) \ No newline at end of file diff --git a/Hallo2/evaluation/LSE/SyncNetModel.py b/Hallo2/evaluation/LSE/SyncNetModel.py new file mode 100644 index 00000000..acb81362 --- /dev/null +++ b/Hallo2/evaluation/LSE/SyncNetModel.py @@ -0,0 +1,117 @@ +#!/usr/bin/python +#-*- coding: utf-8 -*- + +import torch +import torch.nn as nn + +def save(model, filename): + with open(filename, "wb") as f: + torch.save(model, f) + print("%s saved."%filename) + +def load(filename): + net = torch.load(filename) + return net + +class S(nn.Module): + def __init__(self, num_layers_in_fc_layers = 1024): + super(S, self).__init__() + + self.__nFeatures__ = 24 + self.__nChs__ = 32 + self.__midChs__ = 32 + + self.netcnnaud = nn.Sequential( + nn.Conv2d(1, 64, kernel_size=(3,3), stride=(1,1), padding=(1,1)), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=(1,1), stride=(1,1)), + + nn.Conv2d(64, 192, kernel_size=(3,3), stride=(1,1), padding=(1,1)), + nn.BatchNorm2d(192), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=(3,3), stride=(1,2)), + + nn.Conv2d(192, 384, kernel_size=(3,3), padding=(1,1)), + nn.BatchNorm2d(384), + nn.ReLU(inplace=True), + + nn.Conv2d(384, 256, kernel_size=(3,3), padding=(1,1)), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + + nn.Conv2d(256, 256, kernel_size=(3,3), padding=(1,1)), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=(3,3), stride=(2,2)), + + nn.Conv2d(256, 512, kernel_size=(5,4), padding=(0,0)), + nn.BatchNorm2d(512), + nn.ReLU(), + ) + + self.netfcaud = nn.Sequential( + nn.Linear(512, 512), + nn.BatchNorm1d(512), + nn.ReLU(), + nn.Linear(512, num_layers_in_fc_layers), + ) + + self.netfclip = nn.Sequential( + nn.Linear(512, 512), + nn.BatchNorm1d(512), + nn.ReLU(), + nn.Linear(512, num_layers_in_fc_layers), + ) + + self.netcnnlip = nn.Sequential( + nn.Conv3d(3, 96, kernel_size=(5,7,7), stride=(1,2,2), padding=0), + nn.BatchNorm3d(96), + nn.ReLU(inplace=True), + nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2)), + + nn.Conv3d(96, 256, kernel_size=(1,5,5), stride=(1,2,2), padding=(0,1,1)), + nn.BatchNorm3d(256), + nn.ReLU(inplace=True), + nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)), + + nn.Conv3d(256, 256, kernel_size=(1,3,3), padding=(0,1,1)), + nn.BatchNorm3d(256), + nn.ReLU(inplace=True), + + nn.Conv3d(256, 256, kernel_size=(1,3,3), padding=(0,1,1)), + nn.BatchNorm3d(256), + nn.ReLU(inplace=True), + + nn.Conv3d(256, 256, kernel_size=(1,3,3), padding=(0,1,1)), + nn.BatchNorm3d(256), + nn.ReLU(inplace=True), + nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2)), + + nn.Conv3d(256, 512, kernel_size=(1,6,6), padding=0), + nn.BatchNorm3d(512), + nn.ReLU(inplace=True), + ) + + def forward_aud(self, x): + + mid = self.netcnnaud(x) # N x ch x 24 x M + mid = mid.view((mid.size()[0], -1)) # N x (ch x 24) + out = self.netfcaud(mid) + + return out + + def forward_lip(self, x): + + mid = self.netcnnlip(x) + mid = mid.view((mid.size()[0], -1)) # N x (ch x 24) + out = self.netfclip(mid) + + return out + + def forward_lipfeat(self, x): + + mid = self.netcnnlip(x) + out = mid.view((mid.size()[0], -1)) # N x (ch x 24) + + return out \ No newline at end of file diff --git a/Hallo2/evaluation/LSE/__pycache__/SyncNetInstance_calc_scores.cpython-310.pyc b/Hallo2/evaluation/LSE/__pycache__/SyncNetInstance_calc_scores.cpython-310.pyc new file mode 100644 index 00000000..094a7d6e Binary files /dev/null and b/Hallo2/evaluation/LSE/__pycache__/SyncNetInstance_calc_scores.cpython-310.pyc differ diff --git a/Hallo2/evaluation/LSE/__pycache__/SyncNetModel.cpython-310.pyc b/Hallo2/evaluation/LSE/__pycache__/SyncNetModel.cpython-310.pyc new file mode 100644 index 00000000..2fd405ac Binary files /dev/null and b/Hallo2/evaluation/LSE/__pycache__/SyncNetModel.cpython-310.pyc differ diff --git a/Hallo2/evaluation/LSE/__pycache__/calculate_scores_LRS.cpython-310.pyc b/Hallo2/evaluation/LSE/__pycache__/calculate_scores_LRS.cpython-310.pyc new file mode 100644 index 00000000..23f433c3 Binary files /dev/null and b/Hallo2/evaluation/LSE/__pycache__/calculate_scores_LRS.cpython-310.pyc differ diff --git a/Hallo2/evaluation/LSE/calculate_scores_LRS.py b/Hallo2/evaluation/LSE/calculate_scores_LRS.py new file mode 100644 index 00000000..3c91afed --- /dev/null +++ b/Hallo2/evaluation/LSE/calculate_scores_LRS.py @@ -0,0 +1,228 @@ +# #!/usr/bin/python +# #-*- coding: utf-8 -*- + +# import time, pdb, argparse, subprocess +# import glob +# import os +# from tqdm import tqdm +# from moviepy.editor import VideoFileClip +# from shutil import rmtree + +# from LSE.SyncNetInstance_calc_scores import * + +# def calculate_scores(video_path): +# # 视频切割 +# def split_video(input_folder, output_folder, segment_duration=30): +# # 获取 /data/ 文件夹下的所有视频文件 +# video_files = [f for f in os.listdir(input_folder) if f.endswith(('.mp4', '.avi', '.mov'))] + +# # 如果 /merge/ 文件夹不存在,则创建 +# if os.path.exists(output_folder): +# rmtree(output_folder) +# os.makedirs(output_folder) + +# for video_file in video_files: +# # print(video_file) +# # 构建输入视频文件的完整路径 +# video_path = os.path.join(input_folder, video_file) + +# # 使用 VideoFileClip 读取视频文件 +# with VideoFileClip(video_path) as video: +# # 获取视频的总时长(秒) +# total_duration = video.duration + +# # 计算分割的视频段数 +# num_segments = int(total_duration // segment_duration) + 1 + +# # 对视频进行切割 +# for i in range(num_segments): +# # 计算每个片段的起始时间和结束时间 +# start_time = i * segment_duration +# end_time = min((i + 1) * segment_duration, total_duration) + +# # 从视频中切割出指定时间段 +# clip = video.subclip(start_time, end_time) + +# # 构建输出视频文件的路径 +# output_filename = f"{os.path.splitext(video_file)[0]}_segment_{i + 1:04d}.mp4" +# output_path = os.path.join(output_folder, output_filename) + +# # 保存切割出的片段 +# clip.write_videofile(output_path, codec="libx264", audio_codec="aac") + +# print(f"Saved {output_filename} to {output_folder}") + +# # 只执行第一个视频 +# break + +# # ==================== LOAD PARAMS ==================== + + +# parser = argparse.ArgumentParser(description = "SyncNet") + +# parser.add_argument('--initial_model', type=str, default="./LSE/data/syncnet_v2.model", help='') +# parser.add_argument('--batch_size', type=int, default='20', help='') +# parser.add_argument('--vshift', type=int, default='15', help='') +# parser.add_argument('--data_input', type=str, default=f"{video_path}", help='') +# parser.add_argument('--data_root', type=str, default="./LSE/data/merge/", help='') +# parser.add_argument('--tmp_dir', type=str, default="./LSE/data/work/pytmp", help='') +# parser.add_argument('--reference', type=str, default="demo", help='') + +# opt = parser.parse_args() + + +# # ==================== RUN EVALUATION ==================== + +# s = SyncNetInstance() + +# s.loadParameters(opt.initial_model) + +# split_video(opt.data_input,opt.data_root) + +# #print("Model %s loaded."%opt.initial_model) +# merge_path = os.path.join(opt.data_root, "*.mp4") + +# all_videos = glob.glob(merge_path) + +# prog_bar = tqdm(range(len(all_videos))) +# avg_confidence = 0. +# avg_min_distance = 0. + + +# for videofile_idx in prog_bar: +# videofile = all_videos[videofile_idx] +# # print(videofile) +# offset, confidence, min_distance = s.evaluate(opt, videofile=videofile) +# avg_confidence += confidence +# avg_min_distance += min_distance +# prog_bar.set_description('Avg Confidence: {}, Avg Minimum Dist: {}'.format(round(avg_confidence / (videofile_idx + 1), 3), round(avg_min_distance / (videofile_idx + 1), 3))) +# prog_bar.refresh() + +# print('Average Confidence: {}'.format(avg_confidence/len(all_videos))) +# print('Average Minimum Distance: {}'.format(avg_min_distance/len(all_videos))) + +# return avg_confidence/len(all_videos), avg_min_distance/len(all_videos) + +#!/usr/bin/python +#-*- coding: utf-8 -*- + +import time, pdb, argparse, subprocess +import glob +import os +from tqdm import tqdm +from moviepy.editor import VideoFileClip +from shutil import rmtree + +from LSE.SyncNetInstance_calc_scores import * + +def calculate_scores(video_path): + # 视频切割 + def split_video(input_video_path, output_dir='./data/merge', segment_duration=15): + # 如果 /merge/ 文件夹不存在,则创建 + if os.path.exists(output_dir): + rmtree(output_dir) + os.makedirs(output_dir) + + # 打开视频文件 + video = VideoFileClip(input_video_path) + + # 获取视频的总时长 + video_duration = video.duration # 视频总时长(秒) + + # 切割视频并保存 + segment_start = 0 + segment_count = 0 + + while segment_start < video_duration: + # 计算切割的结束时间 + segment_end = min(segment_start + segment_duration, video_duration) + + # 切割出一个片段 + video_segment = video.subclip(segment_start, segment_end) + + # 保存切割后的片段 + segment_filename = f"{output_dir}/segment_{segment_count:03d}.mp4" + video_segment.write_videofile(segment_filename, codec='libx264') + + print(f"Saved segment {segment_count:03d} from {segment_start} to {segment_end}") + + # 更新切割的开始时间和片段计数器 + segment_start = segment_end + segment_count += 1 + + video.close() + + + # ==================== LOAD PARAMS ==================== + + class SyncNetConfig: + def __init__(self, + initial_model="./LSE/data/syncnet_v2.model", + batch_size=20, + vshift=15, + data_input=None, + data_root="./LSE/data/merge/", + tmp_dir="./LSE/data/work/pytmp", + reference="demo"): + self.initial_model = initial_model + self.batch_size = batch_size + self.vshift = vshift + self.data_input = data_input + self.data_root = data_root + self.tmp_dir = tmp_dir + self.reference = reference + + # parser = argparse.ArgumentParser(description = "SyncNet") + + # parser.add_argument('--initial_model', type=str, default="./LSE/data/syncnet_v2.model", help='') + # parser.add_argument('--batch_size', type=int, default='20', help='') + # parser.add_argument('--vshift', type=int, default='15', help='') + # parser.add_argument('--data_input', type=str, default=f"{video_path}", help='') + # parser.add_argument('--data_root', type=str, default="./LSE/data/merge/", help='') + # parser.add_argument('--tmp_dir', type=str, default="./LSE/data/work/pytmp", help='') + # parser.add_argument('--reference', type=str, default="demo", help='') + + opt_new = SyncNetConfig(data_input=f"{video_path}") + + + # ==================== RUN EVALUATION ==================== + + s = SyncNetInstance() + + s.loadParameters(opt_new.initial_model) + + split_video(opt_new.data_input,opt_new.data_root) + + #print("Model %s loaded."%opt.initial_model) + merge_path = os.path.join(opt_new.data_root, "*.mp4") + + all_videos = glob.glob(merge_path) + + prog_bar = tqdm(range(len(all_videos))) + avg_confidence = 0. + avg_min_distance = 0. + + + for videofile_idx in prog_bar: + videofile = all_videos[videofile_idx] + # print(videofile) + offset, confidence, min_distance = s.evaluate(opt_new, videofile=videofile) + avg_confidence += confidence + avg_min_distance += min_distance + prog_bar.set_description('Avg Confidence: {}, Avg Minimum Dist: {}'.format(round(avg_confidence / (videofile_idx + 1), 3), round(avg_min_distance / (videofile_idx + 1), 3))) + prog_bar.refresh() + + print(f"{video_path}:") + print('Average Confidence: {}'.format(avg_confidence/len(all_videos))) + print('Average Minimum Distance: {}'.format(avg_min_distance/len(all_videos))) + + return avg_confidence/len(all_videos), avg_min_distance/len(all_videos) + +if __name__ == '__main__': + video_paths = ['./data/input/Obama.mp4','./data/input/Jae-in.mp4','./data/input/Lieu.mp4','./data/input/Macron.mp4','./data/input/May.mp4', + './data/input/Shaheen.mp4'] + # 填入需要进行计算的video路径 + # for video_path in video_paths: + video_path = './data/input/Lieu.mp4' + calculate_scores(video_path) + diff --git a/Hallo2/evaluation/LSE/calculate_scores_real_videos.py b/Hallo2/evaluation/LSE/calculate_scores_real_videos.py new file mode 100644 index 00000000..775880d9 --- /dev/null +++ b/Hallo2/evaluation/LSE/calculate_scores_real_videos.py @@ -0,0 +1,45 @@ +#!/usr/bin/python +#-*- coding: utf-8 -*- + +import time, pdb, argparse, subprocess, pickle, os, gzip, glob + +from SyncNetInstance_calc_scores import * + +# ==================== PARSE ARGUMENT ==================== + +parser = argparse.ArgumentParser(description = "SyncNet") +parser.add_argument('--initial_model', type=str, default="data/syncnet_v2.model", help='') +parser.add_argument('--batch_size', type=int, default='20', help='') +parser.add_argument('--vshift', type=int, default='15', help='') +parser.add_argument('--data_dir', type=str, default='data/work', help='') +parser.add_argument('--videofile', type=str, default='', help='') +parser.add_argument('--reference', type=str, default='', help='') +opt = parser.parse_args() + +setattr(opt,'avi_dir',os.path.join(opt.data_dir,'pyavi')) +setattr(opt,'tmp_dir',os.path.join(opt.data_dir,'pytmp')) +setattr(opt,'work_dir',os.path.join(opt.data_dir,'pywork')) +setattr(opt,'crop_dir',os.path.join(opt.data_dir,'pycrop')) + + +# ==================== LOAD MODEL AND FILE LIST ==================== + +s = SyncNetInstance() + +s.loadParameters(opt.initial_model) +#print("Model %s loaded."%opt.initial_model) + +flist = glob.glob(os.path.join(opt.crop_dir,opt.reference,'0*.avi')) +flist.sort() + +# ==================== GET OFFSETS ==================== + +dists = [] +for idx, fname in enumerate(flist): + offset, conf, dist = s.evaluate(opt,videofile=fname) + print (str(dist)+" "+str(conf)) + +# ==================== PRINT RESULTS TO FILE ==================== + +#with open(os.path.join(opt.work_dir,opt.reference,'activesd.pckl'), 'wb') as fil: +# pickle.dump(dists, fil) diff --git a/Hallo2/evaluation/LSE/data/input/processed.mp4 b/Hallo2/evaluation/LSE/data/input/processed.mp4 new file mode 100644 index 00000000..9772c211 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/input/processed.mp4 differ diff --git a/Hallo2/evaluation/LSE/data/merge/segment_000.mp4 b/Hallo2/evaluation/LSE/data/merge/segment_000.mp4 new file mode 100644 index 00000000..d8684e0a Binary files /dev/null and b/Hallo2/evaluation/LSE/data/merge/segment_000.mp4 differ diff --git a/Hallo2/evaluation/LSE/data/syncnet_v2.model b/Hallo2/evaluation/LSE/data/syncnet_v2.model new file mode 100644 index 00000000..d714bbf3 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/syncnet_v2.model differ diff --git a/Hallo2/evaluation/LSE/data/test.py b/Hallo2/evaluation/LSE/data/test.py new file mode 100644 index 00000000..3bd80c8a --- /dev/null +++ b/Hallo2/evaluation/LSE/data/test.py @@ -0,0 +1,47 @@ +import os +from moviepy.editor import VideoFileClip + +def split_video(input_folder, output_folder, segment_duration=30): + # 获取 /data/ 文件夹下的所有视频文件 + video_files = [f for f in os.listdir(input_folder) if f.endswith(('.mp4', '.avi', '.mov'))] + + # 如果 /merge/ 文件夹不存在,则创建 + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + for video_file in video_files: + print(video_file) + # 构建输入视频文件的完整路径 + video_path = os.path.join(input_folder, video_file) + + # 使用 VideoFileClip 读取视频文件 + with VideoFileClip(video_path) as video: + # 获取视频的总时长(秒) + total_duration = video.duration + + # 计算分割的视频段数 + num_segments = int(total_duration // segment_duration) + 1 + + # 对视频进行切割 + for i in range(num_segments): + # 计算每个片段的起始时间和结束时间 + start_time = i * segment_duration + end_time = min((i + 1) * segment_duration, total_duration) + + # 从视频中切割出指定时间段 + clip = video.subclip(start_time, end_time) + + # 构建输出视频文件的路径 + output_filename = f"{os.path.splitext(video_file)[0]}_segment_{i + 1}.mp4" + output_path = os.path.join(output_folder, output_filename) + + # 保存切割出的片段 + clip.write_videofile(output_path, codec="libx264", audio_codec="aac") + + print(f"Saved {output_filename} to {output_folder}") + +if __name__ == "__main__": + input_folder = "./LSE/data/input" # 输入视频文件夹路径 + output_folder = "./LSE/data/merge" # 输出切割后视频的文件夹路径 + + split_video(input_folder, output_folder) diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000000.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000000.jpg new file mode 100644 index 00000000..198a7142 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000000.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000001.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000001.jpg new file mode 100644 index 00000000..e6eb7d3f Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000001.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000002.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000002.jpg new file mode 100644 index 00000000..1a62cd33 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000002.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000003.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000003.jpg new file mode 100644 index 00000000..d472fcb4 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000003.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000004.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000004.jpg new file mode 100644 index 00000000..86341ad2 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000004.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000005.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000005.jpg new file mode 100644 index 00000000..bc57eff2 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000005.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000006.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000006.jpg new file mode 100644 index 00000000..6fb480d4 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000006.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000007.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000007.jpg new file mode 100644 index 00000000..0b1e5c69 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000007.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000008.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000008.jpg new file mode 100644 index 00000000..dda06f0c Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000008.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000009.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000009.jpg new file mode 100644 index 00000000..df85b309 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000009.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000010.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000010.jpg new file mode 100644 index 00000000..b52229be Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000010.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000011.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000011.jpg new file mode 100644 index 00000000..bf475e18 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000011.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000012.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000012.jpg new file mode 100644 index 00000000..39785700 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000012.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000013.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000013.jpg new file mode 100644 index 00000000..56dd3108 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000013.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000014.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000014.jpg new file mode 100644 index 00000000..32402502 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000014.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000015.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000015.jpg new file mode 100644 index 00000000..93808367 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000015.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000016.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000016.jpg new file mode 100644 index 00000000..c5f807fa Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000016.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000017.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000017.jpg new file mode 100644 index 00000000..a8199e59 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000017.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000018.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000018.jpg new file mode 100644 index 00000000..2dbdd06a Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000018.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000019.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000019.jpg new file mode 100644 index 00000000..277a005a Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000019.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000020.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000020.jpg new file mode 100644 index 00000000..733fd6e2 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000020.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000021.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000021.jpg new file mode 100644 index 00000000..c17f6bcd Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000021.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000022.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000022.jpg new file mode 100644 index 00000000..c6bef6d7 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000022.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000023.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000023.jpg new file mode 100644 index 00000000..72d690b9 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000023.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000024.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000024.jpg new file mode 100644 index 00000000..7709df43 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000024.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000025.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000025.jpg new file mode 100644 index 00000000..58c33216 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000025.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000026.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000026.jpg new file mode 100644 index 00000000..b0994704 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000026.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000027.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000027.jpg new file mode 100644 index 00000000..87e674aa Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000027.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000028.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000028.jpg new file mode 100644 index 00000000..a620733a Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000028.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000029.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000029.jpg new file mode 100644 index 00000000..1530e923 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000029.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000030.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000030.jpg new file mode 100644 index 00000000..0813a0cf Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000030.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000031.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000031.jpg new file mode 100644 index 00000000..27833fd6 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000031.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000032.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000032.jpg new file mode 100644 index 00000000..70313555 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000032.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000033.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000033.jpg new file mode 100644 index 00000000..05800df4 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000033.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000034.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000034.jpg new file mode 100644 index 00000000..837a518a Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000034.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000035.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000035.jpg new file mode 100644 index 00000000..ef48eb6e Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000035.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000036.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000036.jpg new file mode 100644 index 00000000..513bbdf8 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000036.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000037.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000037.jpg new file mode 100644 index 00000000..2ad8bb38 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000037.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000038.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000038.jpg new file mode 100644 index 00000000..8d81c227 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000038.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000039.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000039.jpg new file mode 100644 index 00000000..d247ce81 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000039.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000040.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000040.jpg new file mode 100644 index 00000000..1a55c647 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000040.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000041.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000041.jpg new file mode 100644 index 00000000..410293fa Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000041.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000042.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000042.jpg new file mode 100644 index 00000000..1d0d33e3 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000042.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000043.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000043.jpg new file mode 100644 index 00000000..3933fc93 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000043.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000044.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000044.jpg new file mode 100644 index 00000000..9672b523 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000044.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000045.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000045.jpg new file mode 100644 index 00000000..1217f622 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000045.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000046.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000046.jpg new file mode 100644 index 00000000..bc2352c5 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000046.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000047.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000047.jpg new file mode 100644 index 00000000..c5a77005 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000047.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000048.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000048.jpg new file mode 100644 index 00000000..1b874aa5 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000048.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000049.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000049.jpg new file mode 100644 index 00000000..f22f6691 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000049.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000050.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000050.jpg new file mode 100644 index 00000000..d4d498b7 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000050.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000051.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000051.jpg new file mode 100644 index 00000000..9a5e0e88 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000051.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000052.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000052.jpg new file mode 100644 index 00000000..87bca8ba Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000052.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000053.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000053.jpg new file mode 100644 index 00000000..9ffbbda5 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000053.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000054.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000054.jpg new file mode 100644 index 00000000..4c5d8a5e Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000054.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000055.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000055.jpg new file mode 100644 index 00000000..93546b14 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000055.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000056.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000056.jpg new file mode 100644 index 00000000..8538400f Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000056.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000057.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000057.jpg new file mode 100644 index 00000000..e5b0a842 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000057.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000058.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000058.jpg new file mode 100644 index 00000000..cca7f246 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000058.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000059.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000059.jpg new file mode 100644 index 00000000..85a69daf Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000059.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000060.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000060.jpg new file mode 100644 index 00000000..492d6b19 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000060.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000061.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000061.jpg new file mode 100644 index 00000000..3da9522c Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000061.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000062.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000062.jpg new file mode 100644 index 00000000..87df8fc9 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000062.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000063.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000063.jpg new file mode 100644 index 00000000..2385bd51 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000063.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000064.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000064.jpg new file mode 100644 index 00000000..0239ac74 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000064.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000065.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000065.jpg new file mode 100644 index 00000000..76d5dc94 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000065.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000066.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000066.jpg new file mode 100644 index 00000000..6e6278b8 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000066.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000067.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000067.jpg new file mode 100644 index 00000000..a1cdfc52 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000067.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000068.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000068.jpg new file mode 100644 index 00000000..a5b901ab Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000068.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000069.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000069.jpg new file mode 100644 index 00000000..b4041aab Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000069.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000070.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000070.jpg new file mode 100644 index 00000000..38a95631 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000070.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000071.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000071.jpg new file mode 100644 index 00000000..e26f8d12 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000071.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000072.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000072.jpg new file mode 100644 index 00000000..ab92aaa6 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000072.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000073.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000073.jpg new file mode 100644 index 00000000..352cc10a Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000073.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000074.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000074.jpg new file mode 100644 index 00000000..fe4a9898 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000074.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000075.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000075.jpg new file mode 100644 index 00000000..b2dd776a Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000075.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000076.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000076.jpg new file mode 100644 index 00000000..42ba84d8 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000076.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000077.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000077.jpg new file mode 100644 index 00000000..94c762e4 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000077.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000078.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000078.jpg new file mode 100644 index 00000000..4e6efabf Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000078.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000079.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000079.jpg new file mode 100644 index 00000000..46f3031d Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000079.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000080.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000080.jpg new file mode 100644 index 00000000..1b0a5f72 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000080.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000081.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000081.jpg new file mode 100644 index 00000000..efcc161b Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000081.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000082.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000082.jpg new file mode 100644 index 00000000..8a8bdcdb Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000082.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000083.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000083.jpg new file mode 100644 index 00000000..c3fe847f Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000083.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000084.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000084.jpg new file mode 100644 index 00000000..7c46910f Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000084.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000085.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000085.jpg new file mode 100644 index 00000000..de200ca5 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000085.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000086.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000086.jpg new file mode 100644 index 00000000..8a569f25 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000086.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000087.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000087.jpg new file mode 100644 index 00000000..5b2c786d Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000087.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000088.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000088.jpg new file mode 100644 index 00000000..4f53353a Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000088.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000089.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000089.jpg new file mode 100644 index 00000000..56f8544f Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000089.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000090.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000090.jpg new file mode 100644 index 00000000..759972c4 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000090.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000091.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000091.jpg new file mode 100644 index 00000000..a4617215 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000091.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000092.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000092.jpg new file mode 100644 index 00000000..805d759d Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000092.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000093.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000093.jpg new file mode 100644 index 00000000..ac46df2b Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000093.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000094.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000094.jpg new file mode 100644 index 00000000..1b0c739a Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000094.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000095.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000095.jpg new file mode 100644 index 00000000..10fd5b08 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000095.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000096.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000096.jpg new file mode 100644 index 00000000..455a564a Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000096.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000097.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000097.jpg new file mode 100644 index 00000000..a9a81bb8 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000097.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000098.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000098.jpg new file mode 100644 index 00000000..66ecf932 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000098.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000099.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000099.jpg new file mode 100644 index 00000000..c7e8a7c2 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000099.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000100.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000100.jpg new file mode 100644 index 00000000..f943fab2 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000100.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000101.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000101.jpg new file mode 100644 index 00000000..033baac0 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000101.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000102.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000102.jpg new file mode 100644 index 00000000..aa9cf294 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000102.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000103.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000103.jpg new file mode 100644 index 00000000..c61fab0b Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000103.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000104.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000104.jpg new file mode 100644 index 00000000..5cb94ec4 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000104.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000105.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000105.jpg new file mode 100644 index 00000000..7bd1bc08 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000105.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000106.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000106.jpg new file mode 100644 index 00000000..e02a8589 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000106.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000107.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000107.jpg new file mode 100644 index 00000000..f8351c97 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000107.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000108.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000108.jpg new file mode 100644 index 00000000..61b8886a Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000108.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000109.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000109.jpg new file mode 100644 index 00000000..2d9b40d9 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000109.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000110.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000110.jpg new file mode 100644 index 00000000..e6f60c31 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000110.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000111.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000111.jpg new file mode 100644 index 00000000..f0a3403f Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000111.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000112.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000112.jpg new file mode 100644 index 00000000..a5f65fb8 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000112.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000113.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000113.jpg new file mode 100644 index 00000000..e3628640 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000113.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000114.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000114.jpg new file mode 100644 index 00000000..bc3cee6a Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000114.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000115.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000115.jpg new file mode 100644 index 00000000..e83ccfe3 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000115.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000116.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000116.jpg new file mode 100644 index 00000000..084e815f Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000116.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000117.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000117.jpg new file mode 100644 index 00000000..3e4da6a2 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000117.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000118.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000118.jpg new file mode 100644 index 00000000..d62eeed2 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000118.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000119.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000119.jpg new file mode 100644 index 00000000..284cdfeb Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000119.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000120.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000120.jpg new file mode 100644 index 00000000..3772a520 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000120.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000121.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000121.jpg new file mode 100644 index 00000000..610f7de6 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000121.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000122.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000122.jpg new file mode 100644 index 00000000..40b49877 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000122.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000123.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000123.jpg new file mode 100644 index 00000000..b8d87ea6 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000123.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000124.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000124.jpg new file mode 100644 index 00000000..3a47300c Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000124.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000125.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000125.jpg new file mode 100644 index 00000000..0340ff9d Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000125.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000126.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000126.jpg new file mode 100644 index 00000000..74128205 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000126.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000127.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000127.jpg new file mode 100644 index 00000000..ed362400 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000127.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000128.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000128.jpg new file mode 100644 index 00000000..117e73b6 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000128.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000129.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000129.jpg new file mode 100644 index 00000000..47159ccf Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000129.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000130.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000130.jpg new file mode 100644 index 00000000..21b4d208 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000130.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000131.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000131.jpg new file mode 100644 index 00000000..4a7504ef Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000131.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000132.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000132.jpg new file mode 100644 index 00000000..a3b7e44f Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000132.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000133.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000133.jpg new file mode 100644 index 00000000..d9ff4332 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000133.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000134.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000134.jpg new file mode 100644 index 00000000..de76d426 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000134.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000135.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000135.jpg new file mode 100644 index 00000000..45bbc1bf Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000135.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000136.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000136.jpg new file mode 100644 index 00000000..2f8c979d Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000136.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000137.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000137.jpg new file mode 100644 index 00000000..54a9f93b Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000137.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000138.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000138.jpg new file mode 100644 index 00000000..e688b53c Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000138.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000139.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000139.jpg new file mode 100644 index 00000000..959ba80a Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000139.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000140.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000140.jpg new file mode 100644 index 00000000..0f7d9114 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000140.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000141.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000141.jpg new file mode 100644 index 00000000..6e6dfeae Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000141.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000142.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000142.jpg new file mode 100644 index 00000000..2fbef846 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000142.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000143.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000143.jpg new file mode 100644 index 00000000..548c8b8f Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000143.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000144.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000144.jpg new file mode 100644 index 00000000..f5edd8ab Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000144.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000145.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000145.jpg new file mode 100644 index 00000000..96abb5ce Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000145.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000146.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000146.jpg new file mode 100644 index 00000000..20d17cc4 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000146.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000147.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000147.jpg new file mode 100644 index 00000000..9b7972e3 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000147.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000148.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000148.jpg new file mode 100644 index 00000000..06821d25 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000148.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000149.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000149.jpg new file mode 100644 index 00000000..3cde509a Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000149.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000150.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000150.jpg new file mode 100644 index 00000000..8048af36 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000150.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000151.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000151.jpg new file mode 100644 index 00000000..ac180153 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000151.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000152.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000152.jpg new file mode 100644 index 00000000..eea71b77 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000152.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000153.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000153.jpg new file mode 100644 index 00000000..904f20c9 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000153.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000154.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000154.jpg new file mode 100644 index 00000000..ea45314a Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000154.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000155.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000155.jpg new file mode 100644 index 00000000..9af6ba87 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000155.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000156.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000156.jpg new file mode 100644 index 00000000..11211e12 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000156.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000157.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000157.jpg new file mode 100644 index 00000000..21ac292d Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000157.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000158.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000158.jpg new file mode 100644 index 00000000..216ce471 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000158.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000159.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000159.jpg new file mode 100644 index 00000000..bb665aa3 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000159.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000160.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000160.jpg new file mode 100644 index 00000000..eb6e8336 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000160.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000161.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000161.jpg new file mode 100644 index 00000000..ad9185a8 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000161.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000162.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000162.jpg new file mode 100644 index 00000000..ef3af722 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000162.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000163.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000163.jpg new file mode 100644 index 00000000..dc7bd7ad Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000163.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000164.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000164.jpg new file mode 100644 index 00000000..9b5949ae Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000164.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000165.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000165.jpg new file mode 100644 index 00000000..269e26d4 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000165.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000166.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000166.jpg new file mode 100644 index 00000000..8606cef5 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000166.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000167.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000167.jpg new file mode 100644 index 00000000..032073d9 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000167.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000168.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000168.jpg new file mode 100644 index 00000000..dfeb1f39 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000168.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000169.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000169.jpg new file mode 100644 index 00000000..a051683f Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000169.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000170.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000170.jpg new file mode 100644 index 00000000..0eafa058 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000170.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000171.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000171.jpg new file mode 100644 index 00000000..83523034 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000171.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000172.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000172.jpg new file mode 100644 index 00000000..6a6df48c Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000172.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000173.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000173.jpg new file mode 100644 index 00000000..63ef6b89 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000173.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000174.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000174.jpg new file mode 100644 index 00000000..ee05d802 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000174.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000175.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000175.jpg new file mode 100644 index 00000000..bde0354a Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000175.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000176.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000176.jpg new file mode 100644 index 00000000..e7ac8905 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000176.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000177.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000177.jpg new file mode 100644 index 00000000..b5d548b5 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000177.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000178.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000178.jpg new file mode 100644 index 00000000..e0a02ac0 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000178.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000179.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000179.jpg new file mode 100644 index 00000000..1a326a40 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000179.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000180.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000180.jpg new file mode 100644 index 00000000..0211935d Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000180.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000181.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000181.jpg new file mode 100644 index 00000000..81751c8c Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000181.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000182.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000182.jpg new file mode 100644 index 00000000..ebfa40eb Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000182.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000183.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000183.jpg new file mode 100644 index 00000000..15b09fd8 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000183.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000184.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000184.jpg new file mode 100644 index 00000000..55fb7f62 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000184.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000185.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000185.jpg new file mode 100644 index 00000000..94163616 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000185.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000186.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000186.jpg new file mode 100644 index 00000000..733fc7da Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000186.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000187.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000187.jpg new file mode 100644 index 00000000..b7711d6a Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000187.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000188.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000188.jpg new file mode 100644 index 00000000..bda1c73e Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000188.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/000189.jpg b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000189.jpg new file mode 100644 index 00000000..0841df46 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/000189.jpg differ diff --git a/Hallo2/evaluation/LSE/data/work/pytmp/demo/audio.wav b/Hallo2/evaluation/LSE/data/work/pytmp/demo/audio.wav new file mode 100644 index 00000000..3f354ad1 Binary files /dev/null and b/Hallo2/evaluation/LSE/data/work/pytmp/demo/audio.wav differ diff --git a/Hallo2/evaluation/LSE/test.py b/Hallo2/evaluation/LSE/test.py new file mode 100644 index 00000000..cef19f27 --- /dev/null +++ b/Hallo2/evaluation/LSE/test.py @@ -0,0 +1,5 @@ +import ffmpeg + +f='./merge_mp4' + +print(f.endswith('mp4')) \ No newline at end of file diff --git a/Hallo2/evaluation/README.md b/Hallo2/evaluation/README.md new file mode 100644 index 00000000..99944c60 --- /dev/null +++ b/Hallo2/evaluation/README.md @@ -0,0 +1,35 @@ +# 评估代码 + +该文件夹存放了评估代码,由evaluation.py进行评估指标的计算,其相关选项为: +```bash +usage: evaluation.py [-h] --original_video_path ORIGINAL_VIDEO_PATH --generated_video_path GENERATED_VIDEO_PATH --output_dir OUTPUT_DIR [--batch_size BATCH_SIZE] [--niqe NIQE] + [--ssim SSIM] [--psnr PSNR] [--fid FID] [--lse LSE] + +Total + +options: + -h, --help show this help message and exit + --original_video_path ORIGINAL_VIDEO_PATH + original video path + --generated_video_path GENERATED_VIDEO_PATH + generated video path + --output_dir OUTPUT_DIR + output directory + --batch_size BATCH_SIZE + batch size + --niqe NIQE whether to calculate NIQE + --ssim SSIM whether to calculate SSIM + --psnr PSNR whether to calculate PSNR + --fid FID whether to calculate FID + --lse LSE whether to calculate LSE-C & D +``` + +其中,original_video_path、generated_video_path和output_dir分别指定原视频地址、生成视频地址和输出文件夹,其次niqe、ssim、psnr、fid和lse(LSE-C&LSE-D)分别代表对应的指标是否需要进行计算,默认是不进行计算,需要直接指定其值为1,如:`--lse 1`。 + +默认输出文件夹路径为`./output`,在该文件夹下会有一个evaluation.txt文件,该文件记录之前评估结果,结构为: +```txt +input_dir: + NIQE:... + LSE-C:... + ... +``` \ No newline at end of file diff --git a/Hallo2/evaluation/evaluation.py b/Hallo2/evaluation/evaluation.py new file mode 100644 index 00000000..ef0d1fee --- /dev/null +++ b/Hallo2/evaluation/evaluation.py @@ -0,0 +1,212 @@ +import cv2 +import numpy as np +from moviepy.editor import VideoFileClip,AudioFileClip +from skimage.metrics import structural_similarity as ssim +# from skimage.metrics import peak_signal_noise_ratio as psnr +from skimage import measure +import librosa +import torch +from torchvision import models, transforms +from scipy.linalg import sqrtm +import os +import time +import shutil +import argparse + +# NIQE +from niqe_python.main import niqe + +# FID +from FID.FID import calculate_fid + +# LSE +from LSE.calculate_scores_LRS import calculate_scores + +# 1. 计算SSIM指标 +def calculate_ssim_batch(original_images, generated_images): + ssim_scores = [] + for orig_img, gen_img in zip(original_images, generated_images): + # print(orig_img.shape) + orig_img = cv2.cvtColor(orig_img, cv2.COLOR_RGB2GRAY) + gen_img = cv2.cvtColor(gen_img, cv2.COLOR_RGB2GRAY) + ssim_index, _ = ssim(orig_img, gen_img, full=True) + ssim_scores.append(ssim_index) + return np.mean(ssim_scores) + +# 2. 计算PSNR指标 +def calculate_psnr_batch(original_images, generated_images): + psnr_scores = [] + for orig_img, gen_img in zip(original_images, generated_images): + orig_img = cv2.cvtColor(orig_img, cv2.COLOR_RGB2GRAY) + gen_img = cv2.cvtColor(gen_img, cv2.COLOR_RGB2GRAY) + mse = np.mean((orig_img - gen_img) ** 2) + if mse == 0: + psnr = 100 + else: + PIXEL_MAX = 255.0 + psnr = 20 * np.log10(PIXEL_MAX / np.sqrt(mse)) + # psnr_result = psnr(orig_img, gen_img) + psnr_scores.append(psnr) + return np.mean(psnr_scores) +# 3. 计算NIQE指标 +def calculate_niqe_batch(original_images): + niqe_scores = [] + for orig_img in original_images: + orig_img = cv2.cvtColor(orig_img, cv2.COLOR_RGB2GRAY) + niqe_score = niqe(orig_img) # This is a placeholder + niqe_scores.append(niqe_score) + return np.mean(niqe_scores) + + + +def save_frames(video_path, output_dir): + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + os.makedirs(output_dir) + + cap = cv2.VideoCapture(video_path) + frame_idx = 0 + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + frame=cv2.resize(frame, (2048,2048)) + frame_filename = os.path.join(output_dir, f"frame_{frame_idx:04d}.jpg") + cv2.imwrite(frame_filename, frame) + frame_idx += 1 + + cap.release() + + return frame_idx + +def read_frames_batch(frame_dir, batch_size, len, start_idx=0): + frame_files = sorted([f for f in os.listdir(frame_dir) if f.endswith('.jpg')]) + + batch_frames = [] + for idx in range(start_idx, start_idx + batch_size): + if idx < len: + frame_path = os.path.join(frame_dir, frame_files[idx]) + frame = cv2.imread(frame_path) + batch_frames.append(frame) + + return np.array(batch_frames), start_idx + batch_size + +def process_video_in_batches(original_video_path, generated_video_path, opt, batch_size=32): + original_output_dir = 'original_frames' + generated_output_dir = 'generated_frames' + + start_time = time.time() + + original_len = save_frames(original_video_path,original_output_dir) + generated_len = save_frames(generated_video_path,generated_output_dir) + total_len = min(original_len, generated_len) + + print(f"load photo total time: {time.time()-start_time}") + + start_idx=0 + ssim_result = 0 + psnr_result = 0 + niqe_result = 0 + + start_time = time.time() + + while True: + # 读取批次数据 + original_batch, next_idx = read_frames_batch(original_output_dir, batch_size, total_len, start_idx) + generated_batch, _ = read_frames_batch(generated_output_dir, batch_size, total_len, start_idx) + + # print(original_batch.shape) + + if len(original_batch) == 0 or len(generated_batch) == 0: + break + + batch_len = len(original_batch) + + print(f"time: {time.time()-start_time}") + if opt.ssim==True: + # 计算指标(例如,SSIM) + batch_ssim_result = calculate_ssim_batch(original_batch, generated_batch) + print(f"\tSSIM for batch starting at frame {start_idx}: {batch_ssim_result}, time: {time.time()-start_time}") + ssim_result += batch_ssim_result * batch_len + + if opt.psnr==True: + # 计算指标(例如,PSNR) + batch_psnr_result = calculate_psnr_batch(original_batch, generated_batch) + print(f"\tPSNR for batch starting at frame {start_idx}: {batch_psnr_result}, time: {time.time()-start_time}") + psnr_result += batch_psnr_result * batch_len + + if opt.niqe==True: + # 计算指标(例如,NIQE) + batch_niqe_result = calculate_niqe_batch(original_batch) + print(f"\tNIQE for batch starting at frame {start_idx}: {batch_niqe_result}, time: {time.time()-start_time}") + niqe_result += batch_niqe_result * batch_len + + # 更新start_idx,以便读取下一个批次 + start_idx = next_idx + + + if opt.lse==True: + print("LSE:") + # 计算指标(LSE-C & LSE-D) + lse_c,lse_d = calculate_scores(generated_video_path) + + # 确保输出目录存在 + os.makedirs(opt.output_dir, exist_ok=True) + + # 假设要写入的结果 + evaluation_result = f"{opt.original_video_path}:\n" + if opt.ssim==True: + evaluation_result += f"\tSSIM: {ssim_result*1.0/total_len}\n" + if opt.psnr==True: + evaluation_result += f"\tPSNR: {psnr_result*1.0/total_len}\n" + if opt.niqe==True: + evaluation_result += f"\tNIQE: {niqe_result*1.0/total_len}\n" + if opt.lse==True: + evaluation_result += f"\tLSE-C: {lse_c}\n\tLSE-D: {lse_d}\n" + + + # 写入到 evaluation.txt 文件 + evaluation_file_path = os.path.join(opt.output_dir, "evaluation.txt") + + with open(evaluation_file_path, "a") as f: + f.write(evaluation_result) + +# 输入参数 +parser = argparse.ArgumentParser(description = "Total") + +# parser.add_argument('--initial_model', type=str, default="./LSE/data/syncnet_v2.model", help='') +# parser.add_argument('--batch_size', type=int, default='20', help='') +# parser.add_argument('--vshift', type=int, default='15', help='') +# parser.add_argument('--data_input', type=str, default=f"{video_path}", help='') +# parser.add_argument('--data_root', type=str, default="./LSE/data/merge/", help='') +# parser.add_argument('--tmp_dir', type=str, default="./LSE/data/work/pytmp", help='') +# parser.add_argument('--reference', type=str, default="demo", help='') + +parser.add_argument('--original_video_path', type=str, default="./input/processed.mp4", help='original video path', required=True) +parser.add_argument('--generated_video_path', type=str, default="./input/merge_video.mp4", help='generated video path', required=True) +parser.add_argument('--output_dir', type=str, default="./output", help='output directory',required=True) +parser.add_argument('--batch_size', type=int, default='32', help='batch size') +parser.add_argument('--niqe', type=bool, default=False, help='whether to calculate NIQE', required=False) +parser.add_argument('--ssim', type=bool, default=False, help='whether to calculate SSIM', required=False) +parser.add_argument('--psnr', type=bool, default=False, help='whether to calculate PSNR', required=False) +parser.add_argument('--fid', type=bool, default=False, help='whether to calculate FID', required=False) +parser.add_argument('--lse', type=bool, default=False, help='whether to calculate LSE-C & D', required=False) + +opt = parser.parse_args() + +sum = opt.niqe + opt.ssim + opt.psnr + opt.fid + opt.lse +# 至少选择一个指标 +if sum == 0: + print("At least one metric should be selected") + exit(1) + +process_video_in_batches(opt.original_video_path, opt.generated_video_path, opt, opt.batch_size) + +if opt.fid==True: + # 计算指标(ID) + fid_result = calculate_fid() + print(f"\tFID: {fid_result}\n") + evaluation_file_path = os.path.join(opt.output_dir, "evaluation.txt") + with open(evaluation_file_path, "a") as f: + f.write(f"\tFID: {fid_result}\n") diff --git a/Hallo2/evaluation/niqe_python/__pycache__/main.cpython-310.pyc b/Hallo2/evaluation/niqe_python/__pycache__/main.cpython-310.pyc new file mode 100644 index 00000000..d602d5e9 Binary files /dev/null and b/Hallo2/evaluation/niqe_python/__pycache__/main.cpython-310.pyc differ diff --git a/Hallo2/evaluation/niqe_python/data/niqe_image_params.mat b/Hallo2/evaluation/niqe_python/data/niqe_image_params.mat new file mode 100644 index 00000000..53df0998 Binary files /dev/null and b/Hallo2/evaluation/niqe_python/data/niqe_image_params.mat differ diff --git a/Hallo2/evaluation/niqe_python/main.py b/Hallo2/evaluation/niqe_python/main.py new file mode 100644 index 00000000..0cfc4b92 --- /dev/null +++ b/Hallo2/evaluation/niqe_python/main.py @@ -0,0 +1,251 @@ +import numpy as np +import scipy.misc +import scipy.io +from os.path import dirname +from os.path import join +import scipy +from PIL import Image +import numpy as np +import scipy.ndimage +import numpy as np +import scipy.special +import math +import time + +gamma_range = np.arange(0.2, 10, 0.001) +a = scipy.special.gamma(2.0 / gamma_range) +a *= a +b = scipy.special.gamma(1.0 / gamma_range) +c = scipy.special.gamma(3.0 / gamma_range) +prec_gammas = a / (b * c) + + +def aggd_features(imdata): + # flatten imdata + imdata.shape = (len(imdata.flat),) + imdata2 = imdata * imdata + left_data = imdata2[imdata < 0] + right_data = imdata2[imdata >= 0] + left_mean_sqrt = 0 + right_mean_sqrt = 0 + if len(left_data) > 0: + left_mean_sqrt = np.sqrt(np.average(left_data)) + if len(right_data) > 0: + right_mean_sqrt = np.sqrt(np.average(right_data)) + + if right_mean_sqrt != 0: + gamma_hat = left_mean_sqrt / right_mean_sqrt + else: + gamma_hat = np.inf + # solve r-hat norm + + imdata2_mean = np.mean(imdata2) + if imdata2_mean != 0: + r_hat = (np.average(np.abs(imdata)) ** 2) / (np.average(imdata2)) + else: + r_hat = np.inf + rhat_norm = r_hat * (((math.pow(gamma_hat, 3) + 1) * (gamma_hat + 1)) / math.pow(math.pow(gamma_hat, 2) + 1, 2)) + + # solve alpha by guessing values that minimize ro + pos = np.argmin((prec_gammas - rhat_norm) ** 2); + alpha = gamma_range[pos] + + gam1 = scipy.special.gamma(1.0 / alpha) + gam2 = scipy.special.gamma(2.0 / alpha) + gam3 = scipy.special.gamma(3.0 / alpha) + + aggdratio = np.sqrt(gam1) / np.sqrt(gam3) + bl = aggdratio * left_mean_sqrt + br = aggdratio * right_mean_sqrt + + # mean parameter + N = (br - bl) * (gam2 / gam1) # *aggdratio + return (alpha, N, bl, br, left_mean_sqrt, right_mean_sqrt) + + +def ggd_features(imdata): + nr_gam = 1 / prec_gammas + sigma_sq = np.var(imdata) + E = np.mean(np.abs(imdata)) + rho = sigma_sq / E ** 2 + pos = np.argmin(np.abs(nr_gam - rho)) + return gamma_range[pos], sigma_sq + + +def paired_product(new_im): + shift1 = np.roll(new_im.copy(), 1, axis=1) + shift2 = np.roll(new_im.copy(), 1, axis=0) + shift3 = np.roll(np.roll(new_im.copy(), 1, axis=0), 1, axis=1) + shift4 = np.roll(np.roll(new_im.copy(), 1, axis=0), -1, axis=1) + + H_img = shift1 * new_im + V_img = shift2 * new_im + D1_img = shift3 * new_im + D2_img = shift4 * new_im + + return (H_img, V_img, D1_img, D2_img) + + +def gen_gauss_window(lw, sigma): + sd = np.float32(sigma) + lw = int(lw) + weights = [0.0] * (2 * lw + 1) + weights[lw] = 1.0 + sum = 1.0 + sd *= sd + for ii in range(1, lw + 1): + tmp = np.exp(-0.5 * np.float32(ii * ii) / sd) + weights[lw + ii] = tmp + weights[lw - ii] = tmp + sum += 2.0 * tmp + for ii in range(2 * lw + 1): + weights[ii] /= sum + return weights + + +def compute_image_mscn_transform(image, C=1, avg_window=None, extend_mode='constant'): + if avg_window is None: + avg_window = gen_gauss_window(3, 7.0 / 6.0) + assert len(np.shape(image)) == 2 + h, w = np.shape(image) + mu_image = np.zeros((h, w), dtype=np.float32) + var_image = np.zeros((h, w), dtype=np.float32) + image = np.array(image).astype('float32') + scipy.ndimage.correlate1d(image, avg_window, 0, mu_image, mode=extend_mode) + scipy.ndimage.correlate1d(mu_image, avg_window, 1, mu_image, mode=extend_mode) + scipy.ndimage.correlate1d(image ** 2, avg_window, 0, var_image, mode=extend_mode) + scipy.ndimage.correlate1d(var_image, avg_window, 1, var_image, mode=extend_mode) + var_image = np.sqrt(np.abs(var_image - mu_image ** 2)) + return (image - mu_image) / (var_image + C), var_image, mu_image + + +def _niqe_extract_subband_feats(mscncoefs): + # alpha_m, = extract_ggd_features(mscncoefs) + alpha_m, N, bl, br, lsq, rsq = aggd_features(mscncoefs.copy()) + pps1, pps2, pps3, pps4 = paired_product(mscncoefs) + alpha1, N1, bl1, br1, lsq1, rsq1 = aggd_features(pps1) + alpha2, N2, bl2, br2, lsq2, rsq2 = aggd_features(pps2) + alpha3, N3, bl3, br3, lsq3, rsq3 = aggd_features(pps3) + alpha4, N4, bl4, br4, lsq4, rsq4 = aggd_features(pps4) + return np.array([alpha_m, (bl + br) / 2.0, + alpha1, N1, bl1, br1, # (V) + alpha2, N2, bl2, br2, # (H) + alpha3, N3, bl3, bl3, # (D1) + alpha4, N4, bl4, bl4, # (D2) + ]) + + +def get_patches_train_features(img, patch_size, stride=8): + return _get_patches_generic(img, patch_size, 1, stride) + + +def get_patches_test_features(img, patch_size, stride=8): + return _get_patches_generic(img, patch_size, 0, stride) + + +def extract_on_patches(img, patch_size): + h, w = img.shape + patch_size = int(patch_size) + patches = [] + for j in range(0, h - patch_size + 1, patch_size): + for i in range(0, w - patch_size + 1, patch_size): + patch = img[j:j + patch_size, i:i + patch_size] + patches.append(patch) + + patches = np.array(patches) + + patch_features = [] + for p in patches: + patch_features.append(_niqe_extract_subband_feats(p)) + patch_features = np.array(patch_features) + + return patch_features + + +def _get_patches_generic(img, patch_size, is_train, stride): + h, w = np.shape(img) + if h < patch_size or w < patch_size: + print("Input image is too small") + exit(0) + + # ensure that the patch divides evenly into img + hoffset = (h % patch_size) + woffset = (w % patch_size) + + if hoffset > 0: + img = img[:-hoffset, :] + if woffset > 0: + img = img[:, :-woffset] + + img = img.astype(np.float32) + + # 使用Pillow来缩小图像 + img2 = Image.fromarray(img) # 转换为PIL图像 + img2 = img2.resize((img2.width // 2, img2.height // 2), Image.BICUBIC) # 缩小为原来的一半 + img2 = np.array(img2) # 转回为numpy数组 + + # 继续进行MSCN变换和特征提取等后续步骤 + mscn1, var, mu = compute_image_mscn_transform(img) + mscn1 = mscn1.astype(np.float32) + + mscn2, _, _ = compute_image_mscn_transform(img2) + mscn2 = mscn2.astype(np.float32) + + feats_lvl1 = extract_on_patches(mscn1, patch_size) + feats_lvl2 = extract_on_patches(mscn2, patch_size/2) + + feats = np.hstack((feats_lvl1, feats_lvl2)) # feats_lvl3)) + + return feats + + + +def niqe(inputImgData): + patch_size = 96 + module_path = dirname(__file__) + + # TODO: memoize + params = scipy.io.loadmat(join(module_path, 'data', 'niqe_image_params.mat')) + pop_mu = np.ravel(params["pop_mu"]) + pop_cov = params["pop_cov"] + + M, N = inputImgData.shape + + # assert C == 1, "niqe called with videos containing %d channels. Please supply only the luminance channel" % (C,) + assert M > ( + patch_size * 2 + 1), "niqe called with small frame size, requires > 192x192 resolution video using current training parameters" + assert N > ( + patch_size * 2 + 1), "niqe called with small frame size, requires > 192x192 resolution video using current training parameters" + + feats = get_patches_test_features(inputImgData, patch_size) + sample_mu = np.mean(feats, axis=0) + sample_cov = np.cov(feats.T) + + X = sample_mu - pop_mu + covmat = ((pop_cov + sample_cov) / 2.0) + pinvmat = scipy.linalg.pinv(covmat) + niqe_score = np.sqrt(np.dot(np.dot(X, pinvmat), X)) + + return niqe_score + + +if __name__ == "__main__": + ref = np.array(Image.open('./test_imgs/0325.png').convert('LA'))[:, :, 0] # ref + # dis = np.array(Image.open('./test_imgs/bikes_distorted.bmp').convert('LA'))[:, :, 0] # dis + # 记录开始时间 + start_time = time.time() + print('NIQE of ref bikes image is: %0.3f' % niqe(ref)) + # print('NIQE of dis bikes image is: %0.3f' % niqe(dis)) + # + # ref = np.array(Image.open('./test_imgs/parrots.bmp').convert('LA'))[:, :, 0] # ref + # dis = np.array(Image.open('./test_imgs/parrots_distorted.bmp').convert('LA'))[:, :, 0] # dis + # + # print('NIQE of ref parrot image is: %0.3f' % niqe(ref)) + # print('NIQE of dis parrot image is: %0.3f' % niqe(dis)) + # 记录结束时间 + end_time = time.time() + + # 计算所用时间 + elapsed_time = end_time - start_time + + print(f'Time taken to compute NIQE score: {elapsed_time:.3f} seconds') diff --git a/Hallo2/evaluation_help.png b/Hallo2/evaluation_help.png new file mode 100644 index 00000000..289b21d4 Binary files /dev/null and b/Hallo2/evaluation_help.png differ diff --git a/Hallo2/hallo2/.gitignore b/Hallo2/hallo2/.gitignore new file mode 100644 index 00000000..1691a1b3 --- /dev/null +++ b/Hallo2/hallo2/.gitignore @@ -0,0 +1,171 @@ +# running cache +mlruns/ + +# Test directories +test_data/ +pretrained_models/ + +# Poetry project +poetry.lock + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# IDE +.idea/ +.vscode/ +pretrained_models +test_data +output_long +hq_results diff --git a/Hallo2/hallo2/LICENSE b/Hallo2/hallo2/LICENSE new file mode 100644 index 00000000..77e0c563 --- /dev/null +++ b/Hallo2/hallo2/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Fusion Lab: Generative Vision Lab of Fudan University + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Hallo2/hallo2/README.md b/Hallo2/hallo2/README.md new file mode 100644 index 00000000..caf5e645 --- /dev/null +++ b/Hallo2/hallo2/README.md @@ -0,0 +1,131 @@ +# Hallo2配置 + +本文件夹为Hallo2项目源文件夹,需要下载预训练模型参数(https://huggingface.co/fudan-generative-ai/hallo2),验收现场拷贝给助教。放在该文件夹下,结构为: + +```text +./pretrained_models/ +|-- audio_separator/ +| |-- download_checks.json +| |-- mdx_model_data.json +| |-- vr_model_data.json +| `-- Kim_Vocal_2.onnx +|-- CodeFormer/ +| |-- codeformer.pth +| `-- vqgan_code1024.pth +|-- face_analysis/ +| `-- models/ +| |-- face_landmarker_v2_with_blendshapes.task # face landmarker model from mediapipe +| |-- 1k3d68.onnx +| |-- 2d106det.onnx +| |-- genderage.onnx +| |-- glintr100.onnx +| `-- scrfd_10g_bnkps.onnx +|-- facelib +| |-- detection_mobilenet0.25_Final.pth +| |-- detection_Resnet50_Final.pth +| |-- parsing_parsenet.pth +| |-- yolov5l-face.pth +| `-- yolov5n-face.pth +|-- hallo2 +| |-- net_g.pth +| `-- net.pth +|-- motion_module/ +| `-- mm_sd_v15_v2.ckpt +|-- realesrgan +| `-- RealESRGAN_x2plus.pth +|-- sd-vae-ft-mse/ +| |-- config.json +| `-- diffusion_pytorch_model.safetensors +|-- stable-diffusion-v1-5/ +| `-- unet/ +| |-- config.json +| `-- diffusion_pytorch_model.safetensors +`-- wav2vec/ + `-- wav2vec2-base-960h/ + |-- config.json + |-- feature_extractor_config.json + |-- model.safetensors + |-- preprocessor_config.json + |-- special_tokens_map.json + |-- tokenizer_config.json + `-- vocab.json +``` + +所需要环境为Ubuntu22.04/20.04、Cuda11.8 + +通过conda创建hallo项目环境:`conda create -n hallo python=3.10`,并且按照相关库: +```bash +pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118 +pip install -r requirements.txt +``` + +此外,还需要按照ffmpeg:`apt-get ffmpeg` + +该项目共有两个部分,第一部分为长时间动画,第二部分为高分辨率动画: + +- 长时间动画 + +只需要运行`scripts/inference_long.py`,指令为: + +```bash +python scripts/inference_long.py --config ./configs/inference/long.yaml +``` + +其中,`scripts/inference_long.py`有关更多选项: + +```bash +usage: inference_long.py [-h] [-c CONFIG] [--source_image SOURCE_IMAGE] [--driving_audio DRIVING_AUDIO] [--pose_weight POSE_WEIGHT] + [--face_weight FACE_WEIGHT] [--lip_weight LIP_WEIGHT] [--face_expand_ratio FACE_EXPAND_RATIO] + +options: + -h, --help show this help message and exit + -c CONFIG, --config CONFIG + --source_image SOURCE_IMAGE + source image + --driving_audio DRIVING_AUDIO + driving audio + --pose_weight POSE_WEIGHT + weight of pose + --face_weight FACE_WEIGHT + weight of face + --lip_weight LIP_WEIGHT + weight of lip + --face_expand_ratio FACE_EXPAND_RATIO + face region +``` + +- 高分辨率动画 + +只需要运行`scripts/video_sr.py`,指令为: + +```bash +python scripts/video_sr.py --input_path [input_video] --output_path [output_dir] --bg_upsampler realesrgan --face_upsample -w 1 -s 4 +``` + +其中,`scripts/video_sr.py`有关更多选项: + +```bash +usage: video_sr.py [-h] [-i INPUT_PATH] [-o OUTPUT_PATH] [-w FIDELITY_WEIGHT] [-s UPSCALE] [--has_aligned] [--only_center_face] [--draw_box] + [--detection_model DETECTION_MODEL] [--bg_upsampler BG_UPSAMPLER] [--face_upsample] [--bg_tile BG_TILE] [--suffix SUFFIX] + +options: + -h, --help show this help message and exit + -i INPUT_PATH, --input_path INPUT_PATH + Input video + -o OUTPUT_PATH, --output_path OUTPUT_PATH + Output folder. + -w FIDELITY_WEIGHT, --fidelity_weight FIDELITY_WEIGHT + Balance the quality and fidelity. Default: 0.5 + -s UPSCALE, --upscale UPSCALE + The final upsampling scale of the image. Default: 2 + --has_aligned Input are cropped and aligned faces. Default: False + --only_center_face Only restore the center face. Default: False + --draw_box Draw the bounding box for the detected faces. Default: False + --detection_model DETECTION_MODEL + Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n. Default: retinaface_resnet50 + --bg_upsampler BG_UPSAMPLER + Background upsampler. Optional: realesrgan + --face_upsample Face upsampler after enhancement. Default: False + --bg_tile BG_TILE Tile size for background sampler. Default: 400 + --suffix SUFFIX Suffix of the restored faces. Default: None +``` diff --git a/Hallo2/hallo2/accelerate_config.yaml b/Hallo2/hallo2/accelerate_config.yaml new file mode 100644 index 00000000..6fa766f1 --- /dev/null +++ b/Hallo2/hallo2/accelerate_config.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +debug: true +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: "no" +main_training_function: main +mixed_precision: "fp16" +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/Hallo2/hallo2/assets/framework_1.jpg b/Hallo2/hallo2/assets/framework_1.jpg new file mode 100644 index 00000000..9d6eaf32 Binary files /dev/null and b/Hallo2/hallo2/assets/framework_1.jpg differ diff --git a/Hallo2/hallo2/assets/framework_2.jpg b/Hallo2/hallo2/assets/framework_2.jpg new file mode 100644 index 00000000..b56c53b9 Binary files /dev/null and b/Hallo2/hallo2/assets/framework_2.jpg differ diff --git a/Hallo2/hallo2/assets/wechat.jpeg b/Hallo2/hallo2/assets/wechat.jpeg new file mode 100644 index 00000000..f641fd9c Binary files /dev/null and b/Hallo2/hallo2/assets/wechat.jpeg differ diff --git a/Hallo2/hallo2/basicsr/VERSION b/Hallo2/hallo2/basicsr/VERSION new file mode 100644 index 00000000..1892b926 --- /dev/null +++ b/Hallo2/hallo2/basicsr/VERSION @@ -0,0 +1 @@ +1.3.2 diff --git a/Hallo2/hallo2/basicsr/__init__.py b/Hallo2/hallo2/basicsr/__init__.py new file mode 100644 index 00000000..c7ffcccd --- /dev/null +++ b/Hallo2/hallo2/basicsr/__init__.py @@ -0,0 +1,11 @@ +# https://github.com/xinntao/BasicSR +# flake8: noqa +from .archs import * +from .data import * +from .losses import * +from .metrics import * +from .models import * +from .ops import * +from .train import * +from .utils import * +from .version import __gitsha__, __version__ diff --git a/Hallo2/hallo2/basicsr/archs/__init__.py b/Hallo2/hallo2/basicsr/archs/__init__.py new file mode 100644 index 00000000..cfb1e4d7 --- /dev/null +++ b/Hallo2/hallo2/basicsr/archs/__init__.py @@ -0,0 +1,25 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from basicsr.utils import get_root_logger, scandir +from basicsr.utils.registry import ARCH_REGISTRY + +__all__ = ['build_network'] + +# automatically scan and import arch modules for registry +# scan all the files under the 'archs' folder and collect files ending with +# '_arch.py' +arch_folder = osp.dirname(osp.abspath(__file__)) +arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] +# import all the arch modules +_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames] + + +def build_network(opt): + opt = deepcopy(opt) + network_type = opt.pop('type') + net = ARCH_REGISTRY.get(network_type)(**opt) + logger = get_root_logger() + logger.info(f'Network [{net.__class__.__name__}] is created.') + return net diff --git a/Hallo2/hallo2/basicsr/archs/arcface_arch.py b/Hallo2/hallo2/basicsr/archs/arcface_arch.py new file mode 100644 index 00000000..fe5afb7b --- /dev/null +++ b/Hallo2/hallo2/basicsr/archs/arcface_arch.py @@ -0,0 +1,245 @@ +import torch.nn as nn +from basicsr.utils.registry import ARCH_REGISTRY + + +def conv3x3(inplanes, outplanes, stride=1): + """A simple wrapper for 3x3 convolution with padding. + + Args: + inplanes (int): Channel number of inputs. + outplanes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + """ + return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + """Basic residual block used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + """ + expansion = 1 # output channel expansion ratio + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class IRBlock(nn.Module): + """Improved residual block (IR Block) used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. + """ + expansion = 1 # output channel expansion ratio + + def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): + super(IRBlock, self).__init__() + self.bn0 = nn.BatchNorm2d(inplanes) + self.conv1 = conv3x3(inplanes, inplanes) + self.bn1 = nn.BatchNorm2d(inplanes) + self.prelu = nn.PReLU() + self.conv2 = conv3x3(inplanes, planes, stride) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + self.use_se = use_se + if self.use_se: + self.se = SEBlock(planes) + + def forward(self, x): + residual = x + out = self.bn0(x) + out = self.conv1(out) + out = self.bn1(out) + out = self.prelu(out) + + out = self.conv2(out) + out = self.bn2(out) + if self.use_se: + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.prelu(out) + + return out + + +class Bottleneck(nn.Module): + """Bottleneck block used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + """ + expansion = 4 # output channel expansion ratio + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class SEBlock(nn.Module): + """The squeeze-and-excitation block (SEBlock) used in the IRBlock. + + Args: + channel (int): Channel number of inputs. + reduction (int): Channel reduction ration. Default: 16. + """ + + def __init__(self, channel, reduction=16): + super(SEBlock, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel), + nn.Sigmoid()) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +@ARCH_REGISTRY.register() +class ResNetArcFace(nn.Module): + """ArcFace with ResNet architectures. + + Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition. + + Args: + block (str): Block used in the ArcFace architecture. + layers (tuple(int)): Block numbers in each layer. + use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. + """ + + def __init__(self, block, layers, use_se=True): + if block == 'IRBlock': + block = IRBlock + self.inplanes = 64 + self.use_se = use_se + super(ResNetArcFace, self).__init__() + + self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.prelu = nn.PReLU() + self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.bn4 = nn.BatchNorm2d(512) + self.dropout = nn.Dropout() + self.fc5 = nn.Linear(512 * 8 * 8, 512) + self.bn5 = nn.BatchNorm1d(512) + + # initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, num_blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se)) + self.inplanes = planes + for _ in range(1, num_blocks): + layers.append(block(self.inplanes, planes, use_se=self.use_se)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.bn4(x) + x = self.dropout(x) + x = x.view(x.size(0), -1) + x = self.fc5(x) + x = self.bn5(x) + + return x \ No newline at end of file diff --git a/Hallo2/hallo2/basicsr/archs/arch_util.py b/Hallo2/hallo2/basicsr/archs/arch_util.py new file mode 100644 index 00000000..bad45ab3 --- /dev/null +++ b/Hallo2/hallo2/basicsr/archs/arch_util.py @@ -0,0 +1,318 @@ +import collections.abc +import math +import torch +import torchvision +import warnings +from distutils.version import LooseVersion +from itertools import repeat +from torch import nn as nn +from torch.nn import functional as F +from torch.nn import init as init +from torch.nn.modules.batchnorm import _BatchNorm + +from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv +from basicsr.utils import get_root_logger + + +@torch.no_grad() +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) + + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + + Args: + basic_block (nn.module): nn.module class for basic block. + num_basic_block (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + + +class ResidualBlockNoBN(nn.Module): + """Residual block without BN. + + It has a style of: + ---Conv-ReLU-Conv-+- + |________________| + + Args: + num_feat (int): Channel number of intermediate features. + Default: 64. + res_scale (float): Residual scale. Default: 1. + pytorch_init (bool): If set to True, use pytorch default init, + otherwise, use default_init_weights. Default: False. + """ + + def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): + super(ResidualBlockNoBN, self).__init__() + self.res_scale = res_scale + self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.relu = nn.ReLU(inplace=True) + + if not pytorch_init: + default_init_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = self.conv2(self.relu(self.conv1(x))) + return identity + out * self.res_scale + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + Returns: + Tensor: Warped image or feature map. + """ + assert x.size()[-2:] == flow.size()[1:3] + _, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) + + # TODO, what if align_corners=False + return output + + +def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False): + """Resize a flow according to ratio or shape. + + Args: + flow (Tensor): Precomputed flow. shape [N, 2, H, W]. + size_type (str): 'ratio' or 'shape'. + sizes (list[int | float]): the ratio for resizing or the final output + shape. + 1) The order of ratio should be [ratio_h, ratio_w]. For + downsampling, the ratio should be smaller than 1.0 (i.e., ratio + < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., + ratio > 1.0). + 2) The order of output_size should be [out_h, out_w]. + interp_mode (str): The mode of interpolation for resizing. + Default: 'bilinear'. + align_corners (bool): Whether align corners. Default: False. + + Returns: + Tensor: Resized flow. + """ + _, _, flow_h, flow_w = flow.size() + if size_type == 'ratio': + output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) + elif size_type == 'shape': + output_h, output_w = sizes[0], sizes[1] + else: + raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.') + + input_flow = flow.clone() + ratio_h = output_h / flow_h + ratio_w = output_w / flow_w + input_flow[:, 0, :, :] *= ratio_w + input_flow[:, 1, :, :] *= ratio_h + resized_flow = F.interpolate( + input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners) + return resized_flow + + +# TODO: may write a cpp file +def pixel_unshuffle(x, scale): + """ Pixel unshuffle. + + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + assert hh % scale == 0 and hw % scale == 0 + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) + + +class DCNv2Pack(ModulatedDeformConvPack): + """Modulated deformable conv for deformable alignment. + + Different from the official DCNv2Pack, which generates offsets and masks + from the preceding features, this DCNv2Pack takes another different + features to generate offsets and masks. + + Ref: + Delving Deep into Deformable Alignment in Video Super-Resolution. + """ + + def forward(self, x, feat): + out = self.conv_offset(feat) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + + offset_absmean = torch.mean(torch.abs(offset)) + if offset_absmean > 50: + logger = get_root_logger() + logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.') + + if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'): + return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding, + self.dilation, mask) + else: + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, + self.dilation, self.groups, self.deformable_groups) + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' + 'The distribution of values may be incorrect.', + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + low = norm_cdf((a - mean) / std) + up = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [low, up], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * low - 1, 2 * up - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + + The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +# From PyTorch +def _ntuple(n): + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple \ No newline at end of file diff --git a/Hallo2/hallo2/basicsr/archs/codeformer_arch.py b/Hallo2/hallo2/basicsr/archs/codeformer_arch.py new file mode 100644 index 00000000..e74dea82 --- /dev/null +++ b/Hallo2/hallo2/basicsr/archs/codeformer_arch.py @@ -0,0 +1,362 @@ +import math +import numpy as np +import torch +from torch import nn, Tensor +import torch.nn.functional as F +from typing import Optional, List + +from basicsr.archs.vqgan_arch import * +from basicsr.utils import get_root_logger +from basicsr.utils.registry import ARCH_REGISTRY + +from einops import rearrange + +from tqdm import tqdm + +def calc_mean_std(feat, eps=1e-5): + """Calculate mean and std for adaptive_instance_normalization. + + Args: + feat (Tensor): 4D tensor. + eps (float): A small value added to the variance to avoid + divide-by-zero. Default: 1e-5. + """ + size = feat.size() + assert len(size) == 4, 'The input feature should be 4D tensor.' + b, c = size[:2] + feat_var = feat.view(b, c, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(b, c, 1, 1) + feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) + return feat_mean, feat_std + + +def adaptive_instance_normalization(content_feat, style_feat): + """Adaptive instance normalization. + + Adjust the reference features to have the similar color and illuminations + as those in the degradate features. + + Args: + content_feat (Tensor): The reference feature. + style_feat (Tensor): The degradate features. + """ + size = content_feat.size() + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x, mask=None): + if mask is None: + mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + + +class TransformerSALayer(nn.Module): + def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"): + super().__init__() + self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout) + # Implementation of Feedforward model - MLP + self.linear1 = nn.Linear(embed_dim, dim_mlp) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_mlp, embed_dim) + + self.norm1 = nn.LayerNorm(embed_dim) + self.norm2 = nn.LayerNorm(embed_dim) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.temp_norm = nn.LayerNorm(embed_dim) + self.temp_dropout = nn.Dropout(dropout) + self.temp_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout) + + self.temp_ffn_norm = nn.LayerNorm(embed_dim) + self.temp_linear1 = nn.Linear(embed_dim, dim_mlp) + self.temp_ffn_dropout = nn.Dropout(dropout) + self.temp_linear2 = nn.Linear(dim_mlp, embed_dim) + + self.activation = _get_activation_fn(activation) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward(self, tgt, video_length, batch_size, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None + ): + + # self attention + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + + # ffn + tgt2 = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout2(tgt2) + + # tmp attn + tgt = rearrange(tgt, "d (b f) c -> f (b d) c", f=video_length) + tgt2 = self.temp_norm(tgt) + query_pos = rearrange(query_pos, "d (b f) c -> f (b d) c", f=video_length) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.temp_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.temp_dropout(tgt2) + tgt = rearrange(tgt, "f (b d) c -> d (b f) c", b=batch_size) + + # ffn + tgt2 = self.temp_ffn_norm(tgt) + tgt2 = self.temp_linear2(self.temp_ffn_dropout(self.activation(self.temp_linear1(tgt2)))) + tgt = tgt + self.temp_ffn_dropout(tgt2) + + return tgt + +class Fuse_sft_block(nn.Module): + def __init__(self, in_ch, out_ch): + super().__init__() + self.encode_enc = ResBlock(2*in_ch, out_ch) + + self.scale = nn.Sequential( + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) + + self.shift = nn.Sequential( + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) + + def forward(self, enc_feat, dec_feat, w=1): + enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1)) + scale = self.scale(enc_feat) + shift = self.shift(enc_feat) + out = w * (dec_feat * scale + shift) + dec_feat + + return out + + +@ARCH_REGISTRY.register() +class CodeFormer(VQAutoEncoder): + def __init__(self, dim_embd=512, n_head=8, n_layers=9, + codebook_size=1024, latent_size=256, + connect_list=['32', '64', '128', '256'], + fix_modules=['quantize','generator'], vqgan_path=None): + super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) + + if vqgan_path is not None: + m, n = self.load_state_dict( + torch.load(vqgan_path, map_location='cpu')['params_ema']) + + # if fix_modules is not None: + # for module in fix_modules: + # for param in getattr(self, module).parameters(): + # param.requires_grad = False + + self.connect_list = connect_list + self.n_layers = n_layers + self.dim_embd = dim_embd + self.dim_mlp = dim_embd*2 + + self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd)) + self.feat_emb = nn.Linear(256, self.dim_embd) + + # transformer + self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) + for _ in range(self.n_layers)]) + + # logits_predict head + self.idx_pred_layer = nn.Sequential( + nn.LayerNorm(dim_embd), + nn.Linear(dim_embd, codebook_size, bias=False)) + + self.channels = { + '16': 512, + '32': 256, + '64': 256, + '128': 128, + '256': 128, + '512': 64, + } + + # after second residual block for > 16, before attn layer for ==16 + self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18} + # after first residual block for > 16, before attn layer for ==16 + self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21} + + # fuse_convs_dict + self.fuse_convs_dict = nn.ModuleDict() + for f_size in self.connect_list: + in_ch = self.channels[f_size] + self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch) + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def forward(self, x, w=0, detach_16=True, code_only=False, adain=False): + b, f, _, _, _ = x.shape + x = rearrange(x, "b f c h w -> (b f) c h w") + # ################### Encoder ##################### + enc_feat_dict = {} + out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list] + for i, block in enumerate(self.encoder.blocks): + x = block(x) + if i in out_list: + enc_feat_dict[str(x.shape[-1])] = x.clone() + + lq_feat = x + # ################# Transformer ################### + # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat) + pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1) + # BCHW -> BC(HW) -> (HW)BC + feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1)) + query_emb = feat_emb + # Transformer encoder + for layer in self.ft_layers: + query_emb = layer(query_emb, query_pos=pos_emb, video_length=f, batch_size=b) + + # output logits + logits = self.idx_pred_layer(query_emb) # (hw)bn + logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n + + if code_only: # for training stage II + # logits doesn't need softmax before cross_entropy loss + return logits, lq_feat + + soft_one_hot = F.softmax(logits, dim=2) + _, top_idx = torch.topk(soft_one_hot, 1, dim=2) + quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256]) + + + if detach_16: + quant_feat = quant_feat.detach() # for training stage III + if adain: + quant_feat = adaptive_instance_normalization(quant_feat, lq_feat) + + # ################## Generator #################### + x = quant_feat + fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list] + + + for i, block in enumerate(self.generator.blocks): + x = block(x) + if i in fuse_list: # fuse after i-th block + f_size = str(x.shape[-1]) + if w>0: + x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w) + + x = rearrange(x, "(b f) c h w -> b f c h w", f=f) + out = x + # logits doesn't need softmax before cross_entropy loss + return out, logits, lq_feat + + + def inference(self, x, w=0, detach_16=True, adain=False): + with torch.no_grad(): + b, f, _, _, _ = x.shape + x = rearrange(x, "b f c h w -> (b f) c h w") + # ################### Encoder ##################### + enc_feat_dict = {} + out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list] + for i, block in enumerate(self.encoder.blocks): + x = block(x) + if i in out_list: + enc_feat_dict[str(x.shape[-1])] = x.detach().cpu().clone() + + lq_feat = x.detach() + pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1) + # BCHW -> BC(HW) -> (HW)BC + feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1)) + query_emb = feat_emb + # Transformer encoder + for layer in self.ft_layers: + query_emb = layer(query_emb, query_pos=pos_emb, video_length=f, batch_size=b) + + # output logits + logits = self.idx_pred_layer(query_emb) # (hw)bn + logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n + + + soft_one_hot = F.softmax(logits, dim=2) + _, top_idx = torch.topk(soft_one_hot, 1, dim=2) + quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256]) + + + if detach_16: + quant_feat = quant_feat.detach() # for training stage III + if adain: + quant_feat = adaptive_instance_normalization(quant_feat, lq_feat) + + # ################## Generator #################### + x = quant_feat + fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list] + + + for i, block in enumerate(self.generator.blocks): + x = block(x) + if i in fuse_list: # fuse after i-th block + f_size = str(x.shape[-1]) + if w>0: + x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].to(x.device), x, w) + + x = rearrange(x, "(b f) c h w -> b f c h w", f=f) + # logits doesn't need softmax before cross_entropy loss + return x, top_idx + \ No newline at end of file diff --git a/Hallo2/hallo2/basicsr/archs/rrdbnet_arch.py b/Hallo2/hallo2/basicsr/archs/rrdbnet_arch.py new file mode 100644 index 00000000..49a2d6c2 --- /dev/null +++ b/Hallo2/hallo2/basicsr/archs/rrdbnet_arch.py @@ -0,0 +1,119 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +from basicsr.utils.registry import ARCH_REGISTRY +from .arch_util import default_init_weights, make_layer, pixel_unshuffle + + +class ResidualDenseBlock(nn.Module): + """Residual Dense Block. + + Used in RRDB block in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat=64, num_grow_ch=32): + super(ResidualDenseBlock, self).__init__() + self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) + self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + # Emperically, we use 0.2 to scale the residual for better performance + return x5 * 0.2 + x + + +class RRDB(nn.Module): + """Residual in Residual Dense Block. + + Used in RRDB-Net in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat, num_grow_ch=32): + super(RRDB, self).__init__() + self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) + + def forward(self, x): + out = self.rdb1(x) + out = self.rdb2(out) + out = self.rdb3(out) + # Emperically, we use 0.2 to scale the residual for better performance + return out * 0.2 + x + + +@ARCH_REGISTRY.register() +class RRDBNet(nn.Module): + """Networks consisting of Residual in Residual Dense Block, which is used + in ESRGAN. + + ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. + + We extend ESRGAN for scale x2 and scale x1. + Note: This is one option for scale 1, scale 2 in RRDBNet. + We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size + and enlarge the channel size before feeding inputs into the main ESRGAN architecture. + + Args: + num_in_ch (int): Channel number of inputs. + num_out_ch (int): Channel number of outputs. + num_feat (int): Channel number of intermediate features. + Default: 64 + num_block (int): Block number in the trunk network. Defaults: 23 + num_grow_ch (int): Channels for each growth. Default: 32. + """ + + def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): + super(RRDBNet, self).__init__() + self.scale = scale + if scale == 2: + num_in_ch = num_in_ch * 4 + elif scale == 1: + num_in_ch = num_in_ch * 16 + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) + self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + # upsample + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + if self.scale == 2: + feat = pixel_unshuffle(x, scale=2) + elif self.scale == 1: + feat = pixel_unshuffle(x, scale=4) + else: + feat = x + feat = self.conv_first(feat) + body_feat = self.conv_body(self.body(feat)) + feat = feat + body_feat + # upsample + feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) + feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.conv_hr(feat))) + return out \ No newline at end of file diff --git a/Hallo2/hallo2/basicsr/archs/vgg_arch.py b/Hallo2/hallo2/basicsr/archs/vgg_arch.py new file mode 100644 index 00000000..23bb0103 --- /dev/null +++ b/Hallo2/hallo2/basicsr/archs/vgg_arch.py @@ -0,0 +1,161 @@ +import os +import torch +from collections import OrderedDict +from torch import nn as nn +from torchvision.models import vgg as vgg + +from basicsr.utils.registry import ARCH_REGISTRY + +VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' +NAMES = { + 'vgg11': [ + 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', + 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', + 'pool5' + ], + 'vgg13': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', + 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' + ], + 'vgg16': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', + 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', + 'pool5' + ], + 'vgg19': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', + 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', + 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5' + ] +} + + +def insert_bn(names): + """Insert bn layer after each conv. + + Args: + names (list): The list of layer names. + + Returns: + list: The list of layer names with bn layers. + """ + names_bn = [] + for name in names: + names_bn.append(name) + if 'conv' in name: + position = name.replace('conv', '') + names_bn.append('bn' + position) + return names_bn + + +@ARCH_REGISTRY.register() +class VGGFeatureExtractor(nn.Module): + """VGG network for feature extraction. + + In this implementation, we allow users to choose whether use normalization + in the input feature and the type of vgg network. Note that the pretrained + path must fit the vgg type. + + Args: + layer_name_list (list[str]): Forward function returns the corresponding + features according to the layer_name_list. + Example: {'relu1_1', 'relu2_1', 'relu3_1'}. + vgg_type (str): Set the type of vgg network. Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image. Importantly, + the input feature must in the range [0, 1]. Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + requires_grad (bool): If true, the parameters of VGG network will be + optimized. Default: False. + remove_pooling (bool): If true, the max pooling operations in VGG net + will be removed. Default: False. + pooling_stride (int): The stride of max pooling operation. Default: 2. + """ + + def __init__(self, + layer_name_list, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + requires_grad=False, + remove_pooling=False, + pooling_stride=2): + super(VGGFeatureExtractor, self).__init__() + + self.layer_name_list = layer_name_list + self.use_input_norm = use_input_norm + self.range_norm = range_norm + + self.names = NAMES[vgg_type.replace('_bn', '')] + if 'bn' in vgg_type: + self.names = insert_bn(self.names) + + # only borrow layers that will be used to avoid unused params + max_idx = 0 + for v in layer_name_list: + idx = self.names.index(v) + if idx > max_idx: + max_idx = idx + + if os.path.exists(VGG_PRETRAIN_PATH): + vgg_net = getattr(vgg, vgg_type)(pretrained=False) + state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) + vgg_net.load_state_dict(state_dict) + else: + vgg_net = getattr(vgg, vgg_type)(pretrained=True) + + features = vgg_net.features[:max_idx + 1] + + modified_net = OrderedDict() + for k, v in zip(self.names, features): + if 'pool' in k: + # if remove_pooling is true, pooling operation will be removed + if remove_pooling: + continue + else: + # in some cases, we may want to change the default stride + modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) + else: + modified_net[k] = v + + self.vgg_net = nn.Sequential(modified_net) + + if not requires_grad: + self.vgg_net.eval() + for param in self.parameters(): + param.requires_grad = False + else: + self.vgg_net.train() + for param in self.parameters(): + param.requires_grad = True + + if self.use_input_norm: + # the mean is for image with range [0, 1] + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + # the std is for image with range [0, 1] + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + if self.range_norm: + x = (x + 1) / 2 + if self.use_input_norm: + x = (x - self.mean) / self.std + output = {} + + for key, layer in self.vgg_net._modules.items(): + x = layer(x) + if key in self.layer_name_list: + output[key] = x.clone() + + return output diff --git a/Hallo2/hallo2/basicsr/archs/vqgan_arch.py b/Hallo2/hallo2/basicsr/archs/vqgan_arch.py new file mode 100644 index 00000000..3a65de10 --- /dev/null +++ b/Hallo2/hallo2/basicsr/archs/vqgan_arch.py @@ -0,0 +1,435 @@ +''' +VQGAN code, adapted from the original created by the Unleashing Transformers authors: +https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py + +''' +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import copy +from basicsr.utils import get_root_logger +from basicsr.utils.registry import ARCH_REGISTRY + + +def normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +@torch.jit.script +def swish(x): + return x*torch.sigmoid(x) + + +# Define VQVAE classes +class VectorQuantizer(nn.Module): + def __init__(self, codebook_size, emb_dim, beta): + super(VectorQuantizer, self).__init__() + self.codebook_size = codebook_size # number of embeddings + self.emb_dim = emb_dim # dimension of embedding + self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 + self.embedding = nn.Embedding(self.codebook_size, self.emb_dim) + self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.emb_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \ + 2 * torch.matmul(z_flattened, self.embedding.weight.t()) + + mean_distance = torch.mean(d) + # find closest encodings + min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + # min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False) + # [0-1], higher score, higher confidence + # min_encoding_scores = torch.exp(-min_encoding_scores/10) + + min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) + # compute loss for embedding + loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2) + # preserve gradients + z_q = z + (z_q - z).detach() + + # perplexity + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q, loss, { + "perplexity": perplexity, + "min_encodings": min_encodings, + "min_encoding_indices": min_encoding_indices, + "mean_distance": mean_distance + } + + def get_codebook_feat(self, indices, shape): + # input indices: batch*token_num -> (batch*token_num)*1 + # shape: batch, height, width, channel + indices = indices.view(-1,1) + min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices) + min_encodings.scatter_(1, indices, 1) + # get quantized latent vectors + z_q = torch.matmul(min_encodings.float(), self.embedding.weight) + + if shape is not None: # reshape back to match original input shape + z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous() + + return z_q + + +class GumbelQuantizer(nn.Module): + def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0): + super().__init__() + self.codebook_size = codebook_size # number of embeddings + self.emb_dim = emb_dim # dimension of embedding + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits + self.embed = nn.Embedding(codebook_size, emb_dim) + + def forward(self, z): + hard = self.straight_through if self.training else True + + logits = self.proj(z) + + soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard) + + z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) + + # + kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean() + min_encoding_indices = soft_one_hot.argmax(dim=1) + + return z_q, diff, { + "min_encoding_indices": min_encoding_indices + } + + +class Downsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + + return x + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels=None): + super(ResBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.norm1 = normalize(in_channels) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = normalize(out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x_in): + x = x_in + x = self.norm1(x) + x = swish(x) + x = self.conv1(x) + x = self.norm2(x) + x = swish(x) + x = self.conv2(x) + if self.in_channels != self.out_channels: + x_in = self.conv_out(x_in) + + return x + x_in + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h*w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h*w) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c)**(-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h*w) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x+h_ + + +class Encoder(nn.Module): + def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions): + super().__init__() + self.nf = nf + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.attn_resolutions = attn_resolutions + + curr_res = self.resolution + in_ch_mult = (1,)+tuple(ch_mult) + + blocks = [] + # initial convultion + blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1)) + + # residual and downsampling blocks, with attention on smaller res (16x16) + for i in range(self.num_resolutions): + block_in_ch = nf * in_ch_mult[i] + block_out_ch = nf * ch_mult[i] + for _ in range(self.num_res_blocks): + blocks.append(ResBlock(block_in_ch, block_out_ch)) + block_in_ch = block_out_ch + if curr_res in attn_resolutions: + blocks.append(AttnBlock(block_in_ch)) + + if i != self.num_resolutions - 1: + blocks.append(Downsample(block_in_ch)) + curr_res = curr_res // 2 + + # non-local attention block + blocks.append(ResBlock(block_in_ch, block_in_ch)) + blocks.append(AttnBlock(block_in_ch)) + blocks.append(ResBlock(block_in_ch, block_in_ch)) + + # normalise and convert to latent size + blocks.append(normalize(block_in_ch)) + blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)) + self.blocks = nn.ModuleList(blocks) + + def forward(self, x): + for block in self.blocks: + x = block(x) + + return x + + +class Generator(nn.Module): + def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions): + super().__init__() + self.nf = nf + self.ch_mult = ch_mult + self.num_resolutions = len(self.ch_mult) + self.num_res_blocks = res_blocks + self.resolution = img_size + self.attn_resolutions = attn_resolutions + self.in_channels = emb_dim + self.out_channels = 3 + block_in_ch = self.nf * self.ch_mult[-1] + curr_res = self.resolution // 2 ** (self.num_resolutions-1) + + blocks = [] + # initial conv + blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)) + + # non-local attention block + blocks.append(ResBlock(block_in_ch, block_in_ch)) + blocks.append(AttnBlock(block_in_ch)) + blocks.append(ResBlock(block_in_ch, block_in_ch)) + + for i in reversed(range(self.num_resolutions)): + block_out_ch = self.nf * self.ch_mult[i] + + for _ in range(self.num_res_blocks): + blocks.append(ResBlock(block_in_ch, block_out_ch)) + block_in_ch = block_out_ch + + if curr_res in self.attn_resolutions: + blocks.append(AttnBlock(block_in_ch)) + + if i != 0: + blocks.append(Upsample(block_in_ch)) + curr_res = curr_res * 2 + + blocks.append(normalize(block_in_ch)) + blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)) + + self.blocks = nn.ModuleList(blocks) + + + def forward(self, x): + for block in self.blocks: + x = block(x) + + return x + + +@ARCH_REGISTRY.register() +class VQAutoEncoder(nn.Module): + def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256, + beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): + super().__init__() + logger = get_root_logger() + self.in_channels = 3 + self.nf = nf + self.n_blocks = res_blocks + self.codebook_size = codebook_size + self.embed_dim = emb_dim + self.ch_mult = ch_mult + self.resolution = img_size + self.attn_resolutions = attn_resolutions + self.quantizer_type = quantizer + self.encoder = Encoder( + self.in_channels, + self.nf, + self.embed_dim, + self.ch_mult, + self.n_blocks, + self.resolution, + self.attn_resolutions + ) + if self.quantizer_type == "nearest": + self.beta = beta #0.25 + self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta) + elif self.quantizer_type == "gumbel": + self.gumbel_num_hiddens = emb_dim + self.straight_through = gumbel_straight_through + self.kl_weight = gumbel_kl_weight + self.quantize = GumbelQuantizer( + self.codebook_size, + self.embed_dim, + self.gumbel_num_hiddens, + self.straight_through, + self.kl_weight + ) + self.generator = Generator( + self.nf, + self.embed_dim, + self.ch_mult, + self.n_blocks, + self.resolution, + self.attn_resolutions + ) + + if model_path is not None: + chkpt = torch.load(model_path, map_location='cpu') + if 'params_ema' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema']) + logger.info(f'vqgan is loaded from: {model_path} [params_ema]') + elif 'params' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) + logger.info(f'vqgan is loaded from: {model_path} [params]') + else: + raise ValueError(f'Wrong params!') + + + def forward(self, x): + x = self.encoder(x) + quant, codebook_loss, quant_stats = self.quantize(x) + x = self.generator(quant) + return x, codebook_loss, quant_stats + + + +# patch based discriminator +@ARCH_REGISTRY.register() +class VQGANDiscriminator(nn.Module): + def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None): + super().__init__() + + layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)] + ndf_mult = 1 + ndf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + ndf_mult_prev = ndf_mult + ndf_mult = min(2 ** n, 8) + layers += [ + nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(ndf * ndf_mult), + nn.LeakyReLU(0.2, True) + ] + + ndf_mult_prev = ndf_mult + ndf_mult = min(2 ** n_layers, 8) + + layers += [ + nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False), + nn.BatchNorm2d(ndf * ndf_mult), + nn.LeakyReLU(0.2, True) + ] + + layers += [ + nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map + self.main = nn.Sequential(*layers) + + if model_path is not None: + chkpt = torch.load(model_path, map_location='cpu') + if 'params_d' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d']) + elif 'params' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) + else: + raise ValueError(f'Wrong params!') + + def forward(self, x): + return self.main(x) \ No newline at end of file diff --git a/Hallo2/hallo2/basicsr/data/__init__.py b/Hallo2/hallo2/basicsr/data/__init__.py new file mode 100644 index 00000000..c6adb4bb --- /dev/null +++ b/Hallo2/hallo2/basicsr/data/__init__.py @@ -0,0 +1,100 @@ +import importlib +import numpy as np +import random +import torch +import torch.utils.data +from copy import deepcopy +from functools import partial +from os import path as osp + +from basicsr.data.prefetch_dataloader import PrefetchDataLoader +from basicsr.utils import get_root_logger, scandir +from basicsr.utils.dist_util import get_dist_info +from basicsr.utils.registry import DATASET_REGISTRY + +__all__ = ['build_dataset', 'build_dataloader'] + +# automatically scan and import dataset modules for registry +# scan all the files under the data folder with '_dataset' in file names +data_folder = osp.dirname(osp.abspath(__file__)) +dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] +# import all the dataset modules +_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames] + + +def build_dataset(dataset_opt): + """Build dataset from options. + + Args: + dataset_opt (dict): Configuration for dataset. It must constain: + name (str): Dataset name. + type (str): Dataset type. + """ + dataset_opt = deepcopy(dataset_opt) + dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) + logger = get_root_logger() + logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.') + return dataset + + +def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): + """Build dataloader. + + Args: + dataset (torch.utils.data.Dataset): Dataset. + dataset_opt (dict): Dataset options. It contains the following keys: + phase (str): 'train' or 'val'. + num_worker_per_gpu (int): Number of workers for each GPU. + batch_size_per_gpu (int): Training batch size for each GPU. + num_gpu (int): Number of GPUs. Used only in the train phase. + Default: 1. + dist (bool): Whether in distributed training. Used only in the train + phase. Default: False. + sampler (torch.utils.data.sampler): Data sampler. Default: None. + seed (int | None): Seed. Default: None + """ + phase = dataset_opt['phase'] + rank, _ = get_dist_info() + if phase == 'train': + if dist: # distributed training + batch_size = dataset_opt['batch_size_per_gpu'] + num_workers = dataset_opt['num_worker_per_gpu'] + else: # non-distributed training + multiplier = 1 if num_gpu == 0 else num_gpu + batch_size = dataset_opt['batch_size_per_gpu'] * multiplier + num_workers = dataset_opt['num_worker_per_gpu'] * multiplier + dataloader_args = dict( + dataset=dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + sampler=sampler, + drop_last=True) + if sampler is None: + dataloader_args['shuffle'] = True + dataloader_args['worker_init_fn'] = partial( + worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None + elif phase in ['val', 'test']: # validation + dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) + else: + raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.") + + dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) + + prefetch_mode = dataset_opt.get('prefetch_mode') + if prefetch_mode == 'cpu': # CPUPrefetcher + num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) + logger = get_root_logger() + logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}') + return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) + else: + # prefetch_mode=None: Normal dataloader + # prefetch_mode='cuda': dataloader for CUDAPrefetcher + return torch.utils.data.DataLoader(**dataloader_args) + + +def worker_init_fn(worker_id, num_workers, rank, seed): + # Set the worker seed to num_workers * rank + worker_id + seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) diff --git a/Hallo2/hallo2/basicsr/data/data_sampler.py b/Hallo2/hallo2/basicsr/data/data_sampler.py new file mode 100644 index 00000000..575452d9 --- /dev/null +++ b/Hallo2/hallo2/basicsr/data/data_sampler.py @@ -0,0 +1,48 @@ +import math +import torch +from torch.utils.data.sampler import Sampler + + +class EnlargedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + + Modified from torch.utils.data.distributed.DistributedSampler + Support enlarging the dataset for iteration-based training, for saving + time when restart the dataloader after each epoch + + Args: + dataset (torch.utils.data.Dataset): Dataset used for sampling. + num_replicas (int | None): Number of processes participating in + the training. It is usually the world_size. + rank (int | None): Rank of the current process within num_replicas. + ratio (int): Enlarging ratio. Default: 1. + """ + + def __init__(self, dataset, num_replicas, rank, ratio=1): + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(self.total_size, generator=g).tolist() + + dataset_size = len(self.dataset) + indices = [v % dataset_size for v in indices] + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/Hallo2/hallo2/basicsr/data/data_util.py b/Hallo2/hallo2/basicsr/data/data_util.py new file mode 100644 index 00000000..44a71910 --- /dev/null +++ b/Hallo2/hallo2/basicsr/data/data_util.py @@ -0,0 +1,392 @@ +import cv2 +import math +import numpy as np +import torch +from os import path as osp +from PIL import Image, ImageDraw +from torch.nn import functional as F + +from basicsr.data.transforms import mod_crop +from basicsr.utils import img2tensor, scandir + + +def read_img_seq(path, require_mod_crop=False, scale=1): + """Read a sequence of images from a given folder path. + + Args: + path (list[str] | str): List of image paths or image folder path. + require_mod_crop (bool): Require mod crop for each image. + Default: False. + scale (int): Scale factor for mod_crop. Default: 1. + + Returns: + Tensor: size (t, c, h, w), RGB, [0, 1]. + """ + if isinstance(path, list): + img_paths = path + else: + img_paths = sorted(list(scandir(path, full_path=True))) + imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths] + if require_mod_crop: + imgs = [mod_crop(img, scale) for img in imgs] + imgs = img2tensor(imgs, bgr2rgb=True, float32=True) + imgs = torch.stack(imgs, dim=0) + return imgs + + +def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'): + """Generate an index list for reading `num_frames` frames from a sequence + of images. + + Args: + crt_idx (int): Current center index. + max_frame_num (int): Max number of the sequence of images (from 1). + num_frames (int): Reading num_frames frames. + padding (str): Padding mode, one of + 'replicate' | 'reflection' | 'reflection_circle' | 'circle' + Examples: current_idx = 0, num_frames = 5 + The generated frame indices under different padding mode: + replicate: [0, 0, 0, 1, 2] + reflection: [2, 1, 0, 1, 2] + reflection_circle: [4, 3, 0, 1, 2] + circle: [3, 4, 0, 1, 2] + + Returns: + list[int]: A list of indices. + """ + assert num_frames % 2 == 1, 'num_frames should be an odd number.' + assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.' + + max_frame_num = max_frame_num - 1 # start from 0 + num_pad = num_frames // 2 + + indices = [] + for i in range(crt_idx - num_pad, crt_idx + num_pad + 1): + if i < 0: + if padding == 'replicate': + pad_idx = 0 + elif padding == 'reflection': + pad_idx = -i + elif padding == 'reflection_circle': + pad_idx = crt_idx + num_pad - i + else: + pad_idx = num_frames + i + elif i > max_frame_num: + if padding == 'replicate': + pad_idx = max_frame_num + elif padding == 'reflection': + pad_idx = max_frame_num * 2 - i + elif padding == 'reflection_circle': + pad_idx = (crt_idx - num_pad) - (i - max_frame_num) + else: + pad_idx = i - num_frames + else: + pad_idx = i + indices.append(pad_idx) + return indices + + +def paired_paths_from_lmdb(folders, keys): + """Generate paired paths from lmdb files. + + Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is: + + lq.lmdb + ├── data.mdb + ├── lock.mdb + ├── meta_info.txt + + The data.mdb and lock.mdb are standard lmdb files and you can refer to + https://lmdb.readthedocs.io/en/release/ for more details. + + The meta_info.txt is a specified txt file to record the meta information + of our datasets. It will be automatically created when preparing + datasets by our provided dataset tools. + Each line in the txt file records + 1)image name (with extension), + 2)image shape, + 3)compression level, separated by a white space. + Example: `baboon.png (120,125,3) 1` + + We use the image name without extension as the lmdb key. + Note that we use the same key for the corresponding lq and gt images. + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + Note that this key is different from lmdb keys. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}') + input_folder, gt_folder = folders + input_key, gt_key = keys + + if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')): + raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb ' + f'formats. But received {input_key}: {input_folder}; ' + f'{gt_key}: {gt_folder}') + # ensure that the two meta_info files are the same + with open(osp.join(input_folder, 'meta_info.txt')) as fin: + input_lmdb_keys = [line.split('.')[0] for line in fin] + with open(osp.join(gt_folder, 'meta_info.txt')) as fin: + gt_lmdb_keys = [line.split('.')[0] for line in fin] + if set(input_lmdb_keys) != set(gt_lmdb_keys): + raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.') + else: + paths = [] + for lmdb_key in sorted(input_lmdb_keys): + paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)])) + return paths + + +def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl): + """Generate paired paths from an meta information file. + + Each line in the meta information file contains the image names and + image shape (usually for gt), separated by a white space. + + Example of an meta information file: + ``` + 0001_s001.png (480,480,3) + 0001_s002.png (480,480,3) + ``` + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + meta_info_file (str): Path to the meta information file. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Usually the filename_tmpl is + for files in the input folder. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}') + input_folder, gt_folder = folders + input_key, gt_key = keys + + with open(meta_info_file, 'r') as fin: + gt_names = [line.split(' ')[0] for line in fin] + + paths = [] + for gt_name in gt_names: + basename, ext = osp.splitext(osp.basename(gt_name)) + input_name = f'{filename_tmpl.format(basename)}{ext}' + input_path = osp.join(input_folder, input_name) + gt_path = osp.join(gt_folder, gt_name) + paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)])) + return paths + + +def paired_paths_from_folder(folders, keys, filename_tmpl): + """Generate paired paths from folders. + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Usually the filename_tmpl is + for files in the input folder. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}') + input_folder, gt_folder = folders + input_key, gt_key = keys + + input_paths = list(scandir(input_folder)) + gt_paths = list(scandir(gt_folder)) + assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: ' + f'{len(input_paths)}, {len(gt_paths)}.') + paths = [] + for gt_path in gt_paths: + basename, ext = osp.splitext(osp.basename(gt_path)) + input_name = f'{filename_tmpl.format(basename)}{ext}' + input_path = osp.join(input_folder, input_name) + assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.') + gt_path = osp.join(gt_folder, gt_path) + paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)])) + return paths + + +def paths_from_folder(folder): + """Generate paths from folder. + + Args: + folder (str): Folder path. + + Returns: + list[str]: Returned path list. + """ + + paths = list(scandir(folder)) + paths = [osp.join(folder, path) for path in paths] + return paths + + +def paths_from_lmdb(folder): + """Generate paths from lmdb. + + Args: + folder (str): Folder path. + + Returns: + list[str]: Returned path list. + """ + if not folder.endswith('.lmdb'): + raise ValueError(f'Folder {folder}folder should in lmdb format.') + with open(osp.join(folder, 'meta_info.txt')) as fin: + paths = [line.split('.')[0] for line in fin] + return paths + + +def generate_gaussian_kernel(kernel_size=13, sigma=1.6): + """Generate Gaussian kernel used in `duf_downsample`. + + Args: + kernel_size (int): Kernel size. Default: 13. + sigma (float): Sigma of the Gaussian kernel. Default: 1.6. + + Returns: + np.array: The Gaussian kernel. + """ + from scipy.ndimage import filters as filters + kernel = np.zeros((kernel_size, kernel_size)) + # set element at the middle to one, a dirac delta + kernel[kernel_size // 2, kernel_size // 2] = 1 + # gaussian-smooth the dirac, resulting in a gaussian filter + return filters.gaussian_filter(kernel, sigma) + + +def duf_downsample(x, kernel_size=13, scale=4): + """Downsamping with Gaussian kernel used in the DUF official code. + + Args: + x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w). + kernel_size (int): Kernel size. Default: 13. + scale (int): Downsampling factor. Supported scale: (2, 3, 4). + Default: 4. + + Returns: + Tensor: DUF downsampled frames. + """ + assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.' + + squeeze_flag = False + if x.ndim == 4: + squeeze_flag = True + x = x.unsqueeze(0) + b, t, c, h, w = x.size() + x = x.view(-1, 1, h, w) + pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2 + x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect') + + gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale) + gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0) + x = F.conv2d(x, gaussian_filter, stride=scale) + x = x[:, :, 2:-2, 2:-2] + x = x.view(b, t, c, x.size(2), x.size(3)) + if squeeze_flag: + x = x.squeeze(0) + return x + + +def brush_stroke_mask(img, color=(255,255,255)): + min_num_vertex = 8 + max_num_vertex = 28 + mean_angle = 2*math.pi / 5 + angle_range = 2*math.pi / 12 + # training large mask ratio (training setting) + min_width = 30 + max_width = 70 + # very large mask ratio (test setting and refine after 200k) + # min_width = 80 + # max_width = 120 + def generate_mask(H, W, img=None): + average_radius = math.sqrt(H*H+W*W) / 8 + mask = Image.new('RGB', (W, H), 0) + if img is not None: mask = img # Image.fromarray(img) + + for _ in range(np.random.randint(1, 4)): + num_vertex = np.random.randint(min_num_vertex, max_num_vertex) + angle_min = mean_angle - np.random.uniform(0, angle_range) + angle_max = mean_angle + np.random.uniform(0, angle_range) + angles = [] + vertex = [] + for i in range(num_vertex): + if i % 2 == 0: + angles.append(2*math.pi - np.random.uniform(angle_min, angle_max)) + else: + angles.append(np.random.uniform(angle_min, angle_max)) + + h, w = mask.size + vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h)))) + for i in range(num_vertex): + r = np.clip( + np.random.normal(loc=average_radius, scale=average_radius//2), + 0, 2*average_radius) + new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w) + new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h) + vertex.append((int(new_x), int(new_y))) + + draw = ImageDraw.Draw(mask) + width = int(np.random.uniform(min_width, max_width)) + draw.line(vertex, fill=color, width=width) + for v in vertex: + draw.ellipse((v[0] - width//2, + v[1] - width//2, + v[0] + width//2, + v[1] + width//2), + fill=color) + + return mask + + width, height = img.size + mask = generate_mask(height, width, img) + return mask + + +def random_ff_mask(shape, max_angle = 10, max_len = 100, max_width = 70, times = 10): + """Generate a random free form mask with configuration. + Args: + config: Config should have configuration including IMG_SHAPES, + VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH. + Returns: + tuple: (top, left, height, width) + Link: + https://github.com/csqiangwen/DeepFillv2_Pytorch/blob/master/train_dataset.py + """ + height = shape[0] + width = shape[1] + mask = np.zeros((height, width), np.float32) + times = np.random.randint(times-5, times) + for i in range(times): + start_x = np.random.randint(width) + start_y = np.random.randint(height) + for j in range(1 + np.random.randint(5)): + angle = 0.01 + np.random.randint(max_angle) + if i % 2 == 0: + angle = 2 * 3.1415926 - angle + length = 10 + np.random.randint(max_len-20, max_len) + brush_w = 5 + np.random.randint(max_width-30, max_width) + end_x = (start_x + length * np.sin(angle)).astype(np.int32) + end_y = (start_y + length * np.cos(angle)).astype(np.int32) + cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w) + start_x, start_y = end_x, end_y + return mask.astype(np.float32) \ No newline at end of file diff --git a/Hallo2/hallo2/basicsr/data/gaussian_kernels.py b/Hallo2/hallo2/basicsr/data/gaussian_kernels.py new file mode 100644 index 00000000..a7c05a33 --- /dev/null +++ b/Hallo2/hallo2/basicsr/data/gaussian_kernels.py @@ -0,0 +1,690 @@ +import math +import numpy as np +import random +from scipy.ndimage.interpolation import shift +from scipy.stats import multivariate_normal + + +def sigma_matrix2(sig_x, sig_y, theta): + """Calculate the rotated sigma matrix (two dimensional matrix). + Args: + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + Returns: + ndarray: Rotated sigma matrix. + """ + D = np.array([[sig_x**2, 0], [0, sig_y**2]]) + U = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + return np.dot(U, np.dot(D, U.T)) + + +def mesh_grid(kernel_size): + """Generate the mesh grid, centering at zero. + Args: + kernel_size (int): + Returns: + xy (ndarray): with the shape (kernel_size, kernel_size, 2) + xx (ndarray): with the shape (kernel_size, kernel_size) + yy (ndarray): with the shape (kernel_size, kernel_size) + """ + ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.) + xx, yy = np.meshgrid(ax, ax) + xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), + yy.reshape(kernel_size * kernel_size, + 1))).reshape(kernel_size, kernel_size, 2) + return xy, xx, yy + + +def pdf2(sigma_matrix, grid): + """Calculate PDF of the bivariate Gaussian distribution. + Args: + sigma_matrix (ndarray): with the shape (2, 2) + grid (ndarray): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. + Returns: + kernel (ndarrray): un-normalized kernel. + """ + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2)) + return kernel + + +def cdf2(D, grid): + """Calculate the CDF of the standard bivariate Gaussian distribution. + Used in skewed Gaussian distribution. + Args: + D (ndarrasy): skew matrix. + grid (ndarray): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. + Returns: + cdf (ndarray): skewed cdf. + """ + rv = multivariate_normal([0, 0], [[1, 0], [0, 1]]) + grid = np.dot(grid, D) + cdf = rv.cdf(grid) + return cdf + + +def bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid=None): + """Generate a bivariate skew Gaussian kernel. + Described in `A multivariate skew normal distribution`_ by Shi et. al (2004). + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + D (ndarrasy): skew matrix. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + Returns: + kernel (ndarray): normalized kernel. + .. _A multivariate skew normal distribution: + https://www.sciencedirect.com/science/article/pii/S0047259X03001313 + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + pdf = pdf2(sigma_matrix, grid) + cdf = cdf2(D, grid) + kernel = pdf * cdf + kernel = kernel / np.sum(kernel) + return kernel + + +def mass_center_shift(kernel_size, kernel): + """Calculate the shift of the mass center of a kenrel. + Args: + kernel_size (int): + kernel (ndarray): normalized kernel. + Returns: + delta_h (float): + delta_w (float): + """ + ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.) + col_sum, row_sum = np.sum(kernel, axis=0), np.sum(kernel, axis=1) + delta_h = np.dot(row_sum, ax) + delta_w = np.dot(col_sum, ax) + return delta_h, delta_w + + +def bivariate_skew_Gaussian_center(kernel_size, + sig_x, + sig_y, + theta, + D, + grid=None): + """Generate a bivariate skew Gaussian kernel at center. Shift with nearest padding. + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + D (ndarrasy): skew matrix. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + Returns: + kernel (ndarray): centered and normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + kernel = bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid) + delta_h, delta_w = mass_center_shift(kernel_size, kernel) + kernel = shift(kernel, [-delta_h, -delta_w], mode='nearest') + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_anisotropic_Gaussian(kernel_size, + sig_x, + sig_y, + theta, + grid=None): + """Generate a bivariate anisotropic Gaussian kernel. + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + kernel = pdf2(sigma_matrix, grid) + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_isotropic_Gaussian(kernel_size, sig, grid=None): + """Generate a bivariate isotropic Gaussian kernel. + Args: + kernel_size (int): + sig (float): + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + sigma_matrix = np.array([[sig**2, 0], [0, sig**2]]) + kernel = pdf2(sigma_matrix, grid) + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_generalized_Gaussian(kernel_size, + sig_x, + sig_y, + theta, + beta, + grid=None): + """Generate a bivariate generalized Gaussian kernel. + Described in `Parameter Estimation For Multivariate Generalized Gaussian Distributions`_ + by Pascal et. al (2013). + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + beta (float): shape parameter, beta = 1 is the normal distribution. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + Returns: + kernel (ndarray): normalized kernel. + .. _Parameter Estimation For Multivariate Generalized Gaussian Distributions: + https://arxiv.org/abs/1302.6498 + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.exp( + -0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta)) + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_plateau_type1(kernel_size, sig_x, sig_y, theta, beta, grid=None): + """Generate a plateau-like anisotropic kernel. + 1 / (1+x^(beta)) + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + beta (float): shape parameter, beta = 1 is the normal distribution. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.reciprocal( + np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1) + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_plateau_type1_iso(kernel_size, sig, beta, grid=None): + """Generate a plateau-like isotropic kernel. + 1 / (1+x^(beta)) + Args: + kernel_size (int): + sig (float): + beta (float): shape parameter, beta = 1 is the normal distribution. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + sigma_matrix = np.array([[sig**2, 0], [0, sig**2]]) + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.reciprocal( + np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1) + kernel = kernel / np.sum(kernel) + return kernel + + +def random_bivariate_skew_Gaussian_center(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + noise_range=None, + strict=False): + """Randomly generate bivariate skew Gaussian kernels at center. + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + if strict: + sigma_max = np.max([sigma_x, sigma_y]) + sigma_min = np.min([sigma_x, sigma_y]) + sigma_x, sigma_y = sigma_max, sigma_min + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + + sigma_max = np.max([sigma_x, sigma_y]) + thres = 3 / sigma_max + D = [[np.random.uniform(-thres, thres), + np.random.uniform(-thres, thres)], + [np.random.uniform(-thres, thres), + np.random.uniform(-thres, thres)]] + + kernel = bivariate_skew_Gaussian_center(kernel_size, sigma_x, sigma_y, + rotation, D) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform( + noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + if strict: + return kernel, sigma_x, sigma_y, rotation, D + else: + return kernel + + +def random_bivariate_anisotropic_Gaussian(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + noise_range=None, + strict=False): + """Randomly generate bivariate anisotropic Gaussian kernels. + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + if strict: + sigma_max = np.max([sigma_x, sigma_y]) + sigma_min = np.min([sigma_x, sigma_y]) + sigma_x, sigma_y = sigma_max, sigma_min + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + + kernel = bivariate_anisotropic_Gaussian(kernel_size, sigma_x, sigma_y, + rotation) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform( + noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + if strict: + return kernel, sigma_x, sigma_y, rotation + else: + return kernel + + +def random_bivariate_isotropic_Gaussian(kernel_size, + sigma_range, + noise_range=None, + strict=False): + """Randomly generate bivariate isotropic Gaussian kernels. + Args: + kernel_size (int): + sigma_range (tuple): [0.6, 5] + noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.' + sigma = np.random.uniform(sigma_range[0], sigma_range[1]) + + kernel = bivariate_isotropic_Gaussian(kernel_size, sigma) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform( + noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + if strict: + return kernel, sigma + else: + return kernel + + +def random_bivariate_generalized_Gaussian(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + beta_range, + noise_range=None, + strict=False): + """Randomly generate bivariate generalized Gaussian kernels. + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + beta_range (tuple): [0.5, 8] + noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + if strict: + sigma_max = np.max([sigma_x, sigma_y]) + sigma_min = np.min([sigma_x, sigma_y]) + sigma_x, sigma_y = sigma_max, sigma_min + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + if np.random.uniform() < 0.5: + beta = np.random.uniform(beta_range[0], 1) + else: + beta = np.random.uniform(1, beta_range[1]) + + kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, + rotation, beta) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform( + noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + if strict: + return kernel, sigma_x, sigma_y, rotation, beta + else: + return kernel + + +def random_bivariate_plateau_type1(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + beta_range, + noise_range=None, + strict=False): + """Randomly generate bivariate plateau type1 kernels. + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi/2, math.pi/2] + beta_range (tuple): [1, 4] + noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + if strict: + sigma_max = np.max([sigma_x, sigma_y]) + sigma_min = np.min([sigma_x, sigma_y]) + sigma_x, sigma_y = sigma_max, sigma_min + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + if np.random.uniform() < 0.5: + beta = np.random.uniform(beta_range[0], 1) + else: + beta = np.random.uniform(1, beta_range[1]) + + kernel = bivariate_plateau_type1(kernel_size, sigma_x, sigma_y, rotation, + beta) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform( + noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + if strict: + return kernel, sigma_x, sigma_y, rotation, beta + else: + return kernel + + +def random_bivariate_plateau_type1_iso(kernel_size, + sigma_range, + beta_range, + noise_range=None, + strict=False): + """Randomly generate bivariate plateau type1 kernels (iso). + Args: + kernel_size (int): + sigma_range (tuple): [0.6, 5] + beta_range (tuple): [1, 4] + noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.' + sigma = np.random.uniform(sigma_range[0], sigma_range[1]) + beta = np.random.uniform(beta_range[0], beta_range[1]) + + kernel = bivariate_plateau_type1_iso(kernel_size, sigma, beta) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform( + noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + if strict: + return kernel, sigma, beta + else: + return kernel + + +def random_mixed_kernels(kernel_list, + kernel_prob, + kernel_size=21, + sigma_x_range=[0.6, 5], + sigma_y_range=[0.6, 5], + rotation_range=[-math.pi, math.pi], + beta_range=[0.5, 8], + noise_range=None): + """Randomly generate mixed kernels. + Args: + kernel_list (tuple): a list name of kenrel types, + support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso', 'plateau_aniso'] + kernel_prob (tuple): corresponding kernel probability for each kernel type + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + beta_range (tuple): [0.5, 8] + noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None + Returns: + kernel (ndarray): + """ + kernel_type = random.choices(kernel_list, kernel_prob)[0] + if kernel_type == 'iso': + kernel = random_bivariate_isotropic_Gaussian( + kernel_size, sigma_x_range, noise_range=noise_range) + elif kernel_type == 'aniso': + kernel = random_bivariate_anisotropic_Gaussian( + kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + noise_range=noise_range) + elif kernel_type == 'skew': + kernel = random_bivariate_skew_Gaussian_center( + kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + noise_range=noise_range) + elif kernel_type == 'generalized': + kernel = random_bivariate_generalized_Gaussian( + kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + beta_range, + noise_range=noise_range) + elif kernel_type == 'plateau_iso': + kernel = random_bivariate_plateau_type1_iso( + kernel_size, sigma_x_range, beta_range, noise_range=noise_range) + elif kernel_type == 'plateau_aniso': + kernel = random_bivariate_plateau_type1( + kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + beta_range, + noise_range=noise_range) + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform( + noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + return kernel + + +def show_one_kernel(): + import matplotlib.pyplot as plt + kernel_size = 21 + + # bivariate skew Gaussian + D = [[0, 0], [0, 0]] + D = [[3 / 4, 0], [0, 0.5]] + kernel = bivariate_skew_Gaussian_center(kernel_size, 2, 4, -math.pi / 4, D) + # bivariate anisotropic Gaussian + kernel = bivariate_anisotropic_Gaussian(kernel_size, 2, 4, -math.pi / 4) + # bivariate anisotropic Gaussian + kernel = bivariate_isotropic_Gaussian(kernel_size, 1) + # bivariate generalized Gaussian + kernel = bivariate_generalized_Gaussian( + kernel_size, 2, 4, -math.pi / 4, beta=4) + + delta_h, delta_w = mass_center_shift(kernel_size, kernel) + print(delta_h, delta_w) + + fig, axs = plt.subplots(nrows=2, ncols=2) + # axs.set_axis_off() + ax = axs[0][0] + im = ax.matshow(kernel, cmap='jet', origin='upper') + fig.colorbar(im, ax=ax) + + # image + ax = axs[0][1] + kernel_vis = kernel - np.min(kernel) + kernel_vis = kernel_vis / np.max(kernel_vis) * 255. + ax.imshow(kernel_vis, interpolation='nearest') + + _, xx, yy = mesh_grid(kernel_size) + # contour + ax = axs[1][0] + CS = ax.contour(xx, yy, kernel, origin='upper') + ax.clabel(CS, inline=1, fontsize=3) + + # contourf + ax = axs[1][1] + kernel = kernel / np.max(kernel) + p = ax.contourf( + xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10)) + fig.colorbar(p) + + plt.show() + + +# def show_plateau_kernel(): +# import matplotlib.pyplot as plt +# kernel_size = 21 + +# kernel = plateau_type1(kernel_size, 2, 4, -math.pi / 8, 2, grid=None) +# kernel_norm = bivariate_isotropic_Gaussian(kernel_size, 5) +# kernel_gau = bivariate_generalized_Gaussian( +# kernel_size, 2, 4, -math.pi / 8, 2, grid=None) +# delta_h, delta_w = mass_center_shift(kernel_size, kernel) +# print(delta_h, delta_w) + + # kernel_slice = kernel[10, :] + # kernel_gau_slice = kernel_gau[10, :] + # kernel_norm_slice = kernel_norm[10, :] + # fig, ax = plt.subplots() + # t = list(range(1, 22)) + + # ax.plot(t, kernel_gau_slice) + # ax.plot(t, kernel_slice) + # ax.plot(t, kernel_norm_slice) + + # t = np.arange(0, 10, 0.1) + # y = np.exp(-0.5 * t) + # y2 = np.reciprocal(1 + t) + # print(t.shape) + # print(y.shape) + # ax.plot(t, y) + # ax.plot(t, y2) + # plt.show() + + # fig, axs = plt.subplots(nrows=2, ncols=2) + # # axs.set_axis_off() + # ax = axs[0][0] + # im = ax.matshow(kernel, cmap='jet', origin='upper') + # fig.colorbar(im, ax=ax) + + # # image + # ax = axs[0][1] + # kernel_vis = kernel - np.min(kernel) + # kernel_vis = kernel_vis / np.max(kernel_vis) * 255. + # ax.imshow(kernel_vis, interpolation='nearest') + + # _, xx, yy = mesh_grid(kernel_size) + # # contour + # ax = axs[1][0] + # CS = ax.contour(xx, yy, kernel, origin='upper') + # ax.clabel(CS, inline=1, fontsize=3) + + # # contourf + # ax = axs[1][1] + # kernel = kernel / np.max(kernel) + # p = ax.contourf( + # xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10)) + # fig.colorbar(p) + + # plt.show() diff --git a/Hallo2/hallo2/basicsr/data/prefetch_dataloader.py b/Hallo2/hallo2/basicsr/data/prefetch_dataloader.py new file mode 100644 index 00000000..50884250 --- /dev/null +++ b/Hallo2/hallo2/basicsr/data/prefetch_dataloader.py @@ -0,0 +1,125 @@ +import queue as Queue +import threading +import torch +from torch.utils.data import DataLoader + + +class PrefetchGenerator(threading.Thread): + """A general prefetch generator. + + Ref: + https://stackoverflow.com/questions/7323664/python-generator-pre-fetch + + Args: + generator: Python generator. + num_prefetch_queue (int): Number of prefetch queue. + """ + + def __init__(self, generator, num_prefetch_queue): + threading.Thread.__init__(self) + self.queue = Queue.Queue(num_prefetch_queue) + self.generator = generator + self.daemon = True + self.start() + + def run(self): + for item in self.generator: + self.queue.put(item) + self.queue.put(None) + + def __next__(self): + next_item = self.queue.get() + if next_item is None: + raise StopIteration + return next_item + + def __iter__(self): + return self + + +class PrefetchDataLoader(DataLoader): + """Prefetch version of dataloader. + + Ref: + https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# + + TODO: + Need to test on single gpu and ddp (multi-gpu). There is a known issue in + ddp. + + Args: + num_prefetch_queue (int): Number of prefetch queue. + kwargs (dict): Other arguments for dataloader. + """ + + def __init__(self, num_prefetch_queue, **kwargs): + self.num_prefetch_queue = num_prefetch_queue + super(PrefetchDataLoader, self).__init__(**kwargs) + + def __iter__(self): + return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) + + +class CPUPrefetcher(): + """CPU prefetcher. + + Args: + loader: Dataloader. + """ + + def __init__(self, loader): + self.ori_loader = loader + self.loader = iter(loader) + + def next(self): + try: + return next(self.loader) + except StopIteration: + return None + + def reset(self): + self.loader = iter(self.ori_loader) + + +class CUDAPrefetcher(): + """CUDA prefetcher. + + Ref: + https://github.com/NVIDIA/apex/issues/304# + + It may consums more GPU memory. + + Args: + loader: Dataloader. + opt (dict): Options. + """ + + def __init__(self, loader, opt): + self.ori_loader = loader + self.loader = iter(loader) + self.opt = opt + self.stream = torch.cuda.Stream() + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.preload() + + def preload(self): + try: + self.batch = next(self.loader) # self.batch is a dict + except StopIteration: + self.batch = None + return None + # put tensors to gpu + with torch.cuda.stream(self.stream): + for k, v in self.batch.items(): + if torch.is_tensor(v): + self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) + + def next(self): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + self.preload() + return batch + + def reset(self): + self.loader = iter(self.ori_loader) + self.preload() diff --git a/Hallo2/hallo2/basicsr/data/transforms.py b/Hallo2/hallo2/basicsr/data/transforms.py new file mode 100644 index 00000000..aead9dc7 --- /dev/null +++ b/Hallo2/hallo2/basicsr/data/transforms.py @@ -0,0 +1,165 @@ +import cv2 +import random + + +def mod_crop(img, scale): + """Mod crop images, used during testing. + + Args: + img (ndarray): Input image. + scale (int): Scale factor. + + Returns: + ndarray: Result image. + """ + img = img.copy() + if img.ndim in (2, 3): + h, w = img.shape[0], img.shape[1] + h_remainder, w_remainder = h % scale, w % scale + img = img[:h - h_remainder, :w - w_remainder, ...] + else: + raise ValueError(f'Wrong img ndim: {img.ndim}.') + return img + + +def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path): + """Paired random crop. + + It crops lists of lq and gt images with corresponding locations. + + Args: + img_gts (list[ndarray] | ndarray): GT images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + img_lqs (list[ndarray] | ndarray): LQ images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + gt_patch_size (int): GT patch size. + scale (int): Scale factor. + gt_path (str): Path to ground-truth. + + Returns: + list[ndarray] | ndarray: GT images and LQ images. If returned results + only have one element, just return ndarray. + """ + + if not isinstance(img_gts, list): + img_gts = [img_gts] + if not isinstance(img_lqs, list): + img_lqs = [img_lqs] + + h_lq, w_lq, _ = img_lqs[0].shape + h_gt, w_gt, _ = img_gts[0].shape + lq_patch_size = gt_patch_size // scale + + if h_gt != h_lq * scale or w_gt != w_lq * scale: + raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', + f'multiplication of LQ ({h_lq}, {w_lq}).') + if h_lq < lq_patch_size or w_lq < lq_patch_size: + raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' + f'({lq_patch_size}, {lq_patch_size}). ' + f'Please remove {gt_path}.') + + # randomly choose top and left coordinates for lq patch + top = random.randint(0, h_lq - lq_patch_size) + left = random.randint(0, w_lq - lq_patch_size) + + # crop lq patch + img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs] + + # crop corresponding gt patch + top_gt, left_gt = int(top * scale), int(left * scale) + img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts] + if len(img_gts) == 1: + img_gts = img_gts[0] + if len(img_lqs) == 1: + img_lqs = img_lqs[0] + return img_gts, img_lqs + + +def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): + """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). + + We use vertical flip and transpose for rotation implementation. + All the images in the list use the same augmentation. + + Args: + imgs (list[ndarray] | ndarray): Images to be augmented. If the input + is an ndarray, it will be transformed to a list. + hflip (bool): Horizontal flip. Default: True. + rotation (bool): Ratotation. Default: True. + flows (list[ndarray]: Flows to be augmented. If the input is an + ndarray, it will be transformed to a list. + Dimension is (h, w, 2). Default: None. + return_status (bool): Return the status of flip and rotation. + Default: False. + + Returns: + list[ndarray] | ndarray: Augmented images and flows. If returned + results only have one element, just return ndarray. + + """ + hflip = hflip and random.random() < 0.5 + vflip = rotation and random.random() < 0.5 + rot90 = rotation and random.random() < 0.5 + + def _augment(img): + if hflip: # horizontal + cv2.flip(img, 1, img) + if vflip: # vertical + cv2.flip(img, 0, img) + if rot90: + img = img.transpose(1, 0, 2) + return img + + def _augment_flow(flow): + if hflip: # horizontal + cv2.flip(flow, 1, flow) + flow[:, :, 0] *= -1 + if vflip: # vertical + cv2.flip(flow, 0, flow) + flow[:, :, 1] *= -1 + if rot90: + flow = flow.transpose(1, 0, 2) + flow = flow[:, :, [1, 0]] + return flow + + if not isinstance(imgs, list): + imgs = [imgs] + imgs = [_augment(img) for img in imgs] + if len(imgs) == 1: + imgs = imgs[0] + + if flows is not None: + if not isinstance(flows, list): + flows = [flows] + flows = [_augment_flow(flow) for flow in flows] + if len(flows) == 1: + flows = flows[0] + return imgs, flows + else: + if return_status: + return imgs, (hflip, vflip, rot90) + else: + return imgs + + +def img_rotate(img, angle, center=None, scale=1.0): + """Rotate image. + + Args: + img (ndarray): Image to be rotated. + angle (float): Rotation angle in degrees. Positive values mean + counter-clockwise rotation. + center (tuple[int]): Rotation center. If the center is None, + initialize it as the center of the image. Default: None. + scale (float): Isotropic scale factor. Default: 1.0. + """ + (h, w) = img.shape[:2] + + if center is None: + center = (w // 2, h // 2) + + matrix = cv2.getRotationMatrix2D(center, angle, scale) + rotated_img = cv2.warpAffine(img, matrix, (w, h)) + return rotated_img diff --git a/Hallo2/hallo2/basicsr/data/vfhq_dataset.py b/Hallo2/hallo2/basicsr/data/vfhq_dataset.py new file mode 100644 index 00000000..eadbdeb5 --- /dev/null +++ b/Hallo2/hallo2/basicsr/data/vfhq_dataset.py @@ -0,0 +1,298 @@ +import cv2 +import math +import random +import numpy as np +import os.path as osp +from scipy.io import loadmat +from PIL import Image +import torch +import torch.utils.data as data +from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, + adjust_hue, adjust_saturation, normalize) +from basicsr.data import gaussian_kernels as gaussian_kernels +from basicsr.data.transforms import augment +from basicsr.data.data_util import paths_from_folder, brush_stroke_mask, random_ff_mask +from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor +from basicsr.utils.registry import DATASET_REGISTRY + +from pathlib import Path +import torchvision.transforms as transforms + +@DATASET_REGISTRY.register() +class VFHQBlindDataset(data.Dataset): + + def __init__(self, opt): + super(VFHQBlindDataset, self).__init__() + logger = get_root_logger() + self.opt = opt + # file client (io backend) + self.file_client = None + self.io_backend_opt = opt['io_backend'] + self.video_length = opt['video_length'] + + self.gt_folder = opt['dataroot_gt'] + self.gt_size = opt.get('gt_size', 512) + self.in_size = opt.get('in_size', 512) + assert self.gt_size >= self.in_size, 'Wrong setting.' + + self.mean = opt.get('mean', [0.5, 0.5, 0.5]) + self.std = opt.get('std', [0.5, 0.5, 0.5]) + + self.component_path = opt.get('component_path', None) + self.latent_gt_path = opt.get('latent_gt_path', None) + + if self.component_path is not None: + self.crop_components = True + self.components_dict = torch.load(self.component_path) + self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4) + self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1) + self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3) + else: + self.crop_components = False + + if self.latent_gt_path is not None: + self.load_latent_gt = True + self.latent_gt_dict = torch.load(self.latent_gt_path) + else: + self.load_latent_gt = False + + if self.io_backend_opt['type'] == 'lmdb': + self.io_backend_opt['db_paths'] = self.gt_folder + if not self.gt_folder.endswith('.lmdb'): + raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}') + with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: + self.paths = [line.split('.')[0] for line in fin] + else: + gt_folder = Path(self.gt_folder) + sub_dir = gt_folder.iterdir() + self.paths = [] + for p in sub_dir: + if p.is_dir(): + l = list(p.glob('*.png')) + if len(l) > self.video_length: + self.paths.append(str(p)) + + + # inpainting mask + self.gen_inpaint_mask = opt.get('gen_inpaint_mask', False) + if self.gen_inpaint_mask: + logger.info(f'generate mask ...') + + + # perform corrupt + self.use_corrupt = opt.get('use_corrupt', True) + self.use_motion_kernel = False + + if self.use_motion_kernel: + self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001) + motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth') + self.motion_kernels = torch.load(motion_kernel_path) + + if self.use_corrupt and not self.gen_inpaint_mask: + # degradation configurations + self.blur_kernel_size = opt['blur_kernel_size'] + self.blur_sigma = opt['blur_sigma'] + self.kernel_list = opt['kernel_list'] + self.kernel_prob = opt['kernel_prob'] + self.downsample_range = opt['downsample_range'] + self.noise_range = opt['noise_range'] + self.jpeg_range = opt['jpeg_range'] + # print + logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]') + logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]') + logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]') + logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]') + + # color jitter + self.color_jitter_prob = opt.get('color_jitter_prob', None) + self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None) + self.color_jitter_shift = opt.get('color_jitter_shift', 20) + if self.color_jitter_prob is not None: + logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}') + + # to gray + self.gray_prob = opt.get('gray_prob', 0.0) + if self.gray_prob is not None: + logger.info(f'Use random gray. Prob: {self.gray_prob}') + self.color_jitter_shift /= 255. + + + @staticmethod + def color_jitter(img, shift): + """jitter color: randomly jitter the RGB values, in numpy formats""" + jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32) + img = img + jitter_val + img = np.clip(img, 0, 1) + return img + + @staticmethod + def color_jitter_pt(img, brightness, contrast, saturation, hue): + """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats""" + fn_idx = torch.randperm(4) + for fn_id in fn_idx: + if fn_id == 0 and brightness is not None: + brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() + img = adjust_brightness(img, brightness_factor) + + if fn_id == 1 and contrast is not None: + contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() + img = adjust_contrast(img, contrast_factor) + + if fn_id == 2 and saturation is not None: + saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() + img = adjust_saturation(img, saturation_factor) + + if fn_id == 3 and hue is not None: + hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() + img = adjust_hue(img, hue_factor) + return img + + + def get_component_locations(self, name, status): + components_bbox = self.components_dict[name] + if status[0]: # hflip + # exchange right and left eye + tmp = components_bbox['left_eye'] + components_bbox['left_eye'] = components_bbox['right_eye'] + components_bbox['right_eye'] = tmp + # modify the width coordinate + components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0] + components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0] + components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0] + components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0] + + locations_gt = {} + locations_in = {} + for part in ['left_eye', 'right_eye', 'nose', 'mouth']: + mean = components_bbox[part][0:2] + half_len = components_bbox[part][2] + if 'eye' in part: + half_len *= self.eye_enlarge_ratio + elif part == 'nose': + half_len *= self.nose_enlarge_ratio + elif part == 'mouth': + half_len *= self.mouth_enlarge_ratio + loc = np.hstack((mean - half_len + 1, mean + half_len)) + loc = torch.from_numpy(loc).float() + locations_gt[part] = loc + loc_in = loc/(self.gt_size//self.in_size) + locations_in[part] = loc_in + return locations_gt, locations_in + + + def __getitem__(self, index): + if self.file_client is None: + self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) + + # load gt image + gt_path = self.paths[index] + + image_list = list(Path(gt_path).glob('*.png')) + lenght = len(image_list) + + start_idx = random.randint(0, lenght-self.video_length-1) + in_list = [] + gt_list = [] + + for i in range(start_idx, start_idx+self.video_length): + gt_path_idx = image_list[i] + + img_bytes = self.file_client.get(gt_path_idx) + img_gt = imfrombytes(img_bytes, float32=True) + + # random horizontal flip + img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True) + + + # generate in image + img_in = img_gt + if self.use_corrupt and not self.gen_inpaint_mask: + # motion blur + if self.use_motion_kernel and random.random() < self.motion_kernel_prob: + m_i = random.randint(0,31) + k = self.motion_kernels[f'{m_i:02d}'] + img_in = cv2.filter2D(img_in,-1,k) + + # gaussian blur + kernel = gaussian_kernels.random_mixed_kernels( + self.kernel_list, + self.kernel_prob, + self.blur_kernel_size, + self.blur_sigma, + self.blur_sigma, + [-math.pi, math.pi], + noise_range=None) + img_in = cv2.filter2D(img_in, -1, kernel) + + # downsample + scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) + img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR) + + # noise + if self.noise_range is not None: + noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.) + noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma + img_in = img_in + noise + img_in = np.clip(img_in, 0, 1) + + # jpeg + if self.jpeg_range is not None: + jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1]) + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)] + _, encimg = cv2.imencode('.jpg', img_in * 255., encode_param) + img_in = np.float32(cv2.imdecode(encimg, 1)) / 255. + + # resize to in_size + img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR) + + + if self.gen_inpaint_mask: + img_in = (img_in*255).astype('uint8') + img_in = brush_stroke_mask(Image.fromarray(img_in)) + img_in = np.array(img_in) / 255. + + # random color jitter (only for lq) + if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob): + img_in = self.color_jitter(img_in, self.color_jitter_shift) + # random to gray (only for lq) + if self.gray_prob and np.random.uniform() < self.gray_prob: + img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY) + img_in = np.tile(img_in[:, :, None], [1, 1, 3]) + + # BGR to RGB, HWC to CHW, numpy to tensor + img_in, img_gt = img2tensor([img_in, img_gt], bgr2rgb=True, float32=True) + + # random color jitter (pytorch version) (only for lq) + if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob): + brightness = self.opt.get('brightness', (0.5, 1.5)) + contrast = self.opt.get('contrast', (0.5, 1.5)) + saturation = self.opt.get('saturation', (0, 1.5)) + hue = self.opt.get('hue', (-0.1, 0.1)) + img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue) + + # round and clip + img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255. + + # Set vgg range_norm=True if use the normalization here + # normalize + normalize(img_in, self.mean, self.std, inplace=True) + normalize(img_gt, self.mean, self.std, inplace=True) + + img_in = img_in.unsqueeze(0) + img_gt = img_gt.unsqueeze(0) + + in_list.append(img_in) + gt_list.append(img_gt) + + in_video = torch.cat(in_list, dim=0) + gt_video = torch.cat(gt_list, dim=0) + + return_dict = {'in': in_video, 'gt': gt_video} + + return return_dict + + + def __len__(self): + return len(self.paths) + + \ No newline at end of file diff --git a/Hallo2/hallo2/basicsr/losses/__init__.py b/Hallo2/hallo2/basicsr/losses/__init__.py new file mode 100644 index 00000000..2b184e74 --- /dev/null +++ b/Hallo2/hallo2/basicsr/losses/__init__.py @@ -0,0 +1,26 @@ +from copy import deepcopy + +from basicsr.utils import get_root_logger +from basicsr.utils.registry import LOSS_REGISTRY +from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize, + gradient_penalty_loss, r1_penalty) + +__all__ = [ + 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss', + 'r1_penalty', 'g_path_regularize' +] + + +def build_loss(opt): + """Build loss from options. + + Args: + opt (dict): Configuration. It must constain: + type (str): Model type. + """ + opt = deepcopy(opt) + loss_type = opt.pop('type') + loss = LOSS_REGISTRY.get(loss_type)(**opt) + logger = get_root_logger() + logger.info(f'Loss [{loss.__class__.__name__}] is created.') + return loss diff --git a/Hallo2/hallo2/basicsr/losses/loss_util.py b/Hallo2/hallo2/basicsr/losses/loss_util.py new file mode 100644 index 00000000..744eeb46 --- /dev/null +++ b/Hallo2/hallo2/basicsr/losses/loss_util.py @@ -0,0 +1,95 @@ +import functools +from torch.nn import functional as F + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are 'none', 'mean' and 'sum'. + + Returns: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + else: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean'): + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. Default: None. + reduction (str): Same as built-in losses of PyTorch. Options are + 'none', 'mean' and 'sum'. Default: 'mean'. + + Returns: + Tensor: Loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + assert weight.dim() == loss.dim() + assert weight.size(1) == 1 or weight.size(1) == loss.size(1) + loss = loss * weight + + # if weight is not specified or reduction is sum, just reduce the loss + if weight is None or reduction == 'sum': + loss = reduce_loss(loss, reduction) + # if reduction is mean, then compute mean over weight region + elif reduction == 'mean': + if weight.size(1) > 1: + weight = weight.sum() + else: + weight = weight.sum() * loss.size(1) + loss = loss.sum() / weight + + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.5000) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, reduction='sum') + tensor(3.) + """ + + @functools.wraps(loss_func) + def wrapper(pred, target, weight=None, reduction='mean', **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction) + return loss + + return wrapper diff --git a/Hallo2/hallo2/basicsr/losses/losses.py b/Hallo2/hallo2/basicsr/losses/losses.py new file mode 100644 index 00000000..1bcf272c --- /dev/null +++ b/Hallo2/hallo2/basicsr/losses/losses.py @@ -0,0 +1,455 @@ +import math +import lpips +import torch +from torch import autograd as autograd +from torch import nn as nn +from torch.nn import functional as F + +from basicsr.archs.vgg_arch import VGGFeatureExtractor +from basicsr.utils.registry import LOSS_REGISTRY +from .loss_util import weighted_loss + +_reduction_modes = ['none', 'mean', 'sum'] + + +@weighted_loss +def l1_loss(pred, target): + return F.l1_loss(pred, target, reduction='none') + + +@weighted_loss +def mse_loss(pred, target): + return F.mse_loss(pred, target, reduction='none') + + +@weighted_loss +def charbonnier_loss(pred, target, eps=1e-12): + return torch.sqrt((pred - target)**2 + eps) + + +@LOSS_REGISTRY.register() +class L1Loss(nn.Module): + """L1 (mean absolute error, MAE) loss. + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(L1Loss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise + weights. Default: None. + """ + return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class MSELoss(nn.Module): + """MSE (L2) loss. + + Args: + loss_weight (float): Loss weight for MSE loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(MSELoss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise + weights. Default: None. + """ + return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class CharbonnierLoss(nn.Module): + """Charbonnier loss (one variant of Robust L1Loss, a differentiable + variant of L1Loss). + + Described in "Deep Laplacian Pyramid Networks for Fast and Accurate + Super-Resolution". + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + eps (float): A value used to control the curvature near zero. + Default: 1e-12. + """ + + def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12): + super(CharbonnierLoss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + self.eps = eps + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise + weights. Default: None. + """ + return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class WeightedTVLoss(L1Loss): + """Weighted TV loss. + + Args: + loss_weight (float): Loss weight. Default: 1.0. + """ + + def __init__(self, loss_weight=1.0): + super(WeightedTVLoss, self).__init__(loss_weight=loss_weight) + + def forward(self, pred, weight=None): + y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :]) + x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1]) + + loss = x_diff + y_diff + + return loss + + +@LOSS_REGISTRY.register() +class PerceptualLoss(nn.Module): + """Perceptual loss with commonly used style loss. + + Args: + layer_weights (dict): The weight for each layer of vgg feature. + Here is an example: {'conv5_4': 1.}, which means the conv5_4 + feature layer (before relu5_4) will be extracted with weight + 1.0 in calculting losses. + vgg_type (str): The type of vgg network used as feature extractor. + Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image in vgg. + Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + perceptual_weight (float): If `perceptual_weight > 0`, the perceptual + loss will be calculated and the loss will multiplied by the + weight. Default: 1.0. + style_weight (float): If `style_weight > 0`, the style loss will be + calculated and the loss will multiplied by the weight. + Default: 0. + criterion (str): Criterion used for perceptual loss. Default: 'l1'. + """ + + def __init__(self, + layer_weights, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + perceptual_weight=1.0, + style_weight=0., + criterion='l1'): + super(PerceptualLoss, self).__init__() + self.perceptual_weight = perceptual_weight + self.style_weight = style_weight + self.layer_weights = layer_weights + self.vgg = VGGFeatureExtractor( + layer_name_list=list(layer_weights.keys()), + vgg_type=vgg_type, + use_input_norm=use_input_norm, + range_norm=range_norm) + + self.criterion_type = criterion + if self.criterion_type == 'l1': + self.criterion = torch.nn.L1Loss() + elif self.criterion_type == 'l2': + self.criterion = torch.nn.L2loss() + elif self.criterion_type == 'mse': + self.criterion = torch.nn.MSELoss(reduction='mean') + elif self.criterion_type == 'fro': + self.criterion = None + else: + raise NotImplementedError(f'{criterion} criterion has not been supported.') + + def forward(self, x, gt): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + gt (Tensor): Ground-truth tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + # extract vgg features + x_features = self.vgg(x) + gt_features = self.vgg(gt.detach()) + + # calculate perceptual loss + if self.perceptual_weight > 0: + percep_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k] + else: + percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] + percep_loss *= self.perceptual_weight + else: + percep_loss = None + + # calculate style loss + if self.style_weight > 0: + style_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + style_loss += torch.norm( + self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k] + else: + style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat( + gt_features[k])) * self.layer_weights[k] + style_loss *= self.style_weight + else: + style_loss = None + + return percep_loss, style_loss + + def _gram_mat(self, x): + """Calculate Gram matrix. + + Args: + x (torch.Tensor): Tensor with shape of (n, c, h, w). + + Returns: + torch.Tensor: Gram matrix. + """ + n, c, h, w = x.size() + features = x.view(n, c, w * h) + features_t = features.transpose(1, 2) + gram = features.bmm(features_t) / (c * h * w) + return gram + + +@LOSS_REGISTRY.register() +class LPIPSLoss(nn.Module): + def __init__(self, + loss_weight=1.0, + use_input_norm=True, + range_norm=False,): + super(LPIPSLoss, self).__init__() + self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval() + self.loss_weight = loss_weight + self.use_input_norm = use_input_norm + self.range_norm = range_norm + + if self.use_input_norm: + # the mean is for image with range [0, 1] + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + # the std is for image with range [0, 1] + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def forward(self, pred, target): + if self.range_norm: + pred = (pred + 1) / 2 + target = (target + 1) / 2 + if self.use_input_norm: + pred = (pred - self.mean) / self.std + target = (target - self.mean) / self.std + lpips_loss = self.perceptual(target.contiguous(), pred.contiguous()) + return self.loss_weight * lpips_loss.mean() + + +@LOSS_REGISTRY.register() +class GANLoss(nn.Module): + """Define GAN loss. + + Args: + gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. + real_label_val (float): The value for real label. Default: 1.0. + fake_label_val (float): The value for fake label. Default: 0.0. + loss_weight (float): Loss weight. Default: 1.0. + Note that loss_weight is only for generators; and it is always 1.0 + for discriminators. + """ + + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): + super(GANLoss, self).__init__() + self.gan_type = gan_type + self.loss_weight = loss_weight + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + + if self.gan_type == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif self.gan_type == 'lsgan': + self.loss = nn.MSELoss() + elif self.gan_type == 'wgan': + self.loss = self._wgan_loss + elif self.gan_type == 'wgan_softplus': + self.loss = self._wgan_softplus_loss + elif self.gan_type == 'hinge': + self.loss = nn.ReLU() + else: + raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.') + + def _wgan_loss(self, input, target): + """wgan loss. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return -input.mean() if target else input.mean() + + def _wgan_softplus_loss(self, input, target): + """wgan loss with soft plus. softplus is a smooth approximation to the + ReLU function. + + In StyleGAN2, it is called: + Logistic loss for discriminator; + Non-saturating loss for generator. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return F.softplus(-input).mean() if target else F.softplus(input).mean() + + def get_target_label(self, input, target_is_real): + """Get target label. + + Args: + input (Tensor): Input tensor. + target_is_real (bool): Whether the target is real or fake. + + Returns: + (bool | Tensor): Target tensor. Return bool for wgan, otherwise, + return Tensor. + """ + + if self.gan_type in ['wgan', 'wgan_softplus']: + return target_is_real + target_val = (self.real_label_val if target_is_real else self.fake_label_val) + return input.new_ones(input.size()) * target_val + + def forward(self, input, target_is_real, is_disc=False): + """ + Args: + input (Tensor): The input for the loss module, i.e., the network + prediction. + target_is_real (bool): Whether the targe is real or fake. + is_disc (bool): Whether the loss for discriminators or not. + Default: False. + + Returns: + Tensor: GAN loss value. + """ + if self.gan_type == 'hinge': + if is_disc: # for discriminators in hinge-gan + input = -input if target_is_real else input + loss = self.loss(1 + input).mean() + else: # for generators in hinge-gan + loss = -input.mean() + else: # other gan types + target_label = self.get_target_label(input, target_is_real) + loss = self.loss(input, target_label) + + # loss_weight is always 1.0 for discriminators + return loss if is_disc else loss * self.loss_weight + + +def r1_penalty(real_pred, real_img): + """R1 regularization for discriminator. The core idea is to + penalize the gradient on real data alone: when the + generator distribution produces the true data distribution + and the discriminator is equal to 0 on the data manifold, the + gradient penalty ensures that the discriminator cannot create + a non-zero gradient orthogonal to the data manifold without + suffering a loss in the GAN game. + + Ref: + Eq. 9 in Which training methods for GANs do actually converge. + """ + grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0] + grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() + return grad_penalty + + +def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): + noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3]) + grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0] + path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) + + path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) + + path_penalty = (path_lengths - path_mean).pow(2).mean() + + return path_penalty, path_lengths.detach().mean(), path_mean.detach() + + +def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None): + """Calculate gradient penalty for wgan-gp. + + Args: + discriminator (nn.Module): Network for the discriminator. + real_data (Tensor): Real input data. + fake_data (Tensor): Fake input data. + weight (Tensor): Weight tensor. Default: None. + + Returns: + Tensor: A tensor for gradient penalty. + """ + + batch_size = real_data.size(0) + alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1)) + + # interpolate between real_data and fake_data + interpolates = alpha * real_data + (1. - alpha) * fake_data + interpolates = autograd.Variable(interpolates, requires_grad=True) + + disc_interpolates = discriminator(interpolates) + gradients = autograd.grad( + outputs=disc_interpolates, + inputs=interpolates, + grad_outputs=torch.ones_like(disc_interpolates), + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + + if weight is not None: + gradients = gradients * weight + + gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() + if weight is not None: + gradients_penalty /= torch.mean(weight) + + return gradients_penalty diff --git a/Hallo2/hallo2/basicsr/metrics/__init__.py b/Hallo2/hallo2/basicsr/metrics/__init__.py new file mode 100644 index 00000000..19d55cc8 --- /dev/null +++ b/Hallo2/hallo2/basicsr/metrics/__init__.py @@ -0,0 +1,19 @@ +from copy import deepcopy + +from basicsr.utils.registry import METRIC_REGISTRY +from .psnr_ssim import calculate_psnr, calculate_ssim + +__all__ = ['calculate_psnr', 'calculate_ssim'] + + +def calculate_metric(data, opt): + """Calculate metric from data and options. + + Args: + opt (dict): Configuration. It must constain: + type (str): Model type. + """ + opt = deepcopy(opt) + metric_type = opt.pop('type') + metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) + return metric diff --git a/Hallo2/hallo2/basicsr/metrics/metric_util.py b/Hallo2/hallo2/basicsr/metrics/metric_util.py new file mode 100644 index 00000000..4d18f0f7 --- /dev/null +++ b/Hallo2/hallo2/basicsr/metrics/metric_util.py @@ -0,0 +1,45 @@ +import numpy as np + +from basicsr.utils.matlab_functions import bgr2ycbcr + + +def reorder_image(img, input_order='HWC'): + """Reorder images to 'HWC' order. + + If the input_order is (h, w), return (h, w, 1); + If the input_order is (c, h, w), return (h, w, c); + If the input_order is (h, w, c), return as it is. + + Args: + img (ndarray): Input image. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + If the input image shape is (h, w), input_order will not have + effects. Default: 'HWC'. + + Returns: + ndarray: reordered image. + """ + + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'") + if len(img.shape) == 2: + img = img[..., None] + if input_order == 'CHW': + img = img.transpose(1, 2, 0) + return img + + +def to_y_channel(img): + """Change to Y channel of YCbCr. + + Args: + img (ndarray): Images with range [0, 255]. + + Returns: + (ndarray): Images with range [0, 255] (float type) without round. + """ + img = img.astype(np.float32) / 255. + if img.ndim == 3 and img.shape[2] == 3: + img = bgr2ycbcr(img, y_only=True) + img = img[..., None] + return img * 255. diff --git a/Hallo2/hallo2/basicsr/metrics/psnr_ssim.py b/Hallo2/hallo2/basicsr/metrics/psnr_ssim.py new file mode 100644 index 00000000..bbd95069 --- /dev/null +++ b/Hallo2/hallo2/basicsr/metrics/psnr_ssim.py @@ -0,0 +1,128 @@ +import cv2 +import numpy as np + +from basicsr.metrics.metric_util import reorder_image, to_y_channel +from basicsr.utils.registry import METRIC_REGISTRY + + +@METRIC_REGISTRY.register() +def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img1 (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the PSNR calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: psnr result. + """ + + assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') + img1 = reorder_image(img1, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img1 = to_y_channel(img1) + img2 = to_y_channel(img2) + + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20. * np.log10(255. / np.sqrt(mse)) + + +def _ssim(img1, img2): + """Calculate SSIM (structural similarity) for one channel images. + + It is called by func:`calculate_ssim`. + + Args: + img1 (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + + Returns: + float: ssim result. + """ + + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +@METRIC_REGISTRY.register() +def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False): + """Calculate SSIM (structural similarity). + + Ref: + Image quality assessment: From error visibility to structural similarity + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img1 (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the SSIM calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: ssim result. + """ + + assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') + img1 = reorder_image(img1, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img1 = to_y_channel(img1) + img2 = to_y_channel(img2) + + ssims = [] + for i in range(img1.shape[2]): + ssims.append(_ssim(img1[..., i], img2[..., i])) + return np.array(ssims).mean() diff --git a/Hallo2/hallo2/basicsr/models/__init__.py b/Hallo2/hallo2/basicsr/models/__init__.py new file mode 100644 index 00000000..00bde45f --- /dev/null +++ b/Hallo2/hallo2/basicsr/models/__init__.py @@ -0,0 +1,30 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from basicsr.utils import get_root_logger, scandir +from basicsr.utils.registry import MODEL_REGISTRY + +__all__ = ['build_model'] + +# automatically scan and import model modules for registry +# scan all the files under the 'models' folder and collect files ending with +# '_model.py' +model_folder = osp.dirname(osp.abspath(__file__)) +model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] +# import all the model modules +_model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames] + + +def build_model(opt): + """Build model from options. + + Args: + opt (dict): Configuration. It must constain: + model_type (str): Model type. + """ + opt = deepcopy(opt) + model = MODEL_REGISTRY.get(opt['model_type'])(opt) + logger = get_root_logger() + logger.info(f'Model [{model.__class__.__name__}] is created.') + return model diff --git a/Hallo2/hallo2/basicsr/models/base_model.py b/Hallo2/hallo2/basicsr/models/base_model.py new file mode 100644 index 00000000..bf1f90ac --- /dev/null +++ b/Hallo2/hallo2/basicsr/models/base_model.py @@ -0,0 +1,322 @@ +import logging +import os +import torch +from collections import OrderedDict +from copy import deepcopy +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from basicsr.models import lr_scheduler as lr_scheduler +from basicsr.utils.dist_util import master_only + +logger = logging.getLogger('basicsr') + + +class BaseModel(): + """Base model.""" + + def __init__(self, opt): + self.opt = opt + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.is_train = opt['is_train'] + self.schedulers = [] + self.optimizers = [] + + def feed_data(self, data): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + pass + + def save(self, epoch, current_iter): + """Save networks and training state.""" + pass + + def validation(self, dataloader, current_iter, tb_logger, save_img=False): + """Validation function. + + Args: + dataloader (torch.utils.data.DataLoader): Validation dataloader. + current_iter (int): Current iteration. + tb_logger (tensorboard logger): Tensorboard logger. + save_img (bool): Whether to save images. Default: False. + """ + if self.opt['dist']: + self.dist_validation(dataloader, current_iter, tb_logger, save_img) + else: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def model_ema(self, decay=0.999): + net_g = self.get_bare_model(self.net_g) + + net_g_params = dict(net_g.named_parameters()) + net_g_ema_params = dict(self.net_g_ema.named_parameters()) + + for k in net_g_ema_params.keys(): + net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay) + + def get_current_log(self): + return self.log_dict + + def model_to_device(self, net): + """Model to device. It also warps models with DistributedDataParallel + or DataParallel. + + Args: + net (nn.Module) + """ + net = net.to(self.device) + if self.opt['dist']: + find_unused_parameters = self.opt.get('find_unused_parameters', False) + net = DistributedDataParallel( + net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters) + elif self.opt['num_gpu'] > 1: + net = DataParallel(net) + return net + + def get_optimizer(self, optim_type, params, lr, **kwargs): + if optim_type == 'Adam': + optimizer = torch.optim.Adam(params, lr, **kwargs) + else: + raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.') + return optimizer + + def setup_schedulers(self): + """Set up schedulers.""" + train_opt = self.opt['train'] + scheduler_type = train_opt['scheduler'].pop('type') + if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']: + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler'])) + elif scheduler_type == 'CosineAnnealingRestartLR': + for optimizer in self.optimizers: + self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler'])) + else: + raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.') + + def get_bare_model(self, net): + """Get bare model, especially under wrapping with + DistributedDataParallel or DataParallel. + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net = net.module + return net + + @master_only + def print_network(self, net): + """Print the str and parameter number of a network. + + Args: + net (nn.Module) + """ + if isinstance(net, (DataParallel, DistributedDataParallel)): + net_cls_str = (f'{net.__class__.__name__} - ' f'{net.module.__class__.__name__}') + else: + net_cls_str = f'{net.__class__.__name__}' + + net = self.get_bare_model(net) + net_str = str(net) + net_params = sum(map(lambda x: x.numel(), net.parameters())) + + logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}') + logger.info(net_str) + + def _set_lr(self, lr_groups_l): + """Set learning rate for warmup. + + Args: + lr_groups_l (list): List for lr_groups, each for an optimizer. + """ + for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): + for param_group, lr in zip(optimizer.param_groups, lr_groups): + param_group['lr'] = lr + + def _get_init_lr(self): + """Get the initial lr, which is set by the scheduler. + """ + init_lr_groups_l = [] + for optimizer in self.optimizers: + init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) + return init_lr_groups_l + + def update_learning_rate(self, current_iter, warmup_iter=-1): + """Update learning rate. + + Args: + current_iter (int): Current iteration. + warmup_iter (int): Warmup iter numbers. -1 for no warmup. + Default: -1. + """ + if current_iter > 1: + for scheduler in self.schedulers: + scheduler.step() + # set up warm-up learning rate + if current_iter < warmup_iter: + # get initial lr for each group + init_lr_g_l = self._get_init_lr() + # modify warming-up learning rates + # currently only support linearly warm up + warm_up_lr_l = [] + for init_lr_g in init_lr_g_l: + warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g]) + # set learning rate + self._set_lr(warm_up_lr_l) + + def get_current_learning_rate(self): + return [param_group['lr'] for param_group in self.optimizers[0].param_groups] + + @master_only + def save_network(self, net, net_label, current_iter, param_key='params'): + """Save networks. + + Args: + net (nn.Module | list[nn.Module]): Network(s) to be saved. + net_label (str): Network label. + current_iter (int): Current iter number. + param_key (str | list[str]): The parameter key(s) to save network. + Default: 'params'. + """ + if current_iter == -1: + current_iter = 'latest' + save_filename = f'{net_label}_{current_iter}.pth' + save_path = os.path.join(self.opt['path']['models'], save_filename) + + net = net if isinstance(net, list) else [net] + param_key = param_key if isinstance(param_key, list) else [param_key] + assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.' + + save_dict = {} + for net_, param_key_ in zip(net, param_key): + net_ = self.get_bare_model(net_) + state_dict = net_.state_dict() + for key, param in state_dict.items(): + if key.startswith('module.'): # remove unnecessary 'module.' + key = key[7:] + state_dict[key] = param.cpu() + save_dict[param_key_] = state_dict + + torch.save(save_dict, save_path) + + def _print_different_keys_loading(self, crt_net, load_net, strict=True): + """Print keys with differnet name or different size when loading models. + + 1. Print keys with differnet names. + 2. If strict=False, print the same key but with different tensor size. + It also ignore these keys with different sizes (not load). + + Args: + crt_net (torch model): Current network. + load_net (dict): Loaded network. + strict (bool): Whether strictly loaded. Default: True. + """ + crt_net = self.get_bare_model(crt_net) + crt_net = crt_net.state_dict() + crt_net_keys = set(crt_net.keys()) + load_net_keys = set(load_net.keys()) + + if crt_net_keys != load_net_keys: + logger.warning('Current net - loaded net:') + for v in sorted(list(crt_net_keys - load_net_keys)): + logger.warning(f' {v}') + logger.warning('Loaded net - current net:') + for v in sorted(list(load_net_keys - crt_net_keys)): + logger.warning(f' {v}') + + # check the size for the same keys + if not strict: + common_keys = crt_net_keys & load_net_keys + for k in common_keys: + if crt_net[k].size() != load_net[k].size(): + logger.warning(f'Size different, ignore [{k}]: crt_net: ' + f'{crt_net[k].shape}; load_net: {load_net[k].shape}') + load_net[k + '.ignore'] = load_net.pop(k) + + def load_network(self, net, load_path, strict=True, param_key='params'): + """Load network. + + Args: + load_path (str): The path of networks to be loaded. + net (nn.Module): Network. + strict (bool): Whether strictly loaded. + param_key (str): The parameter key of loaded network. If set to + None, use the root 'path'. + Default: 'params'. + """ + net = self.get_bare_model(net) + logger.info(f'Loading {net.__class__.__name__} model from {load_path}.') + load_net = torch.load(load_path, map_location=lambda storage, loc: storage) + if param_key is not None: + if param_key not in load_net and 'params' in load_net: + param_key = 'params' + logger.info('Loading: params_ema does not exist, use params.') + load_net = load_net[param_key] + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + self._print_different_keys_loading(net, load_net, strict) + net.load_state_dict(load_net, strict=strict) + + @master_only + def save_training_state(self, epoch, current_iter): + """Save training states during training, which will be used for + resuming. + + Args: + epoch (int): Current epoch. + current_iter (int): Current iteration. + """ + if current_iter != -1: + state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []} + for o in self.optimizers: + state['optimizers'].append(o.state_dict()) + for s in self.schedulers: + state['schedulers'].append(s.state_dict()) + save_filename = f'{current_iter}.state' + save_path = os.path.join(self.opt['path']['training_states'], save_filename) + torch.save(state, save_path) + + def resume_training(self, resume_state): + """Reload the optimizers and schedulers for resumed training. + + Args: + resume_state (dict): Resume state. + """ + resume_optimizers = resume_state['optimizers'] + resume_schedulers = resume_state['schedulers'] + assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' + assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' + for i, o in enumerate(resume_optimizers): + self.optimizers[i].load_state_dict(o) + for i, s in enumerate(resume_schedulers): + self.schedulers[i].load_state_dict(s) + + def reduce_loss_dict(self, loss_dict): + """reduce loss dict. + + In distributed training, it averages the losses among different GPUs . + + Args: + loss_dict (OrderedDict): Loss dict. + """ + with torch.no_grad(): + if self.opt['dist']: + keys = [] + losses = [] + for name, value in loss_dict.items(): + keys.append(name) + losses.append(value) + losses = torch.stack(losses, 0) + torch.distributed.reduce(losses, dst=0) + if self.opt['rank'] == 0: + losses /= self.opt['world_size'] + loss_dict = {key: loss for key, loss in zip(keys, losses)} + + log_dict = OrderedDict() + for name, value in loss_dict.items(): + log_dict[name] = value.mean().item() + + return log_dict diff --git a/Hallo2/hallo2/basicsr/models/codeformer_temporal_model.py b/Hallo2/hallo2/basicsr/models/codeformer_temporal_model.py new file mode 100644 index 00000000..44ff2d96 --- /dev/null +++ b/Hallo2/hallo2/basicsr/models/codeformer_temporal_model.py @@ -0,0 +1,238 @@ +import torch +from collections import OrderedDict +from os import path as osp +from tqdm import tqdm +from einops import rearrange + +from basicsr.archs import build_network +from basicsr.metrics import calculate_metric +from basicsr.utils import get_root_logger, imwrite, tensor2img +from basicsr.utils.registry import MODEL_REGISTRY +import torch.nn.functional as F +from .sr_model import SRModel + +# from icecream import ic + +@MODEL_REGISTRY.register() +class CodeFormerTempModel(SRModel): + def feed_data(self, data): + self.gt = data['gt'].to(self.device) + self.input = data['in'].to(self.device) + self.b = self.gt.shape[0] + self.f = self.gt.shape[1] + self.bf = self.b * self.f + + if 'latent_gt' in data: + self.idx_gt = data['latent_gt'].to(self.device) + self.idx_gt = self.idx_gt.view(self.b, -1) + else: + self.idx_gt = None + + def init_training_settings(self): + logger = get_root_logger() + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema is used only for testing on one GPU and saving + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + if self.opt['datasets']['train'].get('latent_gt_path', None) is not None: + self.generate_idx_gt = False + elif self.opt.get('network_vqgan', None) is not None: + self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device) + self.hq_vqgan_fix.eval() + self.generate_idx_gt = True + for param in self.hq_vqgan_fix.parameters(): + param.requires_grad = False + else: + raise NotImplementedError(f'Shoule have network_vqgan config or pre-calculated latent code.') + + logger.info(f'Need to generate latent GT code: {self.generate_idx_gt}') + + self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True) + self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0) + self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True) + self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5) + + self.net_g.train() + self.net_g.requires_grad_(False) + + trainable_module = train_opt['trainable_para'] + for name, module in self.net_g.named_modules(): + if trainable_module in name : + for params in module.parameters(): + params.requires_grad_(True) + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + + def setup_optimizers(self): + train_opt = self.opt['train'] + # optimizer g + optim_params_g = [] + optim_name = [] + for k, v in self.net_g.named_parameters(): + if v.requires_grad: + optim_params_g.append(v) + optim_name.append(k) + + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + + # ic(optim_name) + + + def optimize_parameters(self, current_iter): + logger = get_root_logger() + # optimize net_g + self.optimizer_g.zero_grad() + + if self.generate_idx_gt: + x = rearrange(self.gt, "b f c h w -> (b f) c h w") + x = self.hq_vqgan_fix.encoder(x) + _, _, quant_stats = self.hq_vqgan_fix.quantize(x) + min_encoding_indices = quant_stats['min_encoding_indices'] + # ic(min_encoding_indices.shape) + self.idx_gt = min_encoding_indices.view(self.bf, -1) + # ic(self.idx_gt.shape) + + if self.hq_feat_loss: + # quant_feats + quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.bf,16,16,256]) + + logits, lq_feat = self.net_g(self.input, w=0, code_only=True) + # ic(logits.shape) + # ic(lq_feat.shape) + + l_g_total = 0 + loss_dict = OrderedDict() + # hq_feat_loss + if self.hq_feat_loss: # codebook loss + l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight + l_g_total += l_feat_encoder + loss_dict['l_feat_encoder'] = l_feat_encoder + + # cross_entropy_loss + if self.cross_entropy_loss: + # b(hw)n -> bn(hw) + cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight + l_g_total += cross_entropy_loss + loss_dict['cross_entropy_loss'] = cross_entropy_loss + + l_g_total.backward() + self.optimizer_g.step() + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + self.log_dict = self.reduce_loss_dict(loss_dict) + + + def test(self): + with torch.no_grad(): + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + self.output, _, _ = self.net_g_ema(self.input, w=0) + else: + logger = get_root_logger() + logger.warning('Do not have self.net_g_ema, use self.net_g.') + self.net_g.eval() + self.output, _, _ = self.net_g(self.input, w=0) + self.net_g.train() + + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + if with_metrics: + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + pbar = tqdm(total=len(dataloader), unit='image') + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + visuals = self.get_current_visuals() + sr_img = tensor2img([visuals['result']]) + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']]) + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + imwrite(sr_img, save_img_path) + + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + metric_data = dict(img1=sr_img, img2=gt_img) + self.metric_results[name] += calculate_metric(metric_data, opt_) + pbar.update(1) + pbar.set_description(f'Test {img_name}') + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + self.metric_results[metric] /= (idx + 1) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}\n' + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) + + + def get_current_visuals(self): + out_dict = OrderedDict() + out_dict['gt'] = self.gt.detach().cpu() + out_dict['result'] = self.output.detach().cpu() + return out_dict + + + def save(self, epoch, current_iter): + if self.ema_decay > 0: + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/Hallo2/hallo2/basicsr/models/lr_scheduler.py b/Hallo2/hallo2/basicsr/models/lr_scheduler.py new file mode 100644 index 00000000..a423ce65 --- /dev/null +++ b/Hallo2/hallo2/basicsr/models/lr_scheduler.py @@ -0,0 +1,96 @@ +import math +from collections import Counter +from torch.optim.lr_scheduler import _LRScheduler + + +class MultiStepRestartLR(_LRScheduler): + """ MultiStep with restarts learning rate scheme. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + milestones (list): Iterations that will decrease learning rate. + gamma (float): Decrease ratio. Default: 0.1. + restarts (list): Restart iterations. Default: [0]. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): + self.milestones = Counter(milestones) + self.gamma = gamma + self.restarts = restarts + self.restart_weights = restart_weights + assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' + super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch in self.restarts: + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [group['initial_lr'] * weight for group in self.optimizer.param_groups] + if self.last_epoch not in self.milestones: + return [group['lr'] for group in self.optimizer.param_groups] + return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] + + +def get_position_from_periods(iteration, cumulative_period): + """Get the position from a period list. + + It will return the index of the right-closest number in the period list. + For example, the cumulative_period = [100, 200, 300, 400], + if iteration == 50, return 0; + if iteration == 210, return 2; + if iteration == 300, return 2. + + Args: + iteration (int): Current iteration. + cumulative_period (list[int]): Cumulative period list. + + Returns: + int: The position of the right-closest number in the period list. + """ + for i, period in enumerate(cumulative_period): + if iteration <= period: + return i + + +class CosineAnnealingRestartLR(_LRScheduler): + """ Cosine annealing with restarts learning rate scheme. + + An example of config: + periods = [10, 10, 10, 10] + restart_weights = [1, 0.5, 0.5, 0.5] + eta_min=1e-7 + + It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the + scheduler will restart with the weights in restart_weights. + + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + periods (list): Period for each cosine anneling cycle. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + eta_min (float): The mimimum lr. Default: 0. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + + def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): + self.periods = periods + self.restart_weights = restart_weights + self.eta_min = eta_min + assert (len(self.periods) == len( + self.restart_weights)), 'periods and restart_weights should have the same length.' + self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] + super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + idx = get_position_from_periods(self.last_epoch, self.cumulative_period) + current_weight = self.restart_weights[idx] + nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] + current_period = self.periods[idx] + + return [ + self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * + (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) + for base_lr in self.base_lrs + ] diff --git a/Hallo2/hallo2/basicsr/models/sr_model.py b/Hallo2/hallo2/basicsr/models/sr_model.py new file mode 100644 index 00000000..4f4b9fe0 --- /dev/null +++ b/Hallo2/hallo2/basicsr/models/sr_model.py @@ -0,0 +1,209 @@ +import torch +from collections import OrderedDict +from os import path as osp +from tqdm import tqdm + +from basicsr.archs import build_network +from basicsr.losses import build_loss +from basicsr.metrics import calculate_metric +from basicsr.utils import get_root_logger, imwrite, tensor2img +from basicsr.utils.registry import MODEL_REGISTRY +from .base_model import BaseModel + +@MODEL_REGISTRY.register() +class SRModel(BaseModel): + """Base SR model for single image super-resolution.""" + + def __init__(self, opt): + super(SRModel, self).__init__(opt) + + # define network + self.net_g = build_network(opt['network_g']) + self.net_g = self.model_to_device(self.net_g) + self.print_network(self.net_g) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + param_key = self.opt['path'].get('param_key_g', 'params') + self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key) + + if self.is_train: + self.init_training_settings() + + def init_training_settings(self): + self.net_g.train() + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger = get_root_logger() + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema is used only for testing on one GPU and saving + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if self.cri_pix is None and self.cri_perceptual is None: + raise ValueError('Both pixel and perceptual losses are None.') + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def setup_optimizers(self): + train_opt = self.opt['train'] + optim_params = [] + for k, v in self.net_g.named_parameters(): + if v.requires_grad: + optim_params.append(v) + else: + logger = get_root_logger() + logger.warning(f'Params {k} will not be optimized.') + + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + + def feed_data(self, data): + self.lq = data['lq'].to(self.device) + if 'gt' in data: + self.gt = data['gt'].to(self.device) + + def optimize_parameters(self, current_iter): + self.optimizer_g.zero_grad() + self.output = self.net_g(self.lq) + + l_total = 0 + loss_dict = OrderedDict() + # pixel loss + if self.cri_pix: + l_pix = self.cri_pix(self.output, self.gt) + l_total += l_pix + loss_dict['l_pix'] = l_pix + # perceptual loss + if self.cri_perceptual: + l_percep, l_style = self.cri_perceptual(self.output, self.gt) + if l_percep is not None: + l_total += l_percep + loss_dict['l_percep'] = l_percep + if l_style is not None: + l_total += l_style + loss_dict['l_style'] = l_style + + l_total.backward() + self.optimizer_g.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + def test(self): + if hasattr(self, 'ema_decay'): + self.net_g_ema.eval() + with torch.no_grad(): + self.output = self.net_g_ema(self.lq) + else: + self.net_g.eval() + with torch.no_grad(): + self.output = self.net_g(self.lq) + self.net_g.train() + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + if with_metrics: + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + pbar = tqdm(total=len(dataloader), unit='image') + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + visuals = self.get_current_visuals() + sr_img = tensor2img([visuals['result']]) + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']]) + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + imwrite(sr_img, save_img_path) + + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + metric_data = dict(img1=sr_img, img2=gt_img) + self.metric_results[name] += calculate_metric(metric_data, opt_) + pbar.update(1) + pbar.set_description(f'Test {img_name}') + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + self.metric_results[metric] /= (idx + 1) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}\n' + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) + + def get_current_visuals(self): + out_dict = OrderedDict() + out_dict['lq'] = self.lq.detach().cpu() + out_dict['result'] = self.output.detach().cpu() + if hasattr(self, 'gt'): + out_dict['gt'] = self.gt.detach().cpu() + return out_dict + + def save(self, epoch, current_iter): + if hasattr(self, 'ema_decay'): + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/Hallo2/hallo2/basicsr/models/vqgan_model.py b/Hallo2/hallo2/basicsr/models/vqgan_model.py new file mode 100644 index 00000000..d345a6b0 --- /dev/null +++ b/Hallo2/hallo2/basicsr/models/vqgan_model.py @@ -0,0 +1,285 @@ +import torch +from collections import OrderedDict +from os import path as osp +from tqdm import tqdm + +from basicsr.archs import build_network +from basicsr.losses import build_loss +from basicsr.metrics import calculate_metric +from basicsr.utils import get_root_logger, imwrite, tensor2img +from basicsr.utils.registry import MODEL_REGISTRY +import torch.nn.functional as F +from .sr_model import SRModel + + +@MODEL_REGISTRY.register() +class VQGANModel(SRModel): + def feed_data(self, data): + self.gt = data['gt'].to(self.device) + self.b = self.gt.shape[0] + + + def init_training_settings(self): + logger = get_root_logger() + train_opt = self.opt['train'] + + self.ema_decay = train_opt.get('ema_decay', 0) + if self.ema_decay > 0: + logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') + # define network net_g with Exponential Moving Average (EMA) + # net_g_ema is used only for testing on one GPU and saving + # There is no need to wrap with DistributedDataParallel + self.net_g_ema = build_network(self.opt['network_g']).to(self.device) + # load pretrained model + load_path = self.opt['path'].get('pretrain_network_g', None) + if load_path is not None: + self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') + else: + self.model_ema(0) # copy net_g weight + self.net_g_ema.eval() + + # define network net_d + self.net_d = build_network(self.opt['network_d']) + self.net_d = self.model_to_device(self.net_d) + self.print_network(self.net_d) + + # load pretrained models + load_path = self.opt['path'].get('pretrain_network_d', None) + if load_path is not None: + self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True)) + + self.net_g.train() + self.net_d.train() + + # define losses + if train_opt.get('pixel_opt'): + self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) + else: + self.cri_pix = None + + if train_opt.get('perceptual_opt'): + self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) + else: + self.cri_perceptual = None + + if train_opt.get('gan_opt'): + self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) + + if train_opt.get('codebook_opt'): + self.l_weight_codebook = train_opt['codebook_opt'].get('loss_weight', 1.0) + else: + self.l_weight_codebook = 1.0 + + self.vqgan_quantizer = self.opt['network_g']['quantizer'] + logger.info(f'vqgan_quantizer: {self.vqgan_quantizer}') + + self.net_g_start_iter = train_opt.get('net_g_start_iter', 0) + self.net_d_iters = train_opt.get('net_d_iters', 1) + self.net_d_start_iter = train_opt.get('net_d_start_iter', 0) + self.disc_weight = train_opt.get('disc_weight', 0.8) + + # set up optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + + def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max): + recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + + d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach() + return d_weight + + def adopt_weight(self, weight, global_step, threshold=0, value=0.): + if global_step < threshold: + weight = value + return weight + + def setup_optimizers(self): + train_opt = self.opt['train'] + # optimizer g + optim_params_g = [] + for k, v in self.net_g.named_parameters(): + if v.requires_grad: + optim_params_g.append(v) + else: + logger = get_root_logger() + logger.warning(f'Params {k} will not be optimized.') + optim_type = train_opt['optim_g'].pop('type') + self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g']) + self.optimizers.append(self.optimizer_g) + # optimizer d + optim_type = train_opt['optim_d'].pop('type') + self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d']) + self.optimizers.append(self.optimizer_d) + + + def optimize_parameters(self, current_iter): + logger = get_root_logger() + loss_dict = OrderedDict() + if self.opt['network_g']['quantizer'] == 'gumbel': + self.net_g.module.quantize.temperature = max(1/16, ((-1/160000) * current_iter) + 1) + if current_iter%1000 == 0: + logger.info(f'temperature: {self.net_g.module.quantize.temperature}') + + # optimize net_g + for p in self.net_d.parameters(): + p.requires_grad = False + + self.optimizer_g.zero_grad() + self.output, l_codebook, quant_stats = self.net_g(self.gt) + + l_codebook = l_codebook*self.l_weight_codebook + + l_g_total = 0 + if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter: + # pixel loss + if self.cri_pix: + l_g_pix = self.cri_pix(self.output, self.gt) + l_g_total += l_g_pix + loss_dict['l_g_pix'] = l_g_pix + # perceptual loss + if self.cri_perceptual: + l_g_percep = self.cri_perceptual(self.output, self.gt) + l_g_total += l_g_percep + loss_dict['l_g_percep'] = l_g_percep + + # gan loss + if current_iter > self.net_d_start_iter: + # fake_g_pred = self.net_d(self.output_1024) + fake_g_pred = self.net_d(self.output) + l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) + recon_loss = l_g_total + last_layer = self.net_g.module.generator.blocks[-1].weight + d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0) + d_weight *= self.adopt_weight(1, current_iter, self.net_d_start_iter) + d_weight *= self.disc_weight # tamming setting 0.8 + l_g_total += d_weight * l_g_gan + loss_dict['l_g_gan'] = d_weight * l_g_gan + + l_g_total += l_codebook + loss_dict['l_codebook'] = l_codebook + + l_g_total.backward() + self.optimizer_g.step() + + # optimize net_d + if current_iter > self.net_d_start_iter: + for p in self.net_d.parameters(): + p.requires_grad = True + + self.optimizer_d.zero_grad() + # real + real_d_pred = self.net_d(self.gt) + l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) + loss_dict['l_d_real'] = l_d_real + loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) + l_d_real.backward() + # fake + fake_d_pred = self.net_d(self.output.detach()) + l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) + loss_dict['l_d_fake'] = l_d_fake + loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) + l_d_fake.backward() + self.optimizer_d.step() + + self.log_dict = self.reduce_loss_dict(loss_dict) + + if self.ema_decay > 0: + self.model_ema(decay=self.ema_decay) + + + def test(self): + with torch.no_grad(): + if hasattr(self, 'net_g_ema'): + self.net_g_ema.eval() + self.output, _, _ = self.net_g_ema(self.gt) + else: + logger = get_root_logger() + logger.warning('Do not have self.net_g_ema, use self.net_g.') + self.net_g.eval() + self.output, _, _ = self.net_g(self.gt) + self.net_g.train() + + + def dist_validation(self, dataloader, current_iter, tb_logger, save_img): + if self.opt['rank'] == 0: + self.nondist_validation(dataloader, current_iter, tb_logger, save_img) + + + def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): + dataset_name = dataloader.dataset.opt['name'] + with_metrics = self.opt['val'].get('metrics') is not None + if with_metrics: + self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} + pbar = tqdm(total=len(dataloader), unit='image') + + for idx, val_data in enumerate(dataloader): + img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] + self.feed_data(val_data) + self.test() + + visuals = self.get_current_visuals() + sr_img = tensor2img([visuals['result']]) + if 'gt' in visuals: + gt_img = tensor2img([visuals['gt']]) + del self.gt + + # tentative for out of GPU memory + del self.lq + del self.output + torch.cuda.empty_cache() + + if save_img: + if self.opt['is_train']: + save_img_path = osp.join(self.opt['path']['visualization'], img_name, + f'{img_name}_{current_iter}.png') + else: + if self.opt['val']['suffix']: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["val"]["suffix"]}.png') + else: + save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, + f'{img_name}_{self.opt["name"]}.png') + imwrite(sr_img, save_img_path) + + if with_metrics: + # calculate metrics + for name, opt_ in self.opt['val']['metrics'].items(): + metric_data = dict(img1=sr_img, img2=gt_img) + self.metric_results[name] += calculate_metric(metric_data, opt_) + pbar.update(1) + pbar.set_description(f'Test {img_name}') + pbar.close() + + if with_metrics: + for metric in self.metric_results.keys(): + self.metric_results[metric] /= (idx + 1) + + self._log_validation_metric_values(current_iter, dataset_name, tb_logger) + + + def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): + log_str = f'Validation {dataset_name}\n' + for metric, value in self.metric_results.items(): + log_str += f'\t # {metric}: {value:.4f}\n' + logger = get_root_logger() + logger.info(log_str) + if tb_logger: + for metric, value in self.metric_results.items(): + tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) + + + def get_current_visuals(self): + out_dict = OrderedDict() + out_dict['gt'] = self.gt.detach().cpu() + out_dict['result'] = self.output.detach().cpu() + return out_dict + + def save(self, epoch, current_iter): + if self.ema_decay > 0: + self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) + else: + self.save_network(self.net_g, 'net_g', current_iter) + self.save_network(self.net_d, 'net_d', current_iter) + self.save_training_state(epoch, current_iter) diff --git a/Hallo2/hallo2/basicsr/ops/__init__.py b/Hallo2/hallo2/basicsr/ops/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Hallo2/hallo2/basicsr/ops/dcn/__init__.py b/Hallo2/hallo2/basicsr/ops/dcn/__init__.py new file mode 100644 index 00000000..32e3592f --- /dev/null +++ b/Hallo2/hallo2/basicsr/ops/dcn/__init__.py @@ -0,0 +1,7 @@ +from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv, + modulated_deform_conv) + +__all__ = [ + 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', + 'modulated_deform_conv' +] diff --git a/Hallo2/hallo2/basicsr/ops/dcn/deform_conv.py b/Hallo2/hallo2/basicsr/ops/dcn/deform_conv.py new file mode 100644 index 00000000..734154f9 --- /dev/null +++ b/Hallo2/hallo2/basicsr/ops/dcn/deform_conv.py @@ -0,0 +1,377 @@ +import math +import torch +from torch import nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn import functional as F +from torch.nn.modules.utils import _pair, _single + +try: + from . import deform_conv_ext +except ImportError: + import os + BASICSR_JIT = os.getenv('BASICSR_JIT') + if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + deform_conv_ext = load( + 'deform_conv', + sources=[ + os.path.join(module_path, 'src', 'deform_conv_ext.cpp'), + os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'), + os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'), + ], + ) + + +class DeformConvFunction(Function): + + @staticmethod + def forward(ctx, + input, + offset, + weight, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + im2col_step=64): + if input is not None and input.dim() != 4: + raise ValueError(f'Expected 4D tensor as input, got {input.dim()}' 'D tensor instead.') + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.im2col_step = im2col_step + + ctx.save_for_backward(input, offset, weight) + + output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride)) + + ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones + + if not input.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' + deform_conv_ext.deform_conv_forward(input, weight, + offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], + ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, + ctx.deformable_groups, cur_im2col_step) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, offset, weight = ctx.saved_tensors + + grad_input = grad_offset = grad_weight = None + + if not grad_output.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input, + grad_offset, weight, ctx.bufs_[0], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], + ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, + ctx.deformable_groups, cur_im2col_step) + + if ctx.needs_input_grad[2]: + grad_weight = torch.zeros_like(weight) + deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight, + ctx.bufs_[0], ctx.bufs_[1], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], + ctx.padding[1], ctx.padding[0], ctx.dilation[1], + ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1, + cur_im2col_step) + + return (grad_input, grad_offset, grad_weight, None, None, None, None, None) + + @staticmethod + def _output_size(input, weight, padding, dilation, stride): + channels = weight.size(0) + output_size = (input.size(0), channels) + for d in range(input.dim() - 2): + in_size = input.size(d + 2) + pad = padding[d] + kernel = dilation[d] * (weight.size(d + 2) - 1) + 1 + stride_ = stride[d] + output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) + if not all(map(lambda s: s > 0, output_size)): + raise ValueError('convolution input is too small (output would be ' f'{"x".join(map(str, output_size))})') + return output_size + + +class ModulatedDeformConvFunction(Function): + + @staticmethod + def forward(ctx, + input, + offset, + mask, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1): + ctx.stride = stride + ctx.padding = padding + ctx.dilation = dilation + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.with_bias = bias is not None + if not ctx.with_bias: + bias = input.new_empty(1) # fake tensor + if not input.is_cuda: + raise NotImplementedError + if weight.requires_grad or mask.requires_grad or offset.requires_grad \ + or input.requires_grad: + ctx.save_for_backward(input, offset, mask, weight, bias) + output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight)) + ctx._bufs = [input.new_empty(0), input.new_empty(0)] + deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output, + ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride, + ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.groups, ctx.deformable_groups, ctx.with_bias) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError + input, offset, mask, weight, bias = ctx.saved_tensors + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + grad_mask = torch.zeros_like(mask) + grad_weight = torch.zeros_like(weight) + grad_bias = torch.zeros_like(bias) + deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], + grad_input, grad_weight, grad_bias, grad_offset, grad_mask, + grad_output, weight.shape[2], weight.shape[3], ctx.stride, + ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.groups, ctx.deformable_groups, ctx.with_bias) + if not ctx.with_bias: + grad_bias = None + + return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None) + + @staticmethod + def _infer_shape(ctx, input, weight): + n = input.size(0) + channels_out = weight.size(0) + height, width = input.shape[2:4] + kernel_h, kernel_w = weight.shape[2:4] + height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1 + width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1 + return n, channels_out, height_out, width_out + + +deform_conv = DeformConvFunction.apply +modulated_deform_conv = ModulatedDeformConvFunction.apply + + +class DeformConv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=False): + super(DeformConv, self).__init__() + + assert not bias + assert in_channels % groups == 0, \ + f'in_channels {in_channels} is not divisible by groups {groups}' + assert out_channels % groups == 0, \ + f'out_channels {out_channels} is not divisible ' \ + f'by groups {groups}' + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deformable_groups = deformable_groups + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size)) + + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + + def forward(self, x, offset): + # To fix an assert error in deform_conv_cuda.cpp:128 + # input image is smaller than kernel + input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1]) + if input_pad: + pad_h = max(self.kernel_size[0] - x.size(2), 0) + pad_w = max(self.kernel_size[1] - x.size(3), 0) + x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous() + offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous() + out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, + self.deformable_groups) + if input_pad: + out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous() + return out + + +class DeformConvPack(DeformConv): + """A Deformable Conv Encapsulation that acts as normal Conv layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super(DeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True) + self.init_offset() + + def init_offset(self): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + offset = self.conv_offset(x) + return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, + self.deformable_groups) + + +class ModulatedDeformConv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=True): + super(ModulatedDeformConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.deformable_groups = deformable_groups + self.with_bias = bias + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.init_weights() + + def init_weights(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.zero_() + + def forward(self, x, offset, mask): + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) + + +class ModulatedDeformConvPack(ModulatedDeformConv): + """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True) + self.init_weights() + + def init_weights(self): + super(ModulatedDeformConvPack, self).init_weights() + if hasattr(self, 'conv_offset'): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + out = self.conv_offset(x) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) diff --git a/Hallo2/hallo2/basicsr/ops/dcn/src/deform_conv_cuda.cpp b/Hallo2/hallo2/basicsr/ops/dcn/src/deform_conv_cuda.cpp new file mode 100644 index 00000000..5d942490 --- /dev/null +++ b/Hallo2/hallo2/basicsr/ops/dcn/src/deform_conv_cuda.cpp @@ -0,0 +1,685 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include + +#include +#include + +void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor data_col); + +void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im); + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const int channels, const int height, + const int width, const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor grad_offset); + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor data_col); + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor grad_im); + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, at::Tensor grad_offset, + at::Tensor grad_mask); + +void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, + at::Tensor weight, int kH, int kW, int dH, int dW, int padH, + int padW, int dilationH, int dilationW, int group, + int deformable_group) { + TORCH_CHECK(weight.ndimension() == 4, + "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " + "but got: %s", + weight.ndimension()); + + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + TORCH_CHECK(kW > 0 && kH > 0, + "kernel size should be greater than zero, but got kH: %d kW: %d", kH, + kW); + + TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW), + "kernel size should be consistent with weight, ", + "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH, + kW, weight.size(2), weight.size(3)); + + TORCH_CHECK(dW > 0 && dH > 0, + "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); + + TORCH_CHECK( + dilationW > 0 && dilationH > 0, + "dilation should be greater than 0, but got dilationH: %d dilationW: %d", + dilationH, dilationW); + + int ndim = input.ndimension(); + int dimf = 0; + int dimh = 1; + int dimw = 2; + + if (ndim == 4) { + dimf++; + dimh++; + dimw++; + } + + TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s", + ndim); + + long nInputPlane = weight.size(1) * group; + long inputHeight = input.size(dimh); + long inputWidth = input.size(dimw); + long nOutputPlane = weight.size(0); + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + + TORCH_CHECK(nInputPlane % deformable_group == 0, + "input channels must divide deformable group size"); + + if (outputWidth < 1 || outputHeight < 1) + AT_ERROR( + "Given input size: (%ld x %ld x %ld). " + "Calculated output size: (%ld x %ld x %ld). Output size is too small", + nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, + outputWidth); + + TORCH_CHECK(input.size(1) == nInputPlane, + "invalid number of input planes, expected: %d, but got: %d", + nInputPlane, input.size(1)); + + TORCH_CHECK((inputHeight >= kH && inputWidth >= kW), + "input image is smaller than kernel"); + + TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth), + "invalid spatial size of offset, expected height: %d width: %d, but " + "got height: %d width: %d", + outputHeight, outputWidth, offset.size(2), offset.size(3)); + + TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), + "invalid number of channels of offset"); + + if (gradOutput != NULL) { + TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane, + "invalid number of gradOutput planes, expected: %d, but got: %d", + nOutputPlane, gradOutput->size(dimf)); + + TORCH_CHECK((gradOutput->size(dimh) == outputHeight && + gradOutput->size(dimw) == outputWidth), + "invalid size of gradOutput, expected height: %d width: %d , but " + "got height: %d width: %d", + outputHeight, outputWidth, gradOutput->size(dimh), + gradOutput->size(dimw)); + } +} + +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) { + // todo: resize columns to include im2col: done + // todo: add im2col_step as input + // todo: add new output buffer and transpose it to output (or directly + // transpose output) todo: possibly change data indexing because of + // parallel_imgs + + shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input.unsqueeze_(0); + offset.unsqueeze_(0); + } + + // todo: assert batchsize dividable by im2col_step + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane, + outputHeight, outputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < outputHeight * outputWidth) { + ones = at::ones({outputHeight, outputWidth}, input.options()); + } + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + at::Tensor output_buffer = + at::zeros({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}, + output.options()); + + output_buffer = output_buffer.view( + {output_buffer.size(0), group, output_buffer.size(1) / group, + output_buffer.size(2), output_buffer.size(3)}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + output_buffer[elt][g] = output_buffer[elt][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output_buffer[elt][g]); + } + } + + output_buffer = output_buffer.view( + {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2), + output_buffer.size(3), output_buffer.size(4)}); + + output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step, outputHeight, outputWidth}); + output_buffer.transpose_(1, 2); + output.copy_(output_buffer); + output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + output = output.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step) { + shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view({1, input.size(0), input.size(1), input.size(2)}); + offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)}); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + // change order of grad output + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, + outputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + // divide into groups + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), group, gradOutput.size(1) / group, + gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)}); + + for (int g = 0; g < group; g++) { + columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + gradOutput[elt][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2), + gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)}); + + deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane, + inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, + dilationH, dilationW, im2col_step, deformable_group, + gradOffset[elt]); + + deformable_col2im(columns, offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, gradInput[elt]); + } + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + gradOffset = gradOffset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + gradOffset = + gradOffset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step) { + // todo: transpose and reshape outGrad + // todo: reshape columns + // todo: add im2col_step as input + + shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH, + padW, dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view( + at::IntList({1, input.size(0), input.size(1), input.size(2)})); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = gradWeight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + at::Tensor gradOutputBuffer = at::zeros_like(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step, + outputHeight, outputWidth}); + gradOutputBuffer.copy_(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}); + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + // divide into group + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group, + gradOutputBuffer.size(2), gradOutputBuffer.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + gradWeight = + gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3)}); + + for (int g = 0; g < group; g++) { + gradWeight[g] = gradWeight[g] + .flatten(1) + .addmm_(gradOutputBuffer[elt][g].flatten(1), + columns[g].transpose(1, 0), 1.0, scale) + .view_as(gradWeight[g]); + } + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), + gradOutputBuffer.size(1) * gradOutputBuffer.size(2), + gradOutputBuffer.size(3), gradOutputBuffer.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3), + gradWeight.size(4)}); + } + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + } + + return 1; +} + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias) { + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + // resize output + output = output.view({batch, channels_out, height_out, width_out}).zero_(); + // resize temporary columns + columns = + at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, + input.options()); + + output = output.view({output.size(0), group, output.size(1) / group, + output.size(2), output.size(3)}); + + for (int b = 0; b < batch; b++) { + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + // divide into group + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + + for (int g = 0; g < group; g++) { + output[b][g] = output[b][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output[b][g]); + } + + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + output = output.view({output.size(0), output.size(1) * output.size(2), + output.size(3), output.size(4)}); + + if (with_bias) { + output += bias.view({1, bias.size(0), 1, 1}); + } +} + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) { + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + grad_input = grad_input.view({batch, channels, height, width}); + columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, + input.options()); + + grad_output = + grad_output.view({grad_output.size(0), group, grad_output.size(1) / group, + grad_output.size(2), grad_output.size(3)}); + + for (int b = 0; b < batch; b++) { + // divide int group + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + grad_output[b][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_cuda( + columns, input[b], offset[b], mask[b], 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b], + grad_mask[b]); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda( + columns, offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, grad_input[b]); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and + // group + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + grad_weight = grad_weight.view({group, grad_weight.size(0) / group, + grad_weight.size(1), grad_weight.size(2), + grad_weight.size(3)}); + if (with_bias) + grad_bias = grad_bias.view({group, grad_bias.size(0) / group}); + + for (int g = 0; g < group; g++) { + grad_weight[g] = + grad_weight[g] + .flatten(1) + .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)) + .view_as(grad_weight[g]); + if (with_bias) { + grad_bias[g] = + grad_bias[g] + .view({-1, 1}) + .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})) + .view(-1); + } + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), grad_weight.size(3), + grad_weight.size(4)}); + if (with_bias) + grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)}); + } + grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1), + grad_output.size(2), grad_output.size(3), + grad_output.size(4)}); +} diff --git a/Hallo2/hallo2/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu b/Hallo2/hallo2/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu new file mode 100644 index 00000000..98752dcc --- /dev/null +++ b/Hallo2/hallo2/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu @@ -0,0 +1,867 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu + +#include +#include +#include +#include +#include +#include + +using namespace at; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +const int kMaxGridNum = 65535; + +inline int GET_BLOCKS(const int N) +{ + return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); +} + +template +__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const scalar_t map_h = i * dilation_h + offset_h; + //const scalar_t map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +void deformable_im2col( + const at::Tensor data_im, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + // todo: check parallel_imgs is correctly passed in + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + channel_per_deformable_group, parallel_imgs, channels, deformable_group, + height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_im2col: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_gpu_kernel( + const int n, const scalar_t *data_col, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index]; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +void deformable_col2im( + const at::Tensor data_col, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im) +{ + + // todo: make sure parallel_imgs is passed in correctly + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, channels, height, width, ksize_h, + ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_col2im: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col, + const scalar_t *data_im, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, scalar_t *grad_offset) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * + batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + const scalar_t weight = get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos]; + cnt += 1; + } + + grad_offset[index] = val; + } +} + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, at::Tensor grad_offset) +{ + + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs; + int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + + deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, channels, height, width, + ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group, + height_col, width_col, grad_offset_); + })); +} + +template +__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void modulated_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const float map_h = i * dilation_h + offset_h; + //const float map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + //data_col_ptr += height_col * width_col; + } + } + } +} + +template +__global__ void modulated_deformable_col2im_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +template +__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_offset, scalar_t *grad_mask) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + else + { + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + } + const scalar_t weight = dmcn_get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; + } +} + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + modulated_deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, channels, deformable_group, height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor grad_im) +{ + + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + modulated_deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + at::Tensor grad_offset, at::Tensor grad_mask) +{ + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + scalar_t *grad_mask_ = grad_mask.data_ptr(); + + modulated_deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, + grad_offset_, grad_mask_); + })); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); + } +} diff --git a/Hallo2/hallo2/basicsr/ops/dcn/src/deform_conv_ext.cpp b/Hallo2/hallo2/basicsr/ops/dcn/src/deform_conv_ext.cpp new file mode 100644 index 00000000..41c6df6f --- /dev/null +++ b/Hallo2/hallo2/basicsr/ops/dcn/src/deform_conv_ext.cpp @@ -0,0 +1,164 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include + +#include +#include + +#define WITH_CUDA // always use cuda +#ifdef WITH_CUDA +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step); + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias); + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias); +#endif + +int deform_conv_forward(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_forward_cuda(input, weight, offset, output, columns, + ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group, + deformable_group, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +int deform_conv_backward_input(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_backward_input_cuda(input, offset, gradOutput, + gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH, + dilationW, dilationH, group, deformable_group, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +int deform_conv_backward_parameters( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_backward_parameters_cuda(input, offset, gradOutput, + gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW, + dilationH, group, deformable_group, scale, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +void modulated_deform_conv_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_forward(input, weight, bias, ones, + offset, mask, output, columns, kernel_h, kernel_w, stride_h, + stride_w, pad_h, pad_w, dilation_h, dilation_w, group, + deformable_group, with_bias); +#else + AT_ERROR("modulated deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform conv is not implemented on CPU"); +} + +void modulated_deform_conv_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_backward(input, weight, bias, ones, + offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset, + grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, deformable_group, + with_bias); +#else + AT_ERROR("modulated deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform conv is not implemented on CPU"); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("deform_conv_forward", &deform_conv_forward, + "deform forward"); + m.def("deform_conv_backward_input", &deform_conv_backward_input, + "deform_conv_backward_input"); + m.def("deform_conv_backward_parameters", + &deform_conv_backward_parameters, + "deform_conv_backward_parameters"); + m.def("modulated_deform_conv_forward", + &modulated_deform_conv_forward, + "modulated deform conv forward"); + m.def("modulated_deform_conv_backward", + &modulated_deform_conv_backward, + "modulated deform conv backward"); +} diff --git a/Hallo2/hallo2/basicsr/ops/fused_act/__init__.py b/Hallo2/hallo2/basicsr/ops/fused_act/__init__.py new file mode 100644 index 00000000..241dc075 --- /dev/null +++ b/Hallo2/hallo2/basicsr/ops/fused_act/__init__.py @@ -0,0 +1,3 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu + +__all__ = ['FusedLeakyReLU', 'fused_leaky_relu'] diff --git a/Hallo2/hallo2/basicsr/ops/fused_act/fused_act.py b/Hallo2/hallo2/basicsr/ops/fused_act/fused_act.py new file mode 100644 index 00000000..588f815e --- /dev/null +++ b/Hallo2/hallo2/basicsr/ops/fused_act/fused_act.py @@ -0,0 +1,89 @@ +# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 + +import torch +from torch import nn +from torch.autograd import Function + +try: + from . import fused_act_ext +except ImportError: + import os + BASICSR_JIT = os.getenv('BASICSR_JIT') + if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + fused_act_ext = load( + 'fused', + sources=[ + os.path.join(module_path, 'src', 'fused_bias_act.cpp'), + os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'), + ], + ) + + +class FusedLeakyReLUFunctionBackward(Function): + + @staticmethod + def forward(ctx, grad_output, out, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + grad_bias = grad_input.sum(dim).detach() + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + out, = ctx.saved_tensors + gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, + ctx.scale) + + return gradgrad_out, None, None, None + + +class FusedLeakyReLUFunction(Function): + + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + + def __init__(self, channel, negative_slope=0.2, scale=2**0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/Hallo2/hallo2/basicsr/ops/fused_act/src/fused_bias_act.cpp b/Hallo2/hallo2/basicsr/ops/fused_act/src/fused_bias_act.cpp new file mode 100644 index 00000000..85ed0a79 --- /dev/null +++ b/Hallo2/hallo2/basicsr/ops/fused_act/src/fused_bias_act.cpp @@ -0,0 +1,26 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp +#include + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, + int act, int grad, float alpha, float scale); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor fused_bias_act(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); +} diff --git a/Hallo2/hallo2/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu b/Hallo2/hallo2/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu new file mode 100644 index 00000000..54c7ff53 --- /dev/null +++ b/Hallo2/hallo2/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu @@ -0,0 +1,100 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +template +static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, + int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { + int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; + + scalar_t zero = 0.0; + + for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { + scalar_t x = p_x[xi]; + + if (use_bias) { + x += p_b[(xi / step_b) % size_b]; + } + + scalar_t ref = use_ref ? p_ref[xi] : zero; + + scalar_t y; + + switch (act * 10 + grad) { + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0; break; + + case 30: y = (x > 0.0) ? x : x * alpha; break; + case 31: y = (ref > 0.0) ? x : x * alpha; break; + case 32: y = 0.0; break; + } + + out[xi] = y * scale; + } +} + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + auto x = input.contiguous(); + auto b = bias.contiguous(); + auto ref = refer.contiguous(); + + int use_bias = b.numel() ? 1 : 0; + int use_ref = ref.numel() ? 1 : 0; + + int size_x = x.numel(); + int size_b = b.numel(); + int step_b = 1; + + for (int i = 1 + 1; i < x.dim(); i++) { + step_b *= x.size(i); + } + + int loop_x = 4; + int block_size = 4 * 32; + int grid_size = (size_x - 1) / (loop_x * block_size) + 1; + + auto y = torch::empty_like(x); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { + fused_bias_act_kernel<<>>( + y.data_ptr(), + x.data_ptr(), + b.data_ptr(), + ref.data_ptr(), + act, + grad, + alpha, + scale, + loop_x, + size_x, + step_b, + size_b, + use_bias, + use_ref + ); + }); + + return y; +} diff --git a/Hallo2/hallo2/basicsr/ops/upfirdn2d/__init__.py b/Hallo2/hallo2/basicsr/ops/upfirdn2d/__init__.py new file mode 100644 index 00000000..397e85be --- /dev/null +++ b/Hallo2/hallo2/basicsr/ops/upfirdn2d/__init__.py @@ -0,0 +1,3 @@ +from .upfirdn2d import upfirdn2d + +__all__ = ['upfirdn2d'] diff --git a/Hallo2/hallo2/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp b/Hallo2/hallo2/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp new file mode 100644 index 00000000..43d0b678 --- /dev/null +++ b/Hallo2/hallo2/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp @@ -0,0 +1,24 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp +#include + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} diff --git a/Hallo2/hallo2/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu b/Hallo2/hallo2/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu new file mode 100644 index 00000000..8870063b --- /dev/null +++ b/Hallo2/hallo2/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu @@ -0,0 +1,370 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + +template +__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; + int out_y = minor_idx / p.minor_dim; + minor_idx -= out_y * p.minor_dim; + int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; + int major_idx_base = blockIdx.z * p.loop_major; + + if (out_x_base >= p.out_w || out_y >= p.out_h || + major_idx_base >= p.major_dim) { + return; + } + + int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; + int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); + int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; + int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major && major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, out_x = out_x_base; + loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { + int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; + int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); + int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; + int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; + + const scalar_t *x_p = + &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + + minor_idx]; + const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; + int x_px = p.minor_dim; + int k_px = -p.up_x; + int x_py = p.in_w * p.minor_dim; + int k_py = -p.up_y * p.kernel_w; + + scalar_t v = 0.0f; + + for (int y = 0; y < h; y++) { + for (int x = 0; x < w; x++) { + v += static_cast(*x_p) * static_cast(*k_p); + x_p += x_px; + k_p += k_px; + } + + x_p += x_py - w * x_px; + k_p += k_py - w * k_px; + } + + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } +} + +template +__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | + major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; + tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major & major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; + loop_x < p.loop_x & tile_out_x < p.out_w; + loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; + in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * + p.minor_dim + + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; + out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + +#pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) +#pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * + sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } + } + } +} + +torch::Tensor upfirdn2d_op(const torch::Tensor &input, + const torch::Tensor &kernel, int up_x, int up_y, + int down_x, int down_y, int pad_x0, int pad_x1, + int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / + p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / + p.down_x; + + auto out = + at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h = -1; + int tile_out_w = -1; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w > 0) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } else { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 4; + block_size = dim3(4, 32, 1); + grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, + (p.out_w - 1) / (p.loop_x * block_size.y) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 2: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 3: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 4: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 5: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 6: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + default: + upfirdn2d_kernel_large<<>>( + out.data_ptr(), x.data_ptr(), + k.data_ptr(), p); + } + }); + + return out; +} diff --git a/Hallo2/hallo2/basicsr/ops/upfirdn2d/upfirdn2d.py b/Hallo2/hallo2/basicsr/ops/upfirdn2d/upfirdn2d.py new file mode 100644 index 00000000..667f96e1 --- /dev/null +++ b/Hallo2/hallo2/basicsr/ops/upfirdn2d/upfirdn2d.py @@ -0,0 +1,186 @@ +# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501 + +import torch +from torch.autograd import Function +from torch.nn import functional as F + +try: + from . import upfirdn2d_ext +except ImportError: + import os + BASICSR_JIT = os.getenv('BASICSR_JIT') + if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + upfirdn2d_ext = load( + 'upfirdn2d', + sources=[ + os.path.join(module_path, 'src', 'upfirdn2d.cpp'), + os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'), + ], + ) + + +class UpFirDn2dBackward(Function): + + @staticmethod + def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_ext.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + kernel, = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_ext.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], + # ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + if input.device.type == 'cpu': + out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + else: + out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])) + + return out + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/Hallo2/hallo2/basicsr/setup.py b/Hallo2/hallo2/basicsr/setup.py new file mode 100644 index 00000000..b24d0450 --- /dev/null +++ b/Hallo2/hallo2/basicsr/setup.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python + +from setuptools import find_packages, setup + +import os +import subprocess +import sys +import time +from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension +from utils.misc import gpu_is_available + +version_file = './basicsr/version.py' + + +def readme(): + with open('README.md', encoding='utf-8') as f: + content = f.read() + return content + + +def get_git_hash(): + + def _minimal_ext_cmd(cmd): + # construct minimal environment + env = {} + for k in ['SYSTEMROOT', 'PATH', 'HOME']: + v = os.environ.get(k) + if v is not None: + env[k] = v + # LANGUAGE is used on win32 + env['LANGUAGE'] = 'C' + env['LANG'] = 'C' + env['LC_ALL'] = 'C' + out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0] + return out + + try: + out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) + sha = out.strip().decode('ascii') + except OSError: + sha = 'unknown' + + return sha + + +def get_hash(): + if os.path.exists('.git'): + sha = get_git_hash()[:7] + elif os.path.exists(version_file): + try: + from version import __version__ + sha = __version__.split('+')[-1] + except ImportError: + raise ImportError('Unable to get git version') + else: + sha = 'unknown' + + return sha + + +def write_version_py(): + content = """# GENERATED VERSION FILE +# TIME: {} +__version__ = '{}' +__gitsha__ = '{}' +version_info = ({}) +""" + sha = get_hash() + with open('./basicsr/VERSION', 'r') as f: + SHORT_VERSION = f.read().strip() + VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) + + version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO) + with open(version_file, 'w') as f: + f.write(version_file_str) + + +def get_version(): + with open(version_file, 'r') as f: + exec(compile(f.read(), version_file, 'exec')) + return locals()['__version__'] + + +def make_cuda_ext(name, module, sources, sources_cuda=None): + if sources_cuda is None: + sources_cuda = [] + define_macros = [] + extra_compile_args = {'cxx': []} + + # if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': + if gpu_is_available or os.getenv('FORCE_CUDA', '0') == '1': + define_macros += [('WITH_CUDA', None)] + extension = CUDAExtension + extra_compile_args['nvcc'] = [ + '-D__CUDA_NO_HALF_OPERATORS__', + '-D__CUDA_NO_HALF_CONVERSIONS__', + '-D__CUDA_NO_HALF2_OPERATORS__', + ] + sources += sources_cuda + else: + print(f'Compiling {name} without CUDA') + extension = CppExtension + + return extension( + name=f'{module}.{name}', + sources=[os.path.join(*module.split('.'), p) for p in sources], + define_macros=define_macros, + extra_compile_args=extra_compile_args) + + +def get_requirements(filename='requirements.txt'): + with open(os.path.join('.', filename), 'r') as f: + requires = [line.replace('\n', '') for line in f.readlines()] + return requires + + +if __name__ == '__main__': + if '--cuda_ext' in sys.argv: + ext_modules = [ + make_cuda_ext( + name='deform_conv_ext', + module='ops.dcn', + sources=['src/deform_conv_ext.cpp'], + sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']), + make_cuda_ext( + name='fused_act_ext', + module='ops.fused_act', + sources=['src/fused_bias_act.cpp'], + sources_cuda=['src/fused_bias_act_kernel.cu']), + make_cuda_ext( + name='upfirdn2d_ext', + module='ops.upfirdn2d', + sources=['src/upfirdn2d.cpp'], + sources_cuda=['src/upfirdn2d_kernel.cu']), + ] + sys.argv.remove('--cuda_ext') + else: + ext_modules = [] + + write_version_py() + setup( + name='basicsr', + version=get_version(), + description='Open Source Image and Video Super-Resolution Toolbox', + long_description=readme(), + long_description_content_type='text/markdown', + author='Xintao Wang', + author_email='xintao.wang@outlook.com', + keywords='computer vision, restoration, super resolution', + url='https://github.com/xinntao/BasicSR', + include_package_data=True, + packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')), + classifiers=[ + 'Development Status :: 4 - Beta', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + ], + license='Apache License 2.0', + setup_requires=['cython', 'numpy'], + install_requires=get_requirements(), + ext_modules=ext_modules, + cmdclass={'build_ext': BuildExtension}, + zip_safe=False) diff --git a/Hallo2/hallo2/basicsr/train.py b/Hallo2/hallo2/basicsr/train.py new file mode 100644 index 00000000..84fe0a28 --- /dev/null +++ b/Hallo2/hallo2/basicsr/train.py @@ -0,0 +1,225 @@ +import argparse +import datetime +import logging +import math +import copy +import random +import time +import torch +from os import path as osp + +from basicsr.data import build_dataloader, build_dataset +from basicsr.data.data_sampler import EnlargedSampler +from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher +from basicsr.models import build_model +from basicsr.utils import (MessageLogger, check_resume, get_env_info, get_root_logger, init_tb_logger, + init_wandb_logger, make_exp_dirs, mkdir_and_rename, set_random_seed) +from basicsr.utils.dist_util import get_dist_info, init_dist +from basicsr.utils.options import dict2str, parse + +import warnings +# ignore UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. +warnings.filterwarnings("ignore", category=UserWarning) + +def parse_options(root_path, is_train=True): + parser = argparse.ArgumentParser() + parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.') + parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') + parser.add_argument('--local-rank','--local_rank', type=int, default=0) + args = parser.parse_args() + opt = parse(args.opt, root_path, is_train=is_train) + + # distributed settings + if args.launcher == 'none': + opt['dist'] = False + print('Disable distributed.', flush=True) + else: + opt['dist'] = True + if args.launcher == 'slurm' and 'dist_params' in opt: + init_dist(args.launcher, **opt['dist_params']) + else: + init_dist(args.launcher) + + opt['rank'], opt['world_size'] = get_dist_info() + + # random seed + seed = opt.get('manual_seed') + if seed is None: + seed = random.randint(1, 10000) + opt['manual_seed'] = seed + set_random_seed(seed + opt['rank']) + + return opt + + +def init_loggers(opt): + log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log") + logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) + logger.info(get_env_info()) + logger.info(dict2str(opt)) + + # initialize wandb logger before tensorboard logger to allow proper sync: + if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') is not None): + assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb') + init_wandb_logger(opt) + tb_logger = None + if opt['logger'].get('use_tb_logger'): + tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name'])) + return logger, tb_logger + + +def create_train_val_dataloader(opt, logger): + # create train and val dataloaders + train_loader, val_loader = None, None + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1) + train_set = build_dataset(dataset_opt) + train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio) + train_loader = build_dataloader( + train_set, + dataset_opt, + num_gpu=opt['num_gpu'], + dist=opt['dist'], + sampler=train_sampler, + seed=opt['manual_seed']) + + num_iter_per_epoch = math.ceil( + len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size'])) + total_iters = int(opt['train']['total_iter']) + total_epochs = math.ceil(total_iters / (num_iter_per_epoch)) + logger.info('Training statistics:' + f'\n\tNumber of train images: {len(train_set)}' + f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}' + f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}' + f'\n\tWorld size (gpu number): {opt["world_size"]}' + f'\n\tRequire iter number per epoch: {num_iter_per_epoch}' + f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.') + + elif phase == 'val': + val_set = build_dataset(dataset_opt) + val_loader = build_dataloader( + val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) + logger.info(f'Number of val images/folders in {dataset_opt["name"]}: ' f'{len(val_set)}') + else: + raise ValueError(f'Dataset phase {phase} is not recognized.') + + return train_loader, train_sampler, val_loader, total_epochs, total_iters + + +def train_pipeline(root_path): + # parse options, set distributed setting, set ramdom seed + opt = parse_options(root_path, is_train=True) + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + # load resume states if necessary + if opt['path'].get('resume_state'): + device_id = torch.cuda.current_device() + resume_state = torch.load( + opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) + else: + resume_state = None + + # mkdir for experiments and logger + if resume_state is None: + make_exp_dirs(opt) + if opt['logger'].get('use_tb_logger') and opt['rank'] == 0: + mkdir_and_rename(osp.join('tb_logger', opt['name'])) + + # initialize loggers + logger, tb_logger = init_loggers(opt) + + # create train and validation dataloaders + result = create_train_val_dataloader(opt, logger) + train_loader, train_sampler, val_loader, total_epochs, total_iters = result + + # create model + if resume_state: # resume training + check_resume(opt, resume_state['iter']) + model = build_model(opt) + model.resume_training(resume_state) # handle optimizers and schedulers + logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.") + start_epoch = resume_state['epoch'] + current_iter = resume_state['iter'] + else: + model = build_model(opt) + start_epoch = 0 + current_iter = 0 + + # create message logger (formatted outputs) + msg_logger = MessageLogger(opt, current_iter, tb_logger) + + # dataloader prefetcher + prefetch_mode = opt['datasets']['train'].get('prefetch_mode') + if prefetch_mode is None or prefetch_mode == 'cpu': + prefetcher = CPUPrefetcher(train_loader) + elif prefetch_mode == 'cuda': + prefetcher = CUDAPrefetcher(train_loader, opt) + logger.info(f'Use {prefetch_mode} prefetch dataloader') + if opt['datasets']['train'].get('pin_memory') is not True: + raise ValueError('Please set pin_memory=True for CUDAPrefetcher.') + else: + raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.' "Supported ones are: None, 'cuda', 'cpu'.") + + # training + logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter+1}') + data_time, iter_time = time.time(), time.time() + start_time = time.time() + + for epoch in range(start_epoch, total_epochs + 1): + train_sampler.set_epoch(epoch) + prefetcher.reset() + train_data = prefetcher.next() + + while train_data is not None: + data_time = time.time() - data_time + + current_iter += 1 + if current_iter > total_iters: + break + # update learning rate + model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1)) + # training + model.feed_data(train_data) + model.optimize_parameters(current_iter) + iter_time = time.time() - iter_time + # log + if current_iter % opt['logger']['print_freq'] == 0: + log_vars = {'epoch': epoch, 'iter': current_iter} + log_vars.update({'lrs': model.get_current_learning_rate()}) + log_vars.update({'time': iter_time, 'data_time': data_time}) + log_vars.update(model.get_current_log()) + msg_logger(log_vars) + + # save models and training states + if current_iter % opt['logger']['save_checkpoint_freq'] == 0: + logger.info('Saving models and training states.') + model.save(epoch, current_iter) + + # validation + if opt.get('val') is not None and opt['datasets'].get('val') is not None \ + and (current_iter % opt['val']['val_freq'] == 0): + model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) + + data_time = time.time() + iter_time = time.time() + train_data = prefetcher.next() + # end of iter + + # end of epoch + + consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time))) + logger.info(f'End of training. Time consumed: {consumed_time}') + logger.info('Save the latest model.') + model.save(epoch=-1, current_iter=-1) # -1 stands for the latest + if opt.get('val') is not None and opt['datasets'].get('val'): + model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) + if tb_logger: + tb_logger.close() + + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + train_pipeline(root_path) diff --git a/Hallo2/hallo2/basicsr/utils/__init__.py b/Hallo2/hallo2/basicsr/utils/__init__.py new file mode 100644 index 00000000..5fcc1d54 --- /dev/null +++ b/Hallo2/hallo2/basicsr/utils/__init__.py @@ -0,0 +1,29 @@ +from .file_client import FileClient +from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img +from .logger import MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger +from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt + +__all__ = [ + # file_client.py + 'FileClient', + # img_util.py + 'img2tensor', + 'tensor2img', + 'imfrombytes', + 'imwrite', + 'crop_border', + # logger.py + 'MessageLogger', + 'init_tb_logger', + 'init_wandb_logger', + 'get_root_logger', + 'get_env_info', + # misc.py + 'set_random_seed', + 'get_time_str', + 'mkdir_and_rename', + 'make_exp_dirs', + 'scandir', + 'check_resume', + 'sizeof_fmt' +] diff --git a/Hallo2/hallo2/basicsr/utils/dist_util.py b/Hallo2/hallo2/basicsr/utils/dist_util.py new file mode 100644 index 00000000..0fab887b --- /dev/null +++ b/Hallo2/hallo2/basicsr/utils/dist_util.py @@ -0,0 +1,82 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 +import functools +import os +import subprocess +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +def init_dist(launcher, backend='nccl', **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend, **kwargs): + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend, port=None): + """Initialize slurm distributed training environment. + + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + +def get_dist_info(): + if dist.is_available(): + initialized = dist.is_initialized() + else: + initialized = False + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def master_only(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper diff --git a/Hallo2/hallo2/basicsr/utils/download_util.py b/Hallo2/hallo2/basicsr/utils/download_util.py new file mode 100644 index 00000000..2a267915 --- /dev/null +++ b/Hallo2/hallo2/basicsr/utils/download_util.py @@ -0,0 +1,95 @@ +import math +import os +import requests +from torch.hub import download_url_to_file, get_dir +from tqdm import tqdm +from urllib.parse import urlparse + +from .misc import sizeof_fmt + + +def download_file_from_google_drive(file_id, save_path): + """Download files from google drive. + Ref: + https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 + Args: + file_id (str): File id. + save_path (str): Save path. + """ + + session = requests.Session() + URL = 'https://docs.google.com/uc?export=download' + params = {'id': file_id} + + response = session.get(URL, params=params, stream=True) + token = get_confirm_token(response) + if token: + params['confirm'] = token + response = session.get(URL, params=params, stream=True) + + # get file size + response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) + print(response_file_size) + if 'Content-Range' in response_file_size.headers: + file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) + else: + file_size = None + + save_response_content(response, save_path, file_size) + + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + + +def save_response_content(response, destination, file_size=None, chunk_size=32768): + if file_size is not None: + pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') + + readable_file_size = sizeof_fmt(file_size) + else: + pbar = None + + with open(destination, 'wb') as f: + downloaded_size = 0 + for chunk in response.iter_content(chunk_size): + downloaded_size += chunk_size + if pbar is not None: + pbar.update(1) + pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') + if chunk: # filter out keep-alive new chunks + f.write(chunk) + if pbar is not None: + pbar.close() + + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Load file form http url, will download models if necessary. + Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + Args: + url (str): URL to be downloaded. + model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. + Default: None. + progress (bool): Whether to show the download progress. Default: True. + file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. + Returns: + str: The path to the downloaded file. + """ + if model_dir is None: # use the pytorch hub_dir + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(model_dir, exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file \ No newline at end of file diff --git a/Hallo2/hallo2/basicsr/utils/file_client.py b/Hallo2/hallo2/basicsr/utils/file_client.py new file mode 100644 index 00000000..7f38d979 --- /dev/null +++ b/Hallo2/hallo2/basicsr/utils/file_client.py @@ -0,0 +1,167 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 +from abc import ABCMeta, abstractmethod + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: ``get()`` and ``get_text()``. + ``get()`` reads the file as a byte stream and ``get_text()`` reads the file + as texts. + """ + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str | None): Additional path to be appended to `sys.path`. + Default: None. + """ + + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError('Please install memcached to enable MemcachedBackend.') + + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) + # mc.pyvector servers as a point which points to a memory cache + self._mc_buffer = mc.pyvector() + + def get(self, filepath): + filepath = str(filepath) + import mc + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class HardDiskBackend(BaseStorageBackend): + """Raw hard disks storage backend.""" + + def get(self, filepath): + filepath = str(filepath) + with open(filepath, 'rb') as f: + value_buf = f.read() + return value_buf + + def get_text(self, filepath): + filepath = str(filepath) + with open(filepath, 'r') as f: + value_buf = f.read() + return value_buf + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_paths (str | list[str]): Lmdb database paths. + client_keys (str | list[str]): Lmdb client keys. Default: 'default'. + readonly (bool, optional): Lmdb environment parameter. If True, + disallow any write operations. Default: True. + lock (bool, optional): Lmdb environment parameter. If False, when + concurrent access occurs, do not lock the database. Default: False. + readahead (bool, optional): Lmdb environment parameter. If False, + disable the OS filesystem readahead mechanism, which may improve + random read performance when a database is larger than RAM. + Default: False. + + Attributes: + db_paths (list): Lmdb database path. + _client (list): A list of several lmdb envs. + """ + + def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): + try: + import lmdb + except ImportError: + raise ImportError('Please install lmdb to enable LmdbBackend.') + + if isinstance(client_keys, str): + client_keys = [client_keys] + + if isinstance(db_paths, list): + self.db_paths = [str(v) for v in db_paths] + elif isinstance(db_paths, str): + self.db_paths = [str(db_paths)] + assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' + f'but received {len(client_keys)} and {len(self.db_paths)}.') + + self._client = {} + for client, path in zip(client_keys, self.db_paths): + self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) + + def get(self, filepath, client_key): + """Get values according to the filepath from one lmdb named client_key. + + Args: + filepath (str | obj:`Path`): Here, filepath is the lmdb key. + client_key (str): Used for distinguishing differnet lmdb envs. + """ + filepath = str(filepath) + assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.') + client = self._client[client_key] + with client.begin(write=False) as txn: + value_buf = txn.get(filepath.encode('ascii')) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class FileClient(object): + """A general file client to access files in different backend. + + The client loads a file or text in a specified backend from its path + and return it as a binary file. it can also register other backend + accessor with a given name and backend class. + + Attributes: + backend (str): The storage backend type. Options are "disk", + "memcached" and "lmdb". + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + 'disk': HardDiskBackend, + 'memcached': MemcachedBackend, + 'lmdb': LmdbBackend, + } + + def __init__(self, backend='disk', **kwargs): + if backend not in self._backends: + raise ValueError(f'Backend {backend} is not supported. Currently supported ones' + f' are {list(self._backends.keys())}') + self.backend = backend + self.client = self._backends[backend](**kwargs) + + def get(self, filepath, client_key='default'): + # client_key is used only for lmdb, where different fileclients have + # different lmdb environments. + if self.backend == 'lmdb': + return self.client.get(filepath, client_key) + else: + return self.client.get(filepath) + + def get_text(self, filepath): + return self.client.get_text(filepath) diff --git a/Hallo2/hallo2/basicsr/utils/img_util.py b/Hallo2/hallo2/basicsr/utils/img_util.py new file mode 100644 index 00000000..5aba82ce --- /dev/null +++ b/Hallo2/hallo2/basicsr/utils/img_util.py @@ -0,0 +1,171 @@ +import cv2 +import math +import numpy as np +import os +import torch +from torchvision.utils import make_grid + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): + """Convert torch Tensors into image numpy arrays. + + After clamping to [min, max], values will be normalized to [0, 1]. + + Args: + tensor (Tensor or list[Tensor]): Accept shapes: + 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); + 2) 3D Tensor of shape (3/1 x H x W); + 3) 2D Tensor of shape (H x W). + Tensor channel should be in RGB order. + rgb2bgr (bool): Whether to change rgb to bgr. + out_type (numpy type): output types. If ``np.uint8``, transform outputs + to uint8 type with range [0, 255]; otherwise, float type with + range [0, 1]. Default: ``np.uint8``. + min_max (tuple[int]): min and max values for clamp. + + Returns: + (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of + shape (H x W). The channel order is BGR. + """ + if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') + + if torch.is_tensor(tensor): + tensor = [tensor] + result = [] + for _tensor in tensor: + _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) + + n_dim = _tensor.dim() + if n_dim == 4: + img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() + img_np = img_np.transpose(1, 2, 0) + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 3: + img_np = _tensor.numpy() + img_np = img_np.transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image + img_np = np.squeeze(img_np, axis=2) + else: + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 2: + img_np = _tensor.numpy() + else: + raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}') + if out_type == np.uint8: + # Unlike MATLAB, numpy.unit8() WILL NOT round by default. + img_np = (img_np * 255.0).round() + img_np = img_np.astype(out_type) + result.append(img_np) + if len(result) == 1: + result = result[0] + return result + + +def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): + """This implementation is slightly faster than tensor2img. + It now only supports torch tensor with shape (1, c, h, w). + + Args: + tensor (Tensor): Now only support torch tensor with (1, c, h, w). + rgb2bgr (bool): Whether to change rgb to bgr. Default: True. + min_max (tuple[int]): min and max values for clamp. + """ + output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) + output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 + output = output.type(torch.uint8).cpu().numpy() + if rgb2bgr: + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + +def imfrombytes(content, flag='color', float32=False): + """Read an image from bytes. + + Args: + content (bytes): Image bytes got from files or other streams. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale` and `unchanged`. + float32 (bool): Whether to change to float32., If True, will also norm + to [0, 1]. Default: False. + + Returns: + ndarray: Loaded image array. + """ + img_np = np.frombuffer(content, np.uint8) + imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} + img = cv2.imdecode(img_np, imread_flags[flag]) + if float32: + img = img.astype(np.float32) / 255. + return img + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv's :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + return cv2.imwrite(file_path, img, params) + + +def crop_border(imgs, crop_border): + """Crop borders of images. + + Args: + imgs (list[ndarray] | ndarray): Images with shape (h, w, c). + crop_border (int): Crop border for each end of height and weight. + + Returns: + list[ndarray]: Cropped images. + """ + if crop_border == 0: + return imgs + else: + if isinstance(imgs, list): + return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] + else: + return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] + \ No newline at end of file diff --git a/Hallo2/hallo2/basicsr/utils/lmdb_util.py b/Hallo2/hallo2/basicsr/utils/lmdb_util.py new file mode 100644 index 00000000..e0a10f60 --- /dev/null +++ b/Hallo2/hallo2/basicsr/utils/lmdb_util.py @@ -0,0 +1,196 @@ +import cv2 +import lmdb +import sys +from multiprocessing import Pool +from os import path as osp +from tqdm import tqdm + + +def make_lmdb_from_imgs(data_path, + lmdb_path, + img_path_list, + keys, + batch=5000, + compress_level=1, + multiprocessing_read=False, + n_thread=40, + map_size=None): + """Make lmdb from images. + + Contents of lmdb. The file structure is: + example.lmdb + ├── data.mdb + ├── lock.mdb + ├── meta_info.txt + + The data.mdb and lock.mdb are standard lmdb files and you can refer to + https://lmdb.readthedocs.io/en/release/ for more details. + + The meta_info.txt is a specified txt file to record the meta information + of our datasets. It will be automatically created when preparing + datasets by our provided dataset tools. + Each line in the txt file records 1)image name (with extension), + 2)image shape, and 3)compression level, separated by a white space. + + For example, the meta information could be: + `000_00000000.png (720,1280,3) 1`, which means: + 1) image name (with extension): 000_00000000.png; + 2) image shape: (720,1280,3); + 3) compression level: 1 + + We use the image name without extension as the lmdb key. + + If `multiprocessing_read` is True, it will read all the images to memory + using multiprocessing. Thus, your server needs to have enough memory. + + Args: + data_path (str): Data path for reading images. + lmdb_path (str): Lmdb save path. + img_path_list (str): Image path list. + keys (str): Used for lmdb keys. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + multiprocessing_read (bool): Whether use multiprocessing to read all + the images to memory. Default: False. + n_thread (int): For multiprocessing. + map_size (int | None): Map size for lmdb env. If None, use the + estimated size from images. Default: None + """ + + assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, ' + f'but got {len(img_path_list)} and {len(keys)}') + print(f'Create lmdb for {data_path}, save to {lmdb_path}...') + print(f'Totoal images: {len(img_path_list)}') + if not lmdb_path.endswith('.lmdb'): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f'Folder {lmdb_path} already exists. Exit.') + sys.exit(1) + + if multiprocessing_read: + # read all the images to memory (multiprocessing) + dataset = {} # use dict to keep the order for multiprocessing + shapes = {} + print(f'Read images with multiprocessing, #thread: {n_thread} ...') + pbar = tqdm(total=len(img_path_list), unit='image') + + def callback(arg): + """get the image data and update pbar.""" + key, dataset[key], shapes[key] = arg + pbar.update(1) + pbar.set_description(f'Read {key}') + + pool = Pool(n_thread) + for path, key in zip(img_path_list, keys): + pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback) + pool.close() + pool.join() + pbar.close() + print(f'Finish reading {len(img_path_list)} images.') + + # create lmdb environment + if map_size is None: + # obtain data size for one image + img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) + _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + data_size_per_img = img_byte.nbytes + print('Data size per image is: ', data_size_per_img) + data_size = data_size_per_img * len(img_path_list) + map_size = data_size * 10 + + env = lmdb.open(lmdb_path, map_size=map_size) + + # write data to lmdb + pbar = tqdm(total=len(img_path_list), unit='chunk') + txn = env.begin(write=True) + txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') + for idx, (path, key) in enumerate(zip(img_path_list, keys)): + pbar.update(1) + pbar.set_description(f'Write {key}') + key_byte = key.encode('ascii') + if multiprocessing_read: + img_byte = dataset[key] + h, w, c = shapes[key] + else: + _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level) + h, w, c = img_shape + + txn.put(key_byte, img_byte) + # write meta information + txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') + if idx % batch == 0: + txn.commit() + txn = env.begin(write=True) + pbar.close() + txn.commit() + env.close() + txt_file.close() + print('\nFinish writing lmdb.') + + +def read_img_worker(path, key, compress_level): + """Read image worker. + + Args: + path (str): Image path. + key (str): Image key. + compress_level (int): Compress level when encoding images. + + Returns: + str: Image key. + byte: Image byte. + tuple[int]: Image shape. + """ + + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + if img.ndim == 2: + h, w = img.shape + c = 1 + else: + h, w, c = img.shape + _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + return (key, img_byte, (h, w, c)) + + +class LmdbMaker(): + """LMDB Maker. + + Args: + lmdb_path (str): Lmdb save path. + map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + """ + + def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1): + if not lmdb_path.endswith('.lmdb'): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f'Folder {lmdb_path} already exists. Exit.') + sys.exit(1) + + self.lmdb_path = lmdb_path + self.batch = batch + self.compress_level = compress_level + self.env = lmdb.open(lmdb_path, map_size=map_size) + self.txn = self.env.begin(write=True) + self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') + self.counter = 0 + + def put(self, img_byte, key, img_shape): + self.counter += 1 + key_byte = key.encode('ascii') + self.txn.put(key_byte, img_byte) + # write meta information + h, w, c = img_shape + self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') + if self.counter % self.batch == 0: + self.txn.commit() + self.txn = self.env.begin(write=True) + + def close(self): + self.txn.commit() + self.env.close() + self.txt_file.close() diff --git a/Hallo2/hallo2/basicsr/utils/logger.py b/Hallo2/hallo2/basicsr/utils/logger.py new file mode 100644 index 00000000..cb27be71 --- /dev/null +++ b/Hallo2/hallo2/basicsr/utils/logger.py @@ -0,0 +1,169 @@ +import datetime +import logging +import time + +from .dist_util import get_dist_info, master_only + +initialized_logger = {} + + +class MessageLogger(): + """Message logger for printing. + Args: + opt (dict): Config. It contains the following keys: + name (str): Exp name. + logger (dict): Contains 'print_freq' (str) for logger interval. + train (dict): Contains 'total_iter' (int) for total iters. + use_tb_logger (bool): Use tensorboard logger. + start_iter (int): Start iter. Default: 1. + tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. + """ + + def __init__(self, opt, start_iter=1, tb_logger=None): + self.exp_name = opt['name'] + self.interval = opt['logger']['print_freq'] + self.start_iter = start_iter + self.max_iters = opt['train']['total_iter'] + self.use_tb_logger = opt['logger']['use_tb_logger'] + self.tb_logger = tb_logger + self.start_time = time.time() + self.logger = get_root_logger() + + @master_only + def __call__(self, log_vars): + """Format logging message. + Args: + log_vars (dict): It contains the following keys: + epoch (int): Epoch number. + iter (int): Current iter. + lrs (list): List for learning rates. + time (float): Iter time. + data_time (float): Data time for each iter. + """ + # epoch, iter, learning rates + epoch = log_vars.pop('epoch') + current_iter = log_vars.pop('iter') + lrs = log_vars.pop('lrs') + + message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(') + for v in lrs: + message += f'{v:.3e},' + message += ')] ' + + # time and estimated time + if 'time' in log_vars.keys(): + iter_time = log_vars.pop('time') + data_time = log_vars.pop('data_time') + + total_time = time.time() - self.start_time + time_sec_avg = total_time / (current_iter - self.start_iter + 1) + eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) + eta_str = str(datetime.timedelta(seconds=int(eta_sec))) + message += f'[eta: {eta_str}, ' + message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' + + # other items, especially losses + for k, v in log_vars.items(): + message += f'{k}: {v:.4e} ' + # tensorboard logger + if self.use_tb_logger: + # if k.startswith('l_'): + # self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) + # else: + self.tb_logger.add_scalar(k, v, current_iter) + self.logger.info(message) + + +@master_only +def init_tb_logger(log_dir): + from torch.utils.tensorboard import SummaryWriter + tb_logger = SummaryWriter(log_dir=log_dir) + return tb_logger + + +@master_only +def init_wandb_logger(opt): + """We now only use wandb to sync tensorboard log.""" + import wandb + logger = logging.getLogger('basicsr') + + project = opt['logger']['wandb']['project'] + resume_id = opt['logger']['wandb'].get('resume_id') + if resume_id: + wandb_id = resume_id + resume = 'allow' + logger.warning(f'Resume wandb logger with id={wandb_id}.') + else: + wandb_id = wandb.util.generate_id() + resume = 'never' + + wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) + + logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') + + +def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): + """Get the root logger. + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. + Args: + logger_name (str): root logger name. Default: 'basicsr'. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + Returns: + logging.Logger: The root logger. + """ + logger = logging.getLogger(logger_name) + # if the logger has been initialized, just return it + if logger_name in initialized_logger: + return logger + + format_str = '%(asctime)s %(levelname)s: %(message)s' + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(logging.Formatter(format_str)) + logger.addHandler(stream_handler) + logger.propagate = False + rank, _ = get_dist_info() + if rank != 0: + logger.setLevel('ERROR') + elif log_file is not None: + logger.setLevel(log_level) + # add file handler + # file_handler = logging.FileHandler(log_file, 'w') + file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log + file_handler.setFormatter(logging.Formatter(format_str)) + file_handler.setLevel(log_level) + logger.addHandler(file_handler) + initialized_logger[logger_name] = True + return logger + + +def get_env_info(): + """Get environment information. + Currently, only log the software version. + """ + import torch + import torchvision + + from basicsr.version import __version__ + msg = r""" + ____ _ _____ ____ + / __ ) ____ _ _____ (_)_____/ ___/ / __ \ + / __ |/ __ `// ___// // ___/\__ \ / /_/ / + / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ + /_____/ \__,_//____//_/ \___//____//_/ |_| + ______ __ __ __ __ + / ____/____ ____ ____/ / / / __ __ _____ / /__ / / + / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / + / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ + \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) + """ + msg += ('\nVersion Information: ' + f'\n\tBasicSR: {__version__}' + f'\n\tPyTorch: {torch.__version__}' + f'\n\tTorchVision: {torchvision.__version__}') + return msg \ No newline at end of file diff --git a/Hallo2/hallo2/basicsr/utils/matlab_functions.py b/Hallo2/hallo2/basicsr/utils/matlab_functions.py new file mode 100644 index 00000000..c6ce1004 --- /dev/null +++ b/Hallo2/hallo2/basicsr/utils/matlab_functions.py @@ -0,0 +1,347 @@ +import math +import numpy as np +import torch + + +def cubic(x): + """cubic function used for calculate_weights_indices.""" + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * ( + (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) * + (absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + """Calculate weights and indices, used for imresize function. + + Args: + in_length (int): Input length. + out_length (int): Output length. + scale (float): Scale factor. + kernel_width (int): Kernel width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + """ + + if (scale < 1) and antialiasing: + # Use a modified kernel (larger kernel width) to simultaneously + # interpolate and antialias + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5 + scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + p = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand( + out_length, p) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices + + # apply cubic kernel + if (scale < 1) and antialiasing: + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, p) + + # If a column in weights is all zero, get rid of it. only consider the + # first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, p - 2) + weights = weights.narrow(1, 1, p - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, p - 2) + weights = weights.narrow(1, 0, p - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +@torch.no_grad() +def imresize(img, scale, antialiasing=True): + """imresize function same as MATLAB. + + It now only supports bicubic. + The same scale applies for both height and width. + + Args: + img (Tensor | Numpy array): + Tensor: Input image with shape (c, h, w), [0, 1] range. + Numpy: Input image with shape (h, w, c), [0, 1] range. + scale (float): Scale factor. The same scale applies for both height + and width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + Default: True. + + Returns: + Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round. + """ + if type(img).__module__ == np.__name__: # numpy type + numpy_type = True + img = torch.from_numpy(img.transpose(2, 0, 1)).float() + else: + numpy_type = False + + in_c, in_h, in_w = img.size() + out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale) + kernel_width = 4 + kernel = 'cubic' + + # get weights and indices + weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width, + antialiasing) + weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width, + antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w) + img_aug.narrow(1, sym_len_hs, in_h).copy_(img) + + sym_patch = img[:, :sym_len_hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_he:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_c, out_h, in_w) + kernel_width = weights_h.size(1) + for i in range(out_h): + idx = int(indices_h[i][0]) + for j in range(in_c): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we) + out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_we:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_c, out_h, out_w) + kernel_width = weights_w.size(1) + for i in range(out_w): + idx = int(indices_w[i][0]) + for j in range(in_c): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i]) + + if numpy_type: + out_2 = out_2.numpy().transpose(1, 2, 0) + return out_2 + + +def rgb2ycbcr(img, y_only=False): + """Convert a RGB image to YCbCr image. + + This function produces the same results as Matlab's `rgb2ycbcr` function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 + else: + out_img = np.matmul( + img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def bgr2ycbcr(img, y_only=False): + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2rgb(img): + """Convert a YCbCr image to RGB image. + + This function produces the same results as Matlab's ycbcr2rgb function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted RGB image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2bgr(img): + """Convert a YCbCr image to BGR image. + + The bgr version of ycbcr2rgb. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted BGR image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0], + [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def _convert_input_type_range(img): + """Convert the type and range of the input image. + + It converts the input image to np.float32 type and range of [0, 1]. + It is mainly used for pre-processing the input image in colorspace + convertion functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + (ndarray): The converted image with type of np.float32 and range of + [0, 1]. + """ + img_type = img.dtype + img = img.astype(np.float32) + if img_type == np.float32: + pass + elif img_type == np.uint8: + img /= 255. + else: + raise TypeError('The img type should be np.float32 or np.uint8, ' f'but got {img_type}') + return img + + +def _convert_output_type_range(img, dst_type): + """Convert the type and range of the image according to dst_type. + + It converts the image to desired type and range. If `dst_type` is np.uint8, + images will be converted to np.uint8 type with range [0, 255]. If + `dst_type` is np.float32, it converts the image to np.float32 type with + range [0, 1]. + It is mainly used for post-processing images in colorspace convertion + functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The image to be converted with np.float32 type and + range [0, 255]. + dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it + converts the image to np.uint8 type with range [0, 255]. If + dst_type is np.float32, it converts the image to np.float32 type + with range [0, 1]. + + Returns: + (ndarray): The converted image with desired type and range. + """ + if dst_type not in (np.uint8, np.float32): + raise TypeError('The dst_type should be np.float32 or np.uint8, ' f'but got {dst_type}') + if dst_type == np.uint8: + img = img.round() + else: + img /= 255. + return img.astype(dst_type) diff --git a/Hallo2/hallo2/basicsr/utils/misc.py b/Hallo2/hallo2/basicsr/utils/misc.py new file mode 100644 index 00000000..f425d68e --- /dev/null +++ b/Hallo2/hallo2/basicsr/utils/misc.py @@ -0,0 +1,157 @@ +import os +import re +import random +import time +import torch +import numpy as np +from os import path as osp + +from .dist_util import master_only +from .logger import get_root_logger + +IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ + torch.__version__)[0][:3])] >= [1, 12, 0] + +def gpu_is_available(): + if IS_HIGH_VERSION: + if torch.backends.mps.is_available(): + return True + return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False + +def get_device(gpu_id=None): + if gpu_id is None: + gpu_str = '' + elif isinstance(gpu_id, int): + gpu_str = f':{gpu_id}' + else: + raise TypeError('Input should be int value.') + + if IS_HIGH_VERSION: + if torch.backends.mps.is_available(): + return torch.device('mps'+gpu_str) + return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu') + + +def set_random_seed(seed): + """Set random seeds.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_time_str(): + return time.strftime('%Y%m%d_%H%M%S', time.localtime()) + + +def mkdir_and_rename(path): + """mkdirs. If path exists, rename it with timestamp and create a new one. + + Args: + path (str): Folder path. + """ + if osp.exists(path): + new_name = path + '_archived_' + get_time_str() + print(f'Path already exists. Rename it to {new_name}', flush=True) + os.rename(path, new_name) + os.makedirs(path, exist_ok=True) + + +@master_only +def make_exp_dirs(opt): + """Make dirs for experiments.""" + path_opt = opt['path'].copy() + if opt['is_train']: + mkdir_and_rename(path_opt.pop('experiments_root')) + else: + mkdir_and_rename(path_opt.pop('results_root')) + for key, path in path_opt.items(): + if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key): + os.makedirs(path, exist_ok=True) + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative pathes. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + + +def check_resume(opt, resume_iter): + """Check resume states and pretrain_network paths. + + Args: + opt (dict): Options. + resume_iter (int): Resume iteration. + """ + logger = get_root_logger() + if opt['path']['resume_state']: + # get all the networks + networks = [key for key in opt.keys() if key.startswith('network_')] + flag_pretrain = False + for network in networks: + if opt['path'].get(f'pretrain_{network}') is not None: + flag_pretrain = True + if flag_pretrain: + logger.warning('pretrain_network path will be ignored during resuming.') + # set pretrained model paths + for network in networks: + name = f'pretrain_{network}' + basename = network.replace('network_', '') + if opt['path'].get('ignore_resume_networks') is None or (basename + not in opt['path']['ignore_resume_networks']): + opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') + logger.info(f"Set {name} to {opt['path'][name]}") + + +def sizeof_fmt(size, suffix='B'): + """Get human readable file size. + + Args: + size (int): File size. + suffix (str): Suffix. Default: 'B'. + + Return: + str: Formated file siz. + """ + for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: + if abs(size) < 1024.0: + return f'{size:3.1f} {unit}{suffix}' + size /= 1024.0 + return f'{size:3.1f} Y{suffix}' diff --git a/Hallo2/hallo2/basicsr/utils/options.py b/Hallo2/hallo2/basicsr/utils/options.py new file mode 100644 index 00000000..db490e4a --- /dev/null +++ b/Hallo2/hallo2/basicsr/utils/options.py @@ -0,0 +1,108 @@ +import yaml +import time +from collections import OrderedDict +from os import path as osp +from basicsr.utils.misc import get_time_str + +def ordered_yaml(): + """Support OrderedDict for yaml. + + Returns: + yaml Loader and Dumper. + """ + try: + from yaml import CDumper as Dumper + from yaml import CLoader as Loader + except ImportError: + from yaml import Dumper, Loader + + _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG + + def dict_representer(dumper, data): + return dumper.represent_dict(data.items()) + + def dict_constructor(loader, node): + return OrderedDict(loader.construct_pairs(node)) + + Dumper.add_representer(OrderedDict, dict_representer) + Loader.add_constructor(_mapping_tag, dict_constructor) + return Loader, Dumper + + +def parse(opt_path, root_path, is_train=True): + """Parse option file. + + Args: + opt_path (str): Option file path. + is_train (str): Indicate whether in training or not. Default: True. + + Returns: + (dict): Options. + """ + with open(opt_path, mode='r') as f: + Loader, _ = ordered_yaml() + opt = yaml.load(f, Loader=Loader) + + opt['is_train'] = is_train + + # opt['name'] = f"{get_time_str()}_{opt['name']}" + if opt['path'].get('resume_state', None): # Shangchen added + resume_state_path = opt['path'].get('resume_state') + opt['name'] = resume_state_path.split("/")[-3] + else: + opt['name'] = f"{get_time_str()}_{opt['name']}" + + + # datasets + for phase, dataset in opt['datasets'].items(): + # for several datasets, e.g., test_1, test_2 + phase = phase.split('_')[0] + dataset['phase'] = phase + if 'scale' in opt: + dataset['scale'] = opt['scale'] + if dataset.get('dataroot_gt') is not None: + dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) + if dataset.get('dataroot_lq') is not None: + dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) + + # paths + for key, val in opt['path'].items(): + if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): + opt['path'][key] = osp.expanduser(val) + + if is_train: + experiments_root = osp.join(root_path, 'experiments', opt['name']) + opt['path']['experiments_root'] = experiments_root + opt['path']['models'] = osp.join(experiments_root, 'models') + opt['path']['training_states'] = osp.join(experiments_root, 'training_states') + opt['path']['log'] = experiments_root + opt['path']['visualization'] = osp.join(experiments_root, 'visualization') + + else: # test + results_root = osp.join(root_path, 'results', opt['name']) + opt['path']['results_root'] = results_root + opt['path']['log'] = results_root + opt['path']['visualization'] = osp.join(results_root, 'visualization') + + return opt + + +def dict2str(opt, indent_level=1): + """dict to string for printing options. + + Args: + opt (dict): Option dict. + indent_level (int): Indent level. Default: 1. + + Return: + (str): Option string for printing. + """ + msg = '\n' + for k, v in opt.items(): + if isinstance(v, dict): + msg += ' ' * (indent_level * 2) + k + ':[' + msg += dict2str(v, indent_level + 1) + msg += ' ' * (indent_level * 2) + ']\n' + else: + msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' + return msg diff --git a/Hallo2/hallo2/basicsr/utils/realesrgan_utils.py b/Hallo2/hallo2/basicsr/utils/realesrgan_utils.py new file mode 100644 index 00000000..2757a058 --- /dev/null +++ b/Hallo2/hallo2/basicsr/utils/realesrgan_utils.py @@ -0,0 +1,302 @@ +import cv2 +import math +import numpy as np +import os +import queue +import threading +import torch +from torch.nn import functional as F +from basicsr.utils.download_util import load_file_from_url +from basicsr.utils.misc import get_device + +# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +class RealESRGANer(): + """A helper class for upsampling images with RealESRGAN. + + Args: + scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4. + model_path (str): The path to the pretrained model. It can be urls (will first download it automatically). + model (nn.Module): The defined network. Default: None. + tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop + input images into tiles, and then process each of them. Finally, they will be merged into one image. + 0 denotes for do not use tile. Default: 0. + tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10. + pre_pad (int): Pad the input images to avoid border artifacts. Default: 10. + half (float): Whether to use half precision during inference. Default: False. + """ + + def __init__(self, + scale, + model_path, + model=None, + tile=0, + tile_pad=10, + pre_pad=10, + half=False, + device=None, + gpu_id=None): + self.scale = scale + self.tile_size = tile + self.tile_pad = tile_pad + self.pre_pad = pre_pad + self.mod_scale = None + self.half = half + + # initialize model + # if gpu_id: + # self.device = torch.device( + # f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device + # else: + # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device + + self.device = get_device(gpu_id) if device is None else device + + # if the model_path starts with https, it will first download models to the folder: realesrgan/weights + if model_path.startswith('https://'): + model_path = load_file_from_url( + url=model_path, model_dir=os.path.join('weights/realesrgan'), progress=True, file_name=None) + loadnet = torch.load(model_path, map_location=torch.device('cpu')) + # prefer to use params_ema + if 'params_ema' in loadnet: + keyname = 'params_ema' + else: + keyname = 'params' + model.load_state_dict(loadnet[keyname], strict=True) + model.eval() + self.model = model.to(self.device) + if self.half: + self.model = self.model.half() + + def pre_process(self, img): + """Pre-process, such as pre-pad and mod pad, so that the images can be divisible + """ + img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() + self.img = img.unsqueeze(0).to(self.device) + if self.half: + self.img = self.img.half() + + # pre_pad + if self.pre_pad != 0: + self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') + # mod pad for divisible borders + if self.scale == 2: + self.mod_scale = 2 + elif self.scale == 1: + self.mod_scale = 4 + if self.mod_scale is not None: + self.mod_pad_h, self.mod_pad_w = 0, 0 + _, _, h, w = self.img.size() + if (h % self.mod_scale != 0): + self.mod_pad_h = (self.mod_scale - h % self.mod_scale) + if (w % self.mod_scale != 0): + self.mod_pad_w = (self.mod_scale - w % self.mod_scale) + self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') + + def process(self): + # model inference + self.output = self.model(self.img) + + def tile_process(self): + """It will first crop input images to tiles, and then process each tile. + Finally, all the processed tiles are merged into one images. + + Modified from: https://github.com/ata4/esrgan-launcher + """ + batch, channel, height, width = self.img.shape + output_height = height * self.scale + output_width = width * self.scale + output_shape = (batch, channel, output_height, output_width) + + # start with black image + self.output = self.img.new_zeros(output_shape) + tiles_x = math.ceil(width / self.tile_size) + tiles_y = math.ceil(height / self.tile_size) + + # loop over all tiles + for y in range(tiles_y): + for x in range(tiles_x): + # extract tile from input image + ofs_x = x * self.tile_size + ofs_y = y * self.tile_size + # input tile area on total image + input_start_x = ofs_x + input_end_x = min(ofs_x + self.tile_size, width) + input_start_y = ofs_y + input_end_y = min(ofs_y + self.tile_size, height) + + # input tile area on total image with padding + input_start_x_pad = max(input_start_x - self.tile_pad, 0) + input_end_x_pad = min(input_end_x + self.tile_pad, width) + input_start_y_pad = max(input_start_y - self.tile_pad, 0) + input_end_y_pad = min(input_end_y + self.tile_pad, height) + + # input tile dimensions + input_tile_width = input_end_x - input_start_x + input_tile_height = input_end_y - input_start_y + tile_idx = y * tiles_x + x + 1 + input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] + + # upscale tile + try: + with torch.no_grad(): + output_tile = self.model(input_tile) + except RuntimeError as error: + print('Error', error) + # print(f'\tTile {tile_idx}/{tiles_x * tiles_y}') + + # output tile area on total image + output_start_x = input_start_x * self.scale + output_end_x = input_end_x * self.scale + output_start_y = input_start_y * self.scale + output_end_y = input_end_y * self.scale + + # output tile area without padding + output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale + output_end_x_tile = output_start_x_tile + input_tile_width * self.scale + output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale + output_end_y_tile = output_start_y_tile + input_tile_height * self.scale + + # put tile into output image + self.output[:, :, output_start_y:output_end_y, + output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, + output_start_x_tile:output_end_x_tile] + + def post_process(self): + # remove extra pad + if self.mod_scale is not None: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale] + # remove prepad + if self.pre_pad != 0: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale] + return self.output + + @torch.no_grad() + def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'): + h_input, w_input = img.shape[0:2] + # img: numpy + img = img.astype(np.float32) + if np.max(img) > 256: # 16-bit image + max_range = 65535 + print('\tInput is a 16-bit image') + else: + max_range = 255 + img = img / max_range + if len(img.shape) == 2: # gray image + img_mode = 'L' + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + elif img.shape[2] == 4: # RGBA image with alpha channel + img_mode = 'RGBA' + alpha = img[:, :, 3] + img = img[:, :, 0:3] + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if alpha_upsampler == 'realesrgan': + alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) + else: + img_mode = 'RGB' + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # ------------------- process image (without the alpha channel) ------------------- # + try: + with torch.no_grad(): + self.pre_process(img) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_img_t = self.post_process() + output_img = output_img_t.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) + if img_mode == 'L': + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) + del output_img_t + torch.cuda.empty_cache() + except RuntimeError as error: + print(f"Failed inference for RealESRGAN: {error}") + + # ------------------- process the alpha channel if necessary ------------------- # + if img_mode == 'RGBA': + if alpha_upsampler == 'realesrgan': + self.pre_process(alpha) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_alpha = self.post_process() + output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) + output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) + else: # use the cv2 resize for alpha channel + h, w = alpha.shape[0:2] + output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR) + + # merge the alpha channel + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) + output_img[:, :, 3] = output_alpha + + # ------------------------------ return ------------------------------ # + if max_range == 65535: # 16-bit image + output = (output_img * 65535.0).round().astype(np.uint16) + else: + output = (output_img * 255.0).round().astype(np.uint8) + + if outscale is not None and outscale != float(self.scale): + output = cv2.resize( + output, ( + int(w_input * outscale), + int(h_input * outscale), + ), interpolation=cv2.INTER_LANCZOS4) + + return output, img_mode + + +class PrefetchReader(threading.Thread): + """Prefetch images. + + Args: + img_list (list[str]): A image list of image paths to be read. + num_prefetch_queue (int): Number of prefetch queue. + """ + + def __init__(self, img_list, num_prefetch_queue): + super().__init__() + self.que = queue.Queue(num_prefetch_queue) + self.img_list = img_list + + def run(self): + for img_path in self.img_list: + img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) + self.que.put(img) + + self.que.put(None) + + def __next__(self): + next_item = self.que.get() + if next_item is None: + raise StopIteration + return next_item + + def __iter__(self): + return self + + +class IOConsumer(threading.Thread): + + def __init__(self, opt, que, qid): + super().__init__() + self._queue = que + self.qid = qid + self.opt = opt + + def run(self): + while True: + msg = self._queue.get() + if isinstance(msg, str) and msg == 'quit': + break + + output = msg['output'] + save_path = msg['save_path'] + cv2.imwrite(save_path, output) + print(f'IO worker {self.qid} is done.') \ No newline at end of file diff --git a/Hallo2/hallo2/basicsr/utils/registry.py b/Hallo2/hallo2/basicsr/utils/registry.py new file mode 100644 index 00000000..655753b3 --- /dev/null +++ b/Hallo2/hallo2/basicsr/utils/registry.py @@ -0,0 +1,82 @@ +# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 + + +class Registry(): + """ + The registry that provides name -> object mapping, to support third-party + users' custom modules. + + To create a registry (e.g. a backbone registry): + + .. code-block:: python + + BACKBONE_REGISTRY = Registry('BACKBONE') + + To register an object: + + .. code-block:: python + + @BACKBONE_REGISTRY.register() + class MyBackbone(): + ... + + Or: + + .. code-block:: python + + BACKBONE_REGISTRY.register(MyBackbone) + """ + + def __init__(self, name): + """ + Args: + name (str): the name of this registry + """ + self._name = name + self._obj_map = {} + + def _do_register(self, name, obj): + assert (name not in self._obj_map), (f"An object named '{name}' was already registered " + f"in '{self._name}' registry!") + self._obj_map[name] = obj + + def register(self, obj=None): + """ + Register the given object under the the name `obj.__name__`. + Can be used as either a decorator or not. + See docstring of this class for usage. + """ + if obj is None: + # used as a decorator + def deco(func_or_class): + name = func_or_class.__name__ + self._do_register(name, func_or_class) + return func_or_class + + return deco + + # used as a function call + name = obj.__name__ + self._do_register(name, obj) + + def get(self, name): + ret = self._obj_map.get(name) + if ret is None: + raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") + return ret + + def __contains__(self, name): + return name in self._obj_map + + def __iter__(self): + return iter(self._obj_map.items()) + + def keys(self): + return self._obj_map.keys() + + +DATASET_REGISTRY = Registry('dataset') +ARCH_REGISTRY = Registry('arch') +MODEL_REGISTRY = Registry('model') +LOSS_REGISTRY = Registry('loss') +METRIC_REGISTRY = Registry('metric') diff --git a/Hallo2/hallo2/basicsr/utils/video_util.py b/Hallo2/hallo2/basicsr/utils/video_util.py new file mode 100644 index 00000000..20a2ff14 --- /dev/null +++ b/Hallo2/hallo2/basicsr/utils/video_util.py @@ -0,0 +1,125 @@ +''' +The code is modified from the Real-ESRGAN: +https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan_video.py + +''' +import cv2 +import sys +import numpy as np + +try: + import ffmpeg +except ImportError: + import pip + pip.main(['install', '--user', 'ffmpeg-python']) + import ffmpeg + +def get_video_meta_info(video_path): + ret = {} + probe = ffmpeg.probe(video_path) + video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video'] + has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams']) + ret['width'] = video_streams[0]['width'] + ret['height'] = video_streams[0]['height'] + ret['fps'] = eval(video_streams[0]['avg_frame_rate']) + ret['audio'] = ffmpeg.input(video_path).audio if has_audio else None + ret['nb_frames'] = int(video_streams[0]['nb_frames']) + return ret + +class VideoReader: + def __init__(self, video_path): + self.paths = [] # for image&folder type + self.audio = None + try: + self.stream_reader = ( + ffmpeg.input(video_path).output('pipe:', format='rawvideo', pix_fmt='bgr24', + loglevel='error').run_async( + pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg')) + except FileNotFoundError: + print('Please install ffmpeg (not ffmpeg-python) by running\n', + '\t$ conda install -c conda-forge ffmpeg') + sys.exit(0) + + meta = get_video_meta_info(video_path) + self.width = meta['width'] + self.height = meta['height'] + self.input_fps = meta['fps'] + self.audio = meta['audio'] + self.nb_frames = meta['nb_frames'] + + self.idx = 0 + + def get_resolution(self): + return self.height, self.width + + def get_fps(self): + if self.input_fps is not None: + return self.input_fps + return 24 + + def get_audio(self): + return self.audio + + def __len__(self): + return self.nb_frames + + def get_frame_from_stream(self): + img_bytes = self.stream_reader.stdout.read(self.width * self.height * 3) # 3 bytes for one pixel + if not img_bytes: + return None + img = np.frombuffer(img_bytes, np.uint8).reshape([self.height, self.width, 3]) + return img + + def get_frame_from_list(self): + if self.idx >= self.nb_frames: + return None + img = cv2.imread(self.paths[self.idx]) + self.idx += 1 + return img + + def get_frame(self): + return self.get_frame_from_stream() + + + def close(self): + self.stream_reader.stdin.close() + self.stream_reader.wait() + + +class VideoWriter: + def __init__(self, video_save_path, height, width, fps, audio): + if height > 2160: + print('You are generating video that is larger than 4K, which will be very slow due to IO speed.', + 'We highly recommend to decrease the outscale(aka, -s).') + if audio is not None: + self.stream_writer = ( + ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{width}x{height}', + framerate=fps).output( + audio, + video_save_path, + pix_fmt='yuv420p', + vcodec='libx264', + loglevel='error', + acodec='copy').overwrite_output().run_async( + pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg')) + else: + self.stream_writer = ( + ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{width}x{height}', + framerate=fps).output( + video_save_path, pix_fmt='yuv420p', vcodec='libx264', + loglevel='error').overwrite_output().run_async( + pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg')) + + def write_frame(self, frame): + try: + frame = frame.astype(np.uint8).tobytes() + self.stream_writer.stdin.write(frame) + except BrokenPipeError: + print('Please re-install ffmpeg and libx264 by running\n', + '\t$ conda install -c conda-forge ffmpeg\n', + '\t$ conda install -c conda-forge x264') + sys.exit(0) + + def close(self): + self.stream_writer.stdin.close() + self.stream_writer.wait() \ No newline at end of file diff --git a/Hallo2/hallo2/basicsr/version.py b/Hallo2/hallo2/basicsr/version.py new file mode 100644 index 00000000..1fae6d28 --- /dev/null +++ b/Hallo2/hallo2/basicsr/version.py @@ -0,0 +1,5 @@ +# GENERATED VERSION FILE +# TIME: Sat Sep 14 04:53:53 2024 +__version__ = '1.3.2' +__gitsha__ = '' +version_info = (1, 3, 2) diff --git a/Hallo2/hallo2/configs/inference/.gitkeep b/Hallo2/hallo2/configs/inference/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/Hallo2/hallo2/configs/inference/long.yaml b/Hallo2/hallo2/configs/inference/long.yaml new file mode 100644 index 00000000..275dd1be --- /dev/null +++ b/Hallo2/hallo2/configs/inference/long.yaml @@ -0,0 +1,96 @@ +source_image: ./examples/reference_images/1.jpg +driving_audio: ./examples/driving_audios/1.wav + +weight_dtype: fp16 + +data: + n_motion_frames: 2 + n_sample_frames: 16 + source_image: + width: 512 + height: 512 + driving_audio: + sample_rate: 16000 + export_video: + fps: 25 + +inference_steps: 40 +cfg_scale: 3.5 + +use_mask: true +mask_rate: 0.25 +use_cut: true + +audio_ckpt_dir: pretrained_models/hallo2 + + +save_path: ./output_long/debug/ +cache_path: ./.cache + +base_model_path: ./pretrained_models/stable-diffusion-v1-5 + +motion_module_path: ./pretrained_models/motion_module/mm_sd_v15_v2.ckpt + +face_analysis: + model_path: ./pretrained_models/face_analysis + +wav2vec: + model_path: ./pretrained_models/wav2vec/wav2vec2-base-960h + features: all + +audio_separator: + model_path: ./pretrained_models/audio_separator/Kim_Vocal_2.onnx + +vae: + model_path: ./pretrained_models/sd-vae-ft-mse + +face_expand_ratio: 1.2 +pose_weight: 1.0 +face_weight: 1.0 +lip_weight: 1.0 + +unet_additional_kwargs: + use_inflated_groupnorm: true + unet_use_cross_frame_attention: false + unet_use_temporal_attention: false + use_motion_module: true + use_audio_module: true + motion_module_resolutions: + - 1 + - 2 + - 4 + - 8 + motion_module_mid_block: true + motion_module_decoder_only: false + motion_module_type: Vanilla + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: + - Temporal_Self + - Temporal_Self + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 + audio_attention_dim: 768 + stack_enable_blocks_name: + - "up" + - "down" + - "mid" + stack_enable_blocks_depth: [0,1,2,3] + + +enable_zero_snr: true + +noise_scheduler_kwargs: + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "linear" + clip_sample: false + steps_offset: 1 + ### Zero-SNR params + prediction_type: "v_prediction" + rescale_betas_zero_snr: True + timestep_spacing: "trailing" + +sampler: DDIM diff --git a/Hallo2/hallo2/configs/train/stage1.yaml b/Hallo2/hallo2/configs/train/stage1.yaml new file mode 100644 index 00000000..28760ed2 --- /dev/null +++ b/Hallo2/hallo2/configs/train/stage1.yaml @@ -0,0 +1,63 @@ +data: + train_bs: 8 + train_width: 512 + train_height: 512 + meta_paths: + - "./data/HDTF_meta.json" + # Margin of frame indexes between ref and tgt images + sample_margin: 30 + +solver: + gradient_accumulation_steps: 1 + mixed_precision: "no" + enable_xformers_memory_efficient_attention: True + gradient_checkpointing: False + max_train_steps: 30000 + max_grad_norm: 1.0 + # lr + learning_rate: 1.0e-5 + scale_lr: False + lr_warmup_steps: 1 + lr_scheduler: "constant" + + # optimizer + use_8bit_adam: False + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_weight_decay: 1.0e-2 + adam_epsilon: 1.0e-8 + +val: + validation_steps: 500 + +noise_scheduler_kwargs: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "scaled_linear" + steps_offset: 1 + clip_sample: false + +base_model_path: "./pretrained_models/stable-diffusion-v1-5/" +vae_model_path: "./pretrained_models/sd-vae-ft-mse" +face_analysis_model_path: "./pretrained_models/face_analysis" + +weight_dtype: "fp16" # [fp16, fp32] +uncond_ratio: 0.1 +noise_offset: 0.05 +snr_gamma: 5.0 +enable_zero_snr: True +face_locator_pretrained: False + +seed: 42 +resume_from_checkpoint: "latest" +checkpointing_steps: 500 +exp_name: "stage1" +output_dir: "./exp_output" + +ref_image_paths: + - "examples/reference_images/1.jpg" + +mask_image_paths: + - "examples/masks/1.png" + diff --git a/Hallo2/hallo2/configs/train/stage2_long.yaml b/Hallo2/hallo2/configs/train/stage2_long.yaml new file mode 100644 index 00000000..73ad5187 --- /dev/null +++ b/Hallo2/hallo2/configs/train/stage2_long.yaml @@ -0,0 +1,125 @@ +data: + train_bs: 4 + val_bs: 1 + train_width: 512 + train_height: 512 + fps: 25 + sample_rate: 16000 + n_motion_frames: 2 + n_sample_frames: 14 + audio_margin: 2 + train_meta_paths: + - "./data/hdtf_split_stage2.json" + +wav2vec_config: + audio_type: "vocals" # audio vocals + model_scale: "base" # base large + features: "all" # last avg all + model_path: ./pretrained_models/wav2vec/wav2vec2-base-960h +audio_separator: + model_path: ./pretrained_models/audio_separator/Kim_Vocal_2.onnx +face_expand_ratio: 1.2 + +solver: + gradient_accumulation_steps: 1 + mixed_precision: "no" + enable_xformers_memory_efficient_attention: True + gradient_checkpointing: True + max_train_steps: 30000 + max_grad_norm: 1.0 + # lr + learning_rate: 1e-5 + scale_lr: False + lr_warmup_steps: 1 + lr_scheduler: "constant" + + # optimizer + use_8bit_adam: True + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_weight_decay: 1.0e-2 + adam_epsilon: 1.0e-8 + +val: + validation_steps: 1000 + +noise_scheduler_kwargs: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "linear" + steps_offset: 1 + clip_sample: false + +unet_additional_kwargs: + use_inflated_groupnorm: true + unet_use_cross_frame_attention: false + unet_use_temporal_attention: false + use_motion_module: true + use_audio_module: true + motion_module_resolutions: + - 1 + - 2 + - 4 + - 8 + motion_module_mid_block: true + motion_module_decoder_only: false + motion_module_type: Vanilla + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: + - Temporal_Self + - Temporal_Self + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 + audio_attention_dim: 768 + stack_enable_blocks_name: + - "up" + - "down" + - "mid" + stack_enable_blocks_depth: [0,1,2,3] + + +trainable_para: + # - audio_modules + - motion_modules + +base_model_path: "./pretrained_models/stable-diffusion-v1-5/" +vae_model_path: "./pretrained_models/sd-vae-ft-mse" +face_analysis_model_path: "./pretrained_models/face_analysis" +mm_path: "./pretrained_models/motion_module/mm_sd_v15_v2.ckpt" + +weight_dtype: "fp16" # [fp16, fp32] +uncond_img_ratio: 0.05 +uncond_audio_ratio: 0.05 +uncond_ia_ratio: 0.05 +start_ratio: 0.05 +noise_offset: 0.05 +snr_gamma: 5.0 +enable_zero_snr: True + +audio_ckpt_dir: ./pretrained_models/hallo + + +single_inference_times: 10 +inference_steps: 40 +cfg_scale: 3.5 +use_mask: true +mask_rate: 0.25 + + +seed: 42 +resume_from_checkpoint: "latest" +checkpointing_steps: 500 + +exp_name: "stage2_long" +output_dir: "./exp_output" + +ref_img_path: + - "./examples/reference_images/1.jpg" +audio_path: + - "./examples/driving_audios/1.wav" + + diff --git a/Hallo2/hallo2/configs/train/video_sr.yaml b/Hallo2/hallo2/configs/train/video_sr.yaml new file mode 100644 index 00000000..28ba65db --- /dev/null +++ b/Hallo2/hallo2/configs/train/video_sr.yaml @@ -0,0 +1,148 @@ +# general settings +name: CodeFormer_temp +model_type: CodeFormerTempModel +num_gpu: 8 +manual_seed: 0 + +# dataset and data loader settings +datasets: + train: + name: VFHQ + type: VFHQBlindDataset + dataroot_gt: ./VFHQ/image + filename_tmpl: '{}' + io_backend: + type: disk + + in_size: 512 + gt_size: 512 + mean: [0.5, 0.5, 0.5] + std: [0.5, 0.5, 0.5] + use_hflip: true + use_corrupt: true + video_length: 16 + + # large degradation in stageII + blur_kernel_size: 41 + use_motion_kernel: false + motion_kernel_prob: 0.001 + kernel_list: ['iso', 'aniso'] + kernel_prob: [0.5, 0.5] + blur_sigma: [1, 15] + downsample_range: [4, 30] + noise_range: [0, 20] + jpeg_range: [30, 80] + + latent_gt_path: ~ # without pre-calculated latent code + + # data loader + num_worker_per_gpu: 8 + batch_size_per_gpu: 4 + dataset_enlarge_ratio: 1 + prefetch_mode: ~ + + # val: + # name: CelebA-HQ-512 + # type: PairedImageDataset + # dataroot_lq: datasets/faces/validation/lq + # dataroot_gt: datasets/faces/validation/gt + # io_backend: + # type: disk + # mean: [0.5, 0.5, 0.5] + # std: [0.5, 0.5, 0.5] + # scale: 1 + +# network structures +network_g: + type: CodeFormer + dim_embd: 512 + n_head: 8 + n_layers: 9 + codebook_size: 1024 + connect_list: ['32', '64', '128', '256'] + fix_modules: ['quantize','generator'] + vqgan_path: './pretrained_models/CodeFormer/vqgan_code1024.pth' # pretrained VQGAN + +network_vqgan: # this config is needed if no pre-calculated latent + type: VQAutoEncoder + img_size: 512 + nf: 64 + ch_mult: [1, 2, 2, 4, 4, 8] + quantizer: 'nearest' + codebook_size: 1024 + model_path: './pretrained_models/CodeFormer/vqgan_code1024.pth' + +# path +path: + pretrain_network_g: './pretrained_models/CodeFormer/codeformer.pth' + param_key_g: params_ema + strict_load_g: false + pretrain_network_d: ~ + strict_load_d: true + resume_state: ~ + +# base_lr(4.5e-6)*bach_size(4) +train: + use_hq_feat_loss: true + feat_loss_weight: 1.0 + cross_entropy_loss: true + entropy_loss_weight: 0.5 + fidelity_weight: 0 + + trainable_para: temp + + optim_g: + type: Adam + lr: !!float 1e-4 + weight_decay: 0 + betas: [0.9, 0.99] + + scheduler: + type: MultiStepLR + milestones: [400000, 450000] + gamma: 0.5 + + # scheduler: + # type: CosineAnnealingRestartLR + # periods: [500000] + # restart_weights: [1] + # eta_min: !!float 2e-5 # no lr reduce in official vqgan code + + total_iter: 500000 + + warmup_iter: -1 # no warm up + ema_decay: 0.995 + + use_adaptive_weight: true + + net_g_start_iter: 0 + net_d_iters: 1 + net_d_start_iter: 0 + manual_seed: 0 + +# validation settings +val: + val_freq: 1000 + save_img: true + + metrics: + psnr: # metric name, can be arbitrary + type: calculate_psnr + crop_border: 4 + test_y_channel: false + +# logging settings +logger: + print_freq: 1 + save_checkpoint_freq: 1000 + use_tb_logger: true + wandb: + project: ~ + resume_id: ~ + +# dist training settings +dist_params: + backend: nccl + port: 29412 + +find_unused_parameters: true diff --git a/Hallo2/hallo2/configs/unet/unet.yaml b/Hallo2/hallo2/configs/unet/unet.yaml new file mode 100644 index 00000000..aa4e3d9c --- /dev/null +++ b/Hallo2/hallo2/configs/unet/unet.yaml @@ -0,0 +1,44 @@ +unet_additional_kwargs: + use_inflated_groupnorm: true + unet_use_cross_frame_attention: false + unet_use_temporal_attention: false + use_motion_module: true + use_audio_module: true + motion_module_resolutions: + - 1 + - 2 + - 4 + - 8 + motion_module_mid_block: true + motion_module_decoder_only: false + motion_module_type: Vanilla + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: + - Temporal_Self + - Temporal_Self + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 + audio_attention_dim: 768 + stack_enable_blocks_name: + - "up" + - "down" + - "mid" + stack_enable_blocks_depth: [0,1,2,3] + +enable_zero_snr: true + +noise_scheduler_kwargs: + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "linear" + clip_sample: false + steps_offset: 1 + ### Zero-SNR params + prediction_type: "v_prediction" + rescale_betas_zero_snr: True + timestep_spacing: "trailing" + +sampler: DDIM diff --git a/Hallo2/hallo2/examples/driving_audios/1.wav b/Hallo2/hallo2/examples/driving_audios/1.wav new file mode 100644 index 00000000..9f3e2325 Binary files /dev/null and b/Hallo2/hallo2/examples/driving_audios/1.wav differ diff --git a/Hallo2/hallo2/examples/driving_audios/2.wav b/Hallo2/hallo2/examples/driving_audios/2.wav new file mode 100644 index 00000000..740a5f01 Binary files /dev/null and b/Hallo2/hallo2/examples/driving_audios/2.wav differ diff --git a/Hallo2/hallo2/examples/driving_audios/3.wav b/Hallo2/hallo2/examples/driving_audios/3.wav new file mode 100644 index 00000000..acafdf22 Binary files /dev/null and b/Hallo2/hallo2/examples/driving_audios/3.wav differ diff --git a/Hallo2/hallo2/examples/driving_audios/4.wav b/Hallo2/hallo2/examples/driving_audios/4.wav new file mode 100644 index 00000000..d42d5874 Binary files /dev/null and b/Hallo2/hallo2/examples/driving_audios/4.wav differ diff --git a/Hallo2/hallo2/examples/driving_audios/5.wav b/Hallo2/hallo2/examples/driving_audios/5.wav new file mode 100644 index 00000000..bf35af7f Binary files /dev/null and b/Hallo2/hallo2/examples/driving_audios/5.wav differ diff --git a/Hallo2/hallo2/examples/masks/1.png b/Hallo2/hallo2/examples/masks/1.png new file mode 100644 index 00000000..c63e0757 Binary files /dev/null and b/Hallo2/hallo2/examples/masks/1.png differ diff --git a/Hallo2/hallo2/examples/reference_images/1.jpg b/Hallo2/hallo2/examples/reference_images/1.jpg new file mode 100644 index 00000000..d7d4f55a Binary files /dev/null and b/Hallo2/hallo2/examples/reference_images/1.jpg differ diff --git a/Hallo2/hallo2/examples/reference_images/2.jpg b/Hallo2/hallo2/examples/reference_images/2.jpg new file mode 100644 index 00000000..7a9b0003 Binary files /dev/null and b/Hallo2/hallo2/examples/reference_images/2.jpg differ diff --git a/Hallo2/hallo2/examples/reference_images/3.jpg b/Hallo2/hallo2/examples/reference_images/3.jpg new file mode 100644 index 00000000..3c1161cb Binary files /dev/null and b/Hallo2/hallo2/examples/reference_images/3.jpg differ diff --git a/Hallo2/hallo2/examples/reference_images/4.jpg b/Hallo2/hallo2/examples/reference_images/4.jpg new file mode 100644 index 00000000..60069f8c Binary files /dev/null and b/Hallo2/hallo2/examples/reference_images/4.jpg differ diff --git a/Hallo2/hallo2/examples/reference_images/5.jpg b/Hallo2/hallo2/examples/reference_images/5.jpg new file mode 100644 index 00000000..d4f5e47b Binary files /dev/null and b/Hallo2/hallo2/examples/reference_images/5.jpg differ diff --git a/Hallo2/hallo2/examples/reference_images/6.jpg b/Hallo2/hallo2/examples/reference_images/6.jpg new file mode 100644 index 00000000..22f1940d Binary files /dev/null and b/Hallo2/hallo2/examples/reference_images/6.jpg differ diff --git a/Hallo2/hallo2/facelib/detection/__init__.py b/Hallo2/hallo2/facelib/detection/__init__.py new file mode 100644 index 00000000..e665ded1 --- /dev/null +++ b/Hallo2/hallo2/facelib/detection/__init__.py @@ -0,0 +1,100 @@ +import os +import torch +from torch import nn +from copy import deepcopy + +from facelib.utils import load_file_from_url +from facelib.utils import download_pretrained_models +from facelib.detection.yolov5face.models.common import Conv + +from .retinaface.retinaface import RetinaFace +from .yolov5face.face_detector import YoloDetector + + +def init_detection_model(model_name, half=False, device='cuda'): + if 'retinaface' in model_name: + model = init_retinaface_model(model_name, half, device) + elif 'YOLOv5' in model_name: + model = init_yolov5face_model(model_name, device) + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + return model + + +def init_retinaface_model(model_name, half=False, device='cuda'): + if model_name == 'retinaface_resnet50': + model = RetinaFace(network_name='resnet50', half=half) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth' + elif model_name == 'retinaface_mobile0.25': + model = RetinaFace(network_name='mobile0.25', half=half) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth' + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + model_path = load_file_from_url(url=model_url, model_dir='pretrained_models/facelib', progress=True, file_name=None) + load_net = torch.load(model_path, map_location=lambda storage, loc: storage) + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + model.load_state_dict(load_net, strict=True) + model.eval() + model = model.to(device) + + return model + + +def init_yolov5face_model(model_name, device='cuda'): + if model_name == 'YOLOv5l': + model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth' + elif model_name == 'YOLOv5n': + model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth' + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + model_path = load_file_from_url(url=model_url, model_dir='pretrained_models/facelib', progress=True, file_name=None) + load_net = torch.load(model_path, map_location=lambda storage, loc: storage) + model.detector.load_state_dict(load_net, strict=True) + model.detector.eval() + model.detector = model.detector.to(device).float() + + for m in model.detector.modules(): + if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: + m.inplace = True # pytorch 1.7.0 compatibility + elif isinstance(m, Conv): + m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility + + return model + + +# Download from Google Drive +# def init_yolov5face_model(model_name, device='cuda'): +# if model_name == 'YOLOv5l': +# model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device) +# f_id = {'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV'} +# elif model_name == 'YOLOv5n': +# model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device) +# f_id = {'yolov5n-face.pth': '1fhcpFvWZqghpGXjYPIne2sw1Fy4yhw6o'} +# else: +# raise NotImplementedError(f'{model_name} is not implemented.') + +# model_path = os.path.join('weights/facelib', list(f_id.keys())[0]) +# if not os.path.exists(model_path): +# download_pretrained_models(file_ids=f_id, save_path_root='weights/facelib') + +# load_net = torch.load(model_path, map_location=lambda storage, loc: storage) +# model.detector.load_state_dict(load_net, strict=True) +# model.detector.eval() +# model.detector = model.detector.to(device).float() + +# for m in model.detector.modules(): +# if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: +# m.inplace = True # pytorch 1.7.0 compatibility +# elif isinstance(m, Conv): +# m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility + +# return model \ No newline at end of file diff --git a/Hallo2/hallo2/facelib/detection/align_trans.py b/Hallo2/hallo2/facelib/detection/align_trans.py new file mode 100644 index 00000000..07f1eb36 --- /dev/null +++ b/Hallo2/hallo2/facelib/detection/align_trans.py @@ -0,0 +1,219 @@ +import cv2 +import numpy as np + +from .matlab_cp2tform import get_similarity_transform_for_cv2 + +# reference facial points, a list of coordinates (x,y) +REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], [65.53179932, 51.50139999], [48.02519989, 71.73660278], + [33.54930115, 92.3655014], [62.72990036, 92.20410156]] + +DEFAULT_CROP_SIZE = (96, 112) + + +class FaceWarpException(Exception): + + def __str__(self): + return 'In File {}:{}'.format(__file__, super.__str__(self)) + + +def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False): + """ + Function: + ---------- + get reference 5 key points according to crop settings: + 0. Set default crop_size: + if default_square: + crop_size = (112, 112) + else: + crop_size = (96, 112) + 1. Pad the crop_size by inner_padding_factor in each side; + 2. Resize crop_size into (output_size - outer_padding*2), + pad into output_size with outer_padding; + 3. Output reference_5point; + Parameters: + ---------- + @output_size: (w, h) or None + size of aligned face image + @inner_padding_factor: (w_factor, h_factor) + padding factor for inner (w, h) + @outer_padding: (w_pad, h_pad) + each row is a pair of coordinates (x, y) + @default_square: True or False + if True: + default crop_size = (112, 112) + else: + default crop_size = (96, 112); + !!! make sure, if output_size is not None: + (output_size - outer_padding) + = some_scale * (default crop_size * (1.0 + + inner_padding_factor)) + Returns: + ---------- + @reference_5point: 5x2 np.array + each row is a pair of transformed coordinates (x, y) + """ + + tmp_5pts = np.array(REFERENCE_FACIAL_POINTS) + tmp_crop_size = np.array(DEFAULT_CROP_SIZE) + + # 0) make the inner region a square + if default_square: + size_diff = max(tmp_crop_size) - tmp_crop_size + tmp_5pts += size_diff / 2 + tmp_crop_size += size_diff + + if (output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]): + + return tmp_5pts + + if (inner_padding_factor == 0 and outer_padding == (0, 0)): + if output_size is None: + return tmp_5pts + else: + raise FaceWarpException('No paddings to do, output_size must be None or {}'.format(tmp_crop_size)) + + # check output size + if not (0 <= inner_padding_factor <= 1.0): + raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)') + + if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None): + output_size = tmp_crop_size * \ + (1 + inner_padding_factor * 2).astype(np.int32) + output_size += np.array(outer_padding) + if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]): + raise FaceWarpException('Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])') + + # 1) pad the inner region according inner_padding_factor + if inner_padding_factor > 0: + size_diff = tmp_crop_size * inner_padding_factor * 2 + tmp_5pts += size_diff / 2 + tmp_crop_size += np.round(size_diff).astype(np.int32) + + # 2) resize the padded inner region + size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 + + if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]: + raise FaceWarpException('Must have (output_size - outer_padding)' + '= some_scale * (crop_size * (1.0 + inner_padding_factor)') + + scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0] + tmp_5pts = tmp_5pts * scale_factor + # size_diff = tmp_crop_size * (scale_factor - min(scale_factor)) + # tmp_5pts = tmp_5pts + size_diff / 2 + tmp_crop_size = size_bf_outer_pad + + # 3) add outer_padding to make output_size + reference_5point = tmp_5pts + np.array(outer_padding) + tmp_crop_size = output_size + + return reference_5point + + +def get_affine_transform_matrix(src_pts, dst_pts): + """ + Function: + ---------- + get affine transform matrix 'tfm' from src_pts to dst_pts + Parameters: + ---------- + @src_pts: Kx2 np.array + source points matrix, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points matrix, each row is a pair of coordinates (x, y) + Returns: + ---------- + @tfm: 2x3 np.array + transform matrix from src_pts to dst_pts + """ + + tfm = np.float32([[1, 0, 0], [0, 1, 0]]) + n_pts = src_pts.shape[0] + ones = np.ones((n_pts, 1), src_pts.dtype) + src_pts_ = np.hstack([src_pts, ones]) + dst_pts_ = np.hstack([dst_pts, ones]) + + A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) + + if rank == 3: + tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]]) + elif rank == 2: + tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]]) + + return tfm + + +def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type='smilarity'): + """ + Function: + ---------- + apply affine transform 'trans' to uv + Parameters: + ---------- + @src_img: 3x3 np.array + input image + @facial_pts: could be + 1)a list of K coordinates (x,y) + or + 2) Kx2 or 2xK np.array + each row or col is a pair of coordinates (x, y) + @reference_pts: could be + 1) a list of K coordinates (x,y) + or + 2) Kx2 or 2xK np.array + each row or col is a pair of coordinates (x, y) + or + 3) None + if None, use default reference facial points + @crop_size: (w, h) + output face image size + @align_type: transform type, could be one of + 1) 'similarity': use similarity transform + 2) 'cv2_affine': use the first 3 points to do affine transform, + by calling cv2.getAffineTransform() + 3) 'affine': use all points to do affine transform + Returns: + ---------- + @face_img: output face image with size (w, h) = @crop_size + """ + + if reference_pts is None: + if crop_size[0] == 96 and crop_size[1] == 112: + reference_pts = REFERENCE_FACIAL_POINTS + else: + default_square = False + inner_padding_factor = 0 + outer_padding = (0, 0) + output_size = crop_size + + reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding, + default_square) + + ref_pts = np.float32(reference_pts) + ref_pts_shp = ref_pts.shape + if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: + raise FaceWarpException('reference_pts.shape must be (K,2) or (2,K) and K>2') + + if ref_pts_shp[0] == 2: + ref_pts = ref_pts.T + + src_pts = np.float32(facial_pts) + src_pts_shp = src_pts.shape + if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: + raise FaceWarpException('facial_pts.shape must be (K,2) or (2,K) and K>2') + + if src_pts_shp[0] == 2: + src_pts = src_pts.T + + if src_pts.shape != ref_pts.shape: + raise FaceWarpException('facial_pts and reference_pts must have the same shape') + + if align_type == 'cv2_affine': + tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) + elif align_type == 'affine': + tfm = get_affine_transform_matrix(src_pts, ref_pts) + else: + tfm = get_similarity_transform_for_cv2(src_pts, ref_pts) + + face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1])) + + return face_img diff --git a/Hallo2/hallo2/facelib/detection/matlab_cp2tform.py b/Hallo2/hallo2/facelib/detection/matlab_cp2tform.py new file mode 100644 index 00000000..b2a8b54a --- /dev/null +++ b/Hallo2/hallo2/facelib/detection/matlab_cp2tform.py @@ -0,0 +1,317 @@ +import numpy as np +from numpy.linalg import inv, lstsq +from numpy.linalg import matrix_rank as rank +from numpy.linalg import norm + + +class MatlabCp2tormException(Exception): + + def __str__(self): + return 'In File {}:{}'.format(__file__, super.__str__(self)) + + +def tformfwd(trans, uv): + """ + Function: + ---------- + apply affine transform 'trans' to uv + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix + @uv: Kx2 np.array + each row is a pair of coordinates (x, y) + + Returns: + ---------- + @xy: Kx2 np.array + each row is a pair of transformed coordinates (x, y) + """ + uv = np.hstack((uv, np.ones((uv.shape[0], 1)))) + xy = np.dot(uv, trans) + xy = xy[:, 0:-1] + return xy + + +def tforminv(trans, uv): + """ + Function: + ---------- + apply the inverse of affine transform 'trans' to uv + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix + @uv: Kx2 np.array + each row is a pair of coordinates (x, y) + + Returns: + ---------- + @xy: Kx2 np.array + each row is a pair of inverse-transformed coordinates (x, y) + """ + Tinv = inv(trans) + xy = tformfwd(Tinv, uv) + return xy + + +def findNonreflectiveSimilarity(uv, xy, options=None): + options = {'K': 2} + + K = options['K'] + M = xy.shape[0] + x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector + y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector + + tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1)))) + tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1)))) + X = np.vstack((tmp1, tmp2)) + + u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector + v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector + U = np.vstack((u, v)) + + # We know that X * r = U + if rank(X) >= 2 * K: + r, _, _, _ = lstsq(X, U, rcond=-1) + r = np.squeeze(r) + else: + raise Exception('cp2tform:twoUniquePointsReq') + sc = r[0] + ss = r[1] + tx = r[2] + ty = r[3] + + Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]]) + T = inv(Tinv) + T[:, 2] = np.array([0, 0, 1]) + + return T, Tinv + + +def findSimilarity(uv, xy, options=None): + options = {'K': 2} + + # uv = np.array(uv) + # xy = np.array(xy) + + # Solve for trans1 + trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options) + + # Solve for trans2 + + # manually reflect the xy data across the Y-axis + xyR = xy + xyR[:, 0] = -1 * xyR[:, 0] + + trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options) + + # manually reflect the tform to undo the reflection done on xyR + TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) + + trans2 = np.dot(trans2r, TreflectY) + + # Figure out if trans1 or trans2 is better + xy1 = tformfwd(trans1, uv) + norm1 = norm(xy1 - xy) + + xy2 = tformfwd(trans2, uv) + norm2 = norm(xy2 - xy) + + if norm1 <= norm2: + return trans1, trans1_inv + else: + trans2_inv = inv(trans2) + return trans2, trans2_inv + + +def get_similarity_transform(src_pts, dst_pts, reflective=True): + """ + Function: + ---------- + Find Similarity Transform Matrix 'trans': + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y, 1] = [u, v, 1] * trans + + Parameters: + ---------- + @src_pts: Kx2 np.array + source points, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points, each row is a pair of transformed + coordinates (x, y) + @reflective: True or False + if True: + use reflective similarity transform + else: + use non-reflective similarity transform + + Returns: + ---------- + @trans: 3x3 np.array + transform matrix from uv to xy + trans_inv: 3x3 np.array + inverse of trans, transform matrix from xy to uv + """ + + if reflective: + trans, trans_inv = findSimilarity(src_pts, dst_pts) + else: + trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts) + + return trans, trans_inv + + +def cvt_tform_mat_for_cv2(trans): + """ + Function: + ---------- + Convert Transform Matrix 'trans' into 'cv2_trans' which could be + directly used by cv2.warpAffine(): + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y].T = cv_trans * [u, v, 1].T + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix from uv to xy + + Returns: + ---------- + @cv2_trans: 2x3 np.array + transform matrix from src_pts to dst_pts, could be directly used + for cv2.warpAffine() + """ + cv2_trans = trans[:, 0:2].T + + return cv2_trans + + +def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True): + """ + Function: + ---------- + Find Similarity Transform Matrix 'cv2_trans' which could be + directly used by cv2.warpAffine(): + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y].T = cv_trans * [u, v, 1].T + + Parameters: + ---------- + @src_pts: Kx2 np.array + source points, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points, each row is a pair of transformed + coordinates (x, y) + reflective: True or False + if True: + use reflective similarity transform + else: + use non-reflective similarity transform + + Returns: + ---------- + @cv2_trans: 2x3 np.array + transform matrix from src_pts to dst_pts, could be directly used + for cv2.warpAffine() + """ + trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective) + cv2_trans = cvt_tform_mat_for_cv2(trans) + + return cv2_trans + + +if __name__ == '__main__': + """ + u = [0, 6, -2] + v = [0, 3, 5] + x = [-1, 0, 4] + y = [-1, -10, 4] + + # In Matlab, run: + # + # uv = [u'; v']; + # xy = [x'; y']; + # tform_sim=cp2tform(uv,xy,'similarity'); + # + # trans = tform_sim.tdata.T + # ans = + # -0.0764 -1.6190 0 + # 1.6190 -0.0764 0 + # -3.2156 0.0290 1.0000 + # trans_inv = tform_sim.tdata.Tinv + # ans = + # + # -0.0291 0.6163 0 + # -0.6163 -0.0291 0 + # -0.0756 1.9826 1.0000 + # xy_m=tformfwd(tform_sim, u,v) + # + # xy_m = + # + # -3.2156 0.0290 + # 1.1833 -9.9143 + # 5.0323 2.8853 + # uv_m=tforminv(tform_sim, x,y) + # + # uv_m = + # + # 0.5698 1.3953 + # 6.0872 2.2733 + # -2.6570 4.3314 + """ + u = [0, 6, -2] + v = [0, 3, 5] + x = [-1, 0, 4] + y = [-1, -10, 4] + + uv = np.array((u, v)).T + xy = np.array((x, y)).T + + print('\n--->uv:') + print(uv) + print('\n--->xy:') + print(xy) + + trans, trans_inv = get_similarity_transform(uv, xy) + + print('\n--->trans matrix:') + print(trans) + + print('\n--->trans_inv matrix:') + print(trans_inv) + + print('\n---> apply transform to uv') + print('\nxy_m = uv_augmented * trans') + uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1)))) + xy_m = np.dot(uv_aug, trans) + print(xy_m) + + print('\nxy_m = tformfwd(trans, uv)') + xy_m = tformfwd(trans, uv) + print(xy_m) + + print('\n---> apply inverse transform to xy') + print('\nuv_m = xy_augmented * trans_inv') + xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1)))) + uv_m = np.dot(xy_aug, trans_inv) + print(uv_m) + + print('\nuv_m = tformfwd(trans_inv, xy)') + uv_m = tformfwd(trans_inv, xy) + print(uv_m) + + uv_m = tforminv(trans, xy) + print('\nuv_m = tforminv(trans, xy)') + print(uv_m) diff --git a/Hallo2/hallo2/facelib/detection/retinaface/retinaface.py b/Hallo2/hallo2/facelib/detection/retinaface/retinaface.py new file mode 100644 index 00000000..c9c0b5a1 --- /dev/null +++ b/Hallo2/hallo2/facelib/detection/retinaface/retinaface.py @@ -0,0 +1,372 @@ +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from torchvision.models._utils import IntermediateLayerGetter as IntermediateLayerGetter + +from facelib.detection.align_trans import get_reference_facial_points, warp_and_crop_face +from facelib.detection.retinaface.retinaface_net import FPN, SSH, MobileNetV1, make_bbox_head, make_class_head, make_landmark_head +from facelib.detection.retinaface.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm, + py_cpu_nms) + +from basicsr.utils.misc import get_device +# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +device = get_device() + + +def generate_config(network_name): + + cfg_mnet = { + 'name': 'mobilenet0.25', + 'min_sizes': [[16, 32], [64, 128], [256, 512]], + 'steps': [8, 16, 32], + 'variance': [0.1, 0.2], + 'clip': False, + 'loc_weight': 2.0, + 'gpu_train': True, + 'batch_size': 32, + 'ngpu': 1, + 'epoch': 250, + 'decay1': 190, + 'decay2': 220, + 'image_size': 640, + 'return_layers': { + 'stage1': 1, + 'stage2': 2, + 'stage3': 3 + }, + 'in_channel': 32, + 'out_channel': 64 + } + + cfg_re50 = { + 'name': 'Resnet50', + 'min_sizes': [[16, 32], [64, 128], [256, 512]], + 'steps': [8, 16, 32], + 'variance': [0.1, 0.2], + 'clip': False, + 'loc_weight': 2.0, + 'gpu_train': True, + 'batch_size': 24, + 'ngpu': 4, + 'epoch': 100, + 'decay1': 70, + 'decay2': 90, + 'image_size': 840, + 'return_layers': { + 'layer2': 1, + 'layer3': 2, + 'layer4': 3 + }, + 'in_channel': 256, + 'out_channel': 256 + } + + if network_name == 'mobile0.25': + return cfg_mnet + elif network_name == 'resnet50': + return cfg_re50 + else: + raise NotImplementedError(f'network_name={network_name}') + + +class RetinaFace(nn.Module): + + def __init__(self, network_name='resnet50', half=False, phase='test'): + super(RetinaFace, self).__init__() + self.half_inference = half + cfg = generate_config(network_name) + self.backbone = cfg['name'] + + self.model_name = f'retinaface_{network_name}' + self.cfg = cfg + self.phase = phase + self.target_size, self.max_size = 1600, 2150 + self.resize, self.scale, self.scale1 = 1., None, None + self.mean_tensor = torch.tensor([[[[104.]], [[117.]], [[123.]]]]).to(device) + self.reference = get_reference_facial_points(default_square=True) + # Build network. + backbone = None + if cfg['name'] == 'mobilenet0.25': + backbone = MobileNetV1() + self.body = IntermediateLayerGetter(backbone, cfg['return_layers']) + elif cfg['name'] == 'Resnet50': + import torchvision.models as models + backbone = models.resnet50(pretrained=False) + self.body = IntermediateLayerGetter(backbone, cfg['return_layers']) + + in_channels_stage2 = cfg['in_channel'] + in_channels_list = [ + in_channels_stage2 * 2, + in_channels_stage2 * 4, + in_channels_stage2 * 8, + ] + + out_channels = cfg['out_channel'] + self.fpn = FPN(in_channels_list, out_channels) + self.ssh1 = SSH(out_channels, out_channels) + self.ssh2 = SSH(out_channels, out_channels) + self.ssh3 = SSH(out_channels, out_channels) + + self.ClassHead = make_class_head(fpn_num=3, inchannels=cfg['out_channel']) + self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg['out_channel']) + self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg['out_channel']) + + self.to(device) + self.eval() + if self.half_inference: + self.half() + + def forward(self, inputs): + out = self.body(inputs) + + if self.backbone == 'mobilenet0.25' or self.backbone == 'Resnet50': + out = list(out.values()) + # FPN + fpn = self.fpn(out) + + # SSH + feature1 = self.ssh1(fpn[0]) + feature2 = self.ssh2(fpn[1]) + feature3 = self.ssh3(fpn[2]) + features = [feature1, feature2, feature3] + + bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1) + classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1) + tmp = [self.LandmarkHead[i](feature) for i, feature in enumerate(features)] + ldm_regressions = (torch.cat(tmp, dim=1)) + + if self.phase == 'train': + output = (bbox_regressions, classifications, ldm_regressions) + else: + output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions) + return output + + def __detect_faces(self, inputs): + # get scale + height, width = inputs.shape[2:] + self.scale = torch.tensor([width, height, width, height], dtype=torch.float32).to(device) + tmp = [width, height, width, height, width, height, width, height, width, height] + self.scale1 = torch.tensor(tmp, dtype=torch.float32).to(device) + + # forawrd + inputs = inputs.to(device) + if self.half_inference: + inputs = inputs.half() + loc, conf, landmarks = self(inputs) + + # get priorbox + priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:]) + priors = priorbox.forward().to(device) + + return loc, conf, landmarks, priors + + # single image detection + def transform(self, image, use_origin_size): + # convert to opencv format + if isinstance(image, Image.Image): + image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) + image = image.astype(np.float32) + + # testing scale + im_size_min = np.min(image.shape[0:2]) + im_size_max = np.max(image.shape[0:2]) + resize = float(self.target_size) / float(im_size_min) + + # prevent bigger axis from being more than max_size + if np.round(resize * im_size_max) > self.max_size: + resize = float(self.max_size) / float(im_size_max) + resize = 1 if use_origin_size else resize + + # resize + if resize != 1: + image = cv2.resize(image, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR) + + # convert to torch.tensor format + # image -= (104, 117, 123) + image = image.transpose(2, 0, 1) + image = torch.from_numpy(image).unsqueeze(0) + + return image, resize + + def detect_faces( + self, + image, + conf_threshold=0.8, + nms_threshold=0.4, + use_origin_size=True, + ): + """ + Params: + imgs: BGR image + """ + image, self.resize = self.transform(image, use_origin_size) + image = image.to(device) + if self.half_inference: + image = image.half() + image = image - self.mean_tensor + + loc, conf, landmarks, priors = self.__detect_faces(image) + + boxes = decode(loc.data.squeeze(0), priors.data, self.cfg['variance']) + boxes = boxes * self.scale / self.resize + boxes = boxes.cpu().numpy() + + scores = conf.squeeze(0).data.cpu().numpy()[:, 1] + + landmarks = decode_landm(landmarks.squeeze(0), priors, self.cfg['variance']) + landmarks = landmarks * self.scale1 / self.resize + landmarks = landmarks.cpu().numpy() + + # ignore low scores + inds = np.where(scores > conf_threshold)[0] + boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds] + + # sort + order = scores.argsort()[::-1] + boxes, landmarks, scores = boxes[order], landmarks[order], scores[order] + + # do NMS + bounding_boxes = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) + keep = py_cpu_nms(bounding_boxes, nms_threshold) + bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep] + # self.t['forward_pass'].toc() + # print(self.t['forward_pass'].average_time) + # import sys + # sys.stdout.flush() + return np.concatenate((bounding_boxes, landmarks), axis=1) + + def __align_multi(self, image, boxes, landmarks, limit=None): + + if len(boxes) < 1: + return [], [] + + if limit: + boxes = boxes[:limit] + landmarks = landmarks[:limit] + + faces = [] + for landmark in landmarks: + facial5points = [[landmark[2 * j], landmark[2 * j + 1]] for j in range(5)] + + warped_face = warp_and_crop_face(np.array(image), facial5points, self.reference, crop_size=(112, 112)) + faces.append(warped_face) + + return np.concatenate((boxes, landmarks), axis=1), faces + + def align_multi(self, img, conf_threshold=0.8, limit=None): + + rlt = self.detect_faces(img, conf_threshold=conf_threshold) + boxes, landmarks = rlt[:, 0:5], rlt[:, 5:] + + return self.__align_multi(img, boxes, landmarks, limit) + + # batched detection + def batched_transform(self, frames, use_origin_size): + """ + Arguments: + frames: a list of PIL.Image, or torch.Tensor(shape=[n, h, w, c], + type=np.float32, BGR format). + use_origin_size: whether to use origin size. + """ + from_PIL = True if isinstance(frames[0], Image.Image) else False + + # convert to opencv format + if from_PIL: + frames = [cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR) for frame in frames] + frames = np.asarray(frames, dtype=np.float32) + + # testing scale + im_size_min = np.min(frames[0].shape[0:2]) + im_size_max = np.max(frames[0].shape[0:2]) + resize = float(self.target_size) / float(im_size_min) + + # prevent bigger axis from being more than max_size + if np.round(resize * im_size_max) > self.max_size: + resize = float(self.max_size) / float(im_size_max) + resize = 1 if use_origin_size else resize + + # resize + if resize != 1: + if not from_PIL: + frames = F.interpolate(frames, scale_factor=resize) + else: + frames = [ + cv2.resize(frame, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR) + for frame in frames + ] + + # convert to torch.tensor format + if not from_PIL: + frames = frames.transpose(1, 2).transpose(1, 3).contiguous() + else: + frames = frames.transpose((0, 3, 1, 2)) + frames = torch.from_numpy(frames) + + return frames, resize + + def batched_detect_faces(self, frames, conf_threshold=0.8, nms_threshold=0.4, use_origin_size=True): + """ + Arguments: + frames: a list of PIL.Image, or np.array(shape=[n, h, w, c], + type=np.uint8, BGR format). + conf_threshold: confidence threshold. + nms_threshold: nms threshold. + use_origin_size: whether to use origin size. + Returns: + final_bounding_boxes: list of np.array ([n_boxes, 5], + type=np.float32). + final_landmarks: list of np.array ([n_boxes, 10], type=np.float32). + """ + # self.t['forward_pass'].tic() + frames, self.resize = self.batched_transform(frames, use_origin_size) + frames = frames.to(device) + frames = frames - self.mean_tensor + + b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames) + + final_bounding_boxes, final_landmarks = [], [] + + # decode + priors = priors.unsqueeze(0) + b_loc = batched_decode(b_loc, priors, self.cfg['variance']) * self.scale / self.resize + b_landmarks = batched_decode_landm(b_landmarks, priors, self.cfg['variance']) * self.scale1 / self.resize + b_conf = b_conf[:, :, 1] + + # index for selection + b_indice = b_conf > conf_threshold + + # concat + b_loc_and_conf = torch.cat((b_loc, b_conf.unsqueeze(-1)), dim=2).float() + + for pred, landm, inds in zip(b_loc_and_conf, b_landmarks, b_indice): + + # ignore low scores + pred, landm = pred[inds, :], landm[inds, :] + if pred.shape[0] == 0: + final_bounding_boxes.append(np.array([], dtype=np.float32)) + final_landmarks.append(np.array([], dtype=np.float32)) + continue + + # sort + # order = score.argsort(descending=True) + # box, landm, score = box[order], landm[order], score[order] + + # to CPU + bounding_boxes, landm = pred.cpu().numpy(), landm.cpu().numpy() + + # NMS + keep = py_cpu_nms(bounding_boxes, nms_threshold) + bounding_boxes, landmarks = bounding_boxes[keep, :], landm[keep] + + # append + final_bounding_boxes.append(bounding_boxes) + final_landmarks.append(landmarks) + # self.t['forward_pass'].toc(average=True) + # self.batch_time += self.t['forward_pass'].diff + # self.total_frame += len(frames) + # print(self.batch_time / self.total_frame) + + return final_bounding_boxes, final_landmarks diff --git a/Hallo2/hallo2/facelib/detection/retinaface/retinaface_net.py b/Hallo2/hallo2/facelib/detection/retinaface/retinaface_net.py new file mode 100644 index 00000000..ab6aa82d --- /dev/null +++ b/Hallo2/hallo2/facelib/detection/retinaface/retinaface_net.py @@ -0,0 +1,196 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def conv_bn(inp, oup, stride=1, leaky=0): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True)) + + +def conv_bn_no_relu(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + ) + + +def conv_bn1X1(inp, oup, stride, leaky=0): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True)) + + +def conv_dw(inp, oup, stride, leaky=0.1): + return nn.Sequential( + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + ) + + +class SSH(nn.Module): + + def __init__(self, in_channel, out_channel): + super(SSH, self).__init__() + assert out_channel % 4 == 0 + leaky = 0 + if (out_channel <= 64): + leaky = 0.1 + self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) + + self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky) + self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) + + self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky) + self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) + + def forward(self, input): + conv3X3 = self.conv3X3(input) + + conv5X5_1 = self.conv5X5_1(input) + conv5X5 = self.conv5X5_2(conv5X5_1) + + conv7X7_2 = self.conv7X7_2(conv5X5_1) + conv7X7 = self.conv7x7_3(conv7X7_2) + + out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) + out = F.relu(out) + return out + + +class FPN(nn.Module): + + def __init__(self, in_channels_list, out_channels): + super(FPN, self).__init__() + leaky = 0 + if (out_channels <= 64): + leaky = 0.1 + self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky) + self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky) + self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky) + + self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) + self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) + + def forward(self, input): + # names = list(input.keys()) + # input = list(input.values()) + + output1 = self.output1(input[0]) + output2 = self.output2(input[1]) + output3 = self.output3(input[2]) + + up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode='nearest') + output2 = output2 + up3 + output2 = self.merge2(output2) + + up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode='nearest') + output1 = output1 + up2 + output1 = self.merge1(output1) + + out = [output1, output2, output3] + return out + + +class MobileNetV1(nn.Module): + + def __init__(self): + super(MobileNetV1, self).__init__() + self.stage1 = nn.Sequential( + conv_bn(3, 8, 2, leaky=0.1), # 3 + conv_dw(8, 16, 1), # 7 + conv_dw(16, 32, 2), # 11 + conv_dw(32, 32, 1), # 19 + conv_dw(32, 64, 2), # 27 + conv_dw(64, 64, 1), # 43 + ) + self.stage2 = nn.Sequential( + conv_dw(64, 128, 2), # 43 + 16 = 59 + conv_dw(128, 128, 1), # 59 + 32 = 91 + conv_dw(128, 128, 1), # 91 + 32 = 123 + conv_dw(128, 128, 1), # 123 + 32 = 155 + conv_dw(128, 128, 1), # 155 + 32 = 187 + conv_dw(128, 128, 1), # 187 + 32 = 219 + ) + self.stage3 = nn.Sequential( + conv_dw(128, 256, 2), # 219 +3 2 = 241 + conv_dw(256, 256, 1), # 241 + 64 = 301 + ) + self.avg = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(256, 1000) + + def forward(self, x): + x = self.stage1(x) + x = self.stage2(x) + x = self.stage3(x) + x = self.avg(x) + # x = self.model(x) + x = x.view(-1, 256) + x = self.fc(x) + return x + + +class ClassHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(ClassHead, self).__init__() + self.num_anchors = num_anchors + self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 2) + + +class BboxHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(BboxHead, self).__init__() + self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 4) + + +class LandmarkHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(LandmarkHead, self).__init__() + self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 10) + + +def make_class_head(fpn_num=3, inchannels=64, anchor_num=2): + classhead = nn.ModuleList() + for i in range(fpn_num): + classhead.append(ClassHead(inchannels, anchor_num)) + return classhead + + +def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2): + bboxhead = nn.ModuleList() + for i in range(fpn_num): + bboxhead.append(BboxHead(inchannels, anchor_num)) + return bboxhead + + +def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2): + landmarkhead = nn.ModuleList() + for i in range(fpn_num): + landmarkhead.append(LandmarkHead(inchannels, anchor_num)) + return landmarkhead diff --git a/Hallo2/hallo2/facelib/detection/retinaface/retinaface_utils.py b/Hallo2/hallo2/facelib/detection/retinaface/retinaface_utils.py new file mode 100644 index 00000000..8c357757 --- /dev/null +++ b/Hallo2/hallo2/facelib/detection/retinaface/retinaface_utils.py @@ -0,0 +1,421 @@ +import numpy as np +import torch +import torchvision +from itertools import product as product +from math import ceil + + +class PriorBox(object): + + def __init__(self, cfg, image_size=None, phase='train'): + super(PriorBox, self).__init__() + self.min_sizes = cfg['min_sizes'] + self.steps = cfg['steps'] + self.clip = cfg['clip'] + self.image_size = image_size + self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps] + self.name = 's' + + def forward(self): + anchors = [] + for k, f in enumerate(self.feature_maps): + min_sizes = self.min_sizes[k] + for i, j in product(range(f[0]), range(f[1])): + for min_size in min_sizes: + s_kx = min_size / self.image_size[1] + s_ky = min_size / self.image_size[0] + dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]] + dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]] + for cy, cx in product(dense_cy, dense_cx): + anchors += [cx, cy, s_kx, s_ky] + + # back to torch land + output = torch.Tensor(anchors).view(-1, 4) + if self.clip: + output.clamp_(max=1, min=0) + return output + + +def py_cpu_nms(dets, thresh): + """Pure Python NMS baseline.""" + keep = torchvision.ops.nms( + boxes=torch.Tensor(dets[:, :4]), + scores=torch.Tensor(dets[:, 4]), + iou_threshold=thresh, + ) + + return list(keep) + + +def point_form(boxes): + """ Convert prior_boxes to (xmin, ymin, xmax, ymax) + representation for comparison to point form ground truth data. + Args: + boxes: (tensor) center-size default boxes from priorbox layers. + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat( + ( + boxes[:, :2] - boxes[:, 2:] / 2, # xmin, ymin + boxes[:, :2] + boxes[:, 2:] / 2), + 1) # xmax, ymax + + +def center_size(boxes): + """ Convert prior_boxes to (cx, cy, w, h) + representation for comparison to center-size form ground truth data. + Args: + boxes: (tensor) point_form boxes + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat( + (boxes[:, 2:] + boxes[:, :2]) / 2, # cx, cy + boxes[:, 2:] - boxes[:, :2], + 1) # w, h + + +def intersect(box_a, box_b): + """ We resize both tensors to [A,B,2] without new malloc: + [A,2] -> [A,1,2] -> [A,B,2] + [B,2] -> [1,B,2] -> [A,B,2] + Then we compute the area of intersect between box_a and box_b. + Args: + box_a: (tensor) bounding boxes, Shape: [A,4]. + box_b: (tensor) bounding boxes, Shape: [B,4]. + Return: + (tensor) intersection area, Shape: [A,B]. + """ + A = box_a.size(0) + B = box_b.size(0) + max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) + min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2)) + inter = torch.clamp((max_xy - min_xy), min=0) + return inter[:, :, 0] * inter[:, :, 1] + + +def jaccard(box_a, box_b): + """Compute the jaccard overlap of two sets of boxes. The jaccard overlap + is simply the intersection over union of two boxes. Here we operate on + ground truth boxes and default boxes. + E.g.: + A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) + Args: + box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] + box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] + Return: + jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] + """ + inter = intersect(box_a, box_b) + area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] + area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] + union = area_a + area_b - inter + return inter / union # [A,B] + + +def matrix_iou(a, b): + """ + return iou of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) + return area_i / (area_a[:, np.newaxis] + area_b - area_i) + + +def matrix_iof(a, b): + """ + return iof of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + return area_i / np.maximum(area_a[:, np.newaxis], 1) + + +def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx): + """Match each prior box with the ground truth box of the highest jaccard + overlap, encode the bounding boxes, then return the matched indices + corresponding to both confidence and location preds. + Args: + threshold: (float) The overlap threshold used when matching boxes. + truths: (tensor) Ground truth boxes, Shape: [num_obj, 4]. + priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. + variances: (tensor) Variances corresponding to each prior coord, + Shape: [num_priors, 4]. + labels: (tensor) All the class labels for the image, Shape: [num_obj]. + landms: (tensor) Ground truth landms, Shape [num_obj, 10]. + loc_t: (tensor) Tensor to be filled w/ encoded location targets. + conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds. + landm_t: (tensor) Tensor to be filled w/ encoded landm targets. + idx: (int) current batch index + Return: + The matched indices corresponding to 1)location 2)confidence + 3)landm preds. + """ + # jaccard index + overlaps = jaccard(truths, point_form(priors)) + # (Bipartite Matching) + # [1,num_objects] best prior for each ground truth + best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) + + # ignore hard gt + valid_gt_idx = best_prior_overlap[:, 0] >= 0.2 + best_prior_idx_filter = best_prior_idx[valid_gt_idx, :] + if best_prior_idx_filter.shape[0] <= 0: + loc_t[idx] = 0 + conf_t[idx] = 0 + return + + # [1,num_priors] best ground truth for each prior + best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) + best_truth_idx.squeeze_(0) + best_truth_overlap.squeeze_(0) + best_prior_idx.squeeze_(1) + best_prior_idx_filter.squeeze_(1) + best_prior_overlap.squeeze_(1) + best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior + # TODO refactor: index best_prior_idx with long tensor + # ensure every gt matches with its prior of max overlap + for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes + best_truth_idx[best_prior_idx[j]] = j + matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来 + conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来 + conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本 + loc = encode(matches, priors, variances) + + matches_landm = landms[best_truth_idx] + landm = encode_landm(matches_landm, priors, variances) + loc_t[idx] = loc # [num_priors,4] encoded offsets to learn + conf_t[idx] = conf # [num_priors] top class label for each prior + landm_t[idx] = landm + + +def encode(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 4]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded boxes (tensor), Shape: [num_priors, 4] + """ + + # dist b/t match center and prior's center + g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] + # encode variance + g_cxcy /= (variances[0] * priors[:, 2:]) + # match wh / prior wh + g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] + g_wh = torch.log(g_wh) / variances[1] + # return target for smooth_l1_loss + return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] + + +def encode_landm(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 10]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded landm (tensor), Shape: [num_priors, 10] + """ + + # dist b/t match center and prior's center + matched = torch.reshape(matched, (matched.size(0), 5, 2)) + priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2) + g_cxcy = matched[:, :, :2] - priors[:, :, :2] + # encode variance + g_cxcy /= (variances[0] * priors[:, :, 2:]) + # g_cxcy /= priors[:, :, 2:] + g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1) + # return target for smooth_l1_loss + return g_cxcy + + +# Adapted from https://github.com/Hakuyume/chainer-ssd +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat((priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + +def decode_landm(pre, priors, variances): + """Decode landm from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + pre (tensor): landm predictions for loc layers, + Shape: [num_priors,10] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded landm predictions + """ + tmp = ( + priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:], + ) + landms = torch.cat(tmp, dim=1) + return landms + + +def batched_decode(b_loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + b_loc (tensor): location predictions for loc layers, + Shape: [num_batches,num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [1,num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + boxes = ( + priors[:, :, :2] + b_loc[:, :, :2] * variances[0] * priors[:, :, 2:], + priors[:, :, 2:] * torch.exp(b_loc[:, :, 2:] * variances[1]), + ) + boxes = torch.cat(boxes, dim=2) + + boxes[:, :, :2] -= boxes[:, :, 2:] / 2 + boxes[:, :, 2:] += boxes[:, :, :2] + return boxes + + +def batched_decode_landm(pre, priors, variances): + """Decode landm from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + pre (tensor): landm predictions for loc layers, + Shape: [num_batches,num_priors,10] + priors (tensor): Prior boxes in center-offset form. + Shape: [1,num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded landm predictions + """ + landms = ( + priors[:, :, :2] + pre[:, :, :2] * variances[0] * priors[:, :, 2:], + priors[:, :, :2] + pre[:, :, 2:4] * variances[0] * priors[:, :, 2:], + priors[:, :, :2] + pre[:, :, 4:6] * variances[0] * priors[:, :, 2:], + priors[:, :, :2] + pre[:, :, 6:8] * variances[0] * priors[:, :, 2:], + priors[:, :, :2] + pre[:, :, 8:10] * variances[0] * priors[:, :, 2:], + ) + landms = torch.cat(landms, dim=2) + return landms + + +def log_sum_exp(x): + """Utility function for computing log_sum_exp while determining + This will be used to determine unaveraged confidence loss across + all examples in a batch. + Args: + x (Variable(tensor)): conf_preds from conf layers + """ + x_max = x.data.max() + return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max + + +# Original author: Francisco Massa: +# https://github.com/fmassa/object-detection.torch +# Ported to PyTorch by Max deGroot (02/01/2017) +def nms(boxes, scores, overlap=0.5, top_k=200): + """Apply non-maximum suppression at test time to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. + scores: (tensor) The class predscores for the img, Shape:[num_priors]. + overlap: (float) The overlap thresh for suppressing unnecessary boxes. + top_k: (int) The Maximum number of box preds to consider. + Return: + The indices of the kept boxes with respect to num_priors. + """ + + keep = torch.Tensor(scores.size(0)).fill_(0).long() + if boxes.numel() == 0: + return keep + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + area = torch.mul(x2 - x1, y2 - y1) + v, idx = scores.sort(0) # sort in ascending order + # I = I[v >= 0.01] + idx = idx[-top_k:] # indices of the top-k largest vals + xx1 = boxes.new() + yy1 = boxes.new() + xx2 = boxes.new() + yy2 = boxes.new() + w = boxes.new() + h = boxes.new() + + # keep = torch.Tensor() + count = 0 + while idx.numel() > 0: + i = idx[-1] # index of current largest val + # keep.append(i) + keep[count] = i + count += 1 + if idx.size(0) == 1: + break + idx = idx[:-1] # remove kept element from view + # load bboxes of next highest vals + torch.index_select(x1, 0, idx, out=xx1) + torch.index_select(y1, 0, idx, out=yy1) + torch.index_select(x2, 0, idx, out=xx2) + torch.index_select(y2, 0, idx, out=yy2) + # store element-wise max with next highest score + xx1 = torch.clamp(xx1, min=x1[i]) + yy1 = torch.clamp(yy1, min=y1[i]) + xx2 = torch.clamp(xx2, max=x2[i]) + yy2 = torch.clamp(yy2, max=y2[i]) + w.resize_as_(xx2) + h.resize_as_(yy2) + w = xx2 - xx1 + h = yy2 - yy1 + # check sizes of xx1 and xx2.. after each iteration + w = torch.clamp(w, min=0.0) + h = torch.clamp(h, min=0.0) + inter = w * h + # IoU = i / (area(a) + area(b) - i) + rem_areas = torch.index_select(area, 0, idx) # load remaining areas) + union = (rem_areas - inter) + area[i] + IoU = inter / union # store result in iou + # keep only elements with an IoU <= overlap + idx = idx[IoU.le(overlap)] + return keep, count diff --git a/Hallo2/hallo2/facelib/detection/yolov5face/__init__.py b/Hallo2/hallo2/facelib/detection/yolov5face/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Hallo2/hallo2/facelib/detection/yolov5face/face_detector.py b/Hallo2/hallo2/facelib/detection/yolov5face/face_detector.py new file mode 100644 index 00000000..1b27e970 --- /dev/null +++ b/Hallo2/hallo2/facelib/detection/yolov5face/face_detector.py @@ -0,0 +1,141 @@ +import cv2 +import copy +import re +import torch +import numpy as np + +from pathlib import Path +from facelib.detection.yolov5face.models.yolo import Model +from facelib.detection.yolov5face.utils.datasets import letterbox +from facelib.detection.yolov5face.utils.general import ( + check_img_size, + non_max_suppression_face, + scale_coords, + scale_coords_landmarks, +) + +# IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.')[:2])) >= (1, 9) +IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ + torch.__version__)[0][:3])] >= [1, 9, 0] + + +def isListempty(inList): + if isinstance(inList, list): # Is a list + return all(map(isListempty, inList)) + return False # Not a list + +class YoloDetector: + def __init__( + self, + config_name, + min_face=10, + target_size=None, + device='cuda', + ): + """ + config_name: name of .yaml config with network configuration from models/ folder. + min_face : minimal face size in pixels. + target_size : target size of smaller image axis (choose lower for faster work). e.g. 480, 720, 1080. + None for original resolution. + """ + self._class_path = Path(__file__).parent.absolute() + self.target_size = target_size + self.min_face = min_face + self.detector = Model(cfg=config_name) + self.device = device + + + def _preprocess(self, imgs): + """ + Preprocessing image before passing through the network. Resize and conversion to torch tensor. + """ + pp_imgs = [] + for img in imgs: + h0, w0 = img.shape[:2] # orig hw + if self.target_size: + r = self.target_size / min(h0, w0) # resize image to img_size + if r < 1: + img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_LINEAR) + + imgsz = check_img_size(max(img.shape[:2]), s=self.detector.stride.max()) # check img_size + img = letterbox(img, new_shape=imgsz)[0] + pp_imgs.append(img) + pp_imgs = np.array(pp_imgs) + pp_imgs = pp_imgs.transpose(0, 3, 1, 2) + pp_imgs = torch.from_numpy(pp_imgs).to(self.device) + pp_imgs = pp_imgs.float() # uint8 to fp16/32 + return pp_imgs / 255.0 # 0 - 255 to 0.0 - 1.0 + + def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres): + """ + Postprocessing of raw pytorch model output. + Returns: + bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2. + points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners). + """ + bboxes = [[] for _ in range(len(origimgs))] + landmarks = [[] for _ in range(len(origimgs))] + + pred = non_max_suppression_face(pred, conf_thres, iou_thres) + + for image_id, origimg in enumerate(origimgs): + img_shape = origimg.shape + image_height, image_width = img_shape[:2] + gn = torch.tensor(img_shape)[[1, 0, 1, 0]] # normalization gain whwh + gn_lks = torch.tensor(img_shape)[[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]] # normalization gain landmarks + det = pred[image_id].cpu() + scale_coords(imgs[image_id].shape[1:], det[:, :4], img_shape).round() + scale_coords_landmarks(imgs[image_id].shape[1:], det[:, 5:15], img_shape).round() + + for j in range(det.size()[0]): + box = (det[j, :4].view(1, 4) / gn).view(-1).tolist() + box = list( + map(int, [box[0] * image_width, box[1] * image_height, box[2] * image_width, box[3] * image_height]) + ) + if box[3] - box[1] < self.min_face: + continue + lm = (det[j, 5:15].view(1, 10) / gn_lks).view(-1).tolist() + lm = list(map(int, [i * image_width if j % 2 == 0 else i * image_height for j, i in enumerate(lm)])) + lm = [lm[i : i + 2] for i in range(0, len(lm), 2)] + bboxes[image_id].append(box) + landmarks[image_id].append(lm) + return bboxes, landmarks + + def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5): + """ + Get bbox coordinates and keypoints of faces on original image. + Params: + imgs: image or list of images to detect faces on with BGR order (convert to RGB order for inference) + conf_thres: confidence threshold for each prediction + iou_thres: threshold for NMS (filter of intersecting bboxes) + Returns: + bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2. + points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners). + """ + # Pass input images through face detector + images = imgs if isinstance(imgs, list) else [imgs] + images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images] + origimgs = copy.deepcopy(images) + + images = self._preprocess(images) + + if IS_HIGH_VERSION: + with torch.inference_mode(): # for pytorch>=1.9 + pred = self.detector(images)[0] + else: + with torch.no_grad(): # for pytorch<1.9 + pred = self.detector(images)[0] + + bboxes, points = self._postprocess(images, origimgs, pred, conf_thres, iou_thres) + + # return bboxes, points + if not isListempty(points): + bboxes = np.array(bboxes).reshape(-1,4) + points = np.array(points).reshape(-1,10) + padding = bboxes[:,0].reshape(-1,1) + return np.concatenate((bboxes, padding, points), axis=1) + else: + return None + + def __call__(self, *args): + return self.predict(*args) diff --git a/Hallo2/hallo2/facelib/detection/yolov5face/models/__init__.py b/Hallo2/hallo2/facelib/detection/yolov5face/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Hallo2/hallo2/facelib/detection/yolov5face/models/common.py b/Hallo2/hallo2/facelib/detection/yolov5face/models/common.py new file mode 100644 index 00000000..497a0044 --- /dev/null +++ b/Hallo2/hallo2/facelib/detection/yolov5face/models/common.py @@ -0,0 +1,299 @@ +# This file contains modules common to various models + +import math + +import numpy as np +import torch +from torch import nn + +from facelib.detection.yolov5face.utils.datasets import letterbox +from facelib.detection.yolov5face.utils.general import ( + make_divisible, + non_max_suppression, + scale_coords, + xyxy2xywh, +) + + +def autopad(k, p=None): # kernel, padding + # Pad to 'same' + if p is None: + p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad + return p + + +def channel_shuffle(x, groups): + batchsize, num_channels, height, width = x.data.size() + channels_per_group = torch.div(num_channels, groups, rounding_mode="trunc") + + # reshape + x = x.view(batchsize, groups, channels_per_group, height, width) + x = torch.transpose(x, 1, 2).contiguous() + + # flatten + return x.view(batchsize, -1, height, width) + + +def DWConv(c1, c2, k=1, s=1, act=True): + # Depthwise convolution + return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act) + + +class Conv(nn.Module): + # Standard convolution + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups + super().__init__() + self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) + self.bn = nn.BatchNorm2d(c2) + self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) + + def forward(self, x): + return self.act(self.bn(self.conv(x))) + + def fuseforward(self, x): + return self.act(self.conv(x)) + + +class StemBlock(nn.Module): + def __init__(self, c1, c2, k=3, s=2, p=None, g=1, act=True): + super().__init__() + self.stem_1 = Conv(c1, c2, k, s, p, g, act) + self.stem_2a = Conv(c2, c2 // 2, 1, 1, 0) + self.stem_2b = Conv(c2 // 2, c2, 3, 2, 1) + self.stem_2p = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True) + self.stem_3 = Conv(c2 * 2, c2, 1, 1, 0) + + def forward(self, x): + stem_1_out = self.stem_1(x) + stem_2a_out = self.stem_2a(stem_1_out) + stem_2b_out = self.stem_2b(stem_2a_out) + stem_2p_out = self.stem_2p(stem_1_out) + return self.stem_3(torch.cat((stem_2b_out, stem_2p_out), 1)) + + +class Bottleneck(nn.Module): + # Standard bottleneck + def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_, c2, 3, 1, g=g) + self.add = shortcut and c1 == c2 + + def forward(self, x): + return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) + + +class BottleneckCSP(nn.Module): + # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False) + self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False) + self.cv4 = Conv(2 * c_, c2, 1, 1) + self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3) + self.act = nn.LeakyReLU(0.1, inplace=True) + self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) + + def forward(self, x): + y1 = self.cv3(self.m(self.cv1(x))) + y2 = self.cv2(x) + return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1)))) + + +class C3(nn.Module): + # CSP Bottleneck with 3 convolutions + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c1, c_, 1, 1) + self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2) + self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) + + def forward(self, x): + return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1)) + + +class ShuffleV2Block(nn.Module): + def __init__(self, inp, oup, stride): + super().__init__() + + if not 1 <= stride <= 3: + raise ValueError("illegal stride value") + self.stride = stride + + branch_features = oup // 2 + + if self.stride > 1: + self.branch1 = nn.Sequential( + self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), + nn.BatchNorm2d(inp), + nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(branch_features), + nn.SiLU(), + ) + else: + self.branch1 = nn.Sequential() + + self.branch2 = nn.Sequential( + nn.Conv2d( + inp if (self.stride > 1) else branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ), + nn.BatchNorm2d(branch_features), + nn.SiLU(), + self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), + nn.BatchNorm2d(branch_features), + nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(branch_features), + nn.SiLU(), + ) + + @staticmethod + def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): + return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) + + def forward(self, x): + if self.stride == 1: + x1, x2 = x.chunk(2, dim=1) + out = torch.cat((x1, self.branch2(x2)), dim=1) + else: + out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) + out = channel_shuffle(out, 2) + return out + + +class SPP(nn.Module): + # Spatial pyramid pooling layer used in YOLOv3-SPP + def __init__(self, c1, c2, k=(5, 9, 13)): + super().__init__() + c_ = c1 // 2 # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1) + self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) + + def forward(self, x): + x = self.cv1(x) + return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1)) + + +class Focus(nn.Module): + # Focus wh information into c-space + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups + super().__init__() + self.conv = Conv(c1 * 4, c2, k, s, p, g, act) + + def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2) + return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)) + + +class Concat(nn.Module): + # Concatenate a list of tensors along dimension + def __init__(self, dimension=1): + super().__init__() + self.d = dimension + + def forward(self, x): + return torch.cat(x, self.d) + + +class NMS(nn.Module): + # Non-Maximum Suppression (NMS) module + conf = 0.25 # confidence threshold + iou = 0.45 # IoU threshold + classes = None # (optional list) filter by class + + def forward(self, x): + return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) + + +class AutoShape(nn.Module): + # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS + img_size = 640 # inference size (pixels) + conf = 0.25 # NMS confidence threshold + iou = 0.45 # NMS IoU threshold + classes = None # (optional list) filter by class + + def __init__(self, model): + super().__init__() + self.model = model.eval() + + def autoshape(self): + print("autoShape already enabled, skipping... ") # model already converted to model.autoshape() + return self + + def forward(self, imgs, size=640, augment=False, profile=False): + # Inference from various sources. For height=720, width=1280, RGB images example inputs are: + # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3) + # PIL: = Image.open('image.jpg') # HWC x(720,1280,3) + # numpy: = np.zeros((720,1280,3)) # HWC + # torch: = torch.zeros(16,3,720,1280) # BCHW + # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images + + p = next(self.model.parameters()) # for device and type + if isinstance(imgs, torch.Tensor): # torch + return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference + + # Pre-process + n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images + shape0, shape1 = [], [] # image and inference shapes + for i, im in enumerate(imgs): + im = np.array(im) # to numpy + if im.shape[0] < 5: # image in CHW + im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1) + im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input + s = im.shape[:2] # HWC + shape0.append(s) # image shape + g = size / max(s) # gain + shape1.append([y * g for y in s]) + imgs[i] = im # update + shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape + x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad + x = np.stack(x, 0) if n > 1 else x[0][None] # stack + x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW + x = torch.from_numpy(x).to(p.device).type_as(p) / 255.0 # uint8 to fp16/32 + + # Inference + with torch.no_grad(): + y = self.model(x, augment, profile)[0] # forward + y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS + + # Post-process + for i in range(n): + scale_coords(shape1, y[i][:, :4], shape0[i]) + + return Detections(imgs, y, self.names) + + +class Detections: + # detections class for YOLOv5 inference results + def __init__(self, imgs, pred, names=None): + super().__init__() + d = pred[0].device # device + gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1.0, 1.0], device=d) for im in imgs] # normalizations + self.imgs = imgs # list of images as numpy arrays + self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls) + self.names = names # class names + self.xyxy = pred # xyxy pixels + self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels + self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized + self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized + self.n = len(self.pred) + + def __len__(self): + return self.n + + def tolist(self): + # return a list of Detections objects, i.e. 'for result in results.tolist():' + x = [Detections([self.imgs[i]], [self.pred[i]], self.names) for i in range(self.n)] + for d in x: + for k in ["imgs", "pred", "xyxy", "xyxyn", "xywh", "xywhn"]: + setattr(d, k, getattr(d, k)[0]) # pop out of list + return x diff --git a/Hallo2/hallo2/facelib/detection/yolov5face/models/experimental.py b/Hallo2/hallo2/facelib/detection/yolov5face/models/experimental.py new file mode 100644 index 00000000..37ba4c44 --- /dev/null +++ b/Hallo2/hallo2/facelib/detection/yolov5face/models/experimental.py @@ -0,0 +1,45 @@ +# # This file contains experimental modules + +import numpy as np +import torch +from torch import nn + +from facelib.detection.yolov5face.models.common import Conv + + +class CrossConv(nn.Module): + # Cross Convolution Downsample + def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False): + # ch_in, ch_out, kernel, stride, groups, expansion, shortcut + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, (1, k), (1, s)) + self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g) + self.add = shortcut and c1 == c2 + + def forward(self, x): + return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) + + +class MixConv2d(nn.Module): + # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595 + def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): + super().__init__() + groups = len(k) + if equal_ch: # equal c_ per group + i = torch.linspace(0, groups - 1e-6, c2).floor() # c2 indices + c_ = [(i == g).sum() for g in range(groups)] # intermediate channels + else: # equal weight.numel() per group + b = [c2] + [0] * groups + a = np.eye(groups + 1, groups, k=-1) + a -= np.roll(a, 1, axis=1) + a *= np.array(k) ** 2 + a[0] = 1 + c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b + + self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)]) + self.bn = nn.BatchNorm2d(c2) + self.act = nn.LeakyReLU(0.1, inplace=True) + + def forward(self, x): + return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) diff --git a/Hallo2/hallo2/facelib/detection/yolov5face/models/yolo.py b/Hallo2/hallo2/facelib/detection/yolov5face/models/yolo.py new file mode 100644 index 00000000..70845d97 --- /dev/null +++ b/Hallo2/hallo2/facelib/detection/yolov5face/models/yolo.py @@ -0,0 +1,235 @@ +import math +from copy import deepcopy +from pathlib import Path + +import torch +import yaml # for torch hub +from torch import nn + +from facelib.detection.yolov5face.models.common import ( + C3, + NMS, + SPP, + AutoShape, + Bottleneck, + BottleneckCSP, + Concat, + Conv, + DWConv, + Focus, + ShuffleV2Block, + StemBlock, +) +from facelib.detection.yolov5face.models.experimental import CrossConv, MixConv2d +from facelib.detection.yolov5face.utils.autoanchor import check_anchor_order +from facelib.detection.yolov5face.utils.general import make_divisible +from facelib.detection.yolov5face.utils.torch_utils import copy_attr, fuse_conv_and_bn + + +class Detect(nn.Module): + stride = None # strides computed during build + export = False # onnx export + + def __init__(self, nc=80, anchors=(), ch=()): # detection layer + super().__init__() + self.nc = nc # number of classes + self.no = nc + 5 + 10 # number of outputs per anchor + + self.nl = len(anchors) # number of detection layers + self.na = len(anchors[0]) // 2 # number of anchors + self.grid = [torch.zeros(1)] * self.nl # init grid + a = torch.tensor(anchors).float().view(self.nl, -1, 2) + self.register_buffer("anchors", a) # shape(nl,na,2) + self.register_buffer("anchor_grid", a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2) + self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv + + def forward(self, x): + z = [] # inference output + if self.export: + for i in range(self.nl): + x[i] = self.m[i](x[i]) + return x + for i in range(self.nl): + x[i] = self.m[i](x[i]) # conv + bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) + x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() + + if not self.training: # inference + if self.grid[i].shape[2:4] != x[i].shape[2:4]: + self.grid[i] = self._make_grid(nx, ny).to(x[i].device) + + y = torch.full_like(x[i], 0) + y[..., [0, 1, 2, 3, 4, 15]] = x[i][..., [0, 1, 2, 3, 4, 15]].sigmoid() + y[..., 5:15] = x[i][..., 5:15] + + y[..., 0:2] = (y[..., 0:2] * 2.0 - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy + y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + + y[..., 5:7] = ( + y[..., 5:7] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x1 y1 + y[..., 7:9] = ( + y[..., 7:9] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x2 y2 + y[..., 9:11] = ( + y[..., 9:11] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x3 y3 + y[..., 11:13] = ( + y[..., 11:13] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x4 y4 + y[..., 13:15] = ( + y[..., 13:15] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x5 y5 + + z.append(y.view(bs, -1, self.no)) + + return x if self.training else (torch.cat(z, 1), x) + + @staticmethod + def _make_grid(nx=20, ny=20): + # yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)], indexing="ij") # for pytorch>=1.10 + yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) + return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float() + + +class Model(nn.Module): + def __init__(self, cfg="yolov5s.yaml", ch=3, nc=None): # model, input channels, number of classes + super().__init__() + self.yaml_file = Path(cfg).name + with Path(cfg).open(encoding="utf8") as f: + self.yaml = yaml.safe_load(f) # model dict + + # Define model + ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels + if nc and nc != self.yaml["nc"]: + self.yaml["nc"] = nc # override yaml value + + self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist + self.names = [str(i) for i in range(self.yaml["nc"])] # default names + + # Build strides, anchors + m = self.model[-1] # Detect() + if isinstance(m, Detect): + s = 128 # 2x min stride + m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward + m.anchors /= m.stride.view(-1, 1, 1) + check_anchor_order(m) + self.stride = m.stride + self._initialize_biases() # only run once + + def forward(self, x): + return self.forward_once(x) # single-scale inference, train + + def forward_once(self, x): + y = [] # outputs + for m in self.model: + if m.f != -1: # if not from previous layer + x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers + + x = m(x) # run + y.append(x if m.i in self.save else None) # save output + + return x + + def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency + # https://arxiv.org/abs/1708.02002 section 3.3 + m = self.model[-1] # Detect() module + for mi, s in zip(m.m, m.stride): # from + b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85) + b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image) + b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls + mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) + + def _print_biases(self): + m = self.model[-1] # Detect() module + for mi in m.m: # from + b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85) + print(("%6g Conv2d.bias:" + "%10.3g" * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean())) + + def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers + print("Fusing layers... ") + for m in self.model.modules(): + if isinstance(m, Conv) and hasattr(m, "bn"): + m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv + delattr(m, "bn") # remove batchnorm + m.forward = m.fuseforward # update forward + elif type(m) is nn.Upsample: + m.recompute_scale_factor = None # torch 1.11.0 compatibility + return self + + def nms(self, mode=True): # add or remove NMS module + present = isinstance(self.model[-1], NMS) # last layer is NMS + if mode and not present: + print("Adding NMS... ") + m = NMS() # module + m.f = -1 # from + m.i = self.model[-1].i + 1 # index + self.model.add_module(name=str(m.i), module=m) # add + self.eval() + elif not mode and present: + print("Removing NMS... ") + self.model = self.model[:-1] # remove + return self + + def autoshape(self): # add autoShape module + print("Adding autoShape... ") + m = AutoShape(self) # wrap model + copy_attr(m, self, include=("yaml", "nc", "hyp", "names", "stride"), exclude=()) # copy attributes + return m + + +def parse_model(d, ch): # model_dict, input_channels(3) + anchors, nc, gd, gw = d["anchors"], d["nc"], d["depth_multiple"], d["width_multiple"] + na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors + no = na * (nc + 5) # number of outputs = anchors * (classes + 5) + + layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out + for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args + m = eval(m) if isinstance(m, str) else m # eval strings + for j, a in enumerate(args): + try: + args[j] = eval(a) if isinstance(a, str) else a # eval strings + except: + pass + + n = max(round(n * gd), 1) if n > 1 else n # depth gain + if m in [ + Conv, + Bottleneck, + SPP, + DWConv, + MixConv2d, + Focus, + CrossConv, + BottleneckCSP, + C3, + ShuffleV2Block, + StemBlock, + ]: + c1, c2 = ch[f], args[0] + + c2 = make_divisible(c2 * gw, 8) if c2 != no else c2 + + args = [c1, c2, *args[1:]] + if m in [BottleneckCSP, C3]: + args.insert(2, n) + n = 1 + elif m is nn.BatchNorm2d: + args = [ch[f]] + elif m is Concat: + c2 = sum(ch[-1 if x == -1 else x + 1] for x in f) + elif m is Detect: + args.append([ch[x + 1] for x in f]) + if isinstance(args[1], int): # number of anchors + args[1] = [list(range(args[1] * 2))] * len(f) + else: + c2 = ch[f] + + m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module + t = str(m)[8:-2].replace("__main__.", "") # module type + np = sum(x.numel() for x in m_.parameters()) # number params + m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params + save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist + layers.append(m_) + ch.append(c2) + return nn.Sequential(*layers), sorted(save) diff --git a/Hallo2/hallo2/facelib/detection/yolov5face/models/yolov5l.yaml b/Hallo2/hallo2/facelib/detection/yolov5face/models/yolov5l.yaml new file mode 100644 index 00000000..0532b0e2 --- /dev/null +++ b/Hallo2/hallo2/facelib/detection/yolov5face/models/yolov5l.yaml @@ -0,0 +1,47 @@ +# parameters +nc: 1 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: + - [4,5, 8,10, 13,16] # P3/8 + - [23,29, 43,55, 73,105] # P4/16 + - [146,217, 231,300, 335,433] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, StemBlock, [64, 3, 2]], # 0-P1/2 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 2-P3/8 + [-1, 9, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 4-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 6-P5/32 + [-1, 1, SPP, [1024, [3,5,7]]], + [-1, 3, C3, [1024, False]], # 8 + ] + +# YOLOv5 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 5], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 12 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 3], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 16 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 13], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 19 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 9], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 22 (P5/32-large) + + [[16, 19, 22], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] \ No newline at end of file diff --git a/Hallo2/hallo2/facelib/detection/yolov5face/models/yolov5n.yaml b/Hallo2/hallo2/facelib/detection/yolov5face/models/yolov5n.yaml new file mode 100644 index 00000000..caba6bed --- /dev/null +++ b/Hallo2/hallo2/facelib/detection/yolov5face/models/yolov5n.yaml @@ -0,0 +1,45 @@ +# parameters +nc: 1 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: + - [4,5, 8,10, 13,16] # P3/8 + - [23,29, 43,55, 73,105] # P4/16 + - [146,217, 231,300, 335,433] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, StemBlock, [32, 3, 2]], # 0-P2/4 + [-1, 1, ShuffleV2Block, [128, 2]], # 1-P3/8 + [-1, 3, ShuffleV2Block, [128, 1]], # 2 + [-1, 1, ShuffleV2Block, [256, 2]], # 3-P4/16 + [-1, 7, ShuffleV2Block, [256, 1]], # 4 + [-1, 1, ShuffleV2Block, [512, 2]], # 5-P5/32 + [-1, 3, ShuffleV2Block, [512, 1]], # 6 + ] + +# YOLOv5 head +head: + [[-1, 1, Conv, [128, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P4 + [-1, 1, C3, [128, False]], # 10 + + [-1, 1, Conv, [128, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 2], 1, Concat, [1]], # cat backbone P3 + [-1, 1, C3, [128, False]], # 14 (P3/8-small) + + [-1, 1, Conv, [128, 3, 2]], + [[-1, 11], 1, Concat, [1]], # cat head P4 + [-1, 1, C3, [128, False]], # 17 (P4/16-medium) + + [-1, 1, Conv, [128, 3, 2]], + [[-1, 7], 1, Concat, [1]], # cat head P5 + [-1, 1, C3, [128, False]], # 20 (P5/32-large) + + [[14, 17, 20], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/Hallo2/hallo2/facelib/detection/yolov5face/utils/__init__.py b/Hallo2/hallo2/facelib/detection/yolov5face/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Hallo2/hallo2/facelib/detection/yolov5face/utils/autoanchor.py b/Hallo2/hallo2/facelib/detection/yolov5face/utils/autoanchor.py new file mode 100644 index 00000000..a4eba3e9 --- /dev/null +++ b/Hallo2/hallo2/facelib/detection/yolov5face/utils/autoanchor.py @@ -0,0 +1,12 @@ +# Auto-anchor utils + + +def check_anchor_order(m): + # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary + a = m.anchor_grid.prod(-1).view(-1) # anchor area + da = a[-1] - a[0] # delta a + ds = m.stride[-1] - m.stride[0] # delta s + if da.sign() != ds.sign(): # same order + print("Reversing anchor order") + m.anchors[:] = m.anchors.flip(0) + m.anchor_grid[:] = m.anchor_grid.flip(0) diff --git a/Hallo2/hallo2/facelib/detection/yolov5face/utils/datasets.py b/Hallo2/hallo2/facelib/detection/yolov5face/utils/datasets.py new file mode 100644 index 00000000..e672b136 --- /dev/null +++ b/Hallo2/hallo2/facelib/detection/yolov5face/utils/datasets.py @@ -0,0 +1,35 @@ +import cv2 +import numpy as np + + +def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale_fill=False, scaleup=True): + # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232 + shape = img.shape[:2] # current shape [height, width] + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + if not scaleup: # only scale down, do not scale up (for better test mAP) + r = min(r, 1.0) + + # Compute padding + ratio = r, r # width, height ratios + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding + if auto: # minimum rectangle + dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding + elif scale_fill: # stretch + dw, dh = 0.0, 0.0 + new_unpad = (new_shape[1], new_shape[0]) + ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios + + dw /= 2 # divide padding into 2 sides + dh /= 2 + + if shape[::-1] != new_unpad: # resize + img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border + return img, ratio, (dw, dh) diff --git a/Hallo2/hallo2/facelib/detection/yolov5face/utils/extract_ckpt.py b/Hallo2/hallo2/facelib/detection/yolov5face/utils/extract_ckpt.py new file mode 100644 index 00000000..4b8b6313 --- /dev/null +++ b/Hallo2/hallo2/facelib/detection/yolov5face/utils/extract_ckpt.py @@ -0,0 +1,5 @@ +import torch +import sys +sys.path.insert(0,'./facelib/detection/yolov5face') +model = torch.load('facelib/detection/yolov5face/yolov5n-face.pt', map_location='cpu')['model'] +torch.save(model.state_dict(),'weights/facelib/yolov5n-face.pth') \ No newline at end of file diff --git a/Hallo2/hallo2/facelib/detection/yolov5face/utils/general.py b/Hallo2/hallo2/facelib/detection/yolov5face/utils/general.py new file mode 100644 index 00000000..1c8e14f5 --- /dev/null +++ b/Hallo2/hallo2/facelib/detection/yolov5face/utils/general.py @@ -0,0 +1,271 @@ +import math +import time + +import numpy as np +import torch +import torchvision + + +def check_img_size(img_size, s=32): + # Verify img_size is a multiple of stride s + new_size = make_divisible(img_size, int(s)) # ceil gs-multiple + # if new_size != img_size: + # print(f"WARNING: --img-size {img_size:g} must be multiple of max stride {s:g}, updating to {new_size:g}") + return new_size + + +def make_divisible(x, divisor): + # Returns x evenly divisible by divisor + return math.ceil(x / divisor) * divisor + + +def xyxy2xywh(x): + # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center + y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center + y[:, 2] = x[:, 2] - x[:, 0] # width + y[:, 3] = x[:, 3] - x[:, 1] # height + return y + + +def xywh2xyxy(x): + # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x + y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y + y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x + y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y + return y + + +def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): + # Rescale coords (xyxy) from img1_shape to img0_shape + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + coords[:, [0, 2]] -= pad[0] # x padding + coords[:, [1, 3]] -= pad[1] # y padding + coords[:, :4] /= gain + clip_coords(coords, img0_shape) + return coords + + +def clip_coords(boxes, img_shape): + # Clip bounding xyxy bounding boxes to image shape (height, width) + boxes[:, 0].clamp_(0, img_shape[1]) # x1 + boxes[:, 1].clamp_(0, img_shape[0]) # y1 + boxes[:, 2].clamp_(0, img_shape[1]) # x2 + boxes[:, 3].clamp_(0, img_shape[0]) # y2 + + +def box_iou(box1, box2): + # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py + """ + Return intersection-over-union (Jaccard index) of boxes. + Both sets of boxes are expected to be in (x1, y1, x2, y2) format. + Arguments: + box1 (Tensor[N, 4]) + box2 (Tensor[M, 4]) + Returns: + iou (Tensor[N, M]): the NxM matrix containing the pairwise + IoU values for every element in boxes1 and boxes2 + """ + + def box_area(box): + return (box[2] - box[0]) * (box[3] - box[1]) + + area1 = box_area(box1.T) + area2 = box_area(box2.T) + + inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2) + return inter / (area1[:, None] + area2 - inter) + + +def non_max_suppression_face(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()): + """Performs Non-Maximum Suppression (NMS) on inference results + Returns: + detections with shape: nx6 (x1, y1, x2, y2, conf, cls) + """ + + nc = prediction.shape[2] - 15 # number of classes + xc = prediction[..., 4] > conf_thres # candidates + + # Settings + # (pixels) maximum box width and height + max_wh = 4096 + time_limit = 10.0 # seconds to quit after + redundant = True # require redundant detections + multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) + merge = False # use merge-NMS + + t = time.time() + output = [torch.zeros((0, 16), device=prediction.device)] * prediction.shape[0] + for xi, x in enumerate(prediction): # image index, image inference + # Apply constraints + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]): + label = labels[xi] + v = torch.zeros((len(label), nc + 15), device=x.device) + v[:, :4] = label[:, 1:5] # box + v[:, 4] = 1.0 # conf + v[range(len(label)), label[:, 0].long() + 15] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Compute conf + x[:, 15:] *= x[:, 4:5] # conf = obj_conf * cls_conf + + # Box (center x, center y, width, height) to (x1, y1, x2, y2) + box = xywh2xyxy(x[:, :4]) + + # Detections matrix nx6 (xyxy, conf, landmarks, cls) + if multi_label: + i, j = (x[:, 15:] > conf_thres).nonzero(as_tuple=False).T + x = torch.cat((box[i], x[i, j + 15, None], x[:, 5:15], j[:, None].float()), 1) + else: # best class only + conf, j = x[:, 15:].max(1, keepdim=True) + x = torch.cat((box, conf, x[:, 5:15], j.float()), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] + + # If none remain process next image + n = x.shape[0] # number of boxes + if not n: + continue + + # Batched NMS + c = x[:, 15:16] * (0 if agnostic else max_wh) # classes + boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + + if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean) + # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix + weights = iou * scores[None] # box weights + x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes + if redundant: + i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + if (time.time() - t) > time_limit: + break # time limit exceeded + + return output + + +def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()): + """Performs Non-Maximum Suppression (NMS) on inference results + + Returns: + detections with shape: nx6 (x1, y1, x2, y2, conf, cls) + """ + + nc = prediction.shape[2] - 5 # number of classes + xc = prediction[..., 4] > conf_thres # candidates + + # Settings + # (pixels) maximum box width and height + max_wh = 4096 + time_limit = 10.0 # seconds to quit after + redundant = True # require redundant detections + multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) + merge = False # use merge-NMS + + t = time.time() + output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0] + for xi, x in enumerate(prediction): # image index, image inference + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]): + label_id = labels[xi] + v = torch.zeros((len(label_id), nc + 5), device=x.device) + v[:, :4] = label_id[:, 1:5] # box + v[:, 4] = 1.0 # conf + v[range(len(label_id)), label_id[:, 0].long() + 5] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Compute conf + x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf + + # Box (center x, center y, width, height) to (x1, y1, x2, y2) + box = xywh2xyxy(x[:, :4]) + + # Detections matrix nx6 (xyxy, conf, cls) + if multi_label: + i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T + x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) + else: # best class only + conf, j = x[:, 5:].max(1, keepdim=True) + x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] + + # Check shape + n = x.shape[0] # number of boxes + if not n: # no boxes + continue + + x = x[x[:, 4].argsort(descending=True)] # sort by confidence + + # Batched NMS + c = x[:, 5:6] * (0 if agnostic else max_wh) # classes + boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean) + # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix + weights = iou * scores[None] # box weights + x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes + if redundant: + i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + if (time.time() - t) > time_limit: + print(f"WARNING: NMS time limit {time_limit}s exceeded") + break # time limit exceeded + + return output + + +def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None): + # Rescale coords (xyxy) from img1_shape to img0_shape + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + coords[:, [0, 2, 4, 6, 8]] -= pad[0] # x padding + coords[:, [1, 3, 5, 7, 9]] -= pad[1] # y padding + coords[:, :10] /= gain + coords[:, 0].clamp_(0, img0_shape[1]) # x1 + coords[:, 1].clamp_(0, img0_shape[0]) # y1 + coords[:, 2].clamp_(0, img0_shape[1]) # x2 + coords[:, 3].clamp_(0, img0_shape[0]) # y2 + coords[:, 4].clamp_(0, img0_shape[1]) # x3 + coords[:, 5].clamp_(0, img0_shape[0]) # y3 + coords[:, 6].clamp_(0, img0_shape[1]) # x4 + coords[:, 7].clamp_(0, img0_shape[0]) # y4 + coords[:, 8].clamp_(0, img0_shape[1]) # x5 + coords[:, 9].clamp_(0, img0_shape[0]) # y5 + return coords diff --git a/Hallo2/hallo2/facelib/detection/yolov5face/utils/torch_utils.py b/Hallo2/hallo2/facelib/detection/yolov5face/utils/torch_utils.py new file mode 100644 index 00000000..af2d0658 --- /dev/null +++ b/Hallo2/hallo2/facelib/detection/yolov5face/utils/torch_utils.py @@ -0,0 +1,40 @@ +import torch +from torch import nn + + +def fuse_conv_and_bn(conv, bn): + # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ + fusedconv = ( + nn.Conv2d( + conv.in_channels, + conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + groups=conv.groups, + bias=True, + ) + .requires_grad_(False) + .to(conv.weight.device) + ) + + # prepare filters + w_conv = conv.weight.clone().view(conv.out_channels, -1) + w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) + fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) + + # prepare spatial bias + b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias + b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) + fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) + + return fusedconv + + +def copy_attr(a, b, include=(), exclude=()): + # Copy attributes from b to a, options to only include [...] and to exclude [...] + for k, v in b.__dict__.items(): + if (include and k not in include) or k.startswith("_") or k in exclude: + continue + + setattr(a, k, v) diff --git a/Hallo2/hallo2/facelib/parsing/__init__.py b/Hallo2/hallo2/facelib/parsing/__init__.py new file mode 100644 index 00000000..0474f53e --- /dev/null +++ b/Hallo2/hallo2/facelib/parsing/__init__.py @@ -0,0 +1,23 @@ +import torch + +from facelib.utils import load_file_from_url +from .bisenet import BiSeNet +from .parsenet import ParseNet + + +def init_parsing_model(model_name='bisenet', half=False, device='cuda'): + if model_name == 'bisenet': + model = BiSeNet(num_class=19) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth' + elif model_name == 'parsenet': + model = ParseNet(in_size=512, out_size=512, parsing_ch=19) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth' + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + model_path = load_file_from_url(url=model_url, model_dir='pretrained_models/facelib', progress=True, file_name=None) + load_net = torch.load(model_path, map_location=lambda storage, loc: storage) + model.load_state_dict(load_net, strict=True) + model.eval() + model = model.to(device) + return model diff --git a/Hallo2/hallo2/facelib/parsing/bisenet.py b/Hallo2/hallo2/facelib/parsing/bisenet.py new file mode 100644 index 00000000..3898cab7 --- /dev/null +++ b/Hallo2/hallo2/facelib/parsing/bisenet.py @@ -0,0 +1,140 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .resnet import ResNet18 + + +class ConvBNReLU(nn.Module): + + def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1): + super(ConvBNReLU, self).__init__() + self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False) + self.bn = nn.BatchNorm2d(out_chan) + + def forward(self, x): + x = self.conv(x) + x = F.relu(self.bn(x)) + return x + + +class BiSeNetOutput(nn.Module): + + def __init__(self, in_chan, mid_chan, num_class): + super(BiSeNetOutput, self).__init__() + self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) + self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False) + + def forward(self, x): + feat = self.conv(x) + out = self.conv_out(feat) + return out, feat + + +class AttentionRefinementModule(nn.Module): + + def __init__(self, in_chan, out_chan): + super(AttentionRefinementModule, self).__init__() + self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) + self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False) + self.bn_atten = nn.BatchNorm2d(out_chan) + self.sigmoid_atten = nn.Sigmoid() + + def forward(self, x): + feat = self.conv(x) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv_atten(atten) + atten = self.bn_atten(atten) + atten = self.sigmoid_atten(atten) + out = torch.mul(feat, atten) + return out + + +class ContextPath(nn.Module): + + def __init__(self): + super(ContextPath, self).__init__() + self.resnet = ResNet18() + self.arm16 = AttentionRefinementModule(256, 128) + self.arm32 = AttentionRefinementModule(512, 128) + self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) + + def forward(self, x): + feat8, feat16, feat32 = self.resnet(x) + h8, w8 = feat8.size()[2:] + h16, w16 = feat16.size()[2:] + h32, w32 = feat32.size()[2:] + + avg = F.avg_pool2d(feat32, feat32.size()[2:]) + avg = self.conv_avg(avg) + avg_up = F.interpolate(avg, (h32, w32), mode='nearest') + + feat32_arm = self.arm32(feat32) + feat32_sum = feat32_arm + avg_up + feat32_up = F.interpolate(feat32_sum, (h16, w16), mode='nearest') + feat32_up = self.conv_head32(feat32_up) + + feat16_arm = self.arm16(feat16) + feat16_sum = feat16_arm + feat32_up + feat16_up = F.interpolate(feat16_sum, (h8, w8), mode='nearest') + feat16_up = self.conv_head16(feat16_up) + + return feat8, feat16_up, feat32_up # x8, x8, x16 + + +class FeatureFusionModule(nn.Module): + + def __init__(self, in_chan, out_chan): + super(FeatureFusionModule, self).__init__() + self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) + self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False) + self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False) + self.relu = nn.ReLU(inplace=True) + self.sigmoid = nn.Sigmoid() + + def forward(self, fsp, fcp): + fcat = torch.cat([fsp, fcp], dim=1) + feat = self.convblk(fcat) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv1(atten) + atten = self.relu(atten) + atten = self.conv2(atten) + atten = self.sigmoid(atten) + feat_atten = torch.mul(feat, atten) + feat_out = feat_atten + feat + return feat_out + + +class BiSeNet(nn.Module): + + def __init__(self, num_class): + super(BiSeNet, self).__init__() + self.cp = ContextPath() + self.ffm = FeatureFusionModule(256, 256) + self.conv_out = BiSeNetOutput(256, 256, num_class) + self.conv_out16 = BiSeNetOutput(128, 64, num_class) + self.conv_out32 = BiSeNetOutput(128, 64, num_class) + + def forward(self, x, return_feat=False): + h, w = x.size()[2:] + feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature + feat_sp = feat_res8 # replace spatial path feature with res3b1 feature + feat_fuse = self.ffm(feat_sp, feat_cp8) + + out, feat = self.conv_out(feat_fuse) + out16, feat16 = self.conv_out16(feat_cp8) + out32, feat32 = self.conv_out32(feat_cp16) + + out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True) + out16 = F.interpolate(out16, (h, w), mode='bilinear', align_corners=True) + out32 = F.interpolate(out32, (h, w), mode='bilinear', align_corners=True) + + if return_feat: + feat = F.interpolate(feat, (h, w), mode='bilinear', align_corners=True) + feat16 = F.interpolate(feat16, (h, w), mode='bilinear', align_corners=True) + feat32 = F.interpolate(feat32, (h, w), mode='bilinear', align_corners=True) + return out, out16, out32, feat, feat16, feat32 + else: + return out, out16, out32 diff --git a/Hallo2/hallo2/facelib/parsing/parsenet.py b/Hallo2/hallo2/facelib/parsing/parsenet.py new file mode 100644 index 00000000..e178ebe4 --- /dev/null +++ b/Hallo2/hallo2/facelib/parsing/parsenet.py @@ -0,0 +1,194 @@ +"""Modified from https://github.com/chaofengc/PSFRGAN +""" +import numpy as np +import torch.nn as nn +from torch.nn import functional as F + + +class NormLayer(nn.Module): + """Normalization Layers. + + Args: + channels: input channels, for batch norm and instance norm. + input_size: input shape without batch size, for layer norm. + """ + + def __init__(self, channels, normalize_shape=None, norm_type='bn'): + super(NormLayer, self).__init__() + norm_type = norm_type.lower() + self.norm_type = norm_type + if norm_type == 'bn': + self.norm = nn.BatchNorm2d(channels, affine=True) + elif norm_type == 'in': + self.norm = nn.InstanceNorm2d(channels, affine=False) + elif norm_type == 'gn': + self.norm = nn.GroupNorm(32, channels, affine=True) + elif norm_type == 'pixel': + self.norm = lambda x: F.normalize(x, p=2, dim=1) + elif norm_type == 'layer': + self.norm = nn.LayerNorm(normalize_shape) + elif norm_type == 'none': + self.norm = lambda x: x * 1.0 + else: + assert 1 == 0, f'Norm type {norm_type} not support.' + + def forward(self, x, ref=None): + if self.norm_type == 'spade': + return self.norm(x, ref) + else: + return self.norm(x) + + +class ReluLayer(nn.Module): + """Relu Layer. + + Args: + relu type: type of relu layer, candidates are + - ReLU + - LeakyReLU: default relu slope 0.2 + - PRelu + - SELU + - none: direct pass + """ + + def __init__(self, channels, relu_type='relu'): + super(ReluLayer, self).__init__() + relu_type = relu_type.lower() + if relu_type == 'relu': + self.func = nn.ReLU(True) + elif relu_type == 'leakyrelu': + self.func = nn.LeakyReLU(0.2, inplace=True) + elif relu_type == 'prelu': + self.func = nn.PReLU(channels) + elif relu_type == 'selu': + self.func = nn.SELU(True) + elif relu_type == 'none': + self.func = lambda x: x * 1.0 + else: + assert 1 == 0, f'Relu type {relu_type} not support.' + + def forward(self, x): + return self.func(x) + + +class ConvLayer(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + scale='none', + norm_type='none', + relu_type='none', + use_pad=True, + bias=True): + super(ConvLayer, self).__init__() + self.use_pad = use_pad + self.norm_type = norm_type + if norm_type in ['bn']: + bias = False + + stride = 2 if scale == 'down' else 1 + + self.scale_func = lambda x: x + if scale == 'up': + self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest') + + self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2))) + self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias) + + self.relu = ReluLayer(out_channels, relu_type) + self.norm = NormLayer(out_channels, norm_type=norm_type) + + def forward(self, x): + out = self.scale_func(x) + if self.use_pad: + out = self.reflection_pad(out) + out = self.conv2d(out) + out = self.norm(out) + out = self.relu(out) + return out + + +class ResidualBlock(nn.Module): + """ + Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html + """ + + def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'): + super(ResidualBlock, self).__init__() + + if scale == 'none' and c_in == c_out: + self.shortcut_func = lambda x: x + else: + self.shortcut_func = ConvLayer(c_in, c_out, 3, scale) + + scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']} + scale_conf = scale_config_dict[scale] + + self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type) + self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none') + + def forward(self, x): + identity = self.shortcut_func(x) + + res = self.conv1(x) + res = self.conv2(res) + return identity + res + + +class ParseNet(nn.Module): + + def __init__(self, + in_size=128, + out_size=128, + min_feat_size=32, + base_ch=64, + parsing_ch=19, + res_depth=10, + relu_type='LeakyReLU', + norm_type='bn', + ch_range=[32, 256]): + super().__init__() + self.res_depth = res_depth + act_args = {'norm_type': norm_type, 'relu_type': relu_type} + min_ch, max_ch = ch_range + + ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731 + min_feat_size = min(in_size, min_feat_size) + + down_steps = int(np.log2(in_size // min_feat_size)) + up_steps = int(np.log2(out_size // min_feat_size)) + + # =============== define encoder-body-decoder ==================== + self.encoder = [] + self.encoder.append(ConvLayer(3, base_ch, 3, 1)) + head_ch = base_ch + for i in range(down_steps): + cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2) + self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args)) + head_ch = head_ch * 2 + + self.body = [] + for i in range(res_depth): + self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args)) + + self.decoder = [] + for i in range(up_steps): + cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2) + self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args)) + head_ch = head_ch // 2 + + self.encoder = nn.Sequential(*self.encoder) + self.body = nn.Sequential(*self.body) + self.decoder = nn.Sequential(*self.decoder) + self.out_img_conv = ConvLayer(ch_clip(head_ch), 3) + self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch) + + def forward(self, x): + feat = self.encoder(x) + x = feat + self.body(feat) + x = self.decoder(x) + out_img = self.out_img_conv(x) + out_mask = self.out_mask_conv(x) + return out_mask, out_img diff --git a/Hallo2/hallo2/facelib/parsing/resnet.py b/Hallo2/hallo2/facelib/parsing/resnet.py new file mode 100644 index 00000000..fec8e82c --- /dev/null +++ b/Hallo2/hallo2/facelib/parsing/resnet.py @@ -0,0 +1,69 @@ +import torch.nn as nn +import torch.nn.functional as F + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + + def __init__(self, in_chan, out_chan, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(in_chan, out_chan, stride) + self.bn1 = nn.BatchNorm2d(out_chan) + self.conv2 = conv3x3(out_chan, out_chan) + self.bn2 = nn.BatchNorm2d(out_chan) + self.relu = nn.ReLU(inplace=True) + self.downsample = None + if in_chan != out_chan or stride != 1: + self.downsample = nn.Sequential( + nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(out_chan), + ) + + def forward(self, x): + residual = self.conv1(x) + residual = F.relu(self.bn1(residual)) + residual = self.conv2(residual) + residual = self.bn2(residual) + + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x) + + out = shortcut + residual + out = self.relu(out) + return out + + +def create_layer_basic(in_chan, out_chan, bnum, stride=1): + layers = [BasicBlock(in_chan, out_chan, stride=stride)] + for i in range(bnum - 1): + layers.append(BasicBlock(out_chan, out_chan, stride=1)) + return nn.Sequential(*layers) + + +class ResNet18(nn.Module): + + def __init__(self): + super(ResNet18, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) + self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) + self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) + self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(self.bn1(x)) + x = self.maxpool(x) + + x = self.layer1(x) + feat8 = self.layer2(x) # 1/8 + feat16 = self.layer3(feat8) # 1/16 + feat32 = self.layer4(feat16) # 1/32 + return feat8, feat16, feat32 diff --git a/Hallo2/hallo2/facelib/utils/__init__.py b/Hallo2/hallo2/facelib/utils/__init__.py new file mode 100644 index 00000000..f03b1c2b --- /dev/null +++ b/Hallo2/hallo2/facelib/utils/__init__.py @@ -0,0 +1,7 @@ +from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back +from .misc import img2tensor, load_file_from_url, download_pretrained_models, scandir + +__all__ = [ + 'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url', + 'download_pretrained_models', 'paste_face_back', 'img2tensor', 'scandir' +] diff --git a/Hallo2/hallo2/facelib/utils/face_restoration_helper.py b/Hallo2/hallo2/facelib/utils/face_restoration_helper.py new file mode 100644 index 00000000..cf7ac424 --- /dev/null +++ b/Hallo2/hallo2/facelib/utils/face_restoration_helper.py @@ -0,0 +1,512 @@ +import cv2 +import numpy as np +import os +import torch +from torchvision.transforms.functional import normalize + +from facelib.detection import init_detection_model +from facelib.parsing import init_parsing_model +from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray, adain_npy +from basicsr.utils.download_util import load_file_from_url +from basicsr.utils.misc import get_device + + +dlib_model_url = { + 'face_detector': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/mmod_human_face_detector-4cb19393.dat', + 'shape_predictor_5': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/shape_predictor_5_face_landmarks-c4b1e980.dat' +} + +def get_largest_face(det_faces, h, w): + + def get_location(val, length): + if val < 0: + return 0 + elif val > length: + return length + else: + return val + + face_areas = [] + for det_face in det_faces: + left = get_location(det_face[0], w) + right = get_location(det_face[2], w) + top = get_location(det_face[1], h) + bottom = get_location(det_face[3], h) + face_area = (right - left) * (bottom - top) + face_areas.append(face_area) + largest_idx = face_areas.index(max(face_areas)) + return det_faces[largest_idx], largest_idx + + +def get_center_face(det_faces, h=0, w=0, center=None): + if center is not None: + center = np.array(center) + else: + center = np.array([w / 2, h / 2]) + center_dist = [] + for det_face in det_faces: + face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2]) + dist = np.linalg.norm(face_center - center) + center_dist.append(dist) + center_idx = center_dist.index(min(center_dist)) + return det_faces[center_idx], center_idx + + +class FaceRestoreHelper(object): + """Helper for the face restoration pipeline (base class).""" + + def __init__(self, + upscale_factor, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + template_3points=False, + pad_blur=False, + use_parse=False, + device=None): + self.template_3points = template_3points # improve robustness + self.upscale_factor = int(upscale_factor) + # the cropped face ratio based on the square face + self.crop_ratio = crop_ratio # (h, w) + assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1' + self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0])) + self.det_model = det_model + + if self.det_model == 'dlib': + # standard 5 landmarks for FFHQ faces with 1024 x 1024 + self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941], + [337.91089109, 488.38613861], [437.95049505, 493.51485149], + [513.58415842, 678.5049505]]) + self.face_template = self.face_template / (1024 // face_size) + elif self.template_3points: + self.face_template = np.array([[192, 240], [319, 240], [257, 371]]) + else: + # standard 5 landmarks for FFHQ faces with 512 x 512 + # facexlib + self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935], + [201.26117, 371.41043], [313.08905, 371.15118]]) + + # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54 + # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894], + # [198.22603, 372.82502], [313.91018, 372.75659]]) + + self.face_template = self.face_template * (face_size / 512.0) + if self.crop_ratio[0] > 1: + self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2 + if self.crop_ratio[1] > 1: + self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2 + self.save_ext = save_ext + self.pad_blur = pad_blur + if self.pad_blur is True: + self.template_3points = False + + self.all_landmarks_5 = [] + self.det_faces = [] + self.affine_matrices = [] + self.inverse_affine_matrices = [] + self.cropped_faces = [] + self.restored_faces = [] + self.pad_input_imgs = [] + + if device is None: + # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = get_device() + else: + self.device = device + + # init face detection model + if self.det_model == 'dlib': + self.face_detector, self.shape_predictor_5 = self.init_dlib(dlib_model_url['face_detector'], dlib_model_url['shape_predictor_5']) + else: + self.face_detector = init_detection_model(det_model, half=False, device=self.device) + + # init face parsing model + self.use_parse = use_parse + self.face_parse = init_parsing_model(model_name='parsenet', device=self.device) + + def set_upscale_factor(self, upscale_factor): + self.upscale_factor = upscale_factor + + def read_image(self, img): + """img can be image path or cv2 loaded image.""" + # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255] + if isinstance(img, str): + img = cv2.imread(img) + + if np.max(img) > 256: # 16-bit image + img = img / 65535 * 255 + if len(img.shape) == 2: # gray image + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif img.shape[2] == 4: # BGRA image with alpha channel + img = img[:, :, 0:3] + + self.input_img = img + self.is_gray = is_gray(img, threshold=10) + if self.is_gray: + print('Grayscale input: True') + + if min(self.input_img.shape[:2])<512: + f = 512.0/min(self.input_img.shape[:2]) + self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR) + + def init_dlib(self, detection_path, landmark5_path): + """Initialize the dlib detectors and predictors.""" + try: + import dlib + except ImportError: + print('Please install dlib by running:' 'conda install -c conda-forge dlib') + detection_path = load_file_from_url(url=detection_path, model_dir='weights/dlib', progress=True, file_name=None) + landmark5_path = load_file_from_url(url=landmark5_path, model_dir='weights/dlib', progress=True, file_name=None) + face_detector = dlib.cnn_face_detection_model_v1(detection_path) + shape_predictor_5 = dlib.shape_predictor(landmark5_path) + return face_detector, shape_predictor_5 + + def get_face_landmarks_5_dlib(self, + only_keep_largest=False, + scale=1): + det_faces = self.face_detector(self.input_img, scale) + + if len(det_faces) == 0: + print('No face detected. Try to increase upsample_num_times.') + return 0 + else: + if only_keep_largest: + print('Detect several faces and only keep the largest.') + face_areas = [] + for i in range(len(det_faces)): + face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * ( + det_faces[i].rect.bottom() - det_faces[i].rect.top()) + face_areas.append(face_area) + largest_idx = face_areas.index(max(face_areas)) + self.det_faces = [det_faces[largest_idx]] + else: + self.det_faces = det_faces + + if len(self.det_faces) == 0: + return 0 + + for face in self.det_faces: + shape = self.shape_predictor_5(self.input_img, face.rect) + landmark = np.array([[part.x, part.y] for part in shape.parts()]) + self.all_landmarks_5.append(landmark) + + return len(self.all_landmarks_5) + + + def get_face_landmarks_5(self, + only_keep_largest=False, + only_center_face=False, + resize=None, + blur_ratio=0.01, + eye_dist_threshold=None): + if self.det_model == 'dlib': + return self.get_face_landmarks_5_dlib(only_keep_largest) + + if resize is None: + scale = 1 + input_img = self.input_img + else: + h, w = self.input_img.shape[0:2] + scale = resize / min(h, w) + scale = max(1, scale) # always scale up + h, w = int(h * scale), int(w * scale) + interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR + input_img = cv2.resize(self.input_img, (w, h), interpolation=interp) + + with torch.no_grad(): + bboxes = self.face_detector.detect_faces(input_img) + # ic(bboxes.shape) + # cv2.imwrite('resized_image.jpg', input_img) + + if bboxes is None or bboxes.shape[0] == 0: + return 0 + else: + bboxes = bboxes / scale + + for bbox in bboxes: + # remove faces with too small eye distance: side faces or too small faces + eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]]) + if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold): + continue + + if self.template_3points: + landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)]) + else: + landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)]) + self.all_landmarks_5.append(landmark) + self.det_faces.append(bbox[0:5]) + + if len(self.det_faces) == 0: + return 0 + if only_keep_largest: + h, w, _ = self.input_img.shape + self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w) + self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]] + elif only_center_face: + h, w, _ = self.input_img.shape + self.det_faces, center_idx = get_center_face(self.det_faces, h, w) + self.all_landmarks_5 = [self.all_landmarks_5[center_idx]] + + # pad blurry images + if self.pad_blur: + self.pad_input_imgs = [] + for landmarks in self.all_landmarks_5: + # get landmarks + eye_left = landmarks[0, :] + eye_right = landmarks[1, :] + eye_avg = (eye_left + eye_right) * 0.5 + mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5 + eye_to_eye = eye_right - eye_left + eye_to_mouth = mouth_avg - eye_avg + + # Get the oriented crop rectangle + # x: half width of the oriented crop rectangle + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise + # norm with the hypotenuse: get the direction + x /= np.hypot(*x) # get the hypotenuse of a right triangle + rect_scale = 1.5 + x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale) + # y: half height of the oriented crop rectangle + y = np.flipud(x) * [-1, 1] + + # c: center + c = eye_avg + eye_to_mouth * 0.1 + # quad: (left_top, left_bottom, right_bottom, right_top) + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + # qsize: side length of the square + qsize = np.hypot(*x) * 2 + border = max(int(np.rint(qsize * 0.1)), 3) + + # get pad + # pad: (width_left, height_top, width_right, height_bottom) + pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + pad = [ + max(-pad[0] + border, 1), + max(-pad[1] + border, 1), + max(pad[2] - self.input_img.shape[0] + border, 1), + max(pad[3] - self.input_img.shape[1] + border, 1) + ] + + if max(pad) > 1: + # pad image + pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + # modify landmark coords + landmarks[:, 0] += pad[0] + landmarks[:, 1] += pad[1] + # blur pad images + h, w, _ = pad_img.shape + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], + np.float32(w - 1 - x) / pad[2]), + 1.0 - np.minimum(np.float32(y) / pad[1], + np.float32(h - 1 - y) / pad[3])) + blur = int(qsize * blur_ratio) + if blur % 2 == 0: + blur += 1 + blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur)) + # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0) + + pad_img = pad_img.astype('float32') + pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0) + pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255] + self.pad_input_imgs.append(pad_img) + else: + self.pad_input_imgs.append(np.copy(self.input_img)) + + return len(bboxes) + + def align_warp_face(self, save_cropped_path=None, border_mode='constant'): + """Align and warp faces with face template. + """ + if self.pad_blur: + assert len(self.pad_input_imgs) == len( + self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}' + for idx, landmark in enumerate(self.all_landmarks_5[-1: ]): + # use 5 landmarks to get affine matrix + # use cv2.LMEDS method for the equivalence to skimage transform + # ref: https://blog.csdn.net/yichxi/article/details/115827338 + affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0] + self.affine_matrices.append(affine_matrix) + # warp and crop faces + if border_mode == 'constant': + border_mode = cv2.BORDER_CONSTANT + elif border_mode == 'reflect101': + border_mode = cv2.BORDER_REFLECT101 + elif border_mode == 'reflect': + border_mode = cv2.BORDER_REFLECT + if self.pad_blur: + input_img = self.pad_input_imgs[idx] + else: + input_img = self.input_img + cropped_face = cv2.warpAffine( + input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray + self.cropped_faces.append(cropped_face) + # save the cropped face + if save_cropped_path is not None: + path = os.path.splitext(save_cropped_path)[0] + save_path = f'{path}_{idx:02d}.{self.save_ext}' + imwrite(cropped_face, save_path) + + def get_inverse_affine(self, save_inverse_affine_path=None): + """Get inverse affine matrix.""" + for idx, affine_matrix in enumerate(self.affine_matrices): + inverse_affine = cv2.invertAffineTransform(affine_matrix) + inverse_affine *= self.upscale_factor + self.inverse_affine_matrices.append(inverse_affine) + # save inverse affine matrices + if save_inverse_affine_path is not None: + path, _ = os.path.splitext(save_inverse_affine_path) + save_path = f'{path}_{idx:02d}.pth' + torch.save(inverse_affine, save_path) + + + def add_restored_face(self, restored_face, input_face=None): + if self.is_gray: + restored_face = bgr2gray(restored_face) # convert img into grayscale + if input_face is not None: + restored_face = adain_npy(restored_face, input_face) # transfer the color + self.restored_faces.append(restored_face) + + + def paste_faces_to_input_image(self, save_path=None, upsample_img_list=None, draw_box=False, face_upsampler=None): + h, w, _ = self.input_img.shape + h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor) + + results = [] + + for i, upsample_img in enumerate(upsample_img_list): + if upsample_img is None: + # simply resize the background + # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4) + upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR) + raise Exception("upsample img couldn't be none") + else: + upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4) + + assert len(self.restored_faces) == len( + self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.') + + inv_mask_borders = [] + # for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices): + restored_face, inverse_affine = self.restored_faces[i], self.inverse_affine_matrices[i] + + if face_upsampler is not None: + restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0] + inverse_affine /= self.upscale_factor + inverse_affine[:, 2] *= self.upscale_factor + face_size = (self.face_size[0]*self.upscale_factor, self.face_size[1]*self.upscale_factor) + else: + # Add an offset to inverse affine matrix, for more precise back alignment + if self.upscale_factor > 1: + extra_offset = 0.5 * self.upscale_factor + else: + extra_offset = 0 + inverse_affine[:, 2] += extra_offset + face_size = self.face_size + inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up)) + + # always use square mask + mask = np.ones(face_size, dtype=np.float32) + inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) + # remove the black borders + inv_mask_erosion = cv2.erode( + inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8)) + pasted_face = inv_mask_erosion[:, :, None] * inv_restored + total_face_area = np.sum(inv_mask_erosion) # // 3 + + # add border + if draw_box: + h, w = face_size + mask_border = np.ones((h, w, 3), dtype=np.float32) + border = int(1400/np.sqrt(total_face_area)) + mask_border[border:h-border, border:w-border,:] = 0 + inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up)) + inv_mask_borders.append(inv_mask_border) + + # compute the fusion edge based on the area of face + w_edge = int(total_face_area**0.5) // 20 + erosion_radius = w_edge * 2 + inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) + blur_size = w_edge * 2 + inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) + if len(upsample_img.shape) == 2: # upsample_img is gray image + upsample_img = upsample_img[:, :, None] + inv_soft_mask = inv_soft_mask[:, :, None] + + # parse mask + if self.use_parse: + # inference + face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR) + face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True) + normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + face_input = torch.unsqueeze(face_input, 0).to(self.device) + with torch.no_grad(): + out = self.face_parse(face_input)[0] + out = out.argmax(dim=1).squeeze().cpu().numpy() + + parse_mask = np.zeros(out.shape) + MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0] + for idx, color in enumerate(MASK_COLORMAP): + parse_mask[out == idx] = color + # blur the mask + parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11) + parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11) + # remove the black borders + thres = 10 + parse_mask[:thres, :] = 0 + parse_mask[-thres:, :] = 0 + parse_mask[:, :thres] = 0 + parse_mask[:, -thres:] = 0 + parse_mask = parse_mask / 255. + + parse_mask = cv2.resize(parse_mask, face_size) + parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3) + inv_soft_parse_mask = parse_mask[:, :, None] + # pasted_face = inv_restored + fuse_mask = (inv_soft_parse_mask 256: # 16-bit image + upsample_img = upsample_img.astype(np.uint16) + else: + upsample_img = upsample_img.astype(np.uint8) + + # draw bounding box + if draw_box: + # upsample_input_img = cv2.resize(input_img, (w_up, h_up)) + img_color = np.ones([*upsample_img.shape], dtype=np.float32) + img_color[:,:,0] = 0 + img_color[:,:,1] = 255 + img_color[:,:,2] = 0 + for inv_mask_border in inv_mask_borders: + upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img + # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img + + if save_path is not None: + path = os.path.splitext(save_path)[0] + save_path = f'{path}.{self.save_ext}' + imwrite(upsample_img, save_path) + + results.append(upsample_img) + + return results + + def clean_all(self): + self.all_landmarks_5 = [] + self.restored_faces = [] + self.affine_matrices = [] + self.cropped_faces = [] + self.inverse_affine_matrices = [] + self.det_faces = [] + self.pad_input_imgs = [] \ No newline at end of file diff --git a/Hallo2/hallo2/facelib/utils/face_utils.py b/Hallo2/hallo2/facelib/utils/face_utils.py new file mode 100644 index 00000000..f1474a2a --- /dev/null +++ b/Hallo2/hallo2/facelib/utils/face_utils.py @@ -0,0 +1,248 @@ +import cv2 +import numpy as np +import torch + + +def compute_increased_bbox(bbox, increase_area, preserve_aspect=True): + left, top, right, bot = bbox + width = right - left + height = bot - top + + if preserve_aspect: + width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width)) + height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height)) + else: + width_increase = height_increase = increase_area + left = int(left - width_increase * width) + top = int(top - height_increase * height) + right = int(right + width_increase * width) + bot = int(bot + height_increase * height) + return (left, top, right, bot) + + +def get_valid_bboxes(bboxes, h, w): + left = max(bboxes[0], 0) + top = max(bboxes[1], 0) + right = min(bboxes[2], w) + bottom = min(bboxes[3], h) + return (left, top, right, bottom) + + +def align_crop_face_landmarks(img, + landmarks, + output_size, + transform_size=None, + enable_padding=True, + return_inverse_affine=False, + shrink_ratio=(1, 1)): + """Align and crop face with landmarks. + + The output_size and transform_size are based on width. The height is + adjusted based on shrink_ratio_h/shring_ration_w. + + Modified from: + https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py + + Args: + img (Numpy array): Input image. + landmarks (Numpy array): 5 or 68 or 98 landmarks. + output_size (int): Output face size. + transform_size (ing): Transform size. Usually the four time of + output_size. + enable_padding (float): Default: True. + shrink_ratio (float | tuple[float] | list[float]): Shring the whole + face for height and width (crop larger area). Default: (1, 1). + + Returns: + (Numpy array): Cropped face. + """ + lm_type = 'retinaface_5' # Options: dlib_5, retinaface_5 + + if isinstance(shrink_ratio, (float, int)): + shrink_ratio = (shrink_ratio, shrink_ratio) + if transform_size is None: + transform_size = output_size * 4 + + # Parse landmarks + lm = np.array(landmarks) + if lm.shape[0] == 5 and lm_type == 'retinaface_5': + eye_left = lm[0] + eye_right = lm[1] + mouth_avg = (lm[3] + lm[4]) * 0.5 + elif lm.shape[0] == 5 and lm_type == 'dlib_5': + lm_eye_left = lm[2:4] + lm_eye_right = lm[0:2] + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + mouth_avg = lm[4] + elif lm.shape[0] == 68: + lm_eye_left = lm[36:42] + lm_eye_right = lm[42:48] + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + mouth_avg = (lm[48] + lm[54]) * 0.5 + elif lm.shape[0] == 98: + lm_eye_left = lm[60:68] + lm_eye_right = lm[68:76] + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + mouth_avg = (lm[76] + lm[82]) * 0.5 + + eye_avg = (eye_left + eye_right) * 0.5 + eye_to_eye = eye_right - eye_left + eye_to_mouth = mouth_avg - eye_avg + + # Get the oriented crop rectangle + # x: half width of the oriented crop rectangle + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise + # norm with the hypotenuse: get the direction + x /= np.hypot(*x) # get the hypotenuse of a right triangle + rect_scale = 1 # TODO: you can edit it to get larger rect + x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale) + # y: half height of the oriented crop rectangle + y = np.flipud(x) * [-1, 1] + + x *= shrink_ratio[1] # width + y *= shrink_ratio[0] # height + + # c: center + c = eye_avg + eye_to_mouth * 0.1 + # quad: (left_top, left_bottom, right_bottom, right_top) + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + # qsize: side length of the square + qsize = np.hypot(*x) * 2 + + quad_ori = np.copy(quad) + # Shrink, for large face + # TODO: do we really need shrink + shrink = int(np.floor(qsize / output_size * 0.5)) + if shrink > 1: + h, w = img.shape[0:2] + rsize = (int(np.rint(float(w) / shrink)), int(np.rint(float(h) / shrink))) + img = cv2.resize(img, rsize, interpolation=cv2.INTER_AREA) + quad /= shrink + qsize /= shrink + + # Crop + h, w = img.shape[0:2] + border = max(int(np.rint(qsize * 0.1)), 3) + crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, w), min(crop[3] + border, h)) + if crop[2] - crop[0] < w or crop[3] - crop[1] < h: + img = img[crop[1]:crop[3], crop[0]:crop[2], :] + quad -= crop[0:2] + + # Pad + # pad: (width_left, height_top, width_right, height_bottom) + h, w = img.shape[0:2] + pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - w + border, 0), max(pad[3] - h + border, 0)) + if enable_padding and max(pad) > border - 4: + pad = np.maximum(pad, int(np.rint(qsize * 0.3))) + img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + h, w = img.shape[0:2] + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], + np.float32(w - 1 - x) / pad[2]), + 1.0 - np.minimum(np.float32(y) / pad[1], + np.float32(h - 1 - y) / pad[3])) + blur = int(qsize * 0.02) + if blur % 2 == 0: + blur += 1 + blur_img = cv2.boxFilter(img, 0, ksize=(blur, blur)) + + img = img.astype('float32') + img += (blur_img - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) + img = np.clip(img, 0, 255) # float32, [0, 255] + quad += pad[:2] + + # Transform use cv2 + h_ratio = shrink_ratio[0] / shrink_ratio[1] + dst_h, dst_w = int(transform_size * h_ratio), transform_size + template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]]) + # use cv2.LMEDS method for the equivalence to skimage transform + # ref: https://blog.csdn.net/yichxi/article/details/115827338 + affine_matrix = cv2.estimateAffinePartial2D(quad, template, method=cv2.LMEDS)[0] + cropped_face = cv2.warpAffine( + img, affine_matrix, (dst_w, dst_h), borderMode=cv2.BORDER_CONSTANT, borderValue=(135, 133, 132)) # gray + + if output_size < transform_size: + cropped_face = cv2.resize( + cropped_face, (output_size, int(output_size * h_ratio)), interpolation=cv2.INTER_LINEAR) + + if return_inverse_affine: + dst_h, dst_w = int(output_size * h_ratio), output_size + template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]]) + # use cv2.LMEDS method for the equivalence to skimage transform + # ref: https://blog.csdn.net/yichxi/article/details/115827338 + affine_matrix = cv2.estimateAffinePartial2D( + quad_ori, np.array([[0, 0], [0, output_size], [dst_w, dst_h], [dst_w, 0]]), method=cv2.LMEDS)[0] + inverse_affine = cv2.invertAffineTransform(affine_matrix) + else: + inverse_affine = None + return cropped_face, inverse_affine + + +def paste_face_back(img, face, inverse_affine): + h, w = img.shape[0:2] + face_h, face_w = face.shape[0:2] + inv_restored = cv2.warpAffine(face, inverse_affine, (w, h)) + mask = np.ones((face_h, face_w, 3), dtype=np.float32) + inv_mask = cv2.warpAffine(mask, inverse_affine, (w, h)) + # remove the black borders + inv_mask_erosion = cv2.erode(inv_mask, np.ones((2, 2), np.uint8)) + inv_restored_remove_border = inv_mask_erosion * inv_restored + total_face_area = np.sum(inv_mask_erosion) // 3 + # compute the fusion edge based on the area of face + w_edge = int(total_face_area**0.5) // 20 + erosion_radius = w_edge * 2 + inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) + blur_size = w_edge * 2 + inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) + img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * img + # float32, [0, 255] + return img + + +if __name__ == '__main__': + import os + + from facelib.detection import init_detection_model + from facelib.utils.face_restoration_helper import get_largest_face + + img_path = '/home/wxt/datasets/ffhq/ffhq_wild/00009.png' + img_name = os.splitext(os.path.basename(img_path))[0] + + # initialize model + det_net = init_detection_model('retinaface_resnet50', half=False) + img_ori = cv2.imread(img_path) + h, w = img_ori.shape[0:2] + # if larger than 800, scale it + scale = max(h / 800, w / 800) + if scale > 1: + img = cv2.resize(img_ori, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_LINEAR) + + with torch.no_grad(): + bboxes = det_net.detect_faces(img, 0.97) + if scale > 1: + bboxes *= scale # the score is incorrect + bboxes = get_largest_face(bboxes, h, w)[0] + + landmarks = np.array([[bboxes[i], bboxes[i + 1]] for i in range(5, 15, 2)]) + + cropped_face, inverse_affine = align_crop_face_landmarks( + img_ori, + landmarks, + output_size=512, + transform_size=None, + enable_padding=True, + return_inverse_affine=True, + shrink_ratio=(1, 1)) + + cv2.imwrite(f'tmp/{img_name}_cropeed_face.png', cropped_face) + img = paste_face_back(img_ori, cropped_face, inverse_affine) + cv2.imwrite(f'tmp/{img_name}_back.png', img) diff --git a/Hallo2/hallo2/facelib/utils/misc.py b/Hallo2/hallo2/facelib/utils/misc.py new file mode 100644 index 00000000..18755792 --- /dev/null +++ b/Hallo2/hallo2/facelib/utils/misc.py @@ -0,0 +1,202 @@ +import cv2 +import os +import os.path as osp +import numpy as np +from PIL import Image +import torch +from torch.hub import download_url_to_file, get_dir +from urllib.parse import urlparse +# from basicsr.utils.download_util import download_file_from_google_drive + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +def download_pretrained_models(file_ids, save_path_root): + import gdown + + os.makedirs(save_path_root, exist_ok=True) + + for file_name, file_id in file_ids.items(): + file_url = 'https://drive.google.com/uc?id='+file_id + save_path = osp.abspath(osp.join(save_path_root, file_name)) + if osp.exists(save_path): + user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n') + if user_response.lower() == 'y': + print(f'Covering {file_name} to {save_path}') + gdown.download(file_url, save_path, quiet=False) + # download_file_from_google_drive(file_id, save_path) + elif user_response.lower() == 'n': + print(f'Skipping {file_name}') + else: + raise ValueError('Wrong input. Only accepts Y/N.') + else: + print(f'Downloading {file_name} to {save_path}') + gdown.download(file_url, save_path, quiet=False) + # download_file_from_google_drive(file_id, save_path) + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv's :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + return cv2.imwrite(file_path, img, params) + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + """ + if model_dir is None: + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + Returns: + A generator for all the interested files with relative paths. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + + +def is_gray(img, threshold=10): + img = Image.fromarray(img) + if len(img.getbands()) == 1: + return True + img1 = np.asarray(img.getchannel(channel=0), dtype=np.int16) + img2 = np.asarray(img.getchannel(channel=1), dtype=np.int16) + img3 = np.asarray(img.getchannel(channel=2), dtype=np.int16) + diff1 = (img1 - img2).var() + diff2 = (img2 - img3).var() + diff3 = (img3 - img1).var() + diff_sum = (diff1 + diff2 + diff3) / 3.0 + if diff_sum <= threshold: + return True + else: + return False + +def rgb2gray(img, out_channel=3): + r, g, b = img[:,:,0], img[:,:,1], img[:,:,2] + gray = 0.2989 * r + 0.5870 * g + 0.1140 * b + if out_channel == 3: + gray = gray[:,:,np.newaxis].repeat(3, axis=2) + return gray + +def bgr2gray(img, out_channel=3): + b, g, r = img[:,:,0], img[:,:,1], img[:,:,2] + gray = 0.2989 * r + 0.5870 * g + 0.1140 * b + if out_channel == 3: + gray = gray[:,:,np.newaxis].repeat(3, axis=2) + return gray + + +def calc_mean_std(feat, eps=1e-5): + """ + Args: + feat (numpy): 3D [w h c]s + """ + size = feat.shape + assert len(size) == 3, 'The input feature should be 3D tensor.' + c = size[2] + feat_var = feat.reshape(-1, c).var(axis=0) + eps + feat_std = np.sqrt(feat_var).reshape(1, 1, c) + feat_mean = feat.reshape(-1, c).mean(axis=0).reshape(1, 1, c) + return feat_mean, feat_std + + +def adain_npy(content_feat, style_feat): + """Adaptive instance normalization for numpy. + + Args: + content_feat (numpy): The input feature. + style_feat (numpy): The reference feature. + """ + size = content_feat.shape + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + normalized_feat = (content_feat - np.broadcast_to(content_mean, size)) / np.broadcast_to(content_std, size) + return normalized_feat * np.broadcast_to(style_std, size) + np.broadcast_to(style_mean, size) \ No newline at end of file diff --git a/Hallo2/hallo2/hallo/__init__.py b/Hallo2/hallo2/hallo/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Hallo2/hallo2/hallo/animate/__init__.py b/Hallo2/hallo2/hallo/animate/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Hallo2/hallo2/hallo/animate/face_animate.py b/Hallo2/hallo2/hallo/animate/face_animate.py new file mode 100644 index 00000000..5c38d15d --- /dev/null +++ b/Hallo2/hallo2/hallo/animate/face_animate.py @@ -0,0 +1,442 @@ +# pylint: disable=R0801 +""" +This module is responsible for animating faces in videos using a combination of deep learning techniques. +It provides a pipeline for generating face animations by processing video frames and extracting face features. +The module utilizes various schedulers and utilities for efficient face animation and supports different types + of latents for more control over the animation process. + +Functions and Classes: +- FaceAnimatePipeline: A class that extends the DiffusionPipeline class from the diffusers library to handle face animation tasks. + - __init__: Initializes the pipeline with the necessary components (VAE, UNets, face locator, etc.). + - prepare_latents: Generates or loads latents for the animation process, scaling them according to the scheduler's requirements. + - prepare_extra_step_kwargs: Prepares extra keyword arguments for the scheduler step, ensuring compatibility with different schedulers. + - decode_latents: Decodes the latents into video frames, ready for animation. + +Usage: +- Import the necessary packages and classes. +- Create a FaceAnimatePipeline instance with the required components. +- Prepare the latents for the animation process. +- Use the pipeline to generate the animated video. + +Note: +- This module is designed to work with the diffusers library, which provides the underlying framework for face animation using deep learning. +- The module is intended for research and development purposes, and further optimization and customization may be required for specific use cases. +""" + +import inspect +from dataclasses import dataclass +from typing import Callable, List, Optional, Union + +import numpy as np +import torch +from diffusers import (DDIMScheduler, DiffusionPipeline, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, + LMSDiscreteScheduler, PNDMScheduler) +from diffusers.image_processor import VaeImageProcessor +from diffusers.utils import BaseOutput +from diffusers.utils.torch_utils import randn_tensor +from einops import rearrange, repeat +from tqdm import tqdm + +from hallo.models.mutual_self_attention import ReferenceAttentionControl + + +@dataclass +class FaceAnimatePipelineOutput(BaseOutput): + """ + FaceAnimatePipelineOutput is a custom class that inherits from BaseOutput and represents the output of the FaceAnimatePipeline. + + Attributes: + videos (Union[torch.Tensor, np.ndarray]): A tensor or numpy array containing the generated video frames. + + Methods: + __init__(self, videos: Union[torch.Tensor, np.ndarray]): Initializes the FaceAnimatePipelineOutput object with the generated video frames. + """ + videos: Union[torch.Tensor, np.ndarray] + +class FaceAnimatePipeline(DiffusionPipeline): + """ + FaceAnimatePipeline is a custom DiffusionPipeline for animating faces. + + It inherits from the DiffusionPipeline class and is used to animate faces by + utilizing a variational autoencoder (VAE), a reference UNet, a denoising UNet, + a face locator, and an image processor. The pipeline is responsible for generating + and animating face latents, and decoding the latents to produce the final video output. + + Attributes: + vae (VaeImageProcessor): Variational autoencoder for processing images. + reference_unet (nn.Module): Reference UNet for mutual self-attention. + denoising_unet (nn.Module): Denoising UNet for image denoising. + face_locator (nn.Module): Face locator for detecting and cropping faces. + image_proj (nn.Module): Image projector for processing images. + scheduler (Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, + EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler]): Diffusion scheduler for + controlling the noise level. + + Methods: + __init__(self, vae, reference_unet, denoising_unet, face_locator, + image_proj, scheduler): Initializes the FaceAnimatePipeline + with the given components and scheduler. + prepare_latents(self, batch_size, num_channels_latents, width, height, + video_length, dtype, device, generator=None, latents=None): + Prepares the initial latents for video generation. + prepare_extra_step_kwargs(self, generator, eta): Prepares extra keyword + arguments for the scheduler step. + decode_latents(self, latents): Decodes the latents to produce the final + video output. + """ + def __init__( + self, + vae, + reference_unet, + denoising_unet, + face_locator, + image_proj, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + ) -> None: + super().__init__() + + self.register_modules( + vae=vae, + reference_unet=reference_unet, + denoising_unet=denoising_unet, + face_locator=face_locator, + scheduler=scheduler, + image_proj=image_proj, + ) + + self.vae_scale_factor: int = 2 ** (len(self.vae.config.block_out_channels) - 1) + + self.ref_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, + ) + + @property + def _execution_device(self): + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def prepare_latents( + self, + batch_size: int, # Number of videos to generate in parallel + num_channels_latents: int, # Number of channels in the latents + width: int, # Width of the video frame + height: int, # Height of the video frame + video_length: int, # Length of the video in frames + dtype: torch.dtype, # Data type of the latents + device: torch.device, # Device to store the latents on + generator: Optional[torch.Generator] = None, # Random number generator for reproducibility + latents: Optional[torch.Tensor] = None # Pre-generated latents (optional) + ): + """ + Prepares the initial latents for video generation. + + Args: + batch_size (int): Number of videos to generate in parallel. + num_channels_latents (int): Number of channels in the latents. + width (int): Width of the video frame. + height (int): Height of the video frame. + video_length (int): Length of the video in frames. + dtype (torch.dtype): Data type of the latents. + device (torch.device): Device to store the latents on. + generator (Optional[torch.Generator]): Random number generator for reproducibility. + latents (Optional[torch.Tensor]): Pre-generated latents (optional). + + Returns: + latents (torch.Tensor): Tensor of shape (batch_size, num_channels_latents, width, height) + containing the initial latents for video generation. + """ + shape = ( + batch_size, + num_channels_latents, + video_length, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype + ) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_extra_step_kwargs(self, generator, eta): + """ + Prepares extra keyword arguments for the scheduler step. + + Args: + generator (Optional[torch.Generator]): Random number generator for reproducibility. + eta (float): The eta (η) parameter used with the DDIMScheduler. + It corresponds to η in the DDIM paper (https://arxiv.org/abs/2010.02502) and should be between [0, 1]. + + Returns: + dict: A dictionary containing the extra keyword arguments for the scheduler step. + """ + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def decode_latents(self, latents): + """ + Decode the latents to produce a video. + + Parameters: + latents (torch.Tensor): The latents to be decoded. + + Returns: + video (torch.Tensor): The decoded video. + video_length (int): The length of the video in frames. + """ + video_length = latents.shape[2] + latents = 1 / 0.18215 * latents + latents = rearrange(latents, "b c f h w -> (b f) c h w") + # video = self.vae.decode(latents).sample + video = [] + for frame_idx in tqdm(range(latents.shape[0])): + video.append(self.vae.decode( + latents[frame_idx: frame_idx + 1]).sample) + video = torch.cat(video) + video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) + video = (video / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + video = video.cpu().float().numpy() + return video + + + @torch.no_grad() + def __call__( + self, + ref_image, + face_emb, + audio_tensor, + face_mask, + pixel_values_full_mask, + pixel_values_face_mask, + pixel_values_lip_mask, + width, + height, + video_length, + num_inference_steps, + guidance_scale, + num_images_per_prompt=1, + eta: float = 0.0, + motion_scale: Optional[List[torch.Tensor]] = None, + generator: Optional[Union[torch.Generator, + List[torch.Generator]]] = None, + output_type: Optional[str] = "tensor", + return_dict: bool = True, + callback: Optional[Callable[[ + int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + # Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + device = self._execution_device + + do_classifier_free_guidance = guidance_scale > 1.0 + + # Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + batch_size = 1 + + # prepare clip image embeddings + clip_image_embeds = face_emb + clip_image_embeds = clip_image_embeds.to(self.image_proj.device, self.image_proj.dtype) + + encoder_hidden_states = self.image_proj(clip_image_embeds) + uncond_encoder_hidden_states = self.image_proj(torch.zeros_like(clip_image_embeds)) + + if do_classifier_free_guidance: + encoder_hidden_states = torch.cat([uncond_encoder_hidden_states, encoder_hidden_states], dim=0) + + reference_control_writer = ReferenceAttentionControl( + self.reference_unet, + do_classifier_free_guidance=do_classifier_free_guidance, + mode="write", + batch_size=batch_size, + fusion_blocks="full", + ) + reference_control_reader = ReferenceAttentionControl( + self.denoising_unet, + do_classifier_free_guidance=do_classifier_free_guidance, + mode="read", + batch_size=batch_size, + fusion_blocks="full", + ) + + num_channels_latents = self.denoising_unet.in_channels + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + width, + height, + video_length, + clip_image_embeds.dtype, + device, + generator, + ) + + # Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # Prepare ref image latents + ref_image_tensor = rearrange(ref_image, "b f c h w -> (b f) c h w") + ref_image_tensor = self.ref_image_processor.preprocess(ref_image_tensor, height=height, width=width) # (bs, c, width, height) + ref_image_tensor = ref_image_tensor.to(dtype=self.vae.dtype, device=self.vae.device) + ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean + ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w) + + + face_mask = face_mask.unsqueeze(1).to(dtype=self.face_locator.dtype, device=self.face_locator.device) # (bs, f, c, H, W) + face_mask = repeat(face_mask, "b f c h w -> b (repeat f) c h w", repeat=video_length) + face_mask = face_mask.transpose(1, 2) # (bs, c, f, H, W) + face_mask = self.face_locator(face_mask) + face_mask = torch.cat([torch.zeros_like(face_mask), face_mask], dim=0) if do_classifier_free_guidance else face_mask + + pixel_values_full_mask = ( + [torch.cat([mask] * 2) for mask in pixel_values_full_mask] + if do_classifier_free_guidance + else pixel_values_full_mask + ) + pixel_values_face_mask = ( + [torch.cat([mask] * 2) for mask in pixel_values_face_mask] + if do_classifier_free_guidance + else pixel_values_face_mask + ) + pixel_values_lip_mask = ( + [torch.cat([mask] * 2) for mask in pixel_values_lip_mask] + if do_classifier_free_guidance + else pixel_values_lip_mask + ) + pixel_values_face_mask_ = [] + for mask in pixel_values_face_mask: + pixel_values_face_mask_.append( + mask.to(device=self.denoising_unet.device, dtype=self.denoising_unet.dtype)) + pixel_values_face_mask = pixel_values_face_mask_ + pixel_values_lip_mask_ = [] + for mask in pixel_values_lip_mask: + pixel_values_lip_mask_.append( + mask.to(device=self.denoising_unet.device, dtype=self.denoising_unet.dtype)) + pixel_values_lip_mask = pixel_values_lip_mask_ + pixel_values_full_mask_ = [] + for mask in pixel_values_full_mask: + pixel_values_full_mask_.append( + mask.to(device=self.denoising_unet.device, dtype=self.denoising_unet.dtype)) + pixel_values_full_mask = pixel_values_full_mask_ + + + uncond_audio_tensor = torch.zeros_like(audio_tensor) + audio_tensor = torch.cat([uncond_audio_tensor, audio_tensor], dim=0) + audio_tensor = audio_tensor.to(dtype=self.denoising_unet.dtype, device=self.denoising_unet.device) + + # denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Forward reference image + if i == 0: + self.reference_unet( + ref_image_latents.repeat( + (2 if do_classifier_free_guidance else 1), 1, 1, 1 + ), + torch.zeros_like(t), + encoder_hidden_states=encoder_hidden_states, + return_dict=False, + ) + reference_control_reader.update(reference_control_writer) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + noise_pred = self.denoising_unet( + latent_model_input, + t, + encoder_hidden_states=encoder_hidden_states, + mask_cond_fea=face_mask, + full_mask=pixel_values_full_mask, + face_mask=pixel_values_face_mask, + lip_mask=pixel_values_lip_mask, + audio_embedding=audio_tensor, + motion_scale=motion_scale, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + reference_control_reader.clear() + reference_control_writer.clear() + + # Post-processing + images = self.decode_latents(latents) # (b, c, f, h, w) + + # Convert to tensor + if output_type == "tensor": + images = torch.from_numpy(images) + + if not return_dict: + return images + + return FaceAnimatePipelineOutput(videos=images) diff --git a/Hallo2/hallo2/hallo/animate/face_animate_static.py b/Hallo2/hallo2/hallo/animate/face_animate_static.py new file mode 100644 index 00000000..42c0fd53 --- /dev/null +++ b/Hallo2/hallo2/hallo/animate/face_animate_static.py @@ -0,0 +1,481 @@ +# pylint: disable=R0801 +""" +This module is responsible for handling the animation of faces using a combination of deep learning models and image processing techniques. +It provides a pipeline to generate realistic face animations by incorporating user-provided conditions such as facial expressions and environments. +The module utilizes various schedulers and utilities to optimize the animation process and ensure efficient performance. + +Functions and Classes: +- StaticPipelineOutput: A class that represents the output of the animation pipeline, c + ontaining properties and methods related to the generated images. +- prepare_latents: A function that prepares the initial noise for the animation process, + scaling it according to the scheduler's requirements. +- prepare_condition: A function that processes the user-provided conditions + (e.g., facial expressions) and prepares them for use in the animation pipeline. +- decode_latents: A function that decodes the latent representations of the face animations into + their corresponding image formats. +- prepare_extra_step_kwargs: A function that prepares additional parameters for each step of + the animation process, such as the generator and eta values. + +Dependencies: +- numpy: A library for numerical computing. +- torch: A machine learning library based on PyTorch. +- diffusers: A library for image-to-image diffusion models. +- transformers: A library for pre-trained transformer models. + +Usage: +- To create an instance of the animation pipeline, provide the necessary components such as + the VAE, reference UNET, denoising UNET, face locator, and image processor. +- Use the pipeline's methods to prepare the latents, conditions, and extra step arguments as + required for the animation process. +- Generate the face animations by decoding the latents and processing the conditions. + +Note: +- The module is designed to work with the diffusers library, which is based on + the paper "Diffusion Models for Image-to-Image Translation" (https://arxiv.org/abs/2102.02765). +- The face animations generated by this module should be used for entertainment purposes + only and should respect the rights and privacy of the individuals involved. +""" +import inspect +from dataclasses import dataclass +from typing import Callable, List, Optional, Union + +import numpy as np +import torch +from diffusers import DiffusionPipeline +from diffusers.image_processor import VaeImageProcessor +from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, LMSDiscreteScheduler, + PNDMScheduler) +from diffusers.utils import BaseOutput, is_accelerate_available +from diffusers.utils.torch_utils import randn_tensor +from einops import rearrange +from tqdm import tqdm +from transformers import CLIPImageProcessor + +from hallo.models.mutual_self_attention import ReferenceAttentionControl + +if is_accelerate_available(): + from accelerate import cpu_offload +else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + +@dataclass +class StaticPipelineOutput(BaseOutput): + """ + StaticPipelineOutput is a class that represents the output of the static pipeline. + It contains the images generated by the pipeline as a union of torch.Tensor and np.ndarray. + + Attributes: + images (Union[torch.Tensor, np.ndarray]): The generated images. + """ + images: Union[torch.Tensor, np.ndarray] + + +class StaticPipeline(DiffusionPipeline): + """ + StaticPipelineOutput is a class that represents the output of the static pipeline. + It contains the images generated by the pipeline as a union of torch.Tensor and np.ndarray. + + Attributes: + images (Union[torch.Tensor, np.ndarray]): The generated images. + """ + _optional_components = [] + + def __init__( + self, + vae, + reference_unet, + denoising_unet, + face_locator, + imageproj, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + ): + super().__init__() + + self.register_modules( + vae=vae, + reference_unet=reference_unet, + denoising_unet=denoising_unet, + face_locator=face_locator, + scheduler=scheduler, + imageproj=imageproj, + ) + self.vae_scale_factor = 2 ** ( + len(self.vae.config.block_out_channels) - 1) + self.clip_image_processor = CLIPImageProcessor() + self.ref_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True + ) + self.cond_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, + do_convert_rgb=True, + do_normalize=False, + ) + + def enable_vae_slicing(self): + """ + Enable VAE slicing. + + This method enables slicing for the VAE model, which can help improve the performance of decoding latents when working with large images. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + """ + Disable vae slicing. + + This function disables the vae slicing for the StaticPipeline object. + It calls the `disable_slicing()` method of the vae model. + This is useful when you want to use the entire vae model for decoding latents + instead of slicing it for better performance. + """ + self.vae.disable_slicing() + + def enable_sequential_cpu_offload(self, gpu_id=0): + """ + Offloads selected models to the GPU for increased performance. + + Args: + gpu_id (int, optional): The ID of the GPU to offload models to. Defaults to 0. + """ + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + def _execution_device(self): + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def decode_latents(self, latents): + """ + Decode the given latents to video frames. + + Parameters: + latents (torch.Tensor): The latents to be decoded. Shape: (batch_size, num_channels_latents, video_length, height, width). + + Returns: + video (torch.Tensor): The decoded video frames. Shape: (batch_size, num_channels_latents, video_length, height, width). + """ + video_length = latents.shape[2] + latents = 1 / 0.18215 * latents + latents = rearrange(latents, "b c f h w -> (b f) c h w") + # video = self.vae.decode(latents).sample + video = [] + for frame_idx in tqdm(range(latents.shape[0])): + video.append(self.vae.decode( + latents[frame_idx: frame_idx + 1]).sample) + video = torch.cat(video) + video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) + video = (video / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + video = video.cpu().float().numpy() + return video + + def prepare_extra_step_kwargs(self, generator, eta): + """ + Prepare extra keyword arguments for the scheduler step. + + Since not all schedulers have the same signature, this function helps to create a consistent interface for the scheduler. + + Args: + generator (Optional[torch.Generator]): A random number generator for reproducibility. + eta (float): The eta parameter used with the DDIMScheduler. It should be between 0 and 1. + + Returns: + dict: A dictionary containing the extra keyword arguments for the scheduler step. + """ + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def prepare_latents( + self, + batch_size, + num_channels_latents, + width, + height, + dtype, + device, + generator, + latents=None, + ): + """ + Prepares the initial latents for the diffusion pipeline. + + Args: + batch_size (int): The number of images to generate in one forward pass. + num_channels_latents (int): The number of channels in the latents tensor. + width (int): The width of the latents tensor. + height (int): The height of the latents tensor. + dtype (torch.dtype): The data type of the latents tensor. + device (torch.device): The device to place the latents tensor on. + generator (Optional[torch.Generator], optional): A random number generator + for reproducibility. Defaults to None. + latents (Optional[torch.Tensor], optional): Pre-computed latents to use as + initial conditions for the diffusion process. Defaults to None. + + Returns: + torch.Tensor: The prepared latents tensor. + """ + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype + ) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_condition( + self, + cond_image, + width, + height, + device, + dtype, + do_classififer_free_guidance=False, + ): + """ + Prepares the condition for the face animation pipeline. + + Args: + cond_image (torch.Tensor): The conditional image tensor. + width (int): The width of the output image. + height (int): The height of the output image. + device (torch.device): The device to run the pipeline on. + dtype (torch.dtype): The data type of the tensor. + do_classififer_free_guidance (bool, optional): Whether to use classifier-free guidance or not. Defaults to False. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple of processed condition and mask tensors. + """ + image = self.cond_image_processor.preprocess( + cond_image, height=height, width=width + ).to(dtype=torch.float32) + + image = image.to(device=device, dtype=dtype) + + if do_classififer_free_guidance: + image = torch.cat([image] * 2) + + return image + + @torch.no_grad() + def __call__( + self, + ref_image, + face_mask, + width, + height, + num_inference_steps, + guidance_scale, + face_embedding, + num_images_per_prompt=1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, + List[torch.Generator]]] = None, + output_type: Optional[str] = "tensor", + return_dict: bool = True, + callback: Optional[Callable[[ + int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + # Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + device = self._execution_device + + do_classifier_free_guidance = guidance_scale > 1.0 + + # Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + batch_size = 1 + + image_prompt_embeds = self.imageproj(face_embedding) + uncond_image_prompt_embeds = self.imageproj( + torch.zeros_like(face_embedding)) + + if do_classifier_free_guidance: + image_prompt_embeds = torch.cat( + [uncond_image_prompt_embeds, image_prompt_embeds], dim=0 + ) + + reference_control_writer = ReferenceAttentionControl( + self.reference_unet, + do_classifier_free_guidance=do_classifier_free_guidance, + mode="write", + batch_size=batch_size, + fusion_blocks="full", + ) + reference_control_reader = ReferenceAttentionControl( + self.denoising_unet, + do_classifier_free_guidance=do_classifier_free_guidance, + mode="read", + batch_size=batch_size, + fusion_blocks="full", + ) + + num_channels_latents = self.denoising_unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + width, + height, + face_embedding.dtype, + device, + generator, + ) + latents = latents.unsqueeze(2) # (bs, c, 1, h', w') + # latents_dtype = latents.dtype + + # Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # Prepare ref image latents + ref_image_tensor = self.ref_image_processor.preprocess( + ref_image, height=height, width=width + ) # (bs, c, width, height) + ref_image_tensor = ref_image_tensor.to( + dtype=self.vae.dtype, device=self.vae.device + ) + ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean + ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w) + + # Prepare face mask image + face_mask_tensor = self.cond_image_processor.preprocess( + face_mask, height=height, width=width + ) + face_mask_tensor = face_mask_tensor.unsqueeze(2) # (bs, c, 1, h, w) + face_mask_tensor = face_mask_tensor.to( + device=device, dtype=self.face_locator.dtype + ) + mask_fea = self.face_locator(face_mask_tensor) + mask_fea = ( + torch.cat( + [mask_fea] * 2) if do_classifier_free_guidance else mask_fea + ) + + # denoising loop + num_warmup_steps = len(timesteps) - \ + num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # 1. Forward reference image + if i == 0: + self.reference_unet( + ref_image_latents.repeat( + (2 if do_classifier_free_guidance else 1), 1, 1, 1 + ), + torch.zeros_like(t), + encoder_hidden_states=image_prompt_embeds, + return_dict=False, + ) + + # 2. Update reference unet feature into denosing net + reference_control_reader.update(reference_control_writer) + + # 3.1 expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat( + [latents] * 2) if do_classifier_free_guidance else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) + + noise_pred = self.denoising_unet( + latent_model_input, + t, + encoder_hidden_states=image_prompt_embeds, + mask_cond_fea=mask_fea, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + reference_control_reader.clear() + reference_control_writer.clear() + + # Post-processing + image = self.decode_latents(latents) # (b, c, 1, h, w) + + # Convert to tensor + if output_type == "tensor": + image = torch.from_numpy(image) + + if not return_dict: + return image + + return StaticPipelineOutput(images=image) diff --git a/Hallo2/hallo2/hallo/datasets/__init__.py b/Hallo2/hallo2/hallo/datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Hallo2/hallo2/hallo/datasets/audio_processor.py b/Hallo2/hallo2/hallo/datasets/audio_processor.py new file mode 100644 index 00000000..2fddb3ad --- /dev/null +++ b/Hallo2/hallo2/hallo/datasets/audio_processor.py @@ -0,0 +1,182 @@ +# pylint: disable=C0301 +''' +This module contains the AudioProcessor class and related functions for processing audio data. +It utilizes various libraries and models to perform tasks such as preprocessing, feature extraction, +and audio separation. The class is initialized with configuration parameters and can process +audio files using the provided models. +''' +import math +import os + +import librosa +import numpy as np +import torch +from audio_separator.separator import Separator +from einops import rearrange +from transformers import Wav2Vec2FeatureExtractor + +from hallo.models.wav2vec import Wav2VecModel +from hallo.utils.util import resample_audio + + +class AudioProcessor: + """ + AudioProcessor is a class that handles the processing of audio files. + It takes care of preprocessing the audio files, extracting features + using wav2vec models, and separating audio signals if needed. + + :param sample_rate: Sampling rate of the audio file + :param fps: Frames per second for the extracted features + :param wav2vec_model_path: Path to the wav2vec model + :param only_last_features: Whether to only use the last features + :param audio_separator_model_path: Path to the audio separator model + :param audio_separator_model_name: Name of the audio separator model + :param cache_dir: Directory to cache the intermediate results + :param device: Device to run the processing on + """ + def __init__( + self, + sample_rate, + fps, + wav2vec_model_path, + only_last_features, + audio_separator_model_path:str=None, + audio_separator_model_name:str=None, + cache_dir:str='', + device="cuda:0", + ) -> None: + self.sample_rate = sample_rate + self.fps = fps + self.device = device + + self.audio_encoder = Wav2VecModel.from_pretrained(wav2vec_model_path, local_files_only=True).to(device=device) + self.audio_encoder.feature_extractor._freeze_parameters() + self.only_last_features = only_last_features + + if audio_separator_model_name is not None: + try: + os.makedirs(cache_dir, exist_ok=True) + except OSError as _: + print("Fail to create the output cache dir.") + self.audio_separator = Separator( + output_dir=cache_dir, + output_single_stem="vocals", + model_file_dir=audio_separator_model_path, + ) + self.audio_separator.load_model(audio_separator_model_name) + assert self.audio_separator.model_instance is not None, "Fail to load audio separate model." + else: + self.audio_separator=None + print("Use audio directly without vocals seperator.") + + + self.wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True) + + + def preprocess(self, wav_file: str, + clip_length: int=-1, + padding=False, + processed_length=0): + """ + Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate. + The separated vocal track is then converted into wav2vec2 for further processing or analysis. + + Args: + wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format. + + Raises: + RuntimeError: Raises an exception if the WAV file cannot be processed. This could be due to issues + such as file not found, unsupported file format, or errors during the audio processing steps. + + Returns: + torch.tensor: Returns an audio embedding as a torch.tensor + """ + if self.audio_separator is not None: + # 1. separate vocals + # TODO: process in memory + outputs = self.audio_separator.separate(wav_file) + if len(outputs) <= 0: + raise RuntimeError("Audio separate failed.") + + vocal_audio_file = outputs[0] + vocal_audio_name, _ = os.path.splitext(vocal_audio_file) + vocal_audio_file = os.path.join(self.audio_separator.output_dir, vocal_audio_file) + vocal_audio_file = resample_audio(vocal_audio_file, os.path.join(self.audio_separator.output_dir, f"{vocal_audio_name}-16k.wav"), self.sample_rate) + else: + vocal_audio_file=wav_file + + # 2. extract wav2vec features + speech_array, sampling_rate = librosa.load(vocal_audio_file, sr=self.sample_rate) + audio_feature = np.squeeze(self.wav2vec_feature_extractor(speech_array, sampling_rate=sampling_rate).input_values) + seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps) + audio_length = seq_len + + audio_feature = torch.from_numpy(audio_feature).float().to(device=self.device) + + if padding: + if clip_length>0 and seq_len % clip_length != 0: + all_len = seq_len + processed_length + audio_feature = torch.nn.functional.pad(audio_feature, (0, (clip_length - all_len % clip_length) * (self.sample_rate // self.fps)), 'constant', 0.0) + seq_len += clip_length - all_len % clip_length + audio_feature = audio_feature.unsqueeze(0) + + with torch.no_grad(): + embeddings = self.audio_encoder(audio_feature, seq_len=seq_len, output_hidden_states=True) + assert len(embeddings) > 0, "Fail to extract audio embedding" + if self.only_last_features: + audio_emb = embeddings.last_hidden_state.squeeze() + else: + audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0) + audio_emb = rearrange(audio_emb, "b s d -> s b d") + + audio_emb = audio_emb.cpu().detach() + + return audio_emb, audio_length + + def get_embedding(self, wav_file: str): + """preprocess wav audio file convert to embeddings + + Args: + wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format. + + Returns: + torch.tensor: Returns an audio embedding as a torch.tensor + """ + speech_array, sampling_rate = librosa.load( + wav_file, sr=self.sample_rate) + assert sampling_rate == 16000, "The audio sample rate must be 16000" + audio_feature = np.squeeze(self.wav2vec_feature_extractor( + speech_array, sampling_rate=sampling_rate).input_values) + seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps) + + audio_feature = torch.from_numpy( + audio_feature).float().to(device=self.device) + audio_feature = audio_feature.unsqueeze(0) + + with torch.no_grad(): + embeddings = self.audio_encoder( + audio_feature, seq_len=seq_len, output_hidden_states=True) + assert len(embeddings) > 0, "Fail to extract audio embedding" + + if self.only_last_features: + audio_emb = embeddings.last_hidden_state.squeeze() + else: + audio_emb = torch.stack( + embeddings.hidden_states[1:], dim=1).squeeze(0) + audio_emb = rearrange(audio_emb, "b s d -> s b d") + + audio_emb = audio_emb.cpu().detach() + + return audio_emb + + def close(self): + """ + TODO: to be implemented + """ + return self + + def __enter__(self): + return self + + def __exit__(self, _exc_type, _exc_val, _exc_tb): + self.close() diff --git a/Hallo2/hallo2/hallo/datasets/image_processor.py b/Hallo2/hallo2/hallo/datasets/image_processor.py new file mode 100644 index 00000000..16515226 --- /dev/null +++ b/Hallo2/hallo2/hallo/datasets/image_processor.py @@ -0,0 +1,346 @@ +# pylint: disable=W0718 +""" +This module is responsible for processing images, particularly for face-related tasks. +It uses various libraries such as OpenCV, NumPy, and InsightFace to perform tasks like +face detection, augmentation, and mask rendering. The ImageProcessor class encapsulates +the functionality for these operations. +""" +import os +from typing import List + +import cv2 +import mediapipe as mp +import numpy as np +import torch +from insightface.app import FaceAnalysis +from PIL import Image +from torchvision import transforms + +from ..utils.util import (blur_mask, get_landmark_overframes, get_mask, + get_union_face_mask, get_union_lip_mask) + +MEAN = 0.5 +STD = 0.5 + +class ImageProcessor: + """ + ImageProcessor is a class responsible for processing images, particularly for face-related tasks. + It takes in an image and performs various operations such as augmentation, face detection, + face embedding extraction, and rendering a face mask. The processed images are then used for + further analysis or recognition purposes. + + Attributes: + img_size (int): The size of the image to be processed. + face_analysis_model_path (str): The path to the face analysis model. + + Methods: + preprocess(source_image_path, cache_dir): + Preprocesses the input image by performing augmentation, face detection, + face embedding extraction, and rendering a face mask. + + close(): + Closes the ImageProcessor and releases any resources being used. + + _augmentation(images, transform, state=None): + Applies image augmentation to the input images using the given transform and state. + + __enter__(): + Enters a runtime context and returns the ImageProcessor object. + + __exit__(_exc_type, _exc_val, _exc_tb): + Exits a runtime context and handles any exceptions that occurred during the processing. + """ + def __init__(self, img_size, face_analysis_model_path) -> None: + self.img_size = img_size + + self.pixel_transform = transforms.Compose( + [ + transforms.Resize(self.img_size), + transforms.ToTensor(), + transforms.Normalize([MEAN], [STD]), + ] + ) + + self.cond_transform = transforms.Compose( + [ + transforms.Resize(self.img_size), + transforms.ToTensor(), + ] + ) + + self.attn_transform_64 = transforms.Compose( + [ + transforms.Resize( + (self.img_size[0] // 8, self.img_size[0] // 8)), + transforms.ToTensor(), + ] + ) + self.attn_transform_32 = transforms.Compose( + [ + transforms.Resize( + (self.img_size[0] // 16, self.img_size[0] // 16)), + transforms.ToTensor(), + ] + ) + self.attn_transform_16 = transforms.Compose( + [ + transforms.Resize( + (self.img_size[0] // 32, self.img_size[0] // 32)), + transforms.ToTensor(), + ] + ) + self.attn_transform_8 = transforms.Compose( + [ + transforms.Resize( + (self.img_size[0] // 64, self.img_size[0] // 64)), + transforms.ToTensor(), + ] + ) + + self.face_analysis = FaceAnalysis( + name="", + root=face_analysis_model_path, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], + ) + self.face_analysis.prepare(ctx_id=0, det_size=(640, 640)) + + def preprocess(self, source_image_path: str, cache_dir: str, face_region_ratio: float): + """ + Apply preprocessing to the source image to prepare for face analysis. + + Parameters: + source_image_path (str): The path to the source image. + cache_dir (str): The directory to cache intermediate results. + + Returns: + None + """ + source_image = Image.open(source_image_path) + ref_image_pil = source_image.convert("RGB") + # 1. image augmentation + pixel_values_ref_img = self._augmentation(ref_image_pil, self.pixel_transform) + + # 2.1 detect face + faces = self.face_analysis.get(cv2.cvtColor(np.array(ref_image_pil.copy()), cv2.COLOR_RGB2BGR)) + if not faces: + print("No faces detected in the image. Using the entire image as the face region.") + # Use the entire image as the face region + face = { + "bbox": [0, 0, ref_image_pil.width, ref_image_pil.height], + "embedding": np.zeros(512) + } + else: + # Sort faces by size and select the largest one + faces_sorted = sorted(faces, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]), reverse=True) + face = faces_sorted[0] # Select the largest face + + # 2.2 face embedding + face_emb = face["embedding"] + + # 2.3 render face mask + get_mask(source_image_path, cache_dir, face_region_ratio) + file_name = os.path.basename(source_image_path).split(".")[0] + face_mask_pil = Image.open( + os.path.join(cache_dir, f"{file_name}_face_mask.png")).convert("RGB") + + face_mask = self._augmentation(face_mask_pil, self.cond_transform) + + # 2.4 detect and expand lip, face mask + sep_background_mask = Image.open( + os.path.join(cache_dir, f"{file_name}_sep_background.png")) + sep_face_mask = Image.open( + os.path.join(cache_dir, f"{file_name}_sep_face.png")) + sep_lip_mask = Image.open( + os.path.join(cache_dir, f"{file_name}_sep_lip.png")) + + pixel_values_face_mask = [ + self._augmentation(sep_face_mask, self.attn_transform_64), + self._augmentation(sep_face_mask, self.attn_transform_32), + self._augmentation(sep_face_mask, self.attn_transform_16), + self._augmentation(sep_face_mask, self.attn_transform_8), + ] + pixel_values_lip_mask = [ + self._augmentation(sep_lip_mask, self.attn_transform_64), + self._augmentation(sep_lip_mask, self.attn_transform_32), + self._augmentation(sep_lip_mask, self.attn_transform_16), + self._augmentation(sep_lip_mask, self.attn_transform_8), + ] + pixel_values_full_mask = [ + self._augmentation(sep_background_mask, self.attn_transform_64), + self._augmentation(sep_background_mask, self.attn_transform_32), + self._augmentation(sep_background_mask, self.attn_transform_16), + self._augmentation(sep_background_mask, self.attn_transform_8), + ] + + pixel_values_full_mask = [mask.view(1, -1) + for mask in pixel_values_full_mask] + pixel_values_face_mask = [mask.view(1, -1) + for mask in pixel_values_face_mask] + pixel_values_lip_mask = [mask.view(1, -1) + for mask in pixel_values_lip_mask] + + return pixel_values_ref_img, face_mask, face_emb, pixel_values_full_mask, pixel_values_face_mask, pixel_values_lip_mask + + def close(self): + """ + Closes the ImageProcessor and releases any resources held by the FaceAnalysis instance. + + Args: + self: The ImageProcessor instance. + + Returns: + None. + """ + for _, model in self.face_analysis.models.items(): + if hasattr(model, "Dispose"): + model.Dispose() + + def _augmentation(self, images, transform, state=None): + if state is not None: + torch.set_rng_state(state) + if isinstance(images, List): + transformed_images = [transform(img) for img in images] + ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w) + else: + ret_tensor = transform(images) # (c, h, w) + return ret_tensor + + def __enter__(self): + return self + + def __exit__(self, _exc_type, _exc_val, _exc_tb): + self.close() + + +class ImageProcessorForDataProcessing(): + """ + ImageProcessor is a class responsible for processing images, particularly for face-related tasks. + It takes in an image and performs various operations such as augmentation, face detection, + face embedding extraction, and rendering a face mask. The processed images are then used for + further analysis or recognition purposes. + + Attributes: + img_size (int): The size of the image to be processed. + face_analysis_model_path (str): The path to the face analysis model. + + Methods: + preprocess(source_image_path, cache_dir): + Preprocesses the input image by performing augmentation, face detection, + face embedding extraction, and rendering a face mask. + + close(): + Closes the ImageProcessor and releases any resources being used. + + _augmentation(images, transform, state=None): + Applies image augmentation to the input images using the given transform and state. + + __enter__(): + Enters a runtime context and returns the ImageProcessor object. + + __exit__(_exc_type, _exc_val, _exc_tb): + Exits a runtime context and handles any exceptions that occurred during the processing. + """ + def __init__(self, face_analysis_model_path, landmark_model_path, step) -> None: + if step == 2: + self.face_analysis = FaceAnalysis( + name="", + root=face_analysis_model_path, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], + ) + self.face_analysis.prepare(ctx_id=0, det_size=(640, 640)) + self.landmarker = None + else: + BaseOptions = mp.tasks.BaseOptions + FaceLandmarker = mp.tasks.vision.FaceLandmarker + FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions + VisionRunningMode = mp.tasks.vision.RunningMode + # Create a face landmarker instance with the video mode: + options = FaceLandmarkerOptions( + base_options=BaseOptions(model_asset_path=landmark_model_path), + running_mode=VisionRunningMode.IMAGE, + ) + self.landmarker = FaceLandmarker.create_from_options(options) + self.face_analysis = None + + def preprocess(self, source_image_path: str): + """ + Apply preprocessing to the source image to prepare for face analysis. + + Parameters: + source_image_path (str): The path to the source image. + cache_dir (str): The directory to cache intermediate results. + + Returns: + None + """ + # 1. get face embdeding + face_mask, face_emb, sep_pose_mask, sep_face_mask, sep_lip_mask = None, None, None, None, None + if self.face_analysis: + for frame in sorted(os.listdir(source_image_path)): + try: + source_image = Image.open( + os.path.join(source_image_path, frame)) + ref_image_pil = source_image.convert("RGB") + # 2.1 detect face + faces = self.face_analysis.get(cv2.cvtColor( + np.array(ref_image_pil.copy()), cv2.COLOR_RGB2BGR)) + # use max size face + face = sorted(faces, key=lambda x: ( + x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))[-1] + # 2.2 face embedding + face_emb = face["embedding"] + if face_emb is not None: + break + except Exception as _: + continue + + if self.landmarker: + # 3.1 get landmark + landmarks, height, width = get_landmark_overframes( + self.landmarker, source_image_path) + assert len(landmarks) == len(os.listdir(source_image_path)) + + # 3 render face and lip mask + face_mask = get_union_face_mask(landmarks, height, width) + lip_mask = get_union_lip_mask(landmarks, height, width) + + # 4 gaussian blur + blur_face_mask = blur_mask(face_mask, (64, 64), (51, 51)) + blur_lip_mask = blur_mask(lip_mask, (64, 64), (31, 31)) + + # 5 seperate mask + sep_face_mask = cv2.subtract(blur_face_mask, blur_lip_mask) + sep_pose_mask = 255.0 - blur_face_mask + sep_lip_mask = blur_lip_mask + + return face_mask, face_emb, sep_pose_mask, sep_face_mask, sep_lip_mask + + def close(self): + """ + Closes the ImageProcessor and releases any resources held by the FaceAnalysis instance. + + Args: + self: The ImageProcessor instance. + + Returns: + None. + """ + for _, model in self.face_analysis.models.items(): + if hasattr(model, "Dispose"): + model.Dispose() + + def _augmentation(self, images, transform, state=None): + if state is not None: + torch.set_rng_state(state) + if isinstance(images, List): + transformed_images = [transform(img) for img in images] + ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w) + else: + ret_tensor = transform(images) # (c, h, w) + return ret_tensor + + def __enter__(self): + return self + + def __exit__(self, _exc_type, _exc_val, _exc_tb): + self.close() diff --git a/Hallo2/hallo2/hallo/datasets/mask_image.py b/Hallo2/hallo2/hallo/datasets/mask_image.py new file mode 100644 index 00000000..2d0c94ff --- /dev/null +++ b/Hallo2/hallo2/hallo/datasets/mask_image.py @@ -0,0 +1,154 @@ +# pylint: disable=R0801 +""" +This module contains the code for a dataset class called FaceMaskDataset, which is used to process and +load image data related to face masks. The dataset class inherits from the PyTorch Dataset class and +provides methods for data augmentation, getting items from the dataset, and determining the length of the +dataset. The module also includes imports for necessary libraries such as json, random, pathlib, torch, +PIL, and transformers. +""" + +import json +import random +from pathlib import Path + +import torch +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from transformers import CLIPImageProcessor + + +class FaceMaskDataset(Dataset): + """ + FaceMaskDataset is a custom dataset for face mask images. + + Args: + img_size (int): The size of the input images. + drop_ratio (float, optional): The ratio of dropped pixels during data augmentation. Defaults to 0.1. + data_meta_paths (list, optional): The paths to the metadata files containing image paths and labels. Defaults to ["./data/HDTF_meta.json"]. + sample_margin (int, optional): The margin for sampling regions in the image. Defaults to 30. + + Attributes: + img_size (int): The size of the input images. + drop_ratio (float): The ratio of dropped pixels during data augmentation. + data_meta_paths (list): The paths to the metadata files containing image paths and labels. + sample_margin (int): The margin for sampling regions in the image. + processor (CLIPImageProcessor): The image processor for preprocessing images. + transform (transforms.Compose): The image augmentation transform. + """ + + def __init__( + self, + img_size, + drop_ratio=0.1, + data_meta_paths=None, + sample_margin=30, + ): + super().__init__() + + self.img_size = img_size + self.sample_margin = sample_margin + + vid_meta = [] + for data_meta_path in data_meta_paths: + with open(data_meta_path, "r", encoding="utf-8") as f: + vid_meta.extend(json.load(f)) + self.vid_meta = vid_meta + self.length = len(self.vid_meta) + + self.clip_image_processor = CLIPImageProcessor() + + self.transform = transforms.Compose( + [ + transforms.Resize(self.img_size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + self.cond_transform = transforms.Compose( + [ + transforms.Resize(self.img_size), + transforms.ToTensor(), + ] + ) + + self.drop_ratio = drop_ratio + + def augmentation(self, image, transform, state=None): + """ + Apply data augmentation to the input image. + + Args: + image (PIL.Image): The input image. + transform (torchvision.transforms.Compose): The data augmentation transforms. + state (dict, optional): The random state for reproducibility. Defaults to None. + + Returns: + PIL.Image: The augmented image. + """ + if state is not None: + torch.set_rng_state(state) + return transform(image) + + def __getitem__(self, index): + video_meta = self.vid_meta[index] + video_path = video_meta["image_path"] + mask_path = video_meta["mask_path"] + face_emb_path = video_meta["face_emb"] + + video_frames = sorted(Path(video_path).iterdir()) + video_length = len(video_frames) + + margin = min(self.sample_margin, video_length) + + ref_img_idx = random.randint(0, video_length - 1) + if ref_img_idx + margin < video_length: + tgt_img_idx = random.randint( + ref_img_idx + margin, video_length - 1) + elif ref_img_idx - margin > 0: + tgt_img_idx = random.randint(0, ref_img_idx - margin) + else: + tgt_img_idx = random.randint(0, video_length - 1) + + ref_img_pil = Image.open(video_frames[ref_img_idx]) + tgt_img_pil = Image.open(video_frames[tgt_img_idx]) + + tgt_mask_pil = Image.open(mask_path) + + assert ref_img_pil is not None, "Fail to load reference image." + assert tgt_img_pil is not None, "Fail to load target image." + assert tgt_mask_pil is not None, "Fail to load target mask." + + state = torch.get_rng_state() + tgt_img = self.augmentation(tgt_img_pil, self.transform, state) + tgt_mask_img = self.augmentation( + tgt_mask_pil, self.cond_transform, state) + tgt_mask_img = tgt_mask_img.repeat(3, 1, 1) + ref_img_vae = self.augmentation( + ref_img_pil, self.transform, state) + face_emb = torch.load(face_emb_path) + + + sample = { + "video_dir": video_path, + "img": tgt_img, + "tgt_mask": tgt_mask_img, + "ref_img": ref_img_vae, + "face_emb": face_emb, + } + + return sample + + def __len__(self): + return len(self.vid_meta) + + +if __name__ == "__main__": + data = FaceMaskDataset(img_size=(512, 512)) + train_dataloader = torch.utils.data.DataLoader( + data, batch_size=4, shuffle=True, num_workers=1 + ) + for step, batch in enumerate(train_dataloader): + print(batch["tgt_mask"].shape) + break diff --git a/Hallo2/hallo2/hallo/datasets/talk_video.py b/Hallo2/hallo2/hallo/datasets/talk_video.py new file mode 100644 index 00000000..25c3ab81 --- /dev/null +++ b/Hallo2/hallo2/hallo/datasets/talk_video.py @@ -0,0 +1,316 @@ +# pylint: disable=R0801 +""" +talking_video_dataset.py + +This module defines the TalkingVideoDataset class, a custom PyTorch dataset +for handling talking video data. The dataset uses video files, masks, and +embeddings to prepare data for tasks such as video generation and +speech-driven video animation. + +Classes: + TalkingVideoDataset + +Dependencies: + json + random + torch + decord.VideoReader, decord.cpu + PIL.Image + torch.utils.data.Dataset + torchvision.transforms + +Example: + from talking_video_dataset import TalkingVideoDataset + from torch.utils.data import DataLoader + + # Example configuration for the Wav2Vec model + class Wav2VecConfig: + def __init__(self, audio_type, model_scale, features): + self.audio_type = audio_type + self.model_scale = model_scale + self.features = features + + wav2vec_cfg = Wav2VecConfig(audio_type="wav2vec2", model_scale="base", features="feature") + + # Initialize dataset + dataset = TalkingVideoDataset( + img_size=(512, 512), + sample_rate=16000, + audio_margin=2, + n_motion_frames=0, + n_sample_frames=16, + data_meta_paths=["path/to/meta1.json", "path/to/meta2.json"], + wav2vec_cfg=wav2vec_cfg, + ) + + # Initialize dataloader + dataloader = DataLoader(dataset, batch_size=4, shuffle=True) + + # Fetch one batch of data + batch = next(iter(dataloader)) + print(batch["pixel_values_vid"].shape) # Example output: (4, 16, 3, 512, 512) + +The TalkingVideoDataset class provides methods for loading video frames, masks, +audio embeddings, and other relevant data, applying transformations, and preparing +the data for training and evaluation in a deep learning pipeline. + +Attributes: + img_size (tuple): The dimensions to resize the video frames to. + sample_rate (int): The audio sample rate. + audio_margin (int): The margin for audio sampling. + n_motion_frames (int): The number of motion frames. + n_sample_frames (int): The number of sample frames. + data_meta_paths (list): List of paths to the JSON metadata files. + wav2vec_cfg (object): Configuration for the Wav2Vec model. + +Methods: + augmentation(images, transform, state=None): Apply transformation to input images. + __getitem__(index): Get a sample from the dataset at the specified index. + __len__(): Return the length of the dataset. +""" + +import json +import random +from typing import List + +import torch +from decord import VideoReader, cpu +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + + +class TalkingVideoDataset(Dataset): + """ + A dataset class for processing talking video data. + + Args: + img_size (tuple, optional): The size of the output images. Defaults to (512, 512). + sample_rate (int, optional): The sample rate of the audio data. Defaults to 16000. + audio_margin (int, optional): The margin for the audio data. Defaults to 2. + n_motion_frames (int, optional): The number of motion frames. Defaults to 0. + n_sample_frames (int, optional): The number of sample frames. Defaults to 16. + data_meta_paths (list, optional): The paths to the data metadata. Defaults to None. + wav2vec_cfg (dict, optional): The configuration for the wav2vec model. Defaults to None. + + Attributes: + img_size (tuple): The size of the output images. + sample_rate (int): The sample rate of the audio data. + audio_margin (int): The margin for the audio data. + n_motion_frames (int): The number of motion frames. + n_sample_frames (int): The number of sample frames. + data_meta_paths (list): The paths to the data metadata. + wav2vec_cfg (dict): The configuration for the wav2vec model. + """ + + def __init__( + self, + img_size=(512, 512), + sample_rate=16000, + audio_margin=2, + n_motion_frames=0, + n_sample_frames=16, + data_meta_paths=None, + wav2vec_cfg=None, + ): + super().__init__() + self.sample_rate = sample_rate + self.img_size = img_size + self.audio_margin = audio_margin + self.n_motion_frames = n_motion_frames + self.n_sample_frames = n_sample_frames + self.audio_type = wav2vec_cfg.audio_type + self.audio_model = wav2vec_cfg.model_scale + self.audio_features = wav2vec_cfg.features + + vid_meta = [] + for data_meta_path in data_meta_paths: + with open(data_meta_path, "r", encoding="utf-8") as f: + vid_meta.extend(json.load(f)) + self.vid_meta = vid_meta + self.length = len(self.vid_meta) + self.pixel_transform = transforms.Compose( + [ + transforms.Resize(self.img_size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + self.cond_transform = transforms.Compose( + [ + transforms.Resize(self.img_size), + transforms.ToTensor(), + ] + ) + self.attn_transform_64 = transforms.Compose( + [ + transforms.Resize( + (self.img_size[0] // 8, self.img_size[0] // 8)), + transforms.ToTensor(), + ] + ) + self.attn_transform_32 = transforms.Compose( + [ + transforms.Resize( + (self.img_size[0] // 16, self.img_size[0] // 16)), + transforms.ToTensor(), + ] + ) + self.attn_transform_16 = transforms.Compose( + [ + transforms.Resize( + (self.img_size[0] // 32, self.img_size[0] // 32)), + transforms.ToTensor(), + ] + ) + self.attn_transform_8 = transforms.Compose( + [ + transforms.Resize( + (self.img_size[0] // 64, self.img_size[0] // 64)), + transforms.ToTensor(), + ] + ) + + def augmentation(self, images, transform, state=None): + """ + Apply the given transformation to the input images. + + Args: + images (List[PIL.Image] or PIL.Image): The input images to be transformed. + transform (torchvision.transforms.Compose): The transformation to be applied to the images. + state (torch.ByteTensor, optional): The state of the random number generator. + If provided, it will set the RNG state to this value before applying the transformation. Defaults to None. + + Returns: + torch.Tensor: The transformed images as a tensor. + If the input was a list of images, the tensor will have shape (f, c, h, w), + where f is the number of images, c is the number of channels, h is the height, and w is the width. + If the input was a single image, the tensor will have shape (c, h, w), + where c is the number of channels, h is the height, and w is the width. + """ + if state is not None: + torch.set_rng_state(state) + if isinstance(images, List): + transformed_images = [transform(img) for img in images] + ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w) + else: + ret_tensor = transform(images) # (c, h, w) + return ret_tensor + + def __getitem__(self, index): + video_meta = self.vid_meta[index] + video_path = video_meta["video_path"] + mask_path = video_meta["mask_path"] + lip_mask_union_path = video_meta.get("sep_mask_lip", None) + face_mask_union_path = video_meta.get("sep_mask_face", None) + full_mask_union_path = video_meta.get("sep_mask_border", None) + face_emb_path = video_meta["face_emb_path"] + audio_emb_path = video_meta[ + f"{self.audio_type}_emb_{self.audio_model}_{self.audio_features}" + ] + tgt_mask_pil = Image.open(mask_path) + video_frames = VideoReader(video_path, ctx=cpu(0)) + assert tgt_mask_pil is not None, "Fail to load target mask." + assert (video_frames is not None and len(video_frames) > 0), "Fail to load video frames." + video_length = len(video_frames) + + assert ( + video_length + > self.n_sample_frames + self.n_motion_frames + 2 * self.audio_margin + ) + start_idx = random.randint( + self.n_motion_frames, + video_length - self.n_sample_frames - self.audio_margin - 1, + ) + + videos = video_frames[start_idx : start_idx + self.n_sample_frames] + + frame_list = [ + Image.fromarray(video).convert("RGB") for video in videos.asnumpy() + ] + + face_masks_list = [Image.open(face_mask_union_path)] * self.n_sample_frames + lip_masks_list = [Image.open(lip_mask_union_path)] * self.n_sample_frames + full_masks_list = [Image.open(full_mask_union_path)] * self.n_sample_frames + assert face_masks_list[0] is not None, "Fail to load face mask." + assert lip_masks_list[0] is not None, "Fail to load lip mask." + assert full_masks_list[0] is not None, "Fail to load full mask." + + + face_emb = torch.load(face_emb_path) + audio_emb = torch.load(audio_emb_path) + indices = ( + torch.arange(2 * self.audio_margin + 1) - self.audio_margin + ) # Generates [-2, -1, 0, 1, 2] + center_indices = torch.arange( + start_idx, + start_idx + self.n_sample_frames, + ).unsqueeze(1) + indices.unsqueeze(0) + audio_tensor = audio_emb[center_indices] + + ref_img_idx = random.randint( + self.n_motion_frames, + video_length - self.n_sample_frames - self.audio_margin - 1, + ) + ref_img = video_frames[ref_img_idx].asnumpy() + ref_img = Image.fromarray(ref_img) + + if self.n_motion_frames > 0: + motions = video_frames[start_idx - self.n_motion_frames : start_idx] + motion_list = [ + Image.fromarray(motion).convert("RGB") for motion in motions.asnumpy() + ] + + # transform + state = torch.get_rng_state() + pixel_values_vid = self.augmentation(frame_list, self.pixel_transform, state) + + pixel_values_mask = self.augmentation(tgt_mask_pil, self.cond_transform, state) + pixel_values_mask = pixel_values_mask.repeat(3, 1, 1) + + pixel_values_face_mask = [ + self.augmentation(face_masks_list, self.attn_transform_64, state), + self.augmentation(face_masks_list, self.attn_transform_32, state), + self.augmentation(face_masks_list, self.attn_transform_16, state), + self.augmentation(face_masks_list, self.attn_transform_8, state), + ] + pixel_values_lip_mask = [ + self.augmentation(lip_masks_list, self.attn_transform_64, state), + self.augmentation(lip_masks_list, self.attn_transform_32, state), + self.augmentation(lip_masks_list, self.attn_transform_16, state), + self.augmentation(lip_masks_list, self.attn_transform_8, state), + ] + pixel_values_full_mask = [ + self.augmentation(full_masks_list, self.attn_transform_64, state), + self.augmentation(full_masks_list, self.attn_transform_32, state), + self.augmentation(full_masks_list, self.attn_transform_16, state), + self.augmentation(full_masks_list, self.attn_transform_8, state), + ] + + pixel_values_ref_img = self.augmentation(ref_img, self.pixel_transform, state) + pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0) + if self.n_motion_frames > 0: + pixel_values_motion = self.augmentation( + motion_list, self.pixel_transform, state + ) + pixel_values_ref_img = torch.cat( + [pixel_values_ref_img, pixel_values_motion], dim=0 + ) + + sample = { + "video_dir": video_path, + "pixel_values_vid": pixel_values_vid, + "pixel_values_mask": pixel_values_mask, + "pixel_values_face_mask": pixel_values_face_mask, + "pixel_values_lip_mask": pixel_values_lip_mask, + "pixel_values_full_mask": pixel_values_full_mask, + "audio_tensor": audio_tensor, + "pixel_values_ref_img": pixel_values_ref_img, + "face_emb": face_emb, + } + + return sample + + def __len__(self): + return len(self.vid_meta) diff --git a/Hallo2/hallo2/hallo/models/__init__.py b/Hallo2/hallo2/hallo/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Hallo2/hallo2/hallo/models/attention.py b/Hallo2/hallo2/hallo/models/attention.py new file mode 100644 index 00000000..d7feec9c --- /dev/null +++ b/Hallo2/hallo2/hallo/models/attention.py @@ -0,0 +1,921 @@ +# pylint: disable=R0801 +# pylint: disable=C0303 + +""" +This module contains various transformer blocks for different applications, such as BasicTransformerBlock, +TemporalBasicTransformerBlock, and AudioTemporalBasicTransformerBlock. These blocks are used in various models, +such as GLIGEN, UNet, and others. The transformer blocks implement self-attention, cross-attention, feed-forward +networks, and other related functions. + +Functions and classes included in this module are: +- BasicTransformerBlock: A basic transformer block with self-attention, cross-attention, and feed-forward layers. +- TemporalBasicTransformerBlock: A transformer block with additional temporal attention mechanisms for video data. +- AudioTemporalBasicTransformerBlock: A transformer block with additional audio-specific mechanisms for audio data. +- zero_module: A function to zero out the parameters of a given module. + +For more information on each specific class and function, please refer to the respective docstrings. +""" + +from typing import Any, Dict, List, Optional + +import torch +from diffusers.models.attention import (AdaLayerNorm, AdaLayerNormZero, + Attention, FeedForward) +from diffusers.models.embeddings import SinusoidalPositionalEmbedding +from einops import rearrange +from torch import nn + + +class GatedSelfAttentionDense(nn.Module): + """ + A gated self-attention dense layer that combines visual features and object features. + + Parameters: + query_dim (`int`): The number of channels in the query. + context_dim (`int`): The number of channels in the context. + n_heads (`int`): The number of heads to use for attention. + d_head (`int`): The number of channels in each head. + """ + + def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): + super().__init__() + + # we need a linear projection since we need cat visual feature and obj feature + self.linear = nn.Linear(context_dim, query_dim) + + self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) + self.ff = FeedForward(query_dim, activation_fn="geglu") + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) + self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) + + self.enabled = True + + def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: + """ + Apply the Gated Self-Attention mechanism to the input tensor `x` and object tensor `objs`. + + Args: + x (torch.Tensor): The input tensor. + objs (torch.Tensor): The object tensor. + + Returns: + torch.Tensor: The output tensor after applying Gated Self-Attention. + """ + if not self.enabled: + return x + + n_visual = x.shape[1] + objs = self.linear(objs) + + x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] + x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) + + return x + +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' + norm_type: str = "layer_norm", + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = ( + num_embeds_ada_norm is not None + ) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = ( + num_embeds_ada_norm is not None + ) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding( + dim, max_seq_length=num_positional_embeddings + ) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm( + dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps + ) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm( + dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps + ) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=( + cross_attention_dim if not double_self_attention else None + ), + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + if not self.use_ada_layer_norm_single: + self.norm3 = nn.LayerNorm( + dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps + ) + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + ) + + # 4. Fuser + if attention_type in {"gated", "gated-text-image"}: # Updated line + self.fuser = GatedSelfAttentionDense( + dim, cross_attention_dim, num_attention_heads, attention_head_dim + ) + + # 5. Scale-shift for PixArt-Alpha. + if self.use_ada_layer_norm_single: + self.scale_shift_table = nn.Parameter( + torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + """ + Sets the chunk size for feed-forward processing in the transformer block. + + Args: + chunk_size (Optional[int]): The size of the chunks to process in feed-forward layers. + If None, the chunk size is set to the maximum possible value. + dim (int, optional): The dimension along which to split the input tensor into chunks. Defaults to 0. + + Returns: + None. + """ + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ) -> torch.FloatTensor: + """ + This function defines the forward pass of the BasicTransformerBlock. + + Args: + self (BasicTransformerBlock): + An instance of the BasicTransformerBlock class. + hidden_states (torch.FloatTensor): + A tensor containing the hidden states. + attention_mask (Optional[torch.FloatTensor], optional): + A tensor containing the attention mask. Defaults to None. + encoder_hidden_states (Optional[torch.FloatTensor], optional): + A tensor containing the encoder hidden states. Defaults to None. + encoder_attention_mask (Optional[torch.FloatTensor], optional): + A tensor containing the encoder attention mask. Defaults to None. + timestep (Optional[torch.LongTensor], optional): + A tensor containing the timesteps. Defaults to None. + cross_attention_kwargs (Dict[str, Any], optional): + Additional cross-attention arguments. Defaults to None. + class_labels (Optional[torch.LongTensor], optional): + A tensor containing the class labels. Defaults to None. + + Returns: + torch.FloatTensor: + A tensor containing the transformed hidden states. + """ + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + gate_msa = None + scale_mlp = None + shift_mlp = None + gate_mlp = None + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.use_layer_norm: + norm_hidden_states = self.norm1(hidden_states) + elif self.use_ada_layer_norm_single: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * \ + (1 + scale_msa) + shift_msa + norm_hidden_states = norm_hidden_states.squeeze(1) + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Retrieve lora scale. + lora_scale = ( + cross_attention_kwargs.get("scale", 1.0) + if cross_attention_kwargs is not None + else 1.0 + ) + + # 2. Prepare GLIGEN inputs + cross_attention_kwargs = ( + cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + ) + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=( + encoder_hidden_states if self.only_cross_attention else None + ), + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.use_ada_layer_norm_single: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 2.5 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.use_ada_layer_norm: + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.use_ada_layer_norm_zero or self.use_layer_norm: + norm_hidden_states = self.norm2(hidden_states) + elif self.use_ada_layer_norm_single: + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.use_ada_layer_norm_single is False: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + if not self.use_ada_layer_norm_single: + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = ( + norm_hidden_states * + (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) + + if self.use_ada_layer_norm_single: + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * \ + (1 + scale_mlp) + shift_mlp + + ff_output = self.ff(norm_hidden_states, scale=lora_scale) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.use_ada_layer_norm_single: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class TemporalBasicTransformerBlock(nn.Module): + """ + A PyTorch module that extends the BasicTransformerBlock to include temporal attention mechanisms. + This class is particularly useful for video-related tasks where capturing temporal information within the sequence of frames is necessary. + + Attributes: + dim (int): The dimension of the input and output embeddings. + num_attention_heads (int): The number of attention heads in the multi-head self-attention mechanism. + attention_head_dim (int): The dimension of each attention head. + dropout (float): The dropout probability for the attention scores. + cross_attention_dim (Optional[int]): The dimension of the cross-attention mechanism. + activation_fn (str): The activation function used in the feed-forward layer. + num_embeds_ada_norm (Optional[int]): The number of embeddings for adaptive normalization. + attention_bias (bool): If True, uses bias in the attention mechanism. + only_cross_attention (bool): If True, only uses cross-attention. + upcast_attention (bool): If True, upcasts the attention mechanism for better performance. + unet_use_cross_frame_attention (Optional[bool]): If True, uses cross-frame attention in the UNet model. + unet_use_temporal_attention (Optional[bool]): If True, uses temporal attention in the UNet model. + """ + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + unet_use_cross_frame_attention=None, + unet_use_temporal_attention=None, + ): + """ + The TemporalBasicTransformerBlock class is a PyTorch module that extends the BasicTransformerBlock to include temporal attention mechanisms. + This is particularly useful for video-related tasks, where the model needs to capture the temporal information within the sequence of frames. + The block consists of self-attention, cross-attention, feed-forward, and temporal attention mechanisms. + + dim (int): The dimension of the input and output embeddings. + num_attention_heads (int): The number of attention heads in the multi-head self-attention mechanism. + attention_head_dim (int): The dimension of each attention head. + dropout (float, optional): The dropout probability for the attention scores. Defaults to 0.0. + cross_attention_dim (int, optional): The dimension of the cross-attention mechanism. Defaults to None. + activation_fn (str, optional): The activation function used in the feed-forward layer. Defaults to "geglu". + num_embeds_ada_norm (int, optional): The number of embeddings for adaptive normalization. Defaults to None. + attention_bias (bool, optional): If True, uses bias in the attention mechanism. Defaults to False. + only_cross_attention (bool, optional): If True, only uses cross-attention. Defaults to False. + upcast_attention (bool, optional): If True, upcasts the attention mechanism for better performance. Defaults to False. + unet_use_cross_frame_attention (bool, optional): If True, uses cross-frame attention in the UNet model. Defaults to None. + unet_use_temporal_attention (bool, optional): If True, uses temporal attention in the UNet model. Defaults to None. + + Forward method: + hidden_states (torch.FloatTensor): The input hidden states. + encoder_hidden_states (torch.FloatTensor, optional): The encoder hidden states. Defaults to None. + timestep (torch.LongTensor, optional): The current timestep for the transformer model. Defaults to None. + attention_mask (torch.FloatTensor, optional): The attention mask for the self-attention mechanism. Defaults to None. + video_length (int, optional): The length of the video sequence. Defaults to None. + + Returns: + torch.FloatTensor: The output hidden states after passing through the TemporalBasicTransformerBlock. + """ + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm = num_embeds_ada_norm is not None + self.unet_use_cross_frame_attention = unet_use_cross_frame_attention + self.unet_use_temporal_attention = unet_use_temporal_attention + + # SC-Attn + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + self.norm1 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim) + ) + + # Cross-Attn + if cross_attention_dim is not None: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + else: + self.attn2 = None + + if cross_attention_dim is not None: + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim) + ) + else: + self.norm2 = None + + # Feed-forward + self.ff = FeedForward(dim, dropout=dropout, + activation_fn=activation_fn) + self.norm3 = nn.LayerNorm(dim) + self.use_ada_layer_norm_zero = False + + # Temp-Attn + # assert unet_use_temporal_attention is not None + if unet_use_temporal_attention is None: + unet_use_temporal_attention = False + if unet_use_temporal_attention: + self.attn_temp = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + nn.init.zeros_(self.attn_temp.to_out[0].weight.data) + self.norm_temp = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim) + ) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + attention_mask=None, + video_length=None, + ): + """ + Forward pass for the TemporalBasicTransformerBlock. + + Args: + hidden_states (torch.FloatTensor): The input hidden states with shape (batch_size, seq_len, dim). + encoder_hidden_states (torch.FloatTensor, optional): The encoder hidden states with shape (batch_size, src_seq_len, dim). + timestep (torch.LongTensor, optional): The timestep for the transformer block. + attention_mask (torch.FloatTensor, optional): The attention mask with shape (batch_size, seq_len, seq_len). + video_length (int, optional): The length of the video sequence. + + Returns: + torch.FloatTensor: The output tensor after passing through the transformer block with shape (batch_size, seq_len, dim). + """ + norm_hidden_states = ( + self.norm1(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm1(hidden_states) + ) + + if self.unet_use_cross_frame_attention: + hidden_states = ( + self.attn1( + norm_hidden_states, + attention_mask=attention_mask, + video_length=video_length, + ) + + hidden_states + ) + else: + hidden_states = ( + self.attn1(norm_hidden_states, attention_mask=attention_mask) + + hidden_states + ) + + if self.attn2 is not None: + # Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm2(hidden_states) + ) + hidden_states = ( + self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + + hidden_states + ) + + # Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + # Temporal-Attention + if self.unet_use_temporal_attention: + d = hidden_states.shape[1] + hidden_states = rearrange( + hidden_states, "(b f) d c -> (b d) f c", f=video_length + ) + norm_hidden_states = ( + self.norm_temp(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm_temp(hidden_states) + ) + hidden_states = self.attn_temp(norm_hidden_states) + hidden_states + hidden_states = rearrange( + hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states + + +class AudioTemporalBasicTransformerBlock(nn.Module): + """ + A PyTorch module designed to handle audio data within a transformer framework, including temporal attention mechanisms. + + Attributes: + dim (int): The dimension of the input and output embeddings. + num_attention_heads (int): The number of attention heads. + attention_head_dim (int): The dimension of each attention head. + dropout (float): The dropout probability. + cross_attention_dim (Optional[int]): The dimension of the cross-attention mechanism. + activation_fn (str): The activation function for the feed-forward network. + num_embeds_ada_norm (Optional[int]): The number of embeddings for adaptive normalization. + attention_bias (bool): If True, uses bias in the attention mechanism. + only_cross_attention (bool): If True, only uses cross-attention. + upcast_attention (bool): If True, upcasts the attention mechanism to float32. + unet_use_cross_frame_attention (Optional[bool]): If True, uses cross-frame attention in UNet. + unet_use_temporal_attention (Optional[bool]): If True, uses temporal attention in UNet. + depth (int): The depth of the transformer block. + unet_block_name (Optional[str]): The name of the UNet block. + stack_enable_blocks_name (Optional[List[str]]): The list of enabled blocks in the stack. + stack_enable_blocks_depth (Optional[List[int]]): The list of depths for the enabled blocks in the stack. + """ + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + unet_use_cross_frame_attention=None, + unet_use_temporal_attention=None, + depth=0, + unet_block_name=None, + stack_enable_blocks_name: Optional[List[str]] = None, + stack_enable_blocks_depth: Optional[List[int]] = None, + ): + """ + Initializes the AudioTemporalBasicTransformerBlock module. + + Args: + dim (int): The dimension of the input and output embeddings. + num_attention_heads (int): The number of attention heads in the multi-head self-attention mechanism. + attention_head_dim (int): The dimension of each attention head. + dropout (float, optional): The dropout probability for the attention mechanism. Defaults to 0.0. + cross_attention_dim (Optional[int], optional): The dimension of the cross-attention mechanism. Defaults to None. + activation_fn (str, optional): The activation function to be used in the feed-forward network. Defaults to "geglu". + num_embeds_ada_norm (Optional[int], optional): The number of embeddings for adaptive normalization. Defaults to None. + attention_bias (bool, optional): If True, uses bias in the attention mechanism. Defaults to False. + only_cross_attention (bool, optional): If True, only uses cross-attention. Defaults to False. + upcast_attention (bool, optional): If True, upcasts the attention mechanism to float32. Defaults to False. + unet_use_cross_frame_attention (Optional[bool], optional): If True, uses cross-frame attention in UNet. Defaults to None. + unet_use_temporal_attention (Optional[bool], optional): If True, uses temporal attention in UNet. Defaults to None. + depth (int, optional): The depth of the transformer block. Defaults to 0. + unet_block_name (Optional[str], optional): The name of the UNet block. Defaults to None. + stack_enable_blocks_name (Optional[List[str]], optional): The list of enabled blocks in the stack. Defaults to None. + stack_enable_blocks_depth (Optional[List[int]], optional): The list of depths for the enabled blocks in the stack. Defaults to None. + """ + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm = num_embeds_ada_norm is not None + self.unet_use_cross_frame_attention = unet_use_cross_frame_attention + self.unet_use_temporal_attention = unet_use_temporal_attention + self.unet_block_name = unet_block_name + self.depth = depth + + zero_conv_full = nn.Conv2d( + dim, dim, kernel_size=1) + self.zero_conv_full = zero_module(zero_conv_full) + + zero_conv_face = nn.Conv2d( + dim, dim, kernel_size=1) + self.zero_conv_face = zero_module(zero_conv_face) + + zero_conv_lip = nn.Conv2d( + dim, dim, kernel_size=1) + self.zero_conv_lip = zero_module(zero_conv_lip) + # SC-Attn + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + self.norm1 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim) + ) + + # Cross-Attn + if cross_attention_dim is not None: + if (stack_enable_blocks_name is not None and + stack_enable_blocks_depth is not None and + self.unet_block_name in stack_enable_blocks_name and + self.depth in stack_enable_blocks_depth): + self.attn2_0 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + self.attn2_1 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + self.attn2_2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + self.attn2 = None + + else: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + self.attn2_0=None + else: + self.attn2 = None + self.attn2_0 = None + + if cross_attention_dim is not None: + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim) + ) + else: + self.norm2 = None + + # Feed-forward + self.ff = FeedForward(dim, dropout=dropout, + activation_fn=activation_fn) + self.norm3 = nn.LayerNorm(dim) + self.use_ada_layer_norm_zero = False + + + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + attention_mask=None, + full_mask=None, + face_mask=None, + lip_mask=None, + motion_scale=None, + video_length=None, + ): + """ + Forward pass for the AudioTemporalBasicTransformerBlock. + + Args: + hidden_states (torch.FloatTensor): The input hidden states. + encoder_hidden_states (torch.FloatTensor, optional): The encoder hidden states. Defaults to None. + timestep (torch.LongTensor, optional): The timestep for the transformer block. Defaults to None. + attention_mask (torch.FloatTensor, optional): The attention mask. Defaults to None. + full_mask (torch.FloatTensor, optional): The full mask. Defaults to None. + face_mask (torch.FloatTensor, optional): The face mask. Defaults to None. + lip_mask (torch.FloatTensor, optional): The lip mask. Defaults to None. + video_length (int, optional): The length of the video. Defaults to None. + + Returns: + torch.FloatTensor: The output tensor after passing through the AudioTemporalBasicTransformerBlock. + """ + norm_hidden_states = ( + self.norm1(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm1(hidden_states) + ) + + if self.unet_use_cross_frame_attention: + hidden_states = ( + self.attn1( + norm_hidden_states, + attention_mask=attention_mask, + video_length=video_length, + ) + + hidden_states + ) + else: + hidden_states = ( + self.attn1(norm_hidden_states, attention_mask=attention_mask) + + hidden_states + ) + + if self.attn2 is not None: + # Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm2(hidden_states) + ) + hidden_states = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + hidden_states + + elif self.attn2_0 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm2(hidden_states) + ) + + level = self.depth + full_hidden_states = ( + self.attn2_0( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) * full_mask[level][:, :, None] + ) + bz, sz, c = full_hidden_states.shape + sz_sqrt = int(sz ** 0.5) + full_hidden_states = full_hidden_states.reshape( + bz, sz_sqrt, sz_sqrt, c).permute(0, 3, 1, 2) + full_hidden_states = self.zero_conv_full(full_hidden_states).permute(0, 2, 3, 1).reshape(bz, -1, c) + + face_hidden_state = ( + self.attn2_1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) * face_mask[level][:, :, None] + ) + face_hidden_state = face_hidden_state.reshape( + bz, sz_sqrt, sz_sqrt, c).permute(0, 3, 1, 2) + face_hidden_state = self.zero_conv_face( + face_hidden_state).permute(0, 2, 3, 1).reshape(bz, -1, c) + + lip_hidden_state = ( + self.attn2_2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) * lip_mask[level][:, :, None] + + ) # [32, 4096, 320] + lip_hidden_state = lip_hidden_state.reshape( + bz, sz_sqrt, sz_sqrt, c).permute(0, 3, 1, 2) + lip_hidden_state = self.zero_conv_lip( + lip_hidden_state).permute(0, 2, 3, 1).reshape(bz, -1, c) + + if motion_scale is not None: + hidden_states = ( + motion_scale[0] * full_hidden_states + + motion_scale[1] * face_hidden_state + + motion_scale[2] * lip_hidden_state + hidden_states + ) + else: + hidden_states = ( + full_hidden_states + + face_hidden_state + + lip_hidden_state + hidden_states + ) + # Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + return hidden_states + +def zero_module(module): + """ + Zeroes out the parameters of a given module. + + Args: + module (nn.Module): The module whose parameters need to be zeroed out. + + Returns: + None. + """ + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/Hallo2/hallo2/hallo/models/audio_proj.py b/Hallo2/hallo2/hallo/models/audio_proj.py new file mode 100644 index 00000000..9edf7d2e --- /dev/null +++ b/Hallo2/hallo2/hallo/models/audio_proj.py @@ -0,0 +1,124 @@ +""" +This module provides the implementation of an Audio Projection Model, which is designed for +audio processing tasks. The model takes audio embeddings as input and outputs context tokens +that can be used for various downstream applications, such as audio analysis or synthesis. + +The AudioProjModel class is based on the ModelMixin class from the diffusers library, which +provides a foundation for building custom models. This implementation includes multiple linear +layers with ReLU activation functions and a LayerNorm for normalization. + +Key Features: +- Audio embedding input with flexible sequence length and block structure. +- Multiple linear layers for feature transformation. +- ReLU activation for non-linear transformation. +- LayerNorm for stabilizing and speeding up training. +- Rearrangement of input embeddings to match the model's expected input shape. +- Customizable number of blocks, channels, and context tokens for adaptability. + +The module is structured to be easily integrated into larger systems or used as a standalone +component for audio feature extraction and processing. + +Classes: +- AudioProjModel: A class representing the audio projection model with configurable parameters. + +Functions: +- (none) + +Dependencies: +- torch: For tensor operations and neural network components. +- diffusers: For the ModelMixin base class. +- einops: For tensor rearrangement operations. + +""" + +import torch +from diffusers import ModelMixin +from einops import rearrange +from torch import nn + + +class AudioProjModel(ModelMixin): + """Audio Projection Model + + This class defines an audio projection model that takes audio embeddings as input + and produces context tokens as output. The model is based on the ModelMixin class + and consists of multiple linear layers and activation functions. It can be used + for various audio processing tasks. + + Attributes: + seq_len (int): The length of the audio sequence. + blocks (int): The number of blocks in the audio projection model. + channels (int): The number of channels in the audio projection model. + intermediate_dim (int): The intermediate dimension of the model. + context_tokens (int): The number of context tokens in the output. + output_dim (int): The output dimension of the context tokens. + + Methods: + __init__(self, seq_len=5, blocks=12, channels=768, intermediate_dim=512, context_tokens=32, output_dim=768): + Initializes the AudioProjModel with the given parameters. + forward(self, audio_embeds): + Defines the forward pass for the AudioProjModel. + Parameters: + audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels). + Returns: + context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim). + + """ + + def __init__( + self, + seq_len=5, + blocks=12, # add a new parameter blocks + channels=768, # add a new parameter channels + intermediate_dim=512, + output_dim=768, + context_tokens=32, + ): + super().__init__() + + self.seq_len = seq_len + self.blocks = blocks + self.channels = channels + self.input_dim = ( + seq_len * blocks * channels + ) # update input_dim to be the product of blocks and channels. + self.intermediate_dim = intermediate_dim + self.context_tokens = context_tokens + self.output_dim = output_dim + + # define multiple linear layers + self.proj1 = nn.Linear(self.input_dim, intermediate_dim) + self.proj2 = nn.Linear(intermediate_dim, intermediate_dim) + self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim) + + self.norm = nn.LayerNorm(output_dim) + + def forward(self, audio_embeds): + """ + Defines the forward pass for the AudioProjModel. + + Parameters: + audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels). + + Returns: + context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim). + """ + # merge + video_length = audio_embeds.shape[1] + audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") + batch_size, window_size, blocks, channels = audio_embeds.shape + audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) + + audio_embeds = torch.relu(self.proj1(audio_embeds)) + audio_embeds = torch.relu(self.proj2(audio_embeds)) + + context_tokens = self.proj3(audio_embeds).reshape( + batch_size, self.context_tokens, self.output_dim + ) + + context_tokens = self.norm(context_tokens) + context_tokens = rearrange( + context_tokens, "(bz f) m c -> bz f m c", f=video_length + ) + + return context_tokens diff --git a/Hallo2/hallo2/hallo/models/face_locator.py b/Hallo2/hallo2/hallo/models/face_locator.py new file mode 100644 index 00000000..f138744c --- /dev/null +++ b/Hallo2/hallo2/hallo/models/face_locator.py @@ -0,0 +1,113 @@ +""" +This module implements the FaceLocator class, which is a neural network model designed to +locate and extract facial features from input images or tensors. It uses a series of +convolutional layers to progressively downsample and refine the facial feature map. + +The FaceLocator class is part of a larger system that may involve facial recognition or +similar tasks where precise location and extraction of facial features are required. + +Attributes: + conditioning_embedding_channels (int): The number of channels in the output embedding. + conditioning_channels (int): The number of input channels for the conditioning tensor. + block_out_channels (Tuple[int]): A tuple of integers representing the output channels + for each block in the model. + +The model uses the following components: +- InflatedConv3d: A convolutional layer that inflates the input to increase the depth. +- zero_module: A utility function that may set certain parameters to zero for regularization + or other purposes. + +The forward method of the FaceLocator class takes a conditioning tensor as input and +produces an embedding tensor as output, which can be used for further processing or analysis. +""" + +from typing import Tuple + +import torch.nn.functional as F +from diffusers.models.modeling_utils import ModelMixin +from torch import nn + +from .motion_module import zero_module +from .resnet import InflatedConv3d + + +class FaceLocator(ModelMixin): + """ + The FaceLocator class is a neural network model designed to process and extract facial + features from an input tensor. It consists of a series of convolutional layers that + progressively downsample the input while increasing the depth of the feature map. + + The model is built using InflatedConv3d layers, which are designed to inflate the + feature channels, allowing for more complex feature extraction. The final output is a + conditioning embedding that can be used for various tasks such as facial recognition or + feature-based image manipulation. + + Parameters: + conditioning_embedding_channels (int): The number of channels in the output embedding. + conditioning_channels (int, optional): The number of input channels for the conditioning tensor. Default is 3. + block_out_channels (Tuple[int], optional): A tuple of integers representing the output channels + for each block in the model. The default is (16, 32, 64, 128), which defines the + progression of the network's depth. + + Attributes: + conv_in (InflatedConv3d): The initial convolutional layer that starts the feature extraction process. + blocks (ModuleList[InflatedConv3d]): A list of convolutional layers that form the core of the model. + conv_out (InflatedConv3d): The final convolutional layer that produces the output embedding. + + The forward method applies the convolutional layers to the input conditioning tensor and + returns the resulting embedding tensor. + """ + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int] = (16, 32, 64, 128), + ): + super().__init__() + self.conv_in = InflatedConv3d( + conditioning_channels, block_out_channels[0], kernel_size=3, padding=1 + ) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append( + InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1) + ) + self.blocks.append( + InflatedConv3d( + channel_in, channel_out, kernel_size=3, padding=1, stride=2 + ) + ) + + self.conv_out = zero_module( + InflatedConv3d( + block_out_channels[-1], + conditioning_embedding_channels, + kernel_size=3, + padding=1, + ) + ) + + def forward(self, conditioning): + """ + Forward pass of the FaceLocator model. + + Args: + conditioning (Tensor): The input conditioning tensor. + + Returns: + Tensor: The output embedding tensor. + """ + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding diff --git a/Hallo2/hallo2/hallo/models/image_proj.py b/Hallo2/hallo2/hallo/models/image_proj.py new file mode 100644 index 00000000..d6522e0c --- /dev/null +++ b/Hallo2/hallo2/hallo/models/image_proj.py @@ -0,0 +1,76 @@ +""" +image_proj_model.py + +This module defines the ImageProjModel class, which is responsible for +projecting image embeddings into a different dimensional space. The model +leverages a linear transformation followed by a layer normalization to +reshape and normalize the input image embeddings for further processing in +cross-attention mechanisms or other downstream tasks. + +Classes: + ImageProjModel + +Dependencies: + torch + diffusers.ModelMixin + +""" + +import torch +from diffusers import ModelMixin + + +class ImageProjModel(ModelMixin): + """ + ImageProjModel is a class that projects image embeddings into a different + dimensional space. It inherits from ModelMixin, providing additional functionalities + specific to image projection. + + Attributes: + cross_attention_dim (int): The dimension of the cross attention. + clip_embeddings_dim (int): The dimension of the CLIP embeddings. + clip_extra_context_tokens (int): The number of extra context tokens in CLIP. + + Methods: + forward(image_embeds): Forward pass of the ImageProjModel, which takes in image + embeddings and returns the projected tokens. + + """ + + def __init__( + self, + cross_attention_dim=1024, + clip_embeddings_dim=1024, + clip_extra_context_tokens=4, + ): + super().__init__() + + self.generator = None + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + self.proj = torch.nn.Linear( + clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim + ) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + """ + Forward pass of the ImageProjModel, which takes in image embeddings and returns the + projected tokens after reshaping and normalization. + + Args: + image_embeds (torch.Tensor): The input image embeddings, with shape + batch_size x num_image_tokens x clip_embeddings_dim. + + Returns: + clip_extra_context_tokens (torch.Tensor): The projected tokens after reshaping + and normalization, with shape batch_size x (clip_extra_context_tokens * + cross_attention_dim). + + """ + embeds = image_embeds + clip_extra_context_tokens = self.proj(embeds).reshape( + -1, self.clip_extra_context_tokens, self.cross_attention_dim + ) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens diff --git a/Hallo2/hallo2/hallo/models/motion_module.py b/Hallo2/hallo2/hallo/models/motion_module.py new file mode 100644 index 00000000..f62877d4 --- /dev/null +++ b/Hallo2/hallo2/hallo/models/motion_module.py @@ -0,0 +1,609 @@ +# pylint: disable=R0801 +# pylint: disable=W0613 +# pylint: disable=W0221 + +""" +temporal_transformers.py + +This module provides classes and functions for implementing Temporal Transformers +in PyTorch, designed for handling video data and temporal sequences within transformer-based models. + +Functions: + zero_module(module) + Zero out the parameters of a module and return it. + +Classes: + TemporalTransformer3DModelOutput(BaseOutput) + Dataclass for storing the output of TemporalTransformer3DModel. + + VanillaTemporalModule(nn.Module) + A Vanilla Temporal Module class for handling temporal data. + + TemporalTransformer3DModel(nn.Module) + A Temporal Transformer 3D Model class for transforming temporal data. + + TemporalTransformerBlock(nn.Module) + A Temporal Transformer Block class for building the transformer architecture. + + PositionalEncoding(nn.Module) + A Positional Encoding module for transformers to encode positional information. + +Dependencies: + math + dataclasses.dataclass + typing (Callable, Optional) + torch + diffusers (FeedForward, Attention, AttnProcessor) + diffusers.utils (BaseOutput) + diffusers.utils.import_utils (is_xformers_available) + einops (rearrange, repeat) + torch.nn + xformers + xformers.ops + +Example Usage: + >>> motion_module = get_motion_module(in_channels=512, motion_module_type="Vanilla", motion_module_kwargs={}) + >>> output = motion_module(input_tensor, temb, encoder_hidden_states) + +This module is designed to facilitate the creation, training, and inference of transformer models +that operate on temporal data, such as videos or time-series. It includes mechanisms for applying temporal attention, +managing positional encoding, and integrating with external libraries for efficient attention operations. +""" + +# This code is copied from https://github.com/guoyww/AnimateDiff. + +import math + +import torch +import xformers +import xformers.ops +from diffusers.models.attention import FeedForward +from diffusers.models.attention_processor import Attention, AttnProcessor +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +from einops import rearrange, repeat +from torch import nn + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + + Args: + - module: A PyTorch module to zero out its parameters. + + Returns: + A zeroed out PyTorch module. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class TemporalTransformer3DModelOutput(BaseOutput): + """ + Output class for the TemporalTransformer3DModel. + + Attributes: + sample (torch.FloatTensor): The output sample tensor from the model. + """ + sample: torch.FloatTensor + + def get_sample_shape(self): + """ + Returns the shape of the sample tensor. + + Returns: + Tuple: The shape of the sample tensor. + """ + return self.sample.shape + + +def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict): + """ + This function returns a motion module based on the given type and parameters. + + Args: + - in_channels (int): The number of input channels for the motion module. + - motion_module_type (str): The type of motion module to create. Currently, only "Vanilla" is supported. + - motion_module_kwargs (dict): Additional keyword arguments to pass to the motion module constructor. + + Returns: + VanillaTemporalModule: The created motion module. + + Raises: + ValueError: If an unsupported motion_module_type is provided. + """ + if motion_module_type == "Vanilla": + return VanillaTemporalModule( + in_channels=in_channels, + **motion_module_kwargs, + ) + + raise ValueError + + +class VanillaTemporalModule(nn.Module): + """ + A Vanilla Temporal Module class. + + Args: + - in_channels (int): The number of input channels for the motion module. + - num_attention_heads (int): Number of attention heads. + - num_transformer_block (int): Number of transformer blocks. + - attention_block_types (tuple): Types of attention blocks. + - cross_frame_attention_mode: Mode for cross-frame attention. + - temporal_position_encoding (bool): Flag for temporal position encoding. + - temporal_position_encoding_max_len (int): Maximum length for temporal position encoding. + - temporal_attention_dim_div (int): Divisor for temporal attention dimension. + - zero_initialize (bool): Flag for zero initialization. + """ + + def __init__( + self, + in_channels, + num_attention_heads=8, + num_transformer_block=2, + attention_block_types=("Temporal_Self", "Temporal_Self"), + cross_frame_attention_mode=None, + temporal_position_encoding=False, + temporal_position_encoding_max_len=24, + temporal_attention_dim_div=1, + zero_initialize=True, + ): + super().__init__() + + self.temporal_transformer = TemporalTransformer3DModel( + in_channels=in_channels, + num_attention_heads=num_attention_heads, + attention_head_dim=in_channels + // num_attention_heads + // temporal_attention_dim_div, + num_layers=num_transformer_block, + attention_block_types=attention_block_types, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + + if zero_initialize: + self.temporal_transformer.proj_out = zero_module( + self.temporal_transformer.proj_out + ) + + def forward( + self, + input_tensor, + encoder_hidden_states, + attention_mask=None, + ): + """ + Forward pass of the TemporalTransformer3DModel. + + Args: + hidden_states (torch.Tensor): The hidden states of the model. + encoder_hidden_states (torch.Tensor, optional): The hidden states of the encoder. + attention_mask (torch.Tensor, optional): The attention mask. + + Returns: + torch.Tensor: The output tensor after the forward pass. + """ + hidden_states = input_tensor + hidden_states = self.temporal_transformer( + hidden_states, encoder_hidden_states + ) + + output = hidden_states + return output + + +class TemporalTransformer3DModel(nn.Module): + """ + A Temporal Transformer 3D Model class. + + Args: + - in_channels (int): The number of input channels. + - num_attention_heads (int): Number of attention heads. + - attention_head_dim (int): Dimension of attention heads. + - num_layers (int): Number of transformer layers. + - attention_block_types (tuple): Types of attention blocks. + - dropout (float): Dropout rate. + - norm_num_groups (int): Number of groups for normalization. + - cross_attention_dim (int): Dimension for cross-attention. + - activation_fn (str): Activation function. + - attention_bias (bool): Flag for attention bias. + - upcast_attention (bool): Flag for upcast attention. + - cross_frame_attention_mode: Mode for cross-frame attention. + - temporal_position_encoding (bool): Flag for temporal position encoding. + - temporal_position_encoding_max_len (int): Maximum length for temporal position encoding. + """ + def __init__( + self, + in_channels, + num_attention_heads, + attention_head_dim, + num_layers, + attention_block_types=( + "Temporal_Self", + "Temporal_Self", + ), + dropout=0.0, + norm_num_groups=32, + cross_attention_dim=768, + activation_fn="geglu", + attention_bias=False, + upcast_attention=False, + cross_frame_attention_mode=None, + temporal_position_encoding=False, + temporal_position_encoding_max_len=24, + ): + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + + self.norm = torch.nn.GroupNorm( + num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + TemporalTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + attention_block_types=attention_block_types, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + upcast_attention=upcast_attention, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + for d in range(num_layers) + ] + ) + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward(self, hidden_states, encoder_hidden_states=None): + """ + Forward pass for the TemporalTransformer3DModel. + + Args: + hidden_states (torch.Tensor): The input hidden states with shape (batch_size, sequence_length, in_channels). + encoder_hidden_states (torch.Tensor, optional): The encoder hidden states with shape (batch_size, encoder_sequence_length, in_channels). + + Returns: + torch.Tensor: The output hidden states with shape (batch_size, sequence_length, in_channels). + """ + assert ( + hidden_states.dim() == 5 + ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + + batch, _, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * weight, inner_dim + ) + hidden_states = self.proj_in(hidden_states) + + # Transformer Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + video_length=video_length, + ) + + # output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + + output = hidden_states + residual + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + + return output + + +class TemporalTransformerBlock(nn.Module): + """ + A Temporal Transformer Block class. + + Args: + - dim (int): Dimension of the block. + - num_attention_heads (int): Number of attention heads. + - attention_head_dim (int): Dimension of attention heads. + - attention_block_types (tuple): Types of attention blocks. + - dropout (float): Dropout rate. + - cross_attention_dim (int): Dimension for cross-attention. + - activation_fn (str): Activation function. + - attention_bias (bool): Flag for attention bias. + - upcast_attention (bool): Flag for upcast attention. + - cross_frame_attention_mode: Mode for cross-frame attention. + - temporal_position_encoding (bool): Flag for temporal position encoding. + - temporal_position_encoding_max_len (int): Maximum length for temporal position encoding. + """ + def __init__( + self, + dim, + num_attention_heads, + attention_head_dim, + attention_block_types=( + "Temporal_Self", + "Temporal_Self", + ), + dropout=0.0, + cross_attention_dim=768, + activation_fn="geglu", + attention_bias=False, + upcast_attention=False, + cross_frame_attention_mode=None, + temporal_position_encoding=False, + temporal_position_encoding_max_len=24, + ): + super().__init__() + + attention_blocks = [] + norms = [] + + for block_name in attention_block_types: + attention_blocks.append( + VersatileAttention( + attention_mode=block_name.split("_", maxsplit=1)[0], + cross_attention_dim=cross_attention_dim + if block_name.endswith("_Cross") + else None, + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + ) + norms.append(nn.LayerNorm(dim)) + + self.attention_blocks = nn.ModuleList(attention_blocks) + self.norms = nn.ModuleList(norms) + + self.ff = FeedForward(dim, dropout=dropout, + activation_fn=activation_fn) + self.ff_norm = nn.LayerNorm(dim) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + video_length=None, + ): + """ + Forward pass for the TemporalTransformerBlock. + + Args: + hidden_states (torch.Tensor): The input hidden states with shape + (batch_size, video_length, in_channels). + encoder_hidden_states (torch.Tensor, optional): The encoder hidden states + with shape (batch_size, encoder_length, in_channels). + video_length (int, optional): The length of the video. + + Returns: + torch.Tensor: The output hidden states with shape + (batch_size, video_length, in_channels). + """ + for attention_block, norm in zip(self.attention_blocks, self.norms): + norm_hidden_states = norm(hidden_states) + hidden_states = ( + attention_block( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states + if attention_block.is_cross_attention + else None, + video_length=video_length, + ) + + hidden_states + ) + + hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states + + output = hidden_states + return output + + +class PositionalEncoding(nn.Module): + """ + Positional Encoding module for transformers. + + Args: + - d_model (int): Model dimension. + - dropout (float): Dropout rate. + - max_len (int): Maximum length for positional encoding. + """ + def __init__(self, d_model, dropout=0.0, max_len=24): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) + ) + pe = torch.zeros(1, max_len, d_model) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x): + """ + Forward pass of the PositionalEncoding module. + + This method takes an input tensor `x` and adds the positional encoding to it. The positional encoding is + generated based on the input tensor's shape and is added to the input tensor element-wise. + + Args: + x (torch.Tensor): The input tensor to be positionally encoded. + + Returns: + torch.Tensor: The positionally encoded tensor. + """ + x = x + self.pe[:, : x.size(1)] + return self.dropout(x) + + +class VersatileAttention(Attention): + """ + Versatile Attention class. + + Args: + - attention_mode: Attention mode. + - temporal_position_encoding (bool): Flag for temporal position encoding. + - temporal_position_encoding_max_len (int): Maximum length for temporal position encoding. + """ + def __init__( + self, + *args, + attention_mode=None, + cross_frame_attention_mode=None, + temporal_position_encoding=False, + temporal_position_encoding_max_len=24, + **kwargs, + ): + super().__init__(*args, **kwargs) + assert attention_mode == "Temporal" + + self.attention_mode = attention_mode + self.is_cross_attention = kwargs.get("cross_attention_dim") is not None + + self.pos_encoder = ( + PositionalEncoding( + kwargs["query_dim"], + dropout=0.0, + max_len=temporal_position_encoding_max_len, + ) + if (temporal_position_encoding and attention_mode == "Temporal") + else None + ) + + def extra_repr(self): + """ + Returns a string representation of the module with information about the attention mode and whether it is cross-attention. + + Returns: + str: A string representation of the module. + """ + return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" + + def set_use_memory_efficient_attention_xformers( + self, + use_memory_efficient_attention_xformers: bool, + attention_op = None, + ): + """ + Sets the use of memory-efficient attention xformers for the VersatileAttention class. + + Args: + use_memory_efficient_attention_xformers (bool): A boolean flag indicating whether to use memory-efficient attention xformers or not. + + Returns: + None + + """ + if use_memory_efficient_attention_xformers: + if not is_xformers_available(): + raise ModuleNotFoundError( + ( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers" + ), + name="xformers", + ) + + if not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" + " only available for GPU " + ) + + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + processor = AttnProcessor() + else: + processor = AttnProcessor() + + self.set_processor(processor) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + video_length=None, + **cross_attention_kwargs, + ): + """ + Args: + hidden_states (`torch.Tensor`): + The hidden states to be passed through the model. + encoder_hidden_states (`torch.Tensor`, optional): + The encoder hidden states to be passed through the model. + attention_mask (`torch.Tensor`, optional): + The attention mask to be used in the model. + video_length (`int`, optional): + The length of the video. + cross_attention_kwargs (`dict`, optional): + Additional keyword arguments to be used for cross-attention. + + Returns: + `torch.Tensor`: + The output tensor after passing through the model. + + """ + if self.attention_mode == "Temporal": + d = hidden_states.shape[1] # d means HxW + hidden_states = rearrange( + hidden_states, "(b f) d c -> (b d) f c", f=video_length + ) + + if self.pos_encoder is not None: + hidden_states = self.pos_encoder(hidden_states) + + encoder_hidden_states = ( + repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) + if encoder_hidden_states is not None + else encoder_hidden_states + ) + + else: + raise NotImplementedError + + hidden_states = self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + if self.attention_mode == "Temporal": + hidden_states = rearrange( + hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states diff --git a/Hallo2/hallo2/hallo/models/mutual_self_attention.py b/Hallo2/hallo2/hallo/models/mutual_self_attention.py new file mode 100644 index 00000000..d00784e5 --- /dev/null +++ b/Hallo2/hallo2/hallo/models/mutual_self_attention.py @@ -0,0 +1,496 @@ +# pylint: disable=E1120 +""" +This module contains the implementation of mutual self-attention, +which is a type of attention mechanism used in deep learning models. +The module includes several classes and functions related to attention mechanisms, +such as BasicTransformerBlock and TemporalBasicTransformerBlock. +The main purpose of this module is to provide a comprehensive attention mechanism for various tasks in deep learning, +such as image and video processing, natural language processing, and so on. +""" + +from typing import Any, Dict, Optional + +import torch +from einops import rearrange + +from .attention import BasicTransformerBlock, TemporalBasicTransformerBlock + + +def torch_dfs(model: torch.nn.Module): + """ + Perform a depth-first search (DFS) traversal on a PyTorch model's neural network architecture. + + This function recursively traverses all the children modules of a given PyTorch model and returns a list + containing all the modules in the model's architecture. The DFS approach starts with the input model and + explores its children modules depth-wise before backtracking and exploring other branches. + + Args: + model (torch.nn.Module): The root module of the neural network to traverse. + + Returns: + list: A list of all the modules in the model's architecture. + """ + result = [model] + for child in model.children(): + result += torch_dfs(child) + return result + + +class ReferenceAttentionControl: + """ + This class is used to control the reference attention mechanism in a neural network model. + It is responsible for managing the guidance and fusion blocks, and modifying the self-attention + and group normalization mechanisms. The class also provides methods for registering reference hooks + and updating/clearing the internal state of the attention control object. + + Attributes: + unet: The UNet model associated with this attention control object. + mode: The operating mode of the attention control object, either 'write' or 'read'. + do_classifier_free_guidance: Whether to use classifier-free guidance in the attention mechanism. + attention_auto_machine_weight: The weight assigned to the attention auto-machine. + gn_auto_machine_weight: The weight assigned to the group normalization auto-machine. + style_fidelity: The style fidelity parameter for the attention mechanism. + reference_attn: Whether to use reference attention in the model. + reference_adain: Whether to use reference AdaIN in the model. + fusion_blocks: The type of fusion blocks to use in the model ('midup', 'late', or 'nofusion'). + batch_size: The batch size used for processing video frames. + + Methods: + register_reference_hooks: Registers the reference hooks for the attention control object. + hacked_basic_transformer_inner_forward: The modified inner forward method for the basic transformer block. + update: Updates the internal state of the attention control object using the provided writer and dtype. + clear: Clears the internal state of the attention control object. + """ + def __init__( + self, + unet, + mode="write", + do_classifier_free_guidance=False, + attention_auto_machine_weight=float("inf"), + gn_auto_machine_weight=1.0, + style_fidelity=1.0, + reference_attn=True, + reference_adain=False, + fusion_blocks="midup", + batch_size=1, + ) -> None: + """ + Initializes the ReferenceAttentionControl class. + + Args: + unet (torch.nn.Module): The UNet model. + mode (str, optional): The mode of operation. Defaults to "write". + do_classifier_free_guidance (bool, optional): Whether to do classifier-free guidance. Defaults to False. + attention_auto_machine_weight (float, optional): The weight for attention auto-machine. Defaults to infinity. + gn_auto_machine_weight (float, optional): The weight for group-norm auto-machine. Defaults to 1.0. + style_fidelity (float, optional): The style fidelity. Defaults to 1.0. + reference_attn (bool, optional): Whether to use reference attention. Defaults to True. + reference_adain (bool, optional): Whether to use reference AdaIN. Defaults to False. + fusion_blocks (str, optional): The fusion blocks to use. Defaults to "midup". + batch_size (int, optional): The batch size. Defaults to 1. + + Raises: + ValueError: If the mode is not recognized. + ValueError: If the fusion blocks are not recognized. + """ + # 10. Modify self attention and group norm + self.unet = unet + assert mode in ["read", "write"] + assert fusion_blocks in ["midup", "full"] + self.reference_attn = reference_attn + self.reference_adain = reference_adain + self.fusion_blocks = fusion_blocks + self.register_reference_hooks( + mode, + do_classifier_free_guidance, + attention_auto_machine_weight, + gn_auto_machine_weight, + style_fidelity, + reference_attn, + reference_adain, + fusion_blocks, + batch_size=batch_size, + ) + + def register_reference_hooks( + self, + mode, + do_classifier_free_guidance, + _attention_auto_machine_weight, + _gn_auto_machine_weight, + _style_fidelity, + _reference_attn, + _reference_adain, + _dtype=torch.float16, + batch_size=1, + num_images_per_prompt=1, + device=torch.device("cpu"), + _fusion_blocks="midup", + ): + """ + Registers reference hooks for the model. + + This function is responsible for registering reference hooks in the model, + which are used to modify the attention mechanism and group normalization layers. + It takes various parameters as input, such as mode, + do_classifier_free_guidance, _attention_auto_machine_weight, _gn_auto_machine_weight, _style_fidelity, + _reference_attn, _reference_adain, _dtype, batch_size, num_images_per_prompt, device, and _fusion_blocks. + + Args: + self: Reference to the instance of the class. + mode: The mode of operation for the reference hooks. + do_classifier_free_guidance: A boolean flag indicating whether to use classifier-free guidance. + _attention_auto_machine_weight: The weight for the attention auto-machine. + _gn_auto_machine_weight: The weight for the group normalization auto-machine. + _style_fidelity: The style fidelity for the reference hooks. + _reference_attn: A boolean flag indicating whether to use reference attention. + _reference_adain: A boolean flag indicating whether to use reference AdaIN. + _dtype: The data type for the reference hooks. + batch_size: The batch size for the reference hooks. + num_images_per_prompt: The number of images per prompt for the reference hooks. + device: The device for the reference hooks. + _fusion_blocks: The fusion blocks for the reference hooks. + + Returns: + None + """ + MODE = mode + if do_classifier_free_guidance: + uc_mask = ( + torch.Tensor( + [1] * batch_size * num_images_per_prompt * 16 + + [0] * batch_size * num_images_per_prompt * 16 + ) + .to(device) + .bool() + ) + else: + uc_mask = ( + torch.Tensor([0] * batch_size * num_images_per_prompt * 2) + .to(device) + .bool() + ) + + def hacked_basic_transformer_inner_forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + video_length=None, + ): + gate_msa = None + shift_mlp = None + scale_mlp = None + gate_mlp = None + + if self.use_ada_layer_norm: # False + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + ( + norm_hidden_states, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + ) = self.norm1( + hidden_states, + timestep, + class_labels, + hidden_dtype=hidden_states.dtype, + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + # 1. Self-Attention + # self.only_cross_attention = False + cross_attention_kwargs = ( + cross_attention_kwargs if cross_attention_kwargs is not None else {} + ) + if self.only_cross_attention: + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=( + encoder_hidden_states if self.only_cross_attention else None + ), + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + else: + if MODE == "write": + self.bank.append(norm_hidden_states.clone()) + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=( + encoder_hidden_states if self.only_cross_attention else None + ), + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if MODE == "read": + + bank_fea = [ + rearrange( + rearrange( + d, + "(b s) l c -> b s l c", + b=norm_hidden_states.shape[0] // video_length, + )[:, 0, :, :] + # .unsqueeze(1) + .repeat(1, video_length, 1, 1), + "b t l c -> (b t) l c", + ) + for d in self.bank + ] + motion_frames_fea = [rearrange( + d, + "(b s) l c -> b s l c", + b=norm_hidden_states.shape[0] // video_length, + )[:, 1:, :, :] for d in self.bank] + modify_norm_hidden_states = torch.cat( + [norm_hidden_states] + bank_fea, dim=1 + ) + hidden_states_uc = ( + self.attn1( + norm_hidden_states, + encoder_hidden_states=modify_norm_hidden_states, + attention_mask=attention_mask, + ) + + hidden_states + ) + if do_classifier_free_guidance: + hidden_states_c = hidden_states_uc.clone() + _uc_mask = uc_mask.clone() + if hidden_states.shape[0] != _uc_mask.shape[0]: + _uc_mask = ( + torch.Tensor( + [1] * (hidden_states.shape[0] // 2) + + [0] * (hidden_states.shape[0] // 2) + ) + .to(device) + .bool() + ) + hidden_states_c[_uc_mask] = ( + self.attn1( + norm_hidden_states[_uc_mask], + encoder_hidden_states=norm_hidden_states[_uc_mask], + attention_mask=attention_mask, + ) + + hidden_states[_uc_mask] + ) + hidden_states = hidden_states_c.clone() + else: + hidden_states = hidden_states_uc + + # self.bank.clear() + if self.attn2 is not None: + # Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm2(hidden_states) + ) + hidden_states = ( + self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + + hidden_states + ) + + # Feed-forward + hidden_states = self.ff(self.norm3( + hidden_states)) + hidden_states + + # Temporal-Attention + if self.unet_use_temporal_attention: + d = hidden_states.shape[1] + hidden_states = rearrange( + hidden_states, "(b f) d c -> (b d) f c", f=video_length + ) + norm_hidden_states = ( + self.norm_temp(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm_temp(hidden_states) + ) + hidden_states = ( + self.attn_temp(norm_hidden_states) + hidden_states + ) + hidden_states = rearrange( + hidden_states, "(b d) f c -> (b f) d c", d=d + ) + + return hidden_states, motion_frames_fea + + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm2(hidden_states) + ) + + # 2. Cross-Attention + tmp = norm_hidden_states.shape[0] // encoder_hidden_states.shape[0] + attn_output = self.attn2( + norm_hidden_states, + # TODO: repeat这个地方需要斟酌一下 + encoder_hidden_states=encoder_hidden_states.repeat( + tmp, 1, 1), + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = ( + norm_hidden_states * + (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) + + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states + + if self.reference_attn: + if self.fusion_blocks == "midup": + attn_modules = [ + module + for module in ( + torch_dfs(self.unet.mid_block) + + torch_dfs(self.unet.up_blocks) + ) + if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock)) + ] + elif self.fusion_blocks == "full": + attn_modules = [ + module + for module in torch_dfs(self.unet) + if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock)) + ] + attn_modules = sorted( + attn_modules, key=lambda x: -x.norm1.normalized_shape[0] + ) + + for i, module in enumerate(attn_modules): + module._original_inner_forward = module.forward + if isinstance(module, BasicTransformerBlock): + module.forward = hacked_basic_transformer_inner_forward.__get__( + module, + BasicTransformerBlock) + if isinstance(module, TemporalBasicTransformerBlock): + module.forward = hacked_basic_transformer_inner_forward.__get__( + module, + TemporalBasicTransformerBlock) + + module.bank = [] + module.attn_weight = float(i) / float(len(attn_modules)) + + def update(self, writer, dtype=torch.float16): + """ + Update the model's parameters. + + Args: + writer (torch.nn.Module): The model's writer object. + dtype (torch.dtype, optional): The data type to be used for the update. Defaults to torch.float16. + + Returns: + None. + """ + if self.reference_attn: + if self.fusion_blocks == "midup": + reader_attn_modules = [ + module + for module in ( + torch_dfs(self.unet.mid_block) + + torch_dfs(self.unet.up_blocks) + ) + if isinstance(module, TemporalBasicTransformerBlock) + ] + writer_attn_modules = [ + module + for module in ( + torch_dfs(writer.unet.mid_block) + + torch_dfs(writer.unet.up_blocks) + ) + if isinstance(module, BasicTransformerBlock) + ] + elif self.fusion_blocks == "full": + reader_attn_modules = [ + module + for module in torch_dfs(self.unet) + if isinstance(module, TemporalBasicTransformerBlock) + ] + writer_attn_modules = [ + module + for module in torch_dfs(writer.unet) + if isinstance(module, BasicTransformerBlock) + ] + + assert len(reader_attn_modules) == len(writer_attn_modules) + reader_attn_modules = sorted( + reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] + ) + writer_attn_modules = sorted( + writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] + ) + for r, w in zip(reader_attn_modules, writer_attn_modules): + r.bank = [v.clone().to(dtype) for v in w.bank] + + + def clear(self): + """ + Clears the attention bank of all reader attention modules. + + This method is used when the `reference_attn` attribute is set to `True`. + It clears the attention bank of all reader attention modules inside the UNet + model based on the selected `fusion_blocks` mode. + + If `fusion_blocks` is set to "midup", it searches for reader attention modules + in both the mid block and up blocks of the UNet model. If `fusion_blocks` is set + to "full", it searches for reader attention modules in the entire UNet model. + + It sorts the reader attention modules by the number of neurons in their + `norm1.normalized_shape[0]` attribute in descending order. This sorting ensures + that the modules with more neurons are cleared first. + + Finally, it iterates through the sorted list of reader attention modules and + calls the `clear()` method on each module's `bank` attribute to clear the + attention bank. + """ + if self.reference_attn: + if self.fusion_blocks == "midup": + reader_attn_modules = [ + module + for module in ( + torch_dfs(self.unet.mid_block) + + torch_dfs(self.unet.up_blocks) + ) + if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock)) + ] + elif self.fusion_blocks == "full": + reader_attn_modules = [ + module + for module in torch_dfs(self.unet) + if isinstance(module, (BasicTransformerBlock, TemporalBasicTransformerBlock)) + ] + reader_attn_modules = sorted( + reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] + ) + for r in reader_attn_modules: + r.bank.clear() diff --git a/Hallo2/hallo2/hallo/models/resnet.py b/Hallo2/hallo2/hallo/models/resnet.py new file mode 100644 index 00000000..5593db7c --- /dev/null +++ b/Hallo2/hallo2/hallo/models/resnet.py @@ -0,0 +1,435 @@ +# pylint: disable=E1120 +# pylint: disable=E1102 +# pylint: disable=W0237 + +# src/models/resnet.py + +""" +This module defines various components used in the ResNet model, such as InflatedConv3D, InflatedGroupNorm, +Upsample3D, Downsample3D, ResnetBlock3D, and Mish activation function. These components are used to construct +a deep neural network model for image classification or other computer vision tasks. + +Classes: +- InflatedConv3d: An inflated 3D convolutional layer, inheriting from nn.Conv2d. +- InflatedGroupNorm: An inflated group normalization layer, inheriting from nn.GroupNorm. +- Upsample3D: A 3D upsampling module, used to increase the resolution of the input tensor. +- Downsample3D: A 3D downsampling module, used to decrease the resolution of the input tensor. +- ResnetBlock3D: A 3D residual block, commonly used in ResNet architectures. +- Mish: A Mish activation function, which is a smooth, non-monotonic activation function. + +To use this module, simply import the classes and functions you need and follow the instructions provided in +the respective class and function docstrings. +""" + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + + +class InflatedConv3d(nn.Conv2d): + """ + InflatedConv3d is a class that inherits from torch.nn.Conv2d and overrides the forward method. + + This class is used to perform 3D convolution on input tensor x. It is a specialized type of convolutional layer + commonly used in deep learning models for computer vision tasks. The main difference between a regular Conv2d and + InflatedConv3d is that InflatedConv3d is designed to handle 3D input tensors, which are typically the result of + inflating 2D convolutional layers to 3D for use in 3D deep learning tasks. + + Attributes: + Same as torch.nn.Conv2d. + + Methods: + forward(self, x): + Performs 3D convolution on the input tensor x using the InflatedConv3d layer. + + Example: + conv_layer = InflatedConv3d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1) + output = conv_layer(input_tensor) + """ + def forward(self, x): + """ + Forward pass of the InflatedConv3d layer. + + Args: + x (torch.Tensor): Input tensor to the layer. + + Returns: + torch.Tensor: Output tensor after applying the InflatedConv3d layer. + """ + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class InflatedGroupNorm(nn.GroupNorm): + """ + InflatedGroupNorm is a custom class that inherits from torch.nn.GroupNorm. + It is used to apply group normalization to 3D tensors. + + Args: + num_groups (int): The number of groups to divide the channels into. + num_channels (int): The number of channels in the input tensor. + eps (float, optional): A small constant to add to the variance to avoid division by zero. Defaults to 1e-5. + affine (bool, optional): If True, the module has learnable affine parameters. Defaults to True. + + Attributes: + weight (torch.Tensor): The learnable weight tensor for scale. + bias (torch.Tensor): The learnable bias tensor for shift. + + Forward method: + x (torch.Tensor): Input tensor to be normalized. + return (torch.Tensor): Normalized tensor. + """ + def forward(self, x): + """ + Performs a forward pass through the CustomClassName. + + :param x: Input tensor of shape (batch_size, channels, video_length, height, width). + :return: Output tensor of shape (batch_size, channels, video_length, height, width). + """ + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class Upsample3D(nn.Module): + """ + Upsample3D is a PyTorch module that upsamples a 3D tensor. + + Args: + channels (int): The number of channels in the input tensor. + use_conv (bool): Whether to use a convolutional layer for upsampling. + use_conv_transpose (bool): Whether to use a transposed convolutional layer for upsampling. + out_channels (int): The number of channels in the output tensor. + name (str): The name of the convolutional layer. + """ + def __init__( + self, + channels, + use_conv=False, + use_conv_transpose=False, + out_channels=None, + name="conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + if use_conv_transpose: + raise NotImplementedError + if use_conv: + self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, hidden_states, output_size=None): + """ + Forward pass of the Upsample3D class. + + Args: + hidden_states (torch.Tensor): Input tensor to be upsampled. + output_size (tuple, optional): Desired output size of the upsampled tensor. + + Returns: + torch.Tensor: Upsampled tensor. + + Raises: + AssertionError: If the number of channels in the input tensor does not match the expected channels. + """ + assert hidden_states.shape[1] == self.channels + + if self.use_conv_transpose: + raise NotImplementedError + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate( + hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest" + ) + else: + hidden_states = F.interpolate( + hidden_states, size=output_size, mode="nearest" + ) + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + # if self.use_conv: + # if self.name == "conv": + # hidden_states = self.conv(hidden_states) + # else: + # hidden_states = self.Conv2d_0(hidden_states) + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class Downsample3D(nn.Module): + """ + The Downsample3D class is a PyTorch module for downsampling a 3D tensor, which is used to + reduce the spatial resolution of feature maps, commonly in the encoder part of a neural network. + + Attributes: + channels (int): Number of input channels. + use_conv (bool): Flag to use a convolutional layer for downsampling. + out_channels (int, optional): Number of output channels. Defaults to input channels if None. + padding (int): Padding added to the input. + name (str): Name of the convolutional layer used for downsampling. + + Methods: + forward(self, hidden_states): + Downsamples the input tensor hidden_states and returns the downsampled tensor. + """ + def __init__( + self, channels, use_conv=False, out_channels=None, padding=1, name="conv" + ): + """ + Downsamples the given input in the 3D space. + + Args: + channels: The number of input channels. + use_conv: Whether to use a convolutional layer for downsampling. + out_channels: The number of output channels. If None, the input channels are used. + padding: The amount of padding to be added to the input. + name: The name of the convolutional layer. + """ + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + self.conv = InflatedConv3d( + self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + raise NotImplementedError + + def forward(self, hidden_states): + """ + Forward pass for the Downsample3D class. + + Args: + hidden_states (torch.Tensor): Input tensor to be downsampled. + + Returns: + torch.Tensor: Downsampled tensor. + + Raises: + AssertionError: If the number of channels in the input tensor does not match the expected channels. + """ + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + raise NotImplementedError + + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + """ + The ResnetBlock3D class defines a 3D residual block, a common building block in ResNet + architectures for both image and video modeling tasks. + + Attributes: + in_channels (int): Number of input channels. + out_channels (int, optional): Number of output channels, defaults to in_channels if None. + conv_shortcut (bool): Flag to use a convolutional shortcut. + dropout (float): Dropout rate. + temb_channels (int): Number of channels in the time embedding tensor. + groups (int): Number of groups for the group normalization layers. + eps (float): Epsilon value for group normalization. + non_linearity (str): Type of nonlinearity to apply after convolutions. + time_embedding_norm (str): Type of normalization for the time embedding. + output_scale_factor (float): Scaling factor for the output tensor. + use_in_shortcut (bool): Flag to include the input tensor in the shortcut connection. + use_inflated_groupnorm (bool): Flag to use inflated group normalization layers. + + Methods: + forward(self, input_tensor, temb): + Passes the input tensor and time embedding through the residual block and + returns the output tensor. + """ + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + use_in_shortcut=None, + use_inflated_groupnorm=None, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + assert use_inflated_groupnorm is not None + if use_inflated_groupnorm: + self.norm1 = InflatedGroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + else: + self.norm1 = torch.nn.GroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + + self.conv1 = InflatedConv3d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError( + f"unknown time_embedding_norm : {self.time_embedding_norm} " + ) + + self.time_emb_proj = torch.nn.Linear( + temb_channels, time_emb_proj_out_channels + ) + else: + self.time_emb_proj = None + + if use_inflated_groupnorm: + self.norm2 = InflatedGroupNorm( + num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True + ) + else: + self.norm2 = torch.nn.GroupNorm( + num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True + ) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = InflatedConv3d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + if non_linearity == "swish": + self.nonlinearity = F.silu() + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = ( + self.in_channels != self.out_channels + if use_in_shortcut is None + else use_in_shortcut + ) + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = InflatedConv3d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, input_tensor, temb): + """ + Forward pass for the ResnetBlock3D class. + + Args: + input_tensor (torch.Tensor): Input tensor to the ResnetBlock3D layer. + temb (torch.Tensor): Token embedding tensor. + + Returns: + torch.Tensor: Output tensor after passing through the ResnetBlock3D layer. + """ + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class Mish(torch.nn.Module): + """ + The Mish class implements the Mish activation function, a smooth, non-monotonic function + that can be used in neural networks as an alternative to traditional activation functions like ReLU. + + Methods: + forward(self, hidden_states): + Applies the Mish activation function to the input tensor hidden_states and + returns the resulting tensor. + """ + def forward(self, hidden_states): + """ + Mish activation function. + + Args: + hidden_states (torch.Tensor): The input tensor to apply the Mish activation function to. + + Returns: + hidden_states (torch.Tensor): The output tensor after applying the Mish activation function. + """ + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) diff --git a/Hallo2/hallo2/hallo/models/transformer_2d.py b/Hallo2/hallo2/hallo/models/transformer_2d.py new file mode 100644 index 00000000..e9c5bbb9 --- /dev/null +++ b/Hallo2/hallo2/hallo/models/transformer_2d.py @@ -0,0 +1,431 @@ +# pylint: disable=E1101 +# src/models/transformer_2d.py + +""" +This module defines the Transformer2DModel, a PyTorch model that extends ModelMixin and ConfigMixin. It includes +methods for gradient checkpointing, forward propagation, and various utility functions. The model is designed for +2D image-related tasks and uses LoRa (Low-Rank All-Attention) compatible layers for efficient attention computation. + +The file includes the following import statements: + +- From dataclasses import dataclass +- From typing import Any, Dict, Optional +- Import torch +- From diffusers.configuration_utils import ConfigMixin, register_to_config +- From diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +- From diffusers.models.modeling_utils import ModelMixin +- From diffusers.models.normalization import AdaLayerNormSingle +- From diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, deprecate, + is_torch_version) +- From torch import nn +- From .attention import BasicTransformerBlock + +The file also includes the following classes and functions: + +- Transformer2DModel: A model class that extends ModelMixin and ConfigMixin. It includes methods for gradient + checkpointing, forward propagation, and various utility functions. +- _set_gradient_checkpointing: A utility function to set gradient checkpointing for a given module. +- forward: The forward propagation method for the Transformer2DModel. + +To use this module, you can import the Transformer2DModel class and create an instance of the model with the desired +configuration. Then, you can use the forward method to pass input tensors through the model and get the output tensors. +""" + +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +# from diffusers.models.embeddings import CaptionProjection +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormSingle +from diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, deprecate, + is_torch_version) +from torch import nn + +from .attention import BasicTransformerBlock + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` + or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + ref_feature: torch.FloatTensor + + +class Transformer2DModel(ModelMixin, ConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + + # 1. Transformer2DModel can process both standard continuous images of + # shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of + # shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate( + "norm_type!=num_embeds_ada_norm", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + + if self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + + if ( + not self.is_input_continuous + and not self.is_input_vectorized + and not self.is_input_patches + ): + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm( + num_groups=norm_num_groups, + num_channels=in_channels, + eps=1e-6, + affine=True, + ) + if use_linear_projection: + self.proj_in = linear_cls(in_channels, inner_dim) + else: + self.proj_in = conv_cls( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = linear_cls(inner_dim, in_channels) + else: + self.proj_out = conv_cls( + inner_dim, in_channels, kernel_size=1, stride=1, padding=0 + ) + + # 5. PixArt-Alpha blocks. + self.adaln_single = None + self.use_additional_conditions = False + if norm_type == "ada_norm_single": + self.use_additional_conditions = self.config.sample_size == 128 + # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use + # additional conditions until we find better name + self.adaln_single = AdaLayerNormSingle( + inner_dim, use_additional_conditions=self.use_additional_conditions + ) + + self.caption_projection = None + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + _added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, + `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor] + (https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(hidden_states.dtype) + ) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # Retrieve lora scale. + lora_scale = ( + cross_attention_kwargs.get("scale", 1.0) + if cross_attention_kwargs is not None + else 1.0 + ) + + # 1. Input + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = ( + self.proj_in(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_in(hidden_states) + ) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * width, inner_dim + ) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * width, inner_dim + ) + hidden_states = ( + self.proj_in(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_in(hidden_states) + ) + + # 2. Blocks + if self.caption_projection is not None: + batch_size = hidden_states.shape[0] + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view( + batch_size, -1, hidden_states.shape[-1] + ) + + ref_feature = hidden_states.reshape(batch, height, width, inner_dim) + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, # shape [5, 4096, 320] + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, # shape [1,4,768] + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + output = None + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, width, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + hidden_states = ( + self.proj_out(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_out(hidden_states) + ) + else: + hidden_states = ( + self.proj_out(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_out(hidden_states) + ) + hidden_states = ( + hidden_states.reshape(batch, height, width, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + + output = hidden_states + residual + if not return_dict: + return (output, ref_feature) + + return Transformer2DModelOutput(sample=output, ref_feature=ref_feature) diff --git a/Hallo2/hallo2/hallo/models/transformer_3d.py b/Hallo2/hallo2/hallo/models/transformer_3d.py new file mode 100644 index 00000000..f2899a33 --- /dev/null +++ b/Hallo2/hallo2/hallo/models/transformer_3d.py @@ -0,0 +1,257 @@ +# pylint: disable=R0801 +""" +This module implements the Transformer3DModel, a PyTorch model designed for processing +3D data such as videos. It extends ModelMixin and ConfigMixin to provide a transformer +model with support for gradient checkpointing and various types of attention mechanisms. +The model can be configured with different parameters such as the number of attention heads, +attention head dimension, and the number of layers. It also supports the use of audio modules +for enhanced feature extraction from video data. +""" + +from dataclasses import dataclass +from typing import Optional + +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models import ModelMixin +from diffusers.utils import BaseOutput +from einops import rearrange, repeat +from torch import nn + +from .attention import (AudioTemporalBasicTransformerBlock, + TemporalBasicTransformerBlock) + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + """ + The output of the [`Transformer3DModel`]. + + Attributes: + sample (`torch.FloatTensor`): + The output tensor from the transformer model, which is the result of processing the input + hidden states through the transformer blocks and any subsequent layers. + """ + sample: torch.FloatTensor + + +class Transformer3DModel(ModelMixin, ConfigMixin): + """ + Transformer3DModel is a PyTorch model that extends `ModelMixin` and `ConfigMixin` to create a 3D transformer model. + It implements the forward pass for processing input hidden states, encoder hidden states, and various types of attention masks. + The model supports gradient checkpointing, which can be enabled by calling the `enable_gradient_checkpointing()` method. + """ + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + unet_use_cross_frame_attention=None, + unet_use_temporal_attention=None, + use_audio_module=False, + depth=0, + unet_block_name=None, + stack_enable_blocks_name = None, + stack_enable_blocks_depth = None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.use_audio_module = use_audio_module + # Define input layers + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm( + num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + + if use_audio_module: + self.transformer_blocks = nn.ModuleList( + [ + AudioTemporalBasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + depth=depth, + unet_block_name=unet_block_name, + stack_enable_blocks_name=stack_enable_blocks_name, + stack_enable_blocks_depth=stack_enable_blocks_depth, + ) + for d in range(num_layers) + ] + ) + else: + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + TemporalBasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d( + inner_dim, in_channels, kernel_size=1, stride=1, padding=0 + ) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + full_mask=None, + face_mask=None, + lip_mask=None, + motion_scale=None, + timestep=None, + return_dict: bool = True, + ): + """ + Forward pass for the Transformer3DModel. + + Args: + hidden_states (torch.Tensor): The input hidden states. + encoder_hidden_states (torch.Tensor, optional): The input encoder hidden states. + attention_mask (torch.Tensor, optional): The attention mask. + full_mask (torch.Tensor, optional): The full mask. + face_mask (torch.Tensor, optional): The face mask. + lip_mask (torch.Tensor, optional): The lip mask. + timestep (int, optional): The current timestep. + return_dict (bool, optional): Whether to return a dictionary or a tuple. + + Returns: + output (Union[Tuple, BaseOutput]): The output of the Transformer3DModel. + """ + # Input + assert ( + hidden_states.dim() == 5 + ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + + # TODO + if self.use_audio_module: + encoder_hidden_states = rearrange( + encoder_hidden_states, + "bs f margin dim -> (bs f) margin dim", + ) + else: + if encoder_hidden_states.shape[0] != hidden_states.shape[0]: + encoder_hidden_states = repeat( + encoder_hidden_states, "b n c -> (b f) n c", f=video_length + ) + + batch, _, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * weight, inner_dim + ) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * weight, inner_dim + ) + hidden_states = self.proj_in(hidden_states) + + # Blocks + motion_frames = [] + for _, block in enumerate(self.transformer_blocks): + if isinstance(block, TemporalBasicTransformerBlock): + hidden_states, motion_frame_fea = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + video_length=video_length, + ) + motion_frames.append(motion_frame_fea) + else: + hidden_states = block( + hidden_states, # shape [2, 4096, 320] + encoder_hidden_states=encoder_hidden_states, # shape [2, 20, 640] + attention_mask=attention_mask, + full_mask=full_mask, + face_mask=face_mask, + lip_mask=lip_mask, + timestep=timestep, + video_length=video_length, + motion_scale=motion_scale, + ) + + # Output + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + + output = hidden_states + residual + + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + if not return_dict: + return (output, motion_frames) + + return Transformer3DModelOutput(sample=output) diff --git a/Hallo2/hallo2/hallo/models/unet_2d_blocks.py b/Hallo2/hallo2/hallo/models/unet_2d_blocks.py new file mode 100644 index 00000000..09883d5f --- /dev/null +++ b/Hallo2/hallo2/hallo/models/unet_2d_blocks.py @@ -0,0 +1,1343 @@ +# pylint: disable=R0801 +# pylint: disable=W1203 + +""" +This file defines the 2D blocks for the UNet model in a PyTorch implementation. +The UNet model is a popular architecture for image segmentation tasks, +which consists of an encoder, a decoder, and a skip connection mechanism. +The 2D blocks in this file include various types of layers, such as ResNet blocks, +Transformer blocks, and cross-attention blocks, +which are used to build the encoder and decoder parts of the UNet model. +The AutoencoderTinyBlock class is a simple autoencoder block for tiny models, +and the UNetMidBlock2D and CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, +and UpBlock2D classes are used for the middle and decoder parts of the UNet model. +The classes and functions in this file provide a flexible and modular way +to construct the UNet model for different image segmentation tasks. +""" + +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import Attention +from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D +from diffusers.models.transformers.dual_transformer_2d import \ + DualTransformer2DModel +from diffusers.utils import is_torch_version, logging +from diffusers.utils.torch_utils import apply_freeu +from torch import nn + +from .transformer_2d import Transformer2DModel + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_down_block( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, + resnet_eps: float, + resnet_act_fn: str, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + downsample_padding: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + attention_head_dim: Optional[int] = None, + dropout: float = 0.0, +): + """ This function creates and returns a UpBlock2D or CrossAttnUpBlock2D object based on the given up_block_type. + + Args: + up_block_type (str): The type of up block to create. Must be either "UpBlock2D" or "CrossAttnUpBlock2D". + num_layers (int): The number of layers in the ResNet block. + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + prev_output_channel (int): The number of channels in the previous output. + temb_channels (int): The number of channels in the token embedding. + add_upsample (bool): Whether to add an upsample layer after the ResNet block. Defaults to True. + resnet_eps (float): The epsilon value for the ResNet block. Defaults to 1e-6. + resnet_act_fn (str): The activation function to use in the ResNet block. Defaults to "swish". + resnet_groups (int): The number of groups in the ResNet block. Defaults to 32. + resnet_pre_norm (bool): Whether to use pre-normalization in the ResNet block. Defaults to True. + output_scale_factor (float): The scale factor to apply to the output. Defaults to 1.0. + + Returns: + nn.Module: The created UpBlock2D or CrossAttnUpBlock2D object. + """ + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warning("It is recommended to provide `attention_head_dim` when calling `get_down_block`.") + logger.warning(f"Defaulting `attention_head_dim` to {num_attention_heads}.") + attention_head_dim = num_attention_heads + + down_block_type = ( + down_block_type[7:] + if down_block_type.startswith("UNetRes") + else down_block_type + ) + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + + if down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnDownBlock2D" + ) + return CrossAttnDownBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + add_upsample: bool, + resnet_eps: float, + resnet_act_fn: str, + resolution_idx: Optional[int] = None, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + attention_head_dim: Optional[int] = None, + dropout: float = 0.0, +) -> nn.Module: + """ This function ... + Args: + Returns: + """ + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warning("It is recommended to provide `attention_head_dim` when calling `get_up_block`.") + logger.warning(f"Defaulting `attention_head_dim` to {num_attention_heads}.") + attention_head_dim = num_attention_heads + + up_block_type = ( + up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + ) + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + if up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnUpBlock2D" + ) + return CrossAttnUpBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + ) + + raise ValueError(f"{up_block_type} does not exist.") + + +class AutoencoderTinyBlock(nn.Module): + """ + Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU + blocks. + + Args: + in_channels (`int`): The number of input channels. + out_channels (`int`): The number of output channels. + act_fn (`str`): + ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`. + + Returns: + `torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to + `out_channels`. + """ + + def __init__(self, in_channels: int, out_channels: int, act_fn: str): + super().__init__() + act_fn = get_activation(act_fn) + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + act_fn, + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + act_fn, + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + ) + self.skip = ( + nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) + if in_channels != out_channels + else nn.Identity() + ) + self.fuse = nn.ReLU() + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + """ + Forward pass of the AutoencoderTinyBlock class. + + Parameters: + x (torch.FloatTensor): The input tensor to the AutoencoderTinyBlock. + + Returns: + torch.FloatTensor: The output tensor after passing through the AutoencoderTinyBlock. + """ + return self.fuse(self.conv(x) + self.skip(x)) + + +class UNetMidBlock2D(nn.Module): + """ + A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks. + + Args: + in_channels (`int`): The number of input channels. + temb_channels (`int`): The number of temporal embedding channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_time_scale_shift (`str`, *optional*, defaults to `default`): + The type of normalization to apply to the time embeddings. This can help to improve the performance of the + model on tasks with long-range temporal dependencies. + resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks. + resnet_pre_norm (`bool`, *optional*, defaults to `True`): + Whether to use pre-normalization for the resnet blocks. + add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks. + attention_head_dim (`int`, *optional*, defaults to 1): + Dimension of a single attention head. The number of attention heads is determined based on this value and + the number of input channels. + output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. + + Returns: + `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + + """ + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + attn_groups: Optional[int] = None, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + ): + super().__init__() + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + self.add_attention = add_attention + + if attn_groups is None: + attn_groups = ( + resnet_groups if resnet_time_scale_shift == "default" else None + ) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." + ) + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=attn_groups, + spatial_norm_dim=( + temb_channels + if resnet_time_scale_shift == "spatial" + else None + ), + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None + ) -> torch.FloatTensor: + """ + Forward pass of the UNetMidBlock2D class. + + Args: + hidden_states (torch.FloatTensor): The input tensor to the UNetMidBlock2D. + temb (Optional[torch.FloatTensor], optional): The token embedding tensor. Defaults to None. + + Returns: + torch.FloatTensor: The output tensor after passing through the UNetMidBlock2D. + """ + # Your implementation here + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class UNetMidBlock2DCrossAttn(nn.Module): + """ + UNetMidBlock2DCrossAttn is a class that represents a mid-block 2D UNet with cross-attention. + + This block is responsible for processing the input tensor with a series of residual blocks, + and applying cross-attention mechanism to attend to the global information in the encoder. + + Args: + in_channels (int): The number of input channels. + temb_channels (int): The number of channels for the token embedding. + dropout (float, optional): The dropout rate. Defaults to 0.0. + num_layers (int, optional): The number of layers in the residual blocks. Defaults to 1. + resnet_eps (float, optional): The epsilon value for the residual blocks. Defaults to 1e-6. + resnet_time_scale_shift (str, optional): The time scale shift type for the residual blocks. Defaults to "default". + resnet_act_fn (str, optional): The activation function for the residual blocks. Defaults to "swish". + resnet_groups (int, optional): The number of groups for the residual blocks. Defaults to 32. + resnet_pre_norm (bool, optional): Whether to apply pre-normalization for the residual blocks. Defaults to True. + num_attention_heads (int, optional): The number of attention heads for cross-attention. Defaults to 1. + cross_attention_dim (int, optional): The dimension of the cross-attention. Defaults to 1280. + output_scale_factor (float, optional): The scale factor for the output tensor. Defaults to 1.0. + """ + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for i in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Forward pass for the UNetMidBlock2DCrossAttn class. + + Args: + hidden_states (torch.FloatTensor): The input hidden states tensor. + temb (Optional[torch.FloatTensor], optional): The optional tensor for time embeddings. + encoder_hidden_states (Optional[torch.FloatTensor], optional): The optional encoder hidden states tensor. + attention_mask (Optional[torch.FloatTensor], optional): The optional attention mask tensor. + cross_attention_kwargs (Optional[Dict[str, Any]], optional): The optional cross-attention kwargs tensor. + encoder_attention_mask (Optional[torch.FloatTensor], optional): The optional encoder attention mask tensor. + + Returns: + torch.FloatTensor: The output tensor after passing through the UNetMidBlock2DCrossAttn layers. + """ + lora_scale = ( + cross_attention_kwargs.get("scale", 1.0) + if cross_attention_kwargs is not None + else 1.0 + ) + hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + hidden_states, _ref_feature = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states, _ref_feature = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + ) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + + return hidden_states + + +class CrossAttnDownBlock2D(nn.Module): + """ + CrossAttnDownBlock2D is a class that represents a 2D cross-attention downsampling block. + + This block is used in the UNet model and consists of a series of ResNet blocks and Transformer layers. + It takes input hidden states, a tensor embedding, and optional encoder hidden states, attention mask, + and cross-attention kwargs. The block performs a series of operations including downsampling, cross-attention, + and residual connections. + + Attributes: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + temb_channels (int): The number of tensor embedding channels. + dropout (float): The dropout rate. + num_layers (int): The number of ResNet layers. + transformer_layers_per_block (Union[int, Tuple[int]]): The number of Transformer layers per block. + resnet_eps (float): The ResNet epsilon value. + resnet_time_scale_shift (str): The ResNet time scale shift type. + resnet_act_fn (str): The ResNet activation function. + resnet_groups (int): The ResNet group size. + resnet_pre_norm (bool): Whether to use ResNet pre-normalization. + num_attention_heads (int): The number of attention heads. + cross_attention_dim (int): The cross-attention dimension. + output_scale_factor (float): The output scale factor. + downsample_padding (int): The downsampling padding. + add_downsample (bool): Whether to add downsampling. + dual_cross_attention (bool): Whether to use dual cross-attention. + use_linear_projection (bool): Whether to use linear projection. + only_cross_attention (bool): Whether to use only cross-attention. + upcast_attention (bool): Whether to upcast attention. + attention_type (str): The attention type. + """ + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + additional_residuals: Optional[torch.FloatTensor] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + """ + Forward pass for the CrossAttnDownBlock2D class. + + Args: + hidden_states (torch.FloatTensor): The input hidden states. + temb (Optional[torch.FloatTensor], optional): The token embeddings. Defaults to None. + encoder_hidden_states (Optional[torch.FloatTensor], optional): The encoder hidden states. Defaults to None. + attention_mask (Optional[torch.FloatTensor], optional): The attention mask. Defaults to None. + cross_attention_kwargs (Optional[Dict[str, Any]], optional): The cross-attention kwargs. Defaults to None. + encoder_attention_mask (Optional[torch.FloatTensor], optional): The encoder attention mask. Defaults to None. + additional_residuals (Optional[torch.FloatTensor], optional): The additional residuals. Defaults to None. + + Returns: + Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: The output hidden states and residuals. + """ + output_states = () + + lora_scale = ( + cross_attention_kwargs.get("scale", 1.0) + if cross_attention_kwargs is not None + else 1.0 + ) + + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states, _ref_feature = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + ) + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states, _ref_feature = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + ) + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale=lora_scale) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class DownBlock2D(nn.Module): + """ + DownBlock2D is a class that represents a 2D downsampling block in a neural network. + + It takes the following parameters: + - in_channels (int): The number of input channels in the block. + - out_channels (int): The number of output channels in the block. + - temb_channels (int): The number of channels in the token embedding. + - dropout (float): The dropout rate for the block. + - num_layers (int): The number of layers in the block. + - resnet_eps (float): The epsilon value for the ResNet layer. + - resnet_time_scale_shift (str): The type of activation function for the ResNet layer. + - resnet_act_fn (str): The activation function for the ResNet layer. + - resnet_groups (int): The number of groups in the ResNet layer. + - resnet_pre_norm (bool): Whether to apply layer normalization before the ResNet layer. + - output_scale_factor (float): The scale factor for the output. + - add_downsample (bool): Whether to add a downsampling layer. + - downsample_padding (int): The padding value for the downsampling layer. + + The DownBlock2D class inherits from the nn.Module class and defines the following methods: + - __init__: Initializes the DownBlock2D class with the given parameters. + - forward: Forward pass of the DownBlock2D class. + + The forward method takes the following parameters: + - hidden_states (torch.FloatTensor): The input tensor to the block. + - temb (Optional[torch.FloatTensor]): The token embedding tensor. + - scale (float): The scale factor for the input tensor. + + The forward method returns a tuple containing the output tensor and a tuple of hidden states. + """ + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + """ + Forward pass of the DownBlock2D class. + + Args: + hidden_states (torch.FloatTensor): The input tensor to the DownBlock2D layer. + temb (Optional[torch.FloatTensor], optional): The token embedding tensor. Defaults to None. + scale (float, optional): The scale factor for the input tensor. Defaults to 1.0. + + Returns: + Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: The output tensor and any additional hidden states. + """ + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb, scale=scale) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale=scale) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock2D(nn.Module): + """ + CrossAttnUpBlock2D is a class that represents a cross-attention UpBlock in a 2D UNet architecture. + + This block is responsible for upsampling the input tensor and performing cross-attention with the encoder's hidden states. + + Args: + in_channels (int): The number of input channels in the tensor. + out_channels (int): The number of output channels in the tensor. + prev_output_channel (int): The number of channels in the previous output tensor. + temb_channels (int): The number of channels in the token embedding tensor. + resolution_idx (Optional[int]): The index of the resolution in the model. + dropout (float): The dropout rate for the layer. + num_layers (int): The number of layers in the ResNet block. + transformer_layers_per_block (Union[int, Tuple[int]]): The number of transformer layers per block. + resnet_eps (float): The epsilon value for the ResNet layer. + resnet_time_scale_shift (str): The type of time scale shift to be applied in the ResNet layer. + resnet_act_fn (str): The activation function to be used in the ResNet layer. + resnet_groups (int): The number of groups in the ResNet layer. + resnet_pre_norm (bool): Whether to use pre-normalization in the ResNet layer. + num_attention_heads (int): The number of attention heads in the cross-attention layer. + cross_attention_dim (int): The dimension of the cross-attention layer. + output_scale_factor (float): The scale factor for the output tensor. + add_upsample (bool): Whether to add upsampling to the block. + dual_cross_attention (bool): Whether to use dual cross-attention. + use_linear_projection (bool): Whether to use linear projection in the cross-attention layer. + only_cross_attention (bool): Whether to only use cross-attention and no self-attention. + upcast_attention (bool): Whether to upcast the attention weights. + attention_type (str): The type of attention to be used in the cross-attention layer. + + Attributes: + up_block (nn.Module): The UpBlock module responsible for upsampling the input tensor. + cross_attn (nn.Module): The cross-attention module that performs attention between + the decoder's hidden states and the encoder's hidden states. + resnet_blocks (nn.ModuleList): A list of ResNet blocks that make up the ResNet portion of the block. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Forward pass for the CrossAttnUpBlock2D class. + + Args: + self (CrossAttnUpBlock2D): An instance of the CrossAttnUpBlock2D class. + hidden_states (torch.FloatTensor): The input hidden states tensor. + res_hidden_states_tuple (Tuple[torch.FloatTensor, ...]): A tuple of residual hidden states tensors. + temb (Optional[torch.FloatTensor], optional): The token embeddings tensor. Defaults to None. + encoder_hidden_states (Optional[torch.FloatTensor], optional): The encoder hidden states tensor. Defaults to None. + cross_attention_kwargs (Optional[Dict[str, Any]], optional): Additional keyword arguments for cross attention. Defaults to None. + upsample_size (Optional[int], optional): The upsample size. Defaults to None. + attention_mask (Optional[torch.FloatTensor], optional): The attention mask tensor. Defaults to None. + encoder_attention_mask (Optional[torch.FloatTensor], optional): The encoder attention mask tensor. Defaults to None. + + Returns: + torch.FloatTensor: The output tensor after passing through the block. + """ + lora_scale = ( + cross_attention_kwargs.get("scale", 1.0) + if cross_attention_kwargs is not None + else 1.0 + ) + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states, _ref_feature = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + ) + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states, _ref_feature = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler( + hidden_states, upsample_size, scale=lora_scale + ) + + return hidden_states + + +class UpBlock2D(nn.Module): + """ + UpBlock2D is a class that represents a 2D upsampling block in a neural network. + + This block is used for upsampling the input tensor by a factor of 2 in both dimensions. + It takes the previous output channel, input channels, and output channels as input + and applies a series of convolutional layers, batch normalization, and activation + functions to produce the upsampled tensor. + + Args: + in_channels (int): The number of input channels in the tensor. + prev_output_channel (int): The number of channels in the previous output tensor. + out_channels (int): The number of output channels in the tensor. + temb_channels (int): The number of channels in the time embedding tensor. + resolution_idx (Optional[int], optional): The index of the resolution in the sequence of resolutions. Defaults to None. + dropout (float, optional): The dropout rate to be applied to the convolutional layers. Defaults to 0.0. + num_layers (int, optional): The number of convolutional layers in the block. Defaults to 1. + resnet_eps (float, optional): The epsilon value used in the batch normalization layer. Defaults to 1e-6. + resnet_time_scale_shift (str, optional): The type of activation function to be applied after the convolutional layers. Defaults to "default". + resnet_act_fn (str, optional): The activation function to be applied after the batch normalization layer. Defaults to "swish". + resnet_groups (int, optional): The number of groups in the group normalization layer. Defaults to 32. + resnet_pre_norm (bool, optional): A flag indicating whether to apply layer normalization before the activation function. Defaults to True. + output_scale_factor (float, optional): The scale factor to be applied to the output tensor. Defaults to 1.0. + add_upsample (bool, optional): A flag indicating whether to add an upsampling layer to the block. Defaults to True. + + Attributes: + layers (nn.ModuleList): A list of nn.Module objects representing the convolutional layers in the block. + upsample (nn.Module): The upsampling layer in the block, if add_upsample is True. + + """ + + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: + + """ + Forward pass for the UpBlock2D class. + + Args: + self (UpBlock2D): An instance of the UpBlock2D class. + hidden_states (torch.FloatTensor): The input tensor to the block. + res_hidden_states_tuple (Tuple[torch.FloatTensor, ...]): A tuple of residual hidden states. + temb (Optional[torch.FloatTensor], optional): The token embeddings. Defaults to None. + upsample_size (Optional[int], optional): The size to upsample the input tensor to. Defaults to None. + scale (float, optional): The scale factor to apply to the input tensor. Defaults to 1.0. + + Returns: + torch.FloatTensor: The output tensor after passing through the block. + """ + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb, scale=scale) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size, scale=scale) + + return hidden_states diff --git a/Hallo2/hallo2/hallo/models/unet_2d_condition.py b/Hallo2/hallo2/hallo/models/unet_2d_condition.py new file mode 100644 index 00000000..590e8dbe --- /dev/null +++ b/Hallo2/hallo2/hallo/models/unet_2d_condition.py @@ -0,0 +1,1432 @@ +# pylint: disable=R0801 +# pylint: disable=E1101 +# pylint: disable=W1203 + +""" +This module implements the `UNet2DConditionModel`, +a variant of the 2D U-Net architecture designed for conditional image generation tasks. +The model is capable of taking a noisy input sample and conditioning it based on additional information such as class labels, +time steps, and encoder hidden states to produce a denoised output. + +The `UNet2DConditionModel` leverages various components such as time embeddings, +class embeddings, and cross-attention mechanisms to integrate the conditioning information effectively. +It is built upon several sub-blocks including down-blocks, a middle block, and up-blocks, +each responsible for different stages of the U-Net's downsampling and upsampling process. + +Key Features: +- Support for multiple types of down and up blocks, including those with cross-attention capabilities. +- Flexible configuration of the model's layers, including the number of layers per block and the output channels for each block. +- Integration of time embeddings and class embeddings to condition the model's output on additional information. +- Implementation of cross-attention to leverage encoder hidden states for conditional generation. +- The model supports gradient checkpointing to reduce memory usage during training. + +The module also includes utility functions and classes such as `UNet2DConditionOutput` for structured output +and `load_change_cross_attention_dim` for loading and modifying pre-trained models. + +Example Usage: +>>> import torch +>>> from unet_2d_condition_model import UNet2DConditionModel +>>> model = UNet2DConditionModel( +... sample_size=(64, 64), +... in_channels=3, +... out_channels=3, +... encoder_hid_dim=512, +... cross_attention_dim=1024, +... ) +>>> # Prepare input tensors +>>> sample = torch.randn(1, 3, 64, 64) +>>> timestep = 0 +>>> encoder_hidden_states = torch.randn(1, 14, 512) +>>> # Forward pass through the model +>>> output = model(sample, timestep, encoder_hidden_states) + +This module is part of a larger ecosystem of diffusion models and can be used for various conditional image generation tasks. +""" + +from dataclasses import dataclass +from os import PathLike +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import UNet2DConditionLoadersMixin +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, AttnAddedKVProcessor, AttnProcessor) +from diffusers.models.embeddings import (GaussianFourierProjection, + GLIGENTextBoundingboxProjection, + ImageHintTimeEmbedding, + ImageProjection, ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, TimestepEmbedding, + Timesteps) +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import (SAFETENSORS_WEIGHTS_NAME, USE_PEFT_BACKEND, + WEIGHTS_NAME, BaseOutput, deprecate, logging, + scale_lora_layers, unscale_lora_layers) +from safetensors.torch import load_file +from torch import nn + +from .unet_2d_blocks import (UNetMidBlock2D, UNetMidBlock2DCrossAttn, + get_down_block, get_up_block) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None + ref_features: Tuple[torch.FloatTensor] = None + + +class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to + `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling + blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, + *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, + *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + _out_channels: int = 4, + _center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + dropout: float = 0.0, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + attention_type: str = "default", + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + addition_embed_type_num_heads=64, + _landmark_net=False, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads`" + "because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131." + "Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in + # https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + "Must provide the same number of `down_block_types` as `up_block_types`." + f"`down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + "Must provide the same number of `block_out_channels` as `down_block_types`." + f"`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len( + only_cross_attention + ) != len(down_block_types): + raise ValueError( + "Must provide the same number of `only_cross_attention` as `down_block_types`." + f"`only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len( + down_block_types + ): + raise ValueError( + "Must provide the same number of `num_attention_heads` as `down_block_types`." + f"`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len( + down_block_types + ): + raise ValueError( + "Must provide the same number of `attention_head_dim` as `down_block_types`." + f"`attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len( + down_block_types + ): + raise ValueError( + "Must provide the same number of `cross_attention_dim` as `down_block_types`." + f"`cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len( + down_block_types + ): + raise ValueError( + "Must provide the same number of `layers_per_block` as `down_block_types`." + f"`layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + if ( + isinstance(transformer_layers_per_block, list) + and reverse_transformer_layers_per_block is None + ): + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError( + "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet." + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=conv_in_kernel, + padding=conv_in_padding, + ) + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError( + f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}." + ) + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, + set_W_to_weight=False, + log=False, + flip_sin_to_cos=flip_sin_to_cos, + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps( + block_out_channels[0], flip_sin_to_cos, freq_shift + ) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info( + "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined." + ) + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear( + encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding( + num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding( + timestep_input_dim, time_embed_dim, act_fn=act_fn + ) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding( + projection_class_embeddings_input_dim, time_embed_dim + ) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear( + projection_class_embeddings_input_dim, time_embed_dim + ) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, + time_embed_dim, + num_heads=addition_embed_type_num_heads, + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, + image_embed_dim=cross_attention_dim, + time_embed_dim=time_embed_dim, + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps( + addition_time_embed_dim, flip_sin_to_cos, freq_shift + ) + self.add_embedding = TimestepEmbedding( + projection_class_embeddings_input_dim, time_embed_dim + ) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding( + image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding( + image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type is not None: + raise ValueError( + f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'." + ) + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [ + only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * \ + len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * \ + len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len( + down_block_types + ) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + attention_head_dim=( + attention_head_dim[i] + if attention_head_dim[i] is not None + else output_channel + ), + dropout=dropout, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + raise NotImplementedError( + f"Unsupport mid_block_type: {mid_block_type}") + elif mid_block_type == "UNetMidBlock2D": + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + num_layers=0, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + add_attention=False, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = ( + list(reversed(transformer_layers_per_block)) + if reverse_transformer_layers_per_block is None + else reverse_transformer_layers_per_block + ) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[ + min(i + 1, len(block_out_channels) - 1) + ] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resolution_idx=i, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + attention_head_dim=( + attention_head_dim[i] + if attention_head_dim[i] is not None + else output_channel + ), + dropout=dropout, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], + num_groups=norm_num_groups, + eps=norm_eps, + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + self.conv_norm_out = None + + if attention_type in ["gated", "gated-text-image"]: + positive_len = 768 + if isinstance(cross_attention_dim, int): + positive_len = cross_attention_dim + elif isinstance(cross_attention_dim, (tuple, list)): + positive_len = cross_attention_dim[0] + + feature_type = "text-only" if attention_type == "gated" else "text-image" + self.position_net = GLIGENTextBoundingboxProjection( + positive_len=positive_len, + out_dim=cross_attention_dim, + feature_type=feature_type, + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor( + return_deprecated_lora=True + ) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors( + f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor( + self, + processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], + _remove_lora=False, + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor, _remove_lora=_remove_lora) + else: + module.set_processor( + processor.pop(f"{name}.processor"), _remove_lora=_remove_lora + ) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor( + f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all( + proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS + for proc in self.attn_processors.values() + ): + processor = AttnAddedKVProcessor() + elif all( + proc.__class__ in CROSS_ATTENTION_PROCESSORS + for proc in self.attn_processors.values() + ): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor, _remove_lora=True) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = ( + num_sliceable_layers * [slice_size] + if not isinstance(slice_size, list) + else slice_size + ) + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i, size in enumerate(slice_size): + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError( + f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice( + module: torch.nn.Module, slice_size: List[int] + ): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def enable_freeu(self, s1, s2, b1, b2): + r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for _, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + def disable_freeu(self): + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for _, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if ( + hasattr(upsample_block, k) + or getattr(upsample_block, k, None) is not None + ): + setattr(upsample_block, k, None) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + cond_tensor: torch.FloatTensor=None, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + post_process: bool = False, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor] + (https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added to UNet long skip connections from down blocks to up blocks for + example from ControlNet side model(s) + mid_block_additional_residual (`torch.Tensor`, *optional*): + additional residual to be added to UNet mid block output, for example from ControlNet side model + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(sample.dtype) + ) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor( + [timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError( + "class_labels should be provided when num_class_embeds > 0" + ) + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding( + class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image'" + "which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get( + "text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time'" + "which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time'" + "which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image'" + "which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if ( + "image_embeds" not in added_cond_kwargs + or "hint" not in added_cond_kwargs + ): + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint'" + "which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if ( + self.encoder_hid_proj is not None + and self.config.encoder_hid_dim_type == "text_proj" + ): + encoder_hidden_states = self.encoder_hid_proj( + encoder_hidden_states) + elif ( + self.encoder_hid_proj is not None + and self.config.encoder_hid_dim_type == "text_image_proj" + ): + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj'" + "which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj( + encoder_hidden_states, image_embeds + ) + elif ( + self.encoder_hid_proj is not None + and self.config.encoder_hid_dim_type == "image_proj" + ): + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj'" + "which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif ( + self.encoder_hid_proj is not None + and self.config.encoder_hid_dim_type == "ip_image_proj" + ): + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj'" + "which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds).to( + encoder_hidden_states.dtype + ) + encoder_hidden_states = torch.cat( + [encoder_hidden_states, image_embeds], dim=1 + ) + + # 2. pre-process + sample = self.conv_in(sample) + if cond_tensor is not None: + sample = sample + cond_tensor + + # 2.5 GLIGEN position net + if ( + cross_attention_kwargs is not None + and cross_attention_kwargs.get("gligen", None) is not None + ): + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = { + "objs": self.position_net(**gligen_args) + } + + # 3. down + lora_scale = ( + cross_attention_kwargs.get("scale", 1.0) + if cross_attention_kwargs is not None + else 1.0 + ) + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + + is_controlnet = ( + mid_block_additional_residual is not None + and down_block_additional_residuals is not None + ) + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + if ( + not is_adapter + and mid_block_additional_residual is None + and down_block_additional_residuals is not None + ): + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if ( + hasattr(downsample_block, "has_cross_attention") + and downsample_block.has_cross_attention + ): + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = ( + down_intrablock_additional_residuals.pop(0) + ) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block( + hidden_states=sample, temb=emb, scale=lora_scale + ) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = ( + down_block_res_sample + down_block_additional_residual + ) + new_down_block_res_samples = new_down_block_res_samples + ( + down_block_res_sample, + ) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + if ( + hasattr(self.mid_block, "has_cross_attention") + and self.mid_block.has_cross_attention + ): + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape + ): + sample += down_intrablock_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets):] + down_block_res_samples = down_block_res_samples[ + : -len(upsample_block.resnets) + ] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if ( + hasattr(upsample_block, "has_cross_attention") + and upsample_block.has_cross_attention + ): + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + scale=lora_scale, + ) + + # 6. post-process + if post_process: + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) + + @classmethod + def load_change_cross_attention_dim( + cls, + pretrained_model_path: PathLike, + subfolder=None, + # unet_additional_kwargs=None, + ): + """ + Load or change the cross-attention dimension of a pre-trained model. + + Parameters: + pretrained_model_name_or_path (:class:`~typing.Union[str, :class:`~pathlib.Path`]`): + The identifier of the pre-trained model or the path to the local folder containing the model. + force_download (:class:`~bool`): + If True, re-download the model even if it is already cached. + resume_download (:class:`~bool`): + If True, resume the download of the model if partially downloaded. + proxies (:class:`~dict`): + A dictionary of proxy servers to use for downloading the model. + cache_dir (:class:`~Optional[str]`): + The path to the cache directory for storing downloaded models. + use_auth_token (:class:`~bool`): + If True, use the authentication token for private models. + revision (:class:`~str`): + The specific model version to use. + use_safetensors (:class:`~bool`): + If True, use the SafeTensors format for loading the model weights. + **kwargs (:class:`~dict`): + Additional keyword arguments passed to the model. + + """ + pretrained_model_path = Path(pretrained_model_path) + if subfolder is not None: + pretrained_model_path = pretrained_model_path.joinpath(subfolder) + config_file = pretrained_model_path / "config.json" + if not (config_file.exists() and config_file.is_file()): + raise RuntimeError( + f"{config_file} does not exist or is not a file") + + unet_config = cls.load_config(config_file) + unet_config["cross_attention_dim"] = 1024 + + model = cls.from_config(unet_config) + # load the vanilla weights + if pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME).exists(): + logger.debug( + f"loading safeTensors weights from {pretrained_model_path} ..." + ) + state_dict = load_file( + pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME), device="cpu" + ) + + elif pretrained_model_path.joinpath(WEIGHTS_NAME).exists(): + logger.debug(f"loading weights from {pretrained_model_path} ...") + state_dict = torch.load( + pretrained_model_path.joinpath(WEIGHTS_NAME), + map_location="cpu", + weights_only=True, + ) + else: + raise FileNotFoundError( + f"no weights file found in {pretrained_model_path}") + + model_state_dict = model.state_dict() + for k in state_dict: + if k in model_state_dict: + if state_dict[k].shape != model_state_dict[k].shape: + state_dict[k] = model_state_dict[k] + # load the weights into the model + m, u = model.load_state_dict(state_dict, strict=False) + print(m, u) + + return model diff --git a/Hallo2/hallo2/hallo/models/unet_3d.py b/Hallo2/hallo2/hallo/models/unet_3d.py new file mode 100644 index 00000000..8a3fdb4c --- /dev/null +++ b/Hallo2/hallo2/hallo/models/unet_3d.py @@ -0,0 +1,839 @@ +# pylint: disable=R0801 +# pylint: disable=E1101 +# pylint: disable=R0402 +# pylint: disable=W1203 + +""" +This is the main file for the UNet3DConditionModel, which defines the UNet3D model architecture. + +The UNet3D model is a 3D convolutional neural network designed for image segmentation and +other computer vision tasks. It consists of an encoder, a decoder, and skip connections between +the corresponding layers of the encoder and decoder. The model can handle 3D data and +performs well on tasks such as image segmentation, object detection, and video analysis. + +This file contains the necessary imports, the main UNet3DConditionModel class, and its +methods for setting attention slice, setting gradient checkpointing, setting attention +processor, and the forward method for model inference. + +The module provides a comprehensive solution for 3D image segmentation tasks and can be +easily extended for other computer vision tasks as well. +""" + +from collections import OrderedDict +from dataclasses import dataclass +from os import PathLike +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.attention_processor import AttentionProcessor +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import (SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, + BaseOutput, logging) +from safetensors.torch import load_file + +from .resnet import InflatedConv3d, InflatedGroupNorm +from .unet_3d_blocks import (UNetMidBlock3DCrossAttn, get_down_block, + get_up_block) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + """ + Data class that serves as the output of the UNet3DConditionModel. + + Attributes: + sample (`torch.FloatTensor`): + A tensor representing the processed sample. The shape and nature of this tensor will depend on the + specific configuration of the model and the input data. + """ + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin): + """ + A 3D UNet model designed to handle conditional image and video generation tasks. This model is particularly + suited for tasks that require the generation of 3D data, such as volumetric medical imaging or 3D video + generation, while incorporating additional conditioning information. + + The model consists of an encoder-decoder structure with skip connections. It utilizes a series of downsampling + and upsampling blocks, with a middle block for further processing. Each block can be customized with different + types of layers and attention mechanisms. + + Parameters: + sample_size (`int`, optional): The size of the input sample. + in_channels (`int`, defaults to 8): The number of input channels. + out_channels (`int`, defaults to 8): The number of output channels. + center_input_sample (`bool`, defaults to False): Whether to center the input sample. + flip_sin_to_cos (`bool`, defaults to True): Whether to flip the sine to cosine in the time embedding. + freq_shift (`int`, defaults to 0): The frequency shift for the time embedding. + down_block_types (`Tuple[str]`): A tuple of strings specifying the types of downsampling blocks. + mid_block_type (`str`): The type of middle block. + up_block_types (`Tuple[str]`): A tuple of strings specifying the types of upsampling blocks. + only_cross_attention (`Union[bool, Tuple[bool]]`): Whether to use only cross-attention. + block_out_channels (`Tuple[int]`): A tuple of integers specifying the output channels for each block. + layers_per_block (`int`, defaults to 2): The number of layers per block. + downsample_padding (`int`, defaults to 1): The padding used in downsampling. + mid_block_scale_factor (`float`, defaults to 1): The scale factor for the middle block. + act_fn (`str`, defaults to 'silu'): The activation function to be used. + norm_num_groups (`int`, defaults to 32): The number of groups for normalization. + norm_eps (`float`, defaults to 1e-5): The epsilon for normalization. + cross_attention_dim (`int`, defaults to 1280): The dimension for cross-attention. + attention_head_dim (`Union[int, Tuple[int]]`): The dimension for attention heads. + dual_cross_attention (`bool`, defaults to False): Whether to use dual cross-attention. + use_linear_projection (`bool`, defaults to False): Whether to use linear projection. + class_embed_type (`str`, optional): The type of class embedding. + num_class_embeds (`int`, optional): The number of class embeddings. + upcast_attention (`bool`, defaults to False): Whether to upcast attention. + resnet_time_scale_shift (`str`, defaults to 'default'): The time scale shift for the ResNet. + use_inflated_groupnorm (`bool`, defaults to False): Whether to use inflated group normalization. + use_motion_module (`bool`, defaults to False): Whether to use a motion module. + motion_module_resolutions (`Tuple[int]`): A tuple of resolutions for the motion module. + motion_module_mid_block (`bool`, defaults to False): Whether to use a motion module in the middle block. + motion_module_decoder_only (`bool`, defaults to False): Whether to use the motion module only in the decoder. + motion_module_type (`str`, optional): The type of motion module. + motion_module_kwargs (`dict`): Keyword arguments for the motion module. + unet_use_cross_frame_attention (`bool`, optional): Whether to use cross-frame attention in the UNet. + unet_use_temporal_attention (`bool`, optional): Whether to use temporal attention in the UNet. + use_audio_module (`bool`, defaults to False): Whether to use an audio module. + audio_attention_dim (`int`, defaults to 768): The dimension for audio attention. + + The model supports various features such as gradient checkpointing, attention processors, and sliced attention + computation, making it flexible and efficient for different computational requirements and use cases. + + The forward method of the model accepts a sample, timestep, and encoder hidden states as input, and it returns + the processed sample as output. The method also supports additional conditioning information such as class + labels, audio embeddings, and masks for specialized tasks. + + The from_pretrained_2d class method allows loading a pre-trained 2D UNet model and adapting it for 3D tasks by + incorporating motion modules and other 3D specific features. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 8, + out_channels: int = 8, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + mid_block_type: str = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + use_inflated_groupnorm=False, + # Additional + use_motion_module=False, + motion_module_resolutions=(1, 2, 4, 8), + motion_module_mid_block=False, + motion_module_decoder_only=False, + motion_module_type=None, + motion_module_kwargs=None, + unet_use_cross_frame_attention=None, + unet_use_temporal_attention=None, + # audio + use_audio_module=False, + audio_attention_dim=768, + stack_enable_blocks_name=None, + stack_enable_blocks_depth=None, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = InflatedConv3d( + in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1) + ) + + # time + self.time_proj = Timesteps( + block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding( + num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding( + timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [ + only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + res = 2**i + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module + and (res in motion_module_resolutions) + and (not motion_module_decoder_only), + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + use_audio_module=use_audio_module, + audio_attention_dim=audio_attention_dim, + depth=i, + stack_enable_blocks_name=stack_enable_blocks_name, + stack_enable_blocks_depth=stack_enable_blocks_depth, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock3DCrossAttn": + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module and motion_module_mid_block, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + use_audio_module=use_audio_module, + audio_attention_dim=audio_attention_dim, + depth=3, + stack_enable_blocks_name=stack_enable_blocks_name, + stack_enable_blocks_depth=stack_enable_blocks_depth, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the videos + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + res = 2 ** (3 - i) + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[ + min(i + 1, len(block_out_channels) - 1) + ] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module + and (res in motion_module_resolutions), + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + use_audio_module=use_audio_module, + audio_attention_dim=audio_attention_dim, + depth=3-i, + stack_enable_blocks_name=stack_enable_blocks_name, + stack_enable_blocks_depth=stack_enable_blocks_depth, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if use_inflated_groupnorm: + self.conv_norm_out = InflatedGroupNorm( + num_channels=block_out_channels[0], + num_groups=norm_num_groups, + eps=norm_eps, + ) + else: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], + num_groups=norm_num_groups, + eps=norm_eps, + ) + self.conv_act = nn.SiLU() + self.conv_out = InflatedConv3d( + block_out_channels[0], out_channels, kernel_size=3, padding=1 + ) + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + if "temporal_transformer" not in sub_name: + fn_recursive_add_processors( + f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + if "temporal_transformer" not in name: + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = ( + num_slicable_layers * [slice_size] + if not isinstance(slice_size, list) + else slice_size + ) + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i, size in enumerate(slice_size): + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError( + f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice( + module: torch.nn.Module, slice_size: List[int] + ): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + if "temporal_transformer" not in sub_name: + fn_recursive_attn_processor( + f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + if "temporal_transformer" not in name: + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + audio_embedding: Optional[torch.Tensor] = None, + class_labels: Optional[torch.Tensor] = None, + mask_cond_fea: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + full_mask: Optional[torch.Tensor] = None, + face_mask: Optional[torch.Tensor] = None, + lip_mask: Optional[torch.Tensor] = None, + motion_scale: Optional[torch.Tensor] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + return_dict: bool = True, + # start: bool = False, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info( + "Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # time + timesteps = timestep + if not torch.is_tensor(timesteps): + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor( + [timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError( + "class_labels should be provided when num_class_embeds > 0" + ) + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # pre-process + sample = self.conv_in(sample) + if mask_cond_fea is not None: + sample = sample + mask_cond_fea + + # down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if ( + hasattr(downsample_block, "has_cross_attention") + and downsample_block.has_cross_attention + ): + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + full_mask=full_mask, + face_mask=face_mask, + lip_mask=lip_mask, + audio_embedding=audio_embedding, + motion_scale=motion_scale, + ) + # print("") + else: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + # audio_embedding=audio_embedding, + ) + # print("") + + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = ( + down_block_res_sample + down_block_additional_residual + ) + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # mid + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + full_mask=full_mask, + face_mask=face_mask, + lip_mask=lip_mask, + audio_embedding=audio_embedding, + motion_scale=motion_scale, + ) + + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + + # up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets):] + down_block_res_samples = down_block_res_samples[ + : -len(upsample_block.resnets) + ] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if ( + hasattr(upsample_block, "has_cross_attention") + and upsample_block.has_cross_attention + ): + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + full_mask=full_mask, + face_mask=face_mask, + lip_mask=lip_mask, + audio_embedding=audio_embedding, + motion_scale=motion_scale, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + encoder_hidden_states=encoder_hidden_states, + # audio_embedding=audio_embedding, + ) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) + + @classmethod + def from_pretrained_2d( + cls, + pretrained_model_path: PathLike, + motion_module_path: PathLike, + subfolder=None, + unet_additional_kwargs=None, + mm_zero_proj_out=False, + use_landmark=True, + ): + """ + Load a pre-trained 2D UNet model from a given directory. + + Parameters: + pretrained_model_path (`str` or `PathLike`): + Path to the directory containing a pre-trained 2D UNet model. + dtype (`torch.dtype`, *optional*): + The data type of the loaded model. If not provided, the default data type is used. + device (`torch.device`, *optional*): + The device on which the loaded model will be placed. If not provided, the default device is used. + **kwargs (`Any`): + Additional keyword arguments passed to the model. + + Returns: + `UNet3DConditionModel`: + The loaded 2D UNet model. + """ + pretrained_model_path = Path(pretrained_model_path) + motion_module_path = Path(motion_module_path) + if subfolder is not None: + pretrained_model_path = pretrained_model_path.joinpath(subfolder) + logger.info( + f"loaded temporal unet's pretrained weights from {pretrained_model_path} ..." + ) + + config_file = pretrained_model_path / "config.json" + if not (config_file.exists() and config_file.is_file()): + raise RuntimeError( + f"{config_file} does not exist or is not a file") + + unet_config = cls.load_config(config_file) + unet_config["_class_name"] = cls.__name__ + unet_config["down_block_types"] = [ + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ] + unet_config["up_block_types"] = [ + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + ] + unet_config["mid_block_type"] = "UNetMidBlock3DCrossAttn" + if use_landmark: + unet_config["in_channels"] = 8 + unet_config["out_channels"] = 8 + + model = cls.from_config(unet_config, **unet_additional_kwargs) + # load the vanilla weights + if pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME).exists(): + logger.debug( + f"loading safeTensors weights from {pretrained_model_path} ..." + ) + state_dict = load_file( + pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME), device="cpu" + ) + + elif pretrained_model_path.joinpath(WEIGHTS_NAME).exists(): + logger.debug(f"loading weights from {pretrained_model_path} ...") + state_dict = torch.load( + pretrained_model_path.joinpath(WEIGHTS_NAME), + map_location="cpu", + weights_only=True, + ) + else: + raise FileNotFoundError( + f"no weights file found in {pretrained_model_path}") + + # load the motion module weights + if motion_module_path.exists() and motion_module_path.is_file(): + if motion_module_path.suffix.lower() in [".pth", ".pt", ".ckpt"]: + print( + f"Load motion module params from {motion_module_path}") + motion_state_dict = torch.load( + motion_module_path, map_location="cpu", weights_only=True + ) + elif motion_module_path.suffix.lower() == ".safetensors": + motion_state_dict = load_file(motion_module_path, device="cpu") + else: + raise RuntimeError( + f"unknown file format for motion module weights: {motion_module_path.suffix}" + ) + if mm_zero_proj_out: + logger.info( + "Zero initialize proj_out layers in motion module...") + new_motion_state_dict = OrderedDict() + for k in motion_state_dict: + if "proj_out" in k: + continue + new_motion_state_dict[k] = motion_state_dict[k] + motion_state_dict = new_motion_state_dict + + # merge the state dicts + state_dict.update(motion_state_dict) + + model_state_dict = model.state_dict() + for k in state_dict: + if k in model_state_dict: + if state_dict[k].shape != model_state_dict[k].shape: + state_dict[k] = model_state_dict[k] + # load the weights into the model + m, u = model.load_state_dict(state_dict, strict=False) + logger.debug( + f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + + params = [ + p.numel() if "temporal" in n else 0 for n, p in model.named_parameters() + ] + logger.info(f"Loaded {sum(params) / 1e6}M-parameter motion module") + + return model diff --git a/Hallo2/hallo2/hallo/models/unet_3d_blocks.py b/Hallo2/hallo2/hallo/models/unet_3d_blocks.py new file mode 100644 index 00000000..d9402949 --- /dev/null +++ b/Hallo2/hallo2/hallo/models/unet_3d_blocks.py @@ -0,0 +1,1401 @@ +# pylint: disable=R0801 +# src/models/unet_3d_blocks.py + +""" +This module defines various 3D UNet blocks used in the video model. + +The blocks include: +- UNetMidBlock3DCrossAttn: The middle block of the UNet with cross attention. +- CrossAttnDownBlock3D: The downsampling block with cross attention. +- DownBlock3D: The standard downsampling block without cross attention. +- CrossAttnUpBlock3D: The upsampling block with cross attention. +- UpBlock3D: The standard upsampling block without cross attention. + +These blocks are used to construct the 3D UNet architecture for video-related tasks. +""" + +import torch +from einops import rearrange +from torch import nn + +from .motion_module import get_motion_module +from .resnet import Downsample3D, ResnetBlock3D, Upsample3D +from .transformer_3d import Transformer3DModel + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + audio_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + unet_use_cross_frame_attention=None, + unet_use_temporal_attention=None, + use_inflated_groupnorm=None, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + use_audio_module=None, + depth=0, + stack_enable_blocks_name=None, + stack_enable_blocks_depth=None, +): + """ + Factory function to instantiate a down-block module for the 3D UNet architecture. + + Down blocks are used in the downsampling part of the U-Net to reduce the spatial dimensions + of the feature maps while increasing the depth. This function can create blocks with or without + cross attention based on the specified parameters. + + Parameters: + - down_block_type (str): The type of down block to instantiate. + - num_layers (int): The number of layers in the block. + - in_channels (int): The number of input channels. + - out_channels (int): The number of output channels. + - temb_channels (int): The number of token embedding channels. + - add_downsample (bool): Flag to add a downsampling layer. + - resnet_eps (float): Epsilon for residual block stability. + - resnet_act_fn (callable): Activation function for the residual block. + - ... (remaining parameters): Additional parameters for configuring the block. + + Returns: + - nn.Module: An instance of a down-sampling block module. + """ + down_block_type = ( + down_block_type[7:] + if down_block_type.startswith("UNetRes") + else down_block_type + ) + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + + if down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnDownBlock3D" + ) + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + audio_attention_dim=audio_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + use_audio_module=use_audio_module, + depth=depth, + stack_enable_blocks_name=stack_enable_blocks_name, + stack_enable_blocks_depth=stack_enable_blocks_depth, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + audio_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + unet_use_cross_frame_attention=None, + unet_use_temporal_attention=None, + use_inflated_groupnorm=None, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + use_audio_module=None, + depth=0, + stack_enable_blocks_name=None, + stack_enable_blocks_depth=None, +): + """ + Factory function to instantiate an up-block module for the 3D UNet architecture. + + Up blocks are used in the upsampling part of the U-Net to increase the spatial dimensions + of the feature maps while decreasing the depth. This function can create blocks with or without + cross attention based on the specified parameters. + + Parameters: + - up_block_type (str): The type of up block to instantiate. + - num_layers (int): The number of layers in the block. + - in_channels (int): The number of input channels. + - out_channels (int): The number of output channels. + - prev_output_channel (int): The number of channels from the previous layer's output. + - temb_channels (int): The number of token embedding channels. + - add_upsample (bool): Flag to add an upsampling layer. + - resnet_eps (float): Epsilon for residual block stability. + - resnet_act_fn (callable): Activation function for the residual block. + - ... (remaining parameters): Additional parameters for configuring the block. + + Returns: + - nn.Module: An instance of an up-sampling block module. + """ + up_block_type = ( + up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + ) + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + + if up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnUpBlock3D" + ) + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + audio_attention_dim=audio_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + use_inflated_groupnorm=use_inflated_groupnorm, + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + use_audio_module=use_audio_module, + depth=depth, + stack_enable_blocks_name=stack_enable_blocks_name, + stack_enable_blocks_depth=stack_enable_blocks_depth, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + """ + A 3D UNet middle block with cross attention mechanism. This block is part of the U-Net architecture + and is used for feature extraction in the middle of the downsampling path. + + Parameters: + - in_channels (int): Number of input channels. + - temb_channels (int): Number of token embedding channels. + - dropout (float): Dropout rate. + - num_layers (int): Number of layers in the block. + - resnet_eps (float): Epsilon for residual block. + - resnet_time_scale_shift (str): Time scale shift for time embedding normalization. + - resnet_act_fn (str): Activation function for the residual block. + - resnet_groups (int): Number of groups for the convolutions in the residual block. + - resnet_pre_norm (bool): Whether to use pre-normalization in the residual block. + - attn_num_head_channels (int): Number of attention heads. + - cross_attention_dim (int): Dimensionality of the cross attention layers. + - audio_attention_dim (int): Dimensionality of the audio attention layers. + - dual_cross_attention (bool): Whether to use dual cross attention. + - use_linear_projection (bool): Whether to use linear projection in attention. + - upcast_attention (bool): Whether to upcast attention to the original input dimension. + - unet_use_cross_frame_attention (bool): Whether to use cross frame attention in U-Net. + - unet_use_temporal_attention (bool): Whether to use temporal attention in U-Net. + - use_inflated_groupnorm (bool): Whether to use inflated group normalization. + - use_motion_module (bool): Whether to use motion module. + - motion_module_type (str): Type of motion module. + - motion_module_kwargs (dict): Keyword arguments for the motion module. + - use_audio_module (bool): Whether to use audio module. + - depth (int): Depth of the block in the network. + - stack_enable_blocks_name (str): Name of the stack enable blocks. + - stack_enable_blocks_depth (int): Depth of the stack enable blocks. + + Forward method: + The forward method applies the residual blocks, cross attention, and optional motion and audio modules + to the input hidden states. It returns the transformed hidden states. + """ + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + audio_attention_dim=1024, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + unet_use_cross_frame_attention=None, + unet_use_temporal_attention=None, + use_inflated_groupnorm=None, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + use_audio_module=None, + depth=0, + stack_enable_blocks_name=None, + stack_enable_blocks_depth=None, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + + # there is always at least one resnet + resnets = [ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ] + attentions = [] + motion_modules = [] + audio_modules = [] + + for _ in range(num_layers): + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + ) + audio_modules.append( + Transformer3DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=audio_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_audio_module=use_audio_module, + depth=depth, + unet_block_name="mid", + stack_enable_blocks_name=stack_enable_blocks_name, + stack_enable_blocks_depth=stack_enable_blocks_depth, + ) + if use_audio_module + else None + ) + + motion_modules.append( + get_motion_module( + in_channels=in_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + if use_motion_module + else None + ) + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.audio_modules = nn.ModuleList(audio_modules) + self.motion_modules = nn.ModuleList(motion_modules) + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + full_mask=None, + face_mask=None, + lip_mask=None, + audio_embedding=None, + motion_scale=None, + ): + """ + Forward pass for the UNetMidBlock3DCrossAttn class. + + Args: + self (UNetMidBlock3DCrossAttn): An instance of the UNetMidBlock3DCrossAttn class. + hidden_states (Tensor): The input hidden states tensor. + temb (Tensor, optional): The input temporal embedding tensor. Defaults to None. + encoder_hidden_states (Tensor, optional): The encoder hidden states tensor. Defaults to None. + attention_mask (Tensor, optional): The attention mask tensor. Defaults to None. + full_mask (Tensor, optional): The full mask tensor. Defaults to None. + face_mask (Tensor, optional): The face mask tensor. Defaults to None. + lip_mask (Tensor, optional): The lip mask tensor. Defaults to None. + audio_embedding (Tensor, optional): The audio embedding tensor. Defaults to None. + + Returns: + Tensor: The output tensor after passing through the UNetMidBlock3DCrossAttn layers. + """ + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet, audio_module, motion_module in zip( + self.attentions, self.resnets[1:], self.audio_modules, self.motion_modules + ): + hidden_states, motion_frame = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + return_dict=False, + ) # .sample + if len(motion_frame[0]) > 0: + # if motion_frame[0][0].numel() > 0: + motion_frames = motion_frame[0][0] + motion_frames = rearrange( + motion_frames, + "b f (d1 d2) c -> b c f d1 d2", + d1=hidden_states.size(-1), + ) + + else: + motion_frames = torch.zeros( + hidden_states.shape[0], + hidden_states.shape[1], + 4, + hidden_states.shape[3], + hidden_states.shape[4], + ) + + n_motion_frames = motion_frames.size(2) + if audio_module is not None: + hidden_states = ( + audio_module( + hidden_states, + encoder_hidden_states=audio_embedding, + attention_mask=attention_mask, + full_mask=full_mask, + face_mask=face_mask, + lip_mask=lip_mask, + motion_scale=motion_scale, + return_dict=False, + ) + )[0] # .sample + if motion_module is not None: + motion_frames = motion_frames.to( + device=hidden_states.device, dtype=hidden_states.dtype + ) + + _hidden_states = ( + torch.cat([motion_frames, hidden_states], dim=2) + if n_motion_frames > 0 + else hidden_states + ) + hidden_states = motion_module( + _hidden_states, encoder_hidden_states=encoder_hidden_states + ) + hidden_states = hidden_states[:, :, n_motion_frames:] + + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + """ + A 3D downsampling block with cross attention for the U-Net architecture. + + Parameters: + - (same as above, refer to the constructor for details) + + Forward method: + The forward method downsamples the input hidden states using residual blocks and cross attention. + It also applies optional motion and audio modules. The method supports gradient checkpointing + to save memory during training. + """ + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + audio_attention_dim=1024, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + unet_use_cross_frame_attention=None, + unet_use_temporal_attention=None, + use_inflated_groupnorm=None, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + use_audio_module=None, + depth=0, + stack_enable_blocks_name=None, + stack_enable_blocks_depth=None, + ): + super().__init__() + resnets = [] + attentions = [] + audio_modules = [] + motion_modules = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + ) + # TODO:检查维度 + audio_modules.append( + Transformer3DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=audio_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_audio_module=use_audio_module, + depth=depth, + unet_block_name="down", + stack_enable_blocks_name=stack_enable_blocks_name, + stack_enable_blocks_depth=stack_enable_blocks_depth, + ) + if use_audio_module + else None + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + if use_motion_module + else None + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.audio_modules = nn.ModuleList(audio_modules) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + full_mask=None, + face_mask=None, + lip_mask=None, + audio_embedding=None, + motion_scale=None, + ): + """ + Defines the forward pass for the CrossAttnDownBlock3D class. + + Parameters: + - hidden_states : torch.Tensor + The input tensor to the block. + temb : torch.Tensor, optional + The token embeddings from the previous block. + encoder_hidden_states : torch.Tensor, optional + The hidden states from the encoder. + attention_mask : torch.Tensor, optional + The attention mask for the cross-attention mechanism. + full_mask : torch.Tensor, optional + The full mask for the cross-attention mechanism. + face_mask : torch.Tensor, optional + The face mask for the cross-attention mechanism. + lip_mask : torch.Tensor, optional + The lip mask for the cross-attention mechanism. + audio_embedding : torch.Tensor, optional + The audio embedding for the cross-attention mechanism. + + Returns: + -- torch.Tensor + The output tensor from the block. + """ + output_states = () + + for _, (resnet, attn, audio_module, motion_module) in enumerate( + zip(self.resnets, self.attentions, self.audio_modules, self.motion_modules) + ): + # self.gradient_checkpointing = False + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + + motion_frames = [] + hidden_states, motion_frame = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + ) + if len(motion_frame[0]) > 0: + motion_frames = motion_frame[0][0] + # motion_frames = torch.cat(motion_frames, dim=0) + motion_frames = rearrange( + motion_frames, + "b f (d1 d2) c -> b c f d1 d2", + d1=hidden_states.size(-1), + ) + + else: + motion_frames = torch.zeros( + hidden_states.shape[0], + hidden_states.shape[1], + 4, + hidden_states.shape[3], + hidden_states.shape[4], + ) + + n_motion_frames = motion_frames.size(2) + + if audio_module is not None: + # audio_embedding = audio_embedding + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(audio_module, return_dict=False), + hidden_states, + audio_embedding, + attention_mask, + full_mask, + face_mask, + lip_mask, + motion_scale, + )[0] + + # add motion module + if motion_module is not None: + motion_frames = motion_frames.to( + device=hidden_states.device, dtype=hidden_states.dtype + ) + _hidden_states = torch.cat( + [motion_frames, hidden_states], dim=2 + ) # if n_motion_frames > 0 else hidden_states + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + _hidden_states, + encoder_hidden_states, + ) + hidden_states = hidden_states[:, :, n_motion_frames:] + + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ).sample + if audio_module is not None: + hidden_states = audio_module( + hidden_states, + audio_embedding, + attention_mask=attention_mask, + full_mask=full_mask, + face_mask=face_mask, + lip_mask=lip_mask, + return_dict=False, + )[0] + # add motion module + if motion_module is not None: + hidden_states = motion_module( + hidden_states, encoder_hidden_states=encoder_hidden_states + ) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + """ + A 3D downsampling block for the U-Net architecture. This block performs downsampling operations + using residual blocks and an optional motion module. + + Parameters: + - in_channels (int): Number of input channels. + - out_channels (int): Number of output channels. + - temb_channels (int): Number of token embedding channels. + - dropout (float): Dropout rate for the block. + - num_layers (int): Number of layers in the block. + - resnet_eps (float): Epsilon for residual block stability. + - resnet_time_scale_shift (str): Time scale shift for the residual block's time embedding. + - resnet_act_fn (str): Activation function used in the residual block. + - resnet_groups (int): Number of groups for the convolutions in the residual block. + - resnet_pre_norm (bool): Whether to use pre-normalization in the residual block. + - output_scale_factor (float): Scaling factor for the block's output. + - add_downsample (bool): Whether to add a downsampling layer. + - downsample_padding (int): Padding for the downsampling layer. + - use_inflated_groupnorm (bool): Whether to use inflated group normalization. + - use_motion_module (bool): Whether to include a motion module. + - motion_module_type (str): Type of motion module to use. + - motion_module_kwargs (dict): Keyword arguments for the motion module. + + Forward method: + The forward method processes the input hidden states through the residual blocks and optional + motion modules, followed by an optional downsampling step. It supports gradient checkpointing + during training to reduce memory usage. + """ + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + use_inflated_groupnorm=None, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + motion_modules = [] + + # use_motion_module = False + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + if use_motion_module + else None + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + ): + """ + forward method for the DownBlock3D class. + + Args: + hidden_states (Tensor): The input tensor to the DownBlock3D layer. + temb (Tensor, optional): The token embeddings, if using transformer. + encoder_hidden_states (Tensor, optional): The hidden states from the encoder. + + Returns: + Tensor: The output tensor after passing through the DownBlock3D layer. + """ + output_states = () + + for resnet, motion_module in zip(self.resnets, self.motion_modules): + # print(f"DownBlock3D {self.gradient_checkpointing = }") + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + + else: + hidden_states = resnet(hidden_states, temb) + + # add motion module + hidden_states = ( + motion_module( + hidden_states, encoder_hidden_states=encoder_hidden_states + ) + if motion_module is not None + else hidden_states + ) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + """ + Standard 3D downsampling block for the U-Net architecture. This block performs downsampling + operations in the U-Net using residual blocks and an optional motion module. + + Parameters: + - in_channels (int): Number of input channels. + - out_channels (int): Number of output channels. + - temb_channels (int): Number of channels for the temporal embedding. + - dropout (float): Dropout rate for the block. + - num_layers (int): Number of layers in the block. + - resnet_eps (float): Epsilon for residual block stability. + - resnet_time_scale_shift (str): Time scale shift for the residual block's time embedding. + - resnet_act_fn (str): Activation function used in the residual block. + - resnet_groups (int): Number of groups for the convolutions in the residual block. + - resnet_pre_norm (bool): Whether to use pre-normalization in the residual block. + - output_scale_factor (float): Scaling factor for the block's output. + - add_downsample (bool): Whether to add a downsampling layer. + - downsample_padding (int): Padding for the downsampling layer. + - use_inflated_groupnorm (bool): Whether to use inflated group normalization. + - use_motion_module (bool): Whether to include a motion module. + - motion_module_type (str): Type of motion module to use. + - motion_module_kwargs (dict): Keyword arguments for the motion module. + + Forward method: + The forward method processes the input hidden states through the residual blocks and optional + motion modules, followed by an optional downsampling step. It supports gradient checkpointing + during training to reduce memory usage. + """ + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + audio_attention_dim=1024, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + unet_use_cross_frame_attention=None, + unet_use_temporal_attention=None, + use_motion_module=None, + use_inflated_groupnorm=None, + motion_module_type=None, + motion_module_kwargs=None, + use_audio_module=None, + depth=0, + stack_enable_blocks_name=None, + stack_enable_blocks_depth=None, + ): + super().__init__() + resnets = [] + attentions = [] + audio_modules = [] + motion_modules = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + unet_use_cross_frame_attention=unet_use_cross_frame_attention, + unet_use_temporal_attention=unet_use_temporal_attention, + ) + ) + audio_modules.append( + Transformer3DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=audio_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_audio_module=use_audio_module, + depth=depth, + unet_block_name="up", + stack_enable_blocks_name=stack_enable_blocks_name, + stack_enable_blocks_depth=stack_enable_blocks_depth, + ) + if use_audio_module + else None + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + if use_motion_module + else None + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.audio_modules = nn.ModuleList(audio_modules) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + full_mask=None, + face_mask=None, + lip_mask=None, + audio_embedding=None, + motion_scale=None, + ): + """ + Forward pass for the CrossAttnUpBlock3D class. + + Args: + self (CrossAttnUpBlock3D): An instance of the CrossAttnUpBlock3D class. + hidden_states (Tensor): The input hidden states tensor. + res_hidden_states_tuple (Tuple[Tensor]): A tuple of residual hidden states tensors. + temb (Tensor, optional): The token embeddings tensor. Defaults to None. + encoder_hidden_states (Tensor, optional): The encoder hidden states tensor. Defaults to None. + upsample_size (int, optional): The upsample size. Defaults to None. + attention_mask (Tensor, optional): The attention mask tensor. Defaults to None. + full_mask (Tensor, optional): The full mask tensor. Defaults to None. + face_mask (Tensor, optional): The face mask tensor. Defaults to None. + lip_mask (Tensor, optional): The lip mask tensor. Defaults to None. + audio_embedding (Tensor, optional): The audio embedding tensor. Defaults to None. + + Returns: + Tensor: The output tensor after passing through the CrossAttnUpBlock3D. + """ + for _, (resnet, attn, audio_module, motion_module) in enumerate( + zip(self.resnets, self.attentions, self.audio_modules, self.motion_modules) + ): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + + motion_frames = [] + hidden_states, motion_frame = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + ) + if len(motion_frame[0]) > 0: + motion_frames = motion_frame[0][0] + # motion_frames = torch.cat(motion_frames, dim=0) + motion_frames = rearrange( + motion_frames, + "b f (d1 d2) c -> b c f d1 d2", + d1=hidden_states.size(-1), + ) + else: + motion_frames = torch.zeros( + hidden_states.shape[0], + hidden_states.shape[1], + 4, + hidden_states.shape[3], + hidden_states.shape[4], + ) + + n_motion_frames = motion_frames.size(2) + + if audio_module is not None: + # audio_embedding = audio_embedding + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(audio_module, return_dict=False), + hidden_states, + audio_embedding, + attention_mask, + full_mask, + face_mask, + lip_mask, + motion_scale, + )[0] + + # add motion module + if motion_module is not None: + motion_frames = motion_frames.to( + device=hidden_states.device, dtype=hidden_states.dtype + ) + + _hidden_states = ( + torch.cat([motion_frames, hidden_states], dim=2) + if n_motion_frames > 0 + else hidden_states + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + _hidden_states, + encoder_hidden_states, + ) + hidden_states = hidden_states[:, :, n_motion_frames:] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ).sample + + if audio_module is not None: + + hidden_states = ( + audio_module( + hidden_states, + encoder_hidden_states=audio_embedding, + attention_mask=attention_mask, + full_mask=full_mask, + face_mask=face_mask, + lip_mask=lip_mask, + ) + ).sample + # add motion module + hidden_states = ( + motion_module( + hidden_states, encoder_hidden_states=encoder_hidden_states + ) + if motion_module is not None + else hidden_states + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + """ + 3D upsampling block with cross attention for the U-Net architecture. This block performs + upsampling operations and incorporates cross attention mechanisms, which allow the model to + focus on different parts of the input when upscaling. + + Parameters: + - in_channels (int): Number of input channels. + - out_channels (int): Number of output channels. + - prev_output_channel (int): Number of channels from the previous layer's output. + - temb_channels (int): Number of channels for the temporal embedding. + - dropout (float): Dropout rate for the block. + - num_layers (int): Number of layers in the block. + - resnet_eps (float): Epsilon for residual block stability. + - resnet_time_scale_shift (str): Time scale shift for the residual block's time embedding. + - resnet_act_fn (str): Activation function used in the residual block. + - resnet_groups (int): Number of groups for the convolutions in the residual block. + - resnet_pre_norm (bool): Whether to use pre-normalization in the residual block. + - attn_num_head_channels (int): Number of attention heads for the cross attention mechanism. + - cross_attention_dim (int): Dimensionality of the cross attention layers. + - audio_attention_dim (int): Dimensionality of the audio attention layers. + - output_scale_factor (float): Scaling factor for the block's output. + - add_upsample (bool): Whether to add an upsampling layer. + - dual_cross_attention (bool): Whether to use dual cross attention (not implemented). + - use_linear_projection (bool): Whether to use linear projection in the cross attention. + - only_cross_attention (bool): Whether to use only cross attention (no self-attention). + - upcast_attention (bool): Whether to upcast attention to the original input dimension. + - unet_use_cross_frame_attention (bool): Whether to use cross frame attention in U-Net. + - unet_use_temporal_attention (bool): Whether to use temporal attention in U-Net. + - use_motion_module (bool): Whether to include a motion module. + - use_inflated_groupnorm (bool): Whether to use inflated group normalization. + - motion_module_type (str): Type of motion module to use. + - motion_module_kwargs (dict): Keyword arguments for the motion module. + - use_audio_module (bool): Whether to include an audio module. + - depth (int): Depth of the block in the network. + - stack_enable_blocks_name (str): Name of the stack enable blocks. + - stack_enable_blocks_depth (int): Depth of the stack enable blocks. + + Forward method: + The forward method upsamples the input hidden states and residual hidden states, processes + them through the residual and cross attention blocks, and optional motion and audio modules. + It supports gradient checkpointing during training. + """ + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + use_inflated_groupnorm=None, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + motion_modules = [] + + # use_motion_module = False + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_inflated_groupnorm=use_inflated_groupnorm, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + if use_motion_module + else None + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + upsample_size=None, + encoder_hidden_states=None, + ): + """ + Forward pass for the UpBlock3D class. + + Args: + self (UpBlock3D): An instance of the UpBlock3D class. + hidden_states (Tensor): The input hidden states tensor. + res_hidden_states_tuple (Tuple[Tensor]): A tuple of residual hidden states tensors. + temb (Tensor, optional): The token embeddings tensor. Defaults to None. + upsample_size (int, optional): The upsample size. Defaults to None. + encoder_hidden_states (Tensor, optional): The encoder hidden states tensor. Defaults to None. + + Returns: + Tensor: The output tensor after passing through the UpBlock3D layers. + """ + for resnet, motion_module in zip(self.resnets, self.motion_modules): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + # print(f"UpBlock3D {self.gradient_checkpointing = }") + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = ( + motion_module( + hidden_states, encoder_hidden_states=encoder_hidden_states + ) + if motion_module is not None + else hidden_states + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/Hallo2/hallo2/hallo/models/wav2vec.py b/Hallo2/hallo2/hallo/models/wav2vec.py new file mode 100644 index 00000000..cd0d002b --- /dev/null +++ b/Hallo2/hallo2/hallo/models/wav2vec.py @@ -0,0 +1,209 @@ +# pylint: disable=R0901 +# src/models/wav2vec.py + +""" +This module defines the Wav2Vec model, which is a pre-trained model for speech recognition and understanding. +It inherits from the Wav2Vec2Model class in the transformers library and provides additional functionalities +such as feature extraction and encoding. + +Classes: + Wav2VecModel: Inherits from Wav2Vec2Model and adds additional methods for feature extraction and encoding. + +Functions: + linear_interpolation: Interpolates the features based on the sequence length. +""" + +import torch.nn.functional as F +from transformers import Wav2Vec2Model +from transformers.modeling_outputs import BaseModelOutput + + +class Wav2VecModel(Wav2Vec2Model): + """ + Wav2VecModel is a custom model class that extends the Wav2Vec2Model class from the transformers library. + It inherits all the functionality of the Wav2Vec2Model and adds additional methods for feature extraction and encoding. + ... + + Attributes: + base_model (Wav2Vec2Model): The base Wav2Vec2Model object. + + Methods: + forward(input_values, seq_len, attention_mask=None, mask_time_indices=None + , output_attentions=None, output_hidden_states=None, return_dict=None): + Forward pass of the Wav2VecModel. + It takes input_values, seq_len, and other optional parameters as input and returns the output of the base model. + + feature_extract(input_values, seq_len): + Extracts features from the input_values using the base model. + + encode(extract_features, attention_mask=None, mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None): + Encodes the extracted features using the base model and returns the encoded features. + """ + def forward( + self, + input_values, + seq_len, + attention_mask=None, + mask_time_indices=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + """ + Forward pass of the Wav2Vec model. + + Args: + self: The instance of the model. + input_values: The input values (waveform) to the model. + seq_len: The sequence length of the input values. + attention_mask: Attention mask to be used for the model. + mask_time_indices: Mask indices to be used for the model. + output_attentions: If set to True, returns attentions. + output_hidden_states: If set to True, returns hidden states. + return_dict: If set to True, returns a BaseModelOutput instead of a tuple. + + Returns: + The output of the Wav2Vec model. + """ + self.config.output_attentions = True + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + extract_features = linear_interpolation(extract_features, seq_len=seq_len) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], attention_mask, add_adapter=False + ) + + hidden_states, extract_features = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states( + hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if self.adapter is not None: + hidden_states = self.adapter(hidden_states) + + if not return_dict: + return (hidden_states, ) + encoder_outputs[1:] + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + + def feature_extract( + self, + input_values, + seq_len, + ): + """ + Extracts features from the input values and returns the extracted features. + + Parameters: + input_values (torch.Tensor): The input values to be processed. + seq_len (torch.Tensor): The sequence lengths of the input values. + + Returns: + extracted_features (torch.Tensor): The extracted features from the input values. + """ + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + extract_features = linear_interpolation(extract_features, seq_len=seq_len) + + return extract_features + + def encode( + self, + extract_features, + attention_mask=None, + mask_time_indices=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + """ + Encodes the input features into the output space. + + Args: + extract_features (torch.Tensor): The extracted features from the audio signal. + attention_mask (torch.Tensor, optional): Attention mask to be used for padding. + mask_time_indices (torch.Tensor, optional): Masked indices for the time dimension. + output_attentions (bool, optional): If set to True, returns the attention weights. + output_hidden_states (bool, optional): If set to True, returns all hidden states. + return_dict (bool, optional): If set to True, returns a BaseModelOutput instead of the tuple. + + Returns: + The encoded output features. + """ + self.config.output_attentions = True + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], attention_mask, add_adapter=False + ) + + hidden_states, extract_features = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states( + hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if self.adapter is not None: + hidden_states = self.adapter(hidden_states) + + if not return_dict: + return (hidden_states, ) + encoder_outputs[1:] + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +def linear_interpolation(features, seq_len): + """ + Transpose the features to interpolate linearly. + + Args: + features (torch.Tensor): The extracted features to be interpolated. + seq_len (torch.Tensor): The sequence lengths of the features. + + Returns: + torch.Tensor: The interpolated features. + """ + features = features.transpose(1, 2) + output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear') + return output_features.transpose(1, 2) diff --git a/Hallo2/hallo2/hallo/utils/__init__.py b/Hallo2/hallo2/hallo/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Hallo2/hallo2/hallo/utils/config.py b/Hallo2/hallo2/hallo/utils/config.py new file mode 100644 index 00000000..69854b61 --- /dev/null +++ b/Hallo2/hallo2/hallo/utils/config.py @@ -0,0 +1,25 @@ +""" +This module provides utility functions for configuration manipulation. +""" + +from typing import Dict + + +def filter_non_none(dict_obj: Dict): + """ + Filters out key-value pairs from the given dictionary where the value is None. + + Args: + dict_obj (Dict): The dictionary to be filtered. + + Returns: + Dict: The dictionary with key-value pairs removed where the value was None. + + This function creates a new dictionary containing only the key-value pairs from + the original dictionary where the value is not None. It then clears the original + dictionary and updates it with the filtered key-value pairs. + """ + non_none_filter = { k: v for k, v in dict_obj.items() if v is not None } + dict_obj.clear() + dict_obj.update(non_none_filter) + return dict_obj diff --git a/Hallo2/hallo2/hallo/utils/util.py b/Hallo2/hallo2/hallo/utils/util.py new file mode 100644 index 00000000..2f26784e --- /dev/null +++ b/Hallo2/hallo2/hallo/utils/util.py @@ -0,0 +1,1026 @@ +# pylint: disable=C0116 +# pylint: disable=W0718 +# pylint: disable=R1732 +# pylint: disable=R0801 +""" +utils.py + +This module provides utility functions for various tasks such as setting random seeds, +importing modules from files, managing checkpoint files, and saving video files from +sequences of PIL images. + +Functions: + seed_everything(seed) + import_filename(filename) + delete_additional_ckpt(base_path, num_keep) + save_videos_from_pil(pil_images, path, fps=8) + +Dependencies: + importlib + os + os.path as osp + random + shutil + sys + pathlib.Path + av + cv2 + mediapipe as mp + numpy as np + torch + torchvision + einops.rearrange + moviepy.editor.AudioFileClip, VideoClip + PIL.Image + +Examples: + seed_everything(42) + imported_module = import_filename('path/to/your/module.py') + delete_additional_ckpt('path/to/checkpoints', 1) + save_videos_from_pil(pil_images, 'output/video.mp4', fps=12) + +The functions in this module ensure reproducibility of experiments by seeding random number +generators, allow dynamic importing of modules, manage checkpoint files by deleting extra ones, +and provide a way to save sequences of images as video files. + +Function Details: + seed_everything(seed) + Seeds all random number generators to ensure reproducibility. + + import_filename(filename) + Imports a module from a given file location. + + delete_additional_ckpt(base_path, num_keep) + Deletes additional checkpoint files in the given directory. + + save_videos_from_pil(pil_images, path, fps=8) + Saves a sequence of images as a video using the Pillow library. + +Attributes: + _ (str): Placeholder for static type checking +""" + +import importlib +import os +import os.path as osp +import random +import shutil +import subprocess +import sys +from pathlib import Path +from typing import List + +import av +import cv2 +import mediapipe as mp +import numpy as np +import torch +import torchvision +from einops import rearrange +from moviepy.editor import AudioFileClip, VideoClip +from moviepy.editor import VideoFileClip, concatenate_videoclips +from PIL import Image + + +def seed_everything(seed): + """ + Seeds all random number generators to ensure reproducibility. + + Args: + seed (int): The seed value to set for all random number generators. + """ + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed % (2**32)) + random.seed(seed) + + +def import_filename(filename): + """ + Import a module from a given file location. + + Args: + filename (str): The path to the file containing the module to be imported. + + Returns: + module: The imported module. + + Raises: + ImportError: If the module cannot be imported. + + Example: + >>> imported_module = import_filename('path/to/your/module.py') + """ + spec = importlib.util.spec_from_file_location("mymodule", filename) + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def delete_additional_ckpt(base_path, num_keep): + """ + Deletes additional checkpoint files in the given directory. + + Args: + base_path (str): The path to the directory containing the checkpoint files. + num_keep (int): The number of most recent checkpoint files to keep. + + Returns: + None + + Raises: + FileNotFoundError: If the base_path does not exist. + + Example: + >>> delete_additional_ckpt('path/to/checkpoints', 1) + # This will delete all but the most recent checkpoint file in 'path/to/checkpoints'. + """ + dirs = [] + for d in os.listdir(base_path): + if d.startswith("checkpoint-"): + dirs.append(d) + num_tot = len(dirs) + if num_tot <= num_keep: + return + # ensure ckpt is sorted and delete the ealier! + del_dirs = sorted(dirs, key=lambda x: int( + x.split("-")[-1]))[: num_tot - num_keep] + for d in del_dirs: + path_to_dir = osp.join(base_path, d) + if osp.exists(path_to_dir): + shutil.rmtree(path_to_dir) + + +def save_videos_from_pil(pil_images, path, fps=8): + """ + Save a sequence of images as a video using the Pillow library. + + Args: + pil_images (List[PIL.Image]): A list of PIL.Image objects representing the frames of the video. + path (str): The output file path for the video. + fps (int, optional): The frames per second rate of the video. Defaults to 8. + + Returns: + None + + Raises: + ValueError: If the save format is not supported. + + This function takes a list of PIL.Image objects and saves them as a video file with a specified frame rate. + The output file format is determined by the file extension of the provided path. Supported formats include + .mp4, .avi, and .mkv. The function uses the Pillow library to handle the image processing and video + creation. + """ + save_fmt = Path(path).suffix + os.makedirs(os.path.dirname(path), exist_ok=True) + width, height = pil_images[0].size + + if save_fmt == ".mp4": + codec = "libx264" + container = av.open(path, "w") + stream = container.add_stream(codec, rate=fps) + + stream.width = width + stream.height = height + + for pil_image in pil_images: + # pil_image = Image.fromarray(image_arr).convert("RGB") + av_frame = av.VideoFrame.from_image(pil_image) + container.mux(stream.encode(av_frame)) + container.mux(stream.encode()) + container.close() + + elif save_fmt == ".gif": + pil_images[0].save( + fp=path, + format="GIF", + append_images=pil_images[1:], + save_all=True, + duration=(1 / fps * 1000), + loop=0, + ) + else: + raise ValueError("Unsupported file type. Use .mp4 or .gif.") + + +def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): + """ + Save a grid of videos as an animation or video. + + Args: + videos (torch.Tensor): A tensor of shape (batch_size, channels, time, height, width) + containing the videos to save. + path (str): The path to save the video grid. Supported formats are .mp4, .avi, and .gif. + rescale (bool, optional): If True, rescale the video to the original resolution. + Defaults to False. + n_rows (int, optional): The number of rows in the video grid. Defaults to 6. + fps (int, optional): The frame rate of the saved video. Defaults to 8. + + Raises: + ValueError: If the video format is not supported. + + Returns: + None + """ + videos = rearrange(videos, "b c t h w -> t b c h w") + # height, width = videos.shape[-2:] + outputs = [] + + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).numpy().astype(np.uint8) + x = Image.fromarray(x) + + outputs.append(x) + + os.makedirs(os.path.dirname(path), exist_ok=True) + + save_videos_from_pil(outputs, path, fps) + + +def read_frames(video_path): + """ + Reads video frames from a given video file. + + Args: + video_path (str): The path to the video file. + + Returns: + container (av.container.InputContainer): The input container object + containing the video stream. + + Raises: + FileNotFoundError: If the video file is not found. + RuntimeError: If there is an error in reading the video stream. + + The function reads the video frames from the specified video file using the + Python AV library (av). It returns an input container object that contains + the video stream. If the video file is not found, it raises a FileNotFoundError, + and if there is an error in reading the video stream, it raises a RuntimeError. + """ + container = av.open(video_path) + + video_stream = next(s for s in container.streams if s.type == "video") + frames = [] + for packet in container.demux(video_stream): + for frame in packet.decode(): + image = Image.frombytes( + "RGB", + (frame.width, frame.height), + frame.to_rgb().to_ndarray(), + ) + frames.append(image) + + return frames + + +def get_fps(video_path): + """ + Get the frame rate (FPS) of a video file. + + Args: + video_path (str): The path to the video file. + + Returns: + int: The frame rate (FPS) of the video file. + """ + container = av.open(video_path) + video_stream = next(s for s in container.streams if s.type == "video") + fps = video_stream.average_rate + container.close() + return fps + + +def tensor_to_video(tensor, output_video_file, audio_source, fps=25): + """ + Converts a Tensor with shape [c, f, h, w] into a video and adds an audio track from the specified audio file. + + Args: + tensor (Tensor): The Tensor to be converted, shaped [c, f, h, w]. + output_video_file (str): The file path where the output video will be saved. + audio_source (str): The path to the audio file (WAV file) that contains the audio track to be added. + fps (int): The frame rate of the output video. Default is 25 fps. + """ + tensor = tensor.permute(1, 2, 3, 0).cpu( + ).numpy() # convert to [f, h, w, c] + tensor = np.clip(tensor * 255, 0, 255).astype( + np.uint8 + ) # to [0, 255] + + def make_frame(t): + # get index + frame_index = min(int(t * fps), tensor.shape[0] - 1) + return tensor[frame_index] + new_video_clip = VideoClip(make_frame, duration=tensor.shape[0] / fps) + audio_clip = AudioFileClip(audio_source).subclip(0, tensor.shape[0] / fps) + new_video_clip = new_video_clip.set_audio(audio_clip) + new_video_clip.write_videofile(output_video_file, fps=fps, audio_codec='aac') + + +def tensor_to_video_batch(tensor, output_video_file, start, audio_source, fps=25): + """ + Converts a Tensor with shape [c, f, h, w] into a video and adds an audio track from the specified audio file. + + Args: + tensor (Tensor): The Tensor to be converted, shaped [c, f, h, w]. + output_video_file (str): The file path where the output video will be saved. + audio_source (str): The path to the audio file (WAV file) that contains the audio track to be added. + fps (int): The frame rate of the output video. Default is 25 fps. + """ + tensor = tensor.permute(1, 2, 3, 0).cpu( + ).numpy() # convert to [f, h, w, c] + tensor = np.clip(tensor * 255, 0, 255).astype( + np.uint8 + ) # to [0, 255] + + def make_frame(t): + # get index + frame_index = min(int(t * fps), tensor.shape[0] - 1) + return tensor[frame_index] + new_video_clip = VideoClip(make_frame, duration=tensor.shape[0] / fps) + audio_clip = AudioFileClip(audio_source).subclip(start / fps, (start + tensor.shape[0]) / fps) + new_video_clip = new_video_clip.set_audio(audio_clip) + new_video_clip.write_videofile(output_video_file, fps=fps, audio_codec='aac') + +def merge_videos(input_directory, output_file): + video_files = [f for f in os.listdir(input_directory) if f.endswith('.mp4')] + + video_files.sort() + + clips = [] + + for video_file in video_files: + file_path = os.path.join(input_directory, video_file) + clip = VideoFileClip(file_path) + clips.append(clip) + + final_clip = concatenate_videoclips(clips) + + final_clip.write_videofile(output_file, codec="libx264") + + for clip in clips: + clip.close() + + +silhouette_ids = [ + 10, 338, 297, 332, 284, 251, 389, 356, 454, 323, 361, 288, + 397, 365, 379, 378, 400, 377, 152, 148, 176, 149, 150, 136, + 172, 58, 132, 93, 234, 127, 162, 21, 54, 103, 67, 109 +] +lip_ids = [61, 185, 40, 39, 37, 0, 267, 269, 270, 409, 291, + 146, 91, 181, 84, 17, 314, 405, 321, 375] + + +def compute_face_landmarks(detection_result, h, w): + """ + Compute face landmarks from a detection result. + + Args: + detection_result (mediapipe.solutions.face_mesh.FaceMesh): The detection result containing face landmarks. + h (int): The height of the video frame. + w (int): The width of the video frame. + + Returns: + face_landmarks_list (list): A list of face landmarks. + """ + face_landmarks_list = detection_result.face_landmarks + if len(face_landmarks_list) != 1: + print("#face is invalid:", len(face_landmarks_list)) + return [] + return [[p.x * w, p.y * h] for p in face_landmarks_list[0]] + + +def get_landmark(file): + """ + This function takes a file as input and returns the facial landmarks detected in the file. + + Args: + file (str): The path to the file containing the video or image to be processed. + + Returns: + Tuple[List[float], List[float]]: A tuple containing two lists of floats representing the x and y coordinates of the facial landmarks. + """ + model_path = "pretrained_models/face_analysis/models/face_landmarker_v2_with_blendshapes.task" + BaseOptions = mp.tasks.BaseOptions + FaceLandmarker = mp.tasks.vision.FaceLandmarker + FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions + VisionRunningMode = mp.tasks.vision.RunningMode + # Create a face landmarker instance with the video mode: + options = FaceLandmarkerOptions( + base_options=BaseOptions(model_asset_path=model_path), + running_mode=VisionRunningMode.IMAGE, + ) + + with FaceLandmarker.create_from_options(options) as landmarker: + image = mp.Image.create_from_file(str(file)) + height, width = image.height, image.width + face_landmarker_result = landmarker.detect(image) + face_landmark = compute_face_landmarks( + face_landmarker_result, height, width) + + return np.array(face_landmark), height, width + + +def get_landmark_overframes(landmark_model, frames_path): + """ + This function iterate frames and returns the facial landmarks detected in each frame. + + Args: + landmark_model: mediapipe landmark model instance + frames_path (str): The path to the video frames. + + Returns: + List[List[float], float, float]: A List containing two lists of floats representing the x and y coordinates of the facial landmarks. + """ + + face_landmarks = [] + + for file in sorted(os.listdir(frames_path)): + image = mp.Image.create_from_file(os.path.join(frames_path, file)) + height, width = image.height, image.width + landmarker_result = landmark_model.detect(image) + frame_landmark = compute_face_landmarks( + landmarker_result, height, width) + face_landmarks.append(frame_landmark) + + return face_landmarks, height, width + + +def get_lip_mask(landmarks, height, width, out_path=None, expand_ratio=2.0): + """ + Extracts the lip region from the given landmarks and saves it as an image. + + Parameters: + landmarks (numpy.ndarray): Array of facial landmarks. + height (int): Height of the output lip mask image. + width (int): Width of the output lip mask image. + out_path (pathlib.Path): Path to save the lip mask image. + expand_ratio (float): Expand ratio of mask. + """ + lip_landmarks = np.take(landmarks, lip_ids, 0) + min_xy_lip = np.round(np.min(lip_landmarks, 0)) + max_xy_lip = np.round(np.max(lip_landmarks, 0)) + min_xy_lip[0], max_xy_lip[0], min_xy_lip[1], max_xy_lip[1] = expand_region( + [min_xy_lip[0], max_xy_lip[0], min_xy_lip[1], max_xy_lip[1]], width, height, expand_ratio) + lip_mask = np.zeros((height, width), dtype=np.uint8) + lip_mask[round(min_xy_lip[1]):round(max_xy_lip[1]), + round(min_xy_lip[0]):round(max_xy_lip[0])] = 255 + if out_path: + cv2.imwrite(str(out_path), lip_mask) + return None + + return lip_mask + + +def get_union_lip_mask(landmarks, height, width, expand_ratio=1): + """ + Extracts the lip region from the given landmarks and saves it as an image. + + Parameters: + landmarks (numpy.ndarray): Array of facial landmarks. + height (int): Height of the output lip mask image. + width (int): Width of the output lip mask image. + expand_ratio (float): Expand ratio of mask. + """ + lip_masks = [] + for landmark in landmarks: + lip_masks.append(get_lip_mask(landmarks=landmark, height=height, + width=width, expand_ratio=expand_ratio)) + union_mask = get_union_mask(lip_masks) + return union_mask + + +def get_face_mask(landmarks, height, width, out_path=None, expand_ratio=1.2): + """ + Generate a face mask based on the given landmarks. + + Args: + landmarks (numpy.ndarray): The landmarks of the face. + height (int): The height of the output face mask image. + width (int): The width of the output face mask image. + out_path (pathlib.Path): The path to save the face mask image. + expand_ratio (float): Expand ratio of mask. + Returns: + None. The face mask image is saved at the specified path. + """ + face_landmarks = np.take(landmarks, silhouette_ids, 0) + min_xy_face = np.round(np.min(face_landmarks, 0)) + max_xy_face = np.round(np.max(face_landmarks, 0)) + min_xy_face[0], max_xy_face[0], min_xy_face[1], max_xy_face[1] = expand_region( + [min_xy_face[0], max_xy_face[0], min_xy_face[1], max_xy_face[1]], width, height, expand_ratio) + face_mask = np.zeros((height, width), dtype=np.uint8) + face_mask[round(min_xy_face[1]):round(max_xy_face[1]), + round(min_xy_face[0]):round(max_xy_face[0])] = 255 + if out_path: + cv2.imwrite(str(out_path), face_mask) + return None + + return face_mask + + +def get_union_face_mask(landmarks, height, width, expand_ratio=1): + """ + Generate a face mask based on the given landmarks. + + Args: + landmarks (numpy.ndarray): The landmarks of the face. + height (int): The height of the output face mask image. + width (int): The width of the output face mask image. + expand_ratio (float): Expand ratio of mask. + Returns: + None. The face mask image is saved at the specified path. + """ + face_masks = [] + for landmark in landmarks: + face_masks.append(get_face_mask(landmarks=landmark,height=height,width=width,expand_ratio=expand_ratio)) + union_mask = get_union_mask(face_masks) + return union_mask + +def get_mask(file, cache_dir, face_expand_raio): + """ + Generate a face mask based on the given landmarks and save it to the specified cache directory. + + Args: + file (str): The path to the file containing the landmarks. + cache_dir (str): The directory to save the generated face mask. + + Returns: + None + """ + landmarks, height, width = get_landmark(file) + file_name = os.path.basename(file).split(".")[0] + get_lip_mask(landmarks, height, width, os.path.join( + cache_dir, f"{file_name}_lip_mask.png")) + get_face_mask(landmarks, height, width, os.path.join( + cache_dir, f"{file_name}_face_mask.png"), face_expand_raio) + get_blur_mask(os.path.join( + cache_dir, f"{file_name}_face_mask.png"), os.path.join( + cache_dir, f"{file_name}_face_mask_blur.png"), kernel_size=(51, 51)) + get_blur_mask(os.path.join( + cache_dir, f"{file_name}_lip_mask.png"), os.path.join( + cache_dir, f"{file_name}_sep_lip.png"), kernel_size=(31, 31)) + get_background_mask(os.path.join( + cache_dir, f"{file_name}_face_mask_blur.png"), os.path.join( + cache_dir, f"{file_name}_sep_background.png")) + get_sep_face_mask(os.path.join( + cache_dir, f"{file_name}_face_mask_blur.png"), os.path.join( + cache_dir, f"{file_name}_sep_lip.png"), os.path.join( + cache_dir, f"{file_name}_sep_face.png")) + + +def expand_region(region, image_w, image_h, expand_ratio=1.0): + """ + Expand the given region by a specified ratio. + Args: + region (tuple): A tuple containing the coordinates (min_x, max_x, min_y, max_y) of the region. + image_w (int): The width of the image. + image_h (int): The height of the image. + expand_ratio (float, optional): The ratio by which the region should be expanded. Defaults to 1.0. + + Returns: + tuple: A tuple containing the expanded coordinates (min_x, max_x, min_y, max_y) of the region. + """ + + min_x, max_x, min_y, max_y = region + mid_x = (max_x + min_x) // 2 + side_len_x = (max_x - min_x) * expand_ratio + mid_y = (max_y + min_y) // 2 + side_len_y = (max_y - min_y) * expand_ratio + min_x = mid_x - side_len_x // 2 + max_x = mid_x + side_len_x // 2 + min_y = mid_y - side_len_y // 2 + max_y = mid_y + side_len_y // 2 + if min_x < 0: + max_x -= min_x + min_x = 0 + if max_x > image_w: + min_x -= max_x - image_w + max_x = image_w + if min_y < 0: + max_y -= min_y + min_y = 0 + if max_y > image_h: + min_y -= max_y - image_h + max_y = image_h + + return round(min_x), round(max_x), round(min_y), round(max_y) + + +def get_blur_mask(file_path, output_file_path, resize_dim=(64, 64), kernel_size=(101, 101)): + """ + Read, resize, blur, normalize, and save an image. + + Parameters: + file_path (str): Path to the input image file. + output_dir (str): Path to the output directory to save blurred images. + resize_dim (tuple): Dimensions to resize the images to. + kernel_size (tuple): Size of the kernel to use for Gaussian blur. + """ + # Read the mask image + mask = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE) + + # Check if the image is loaded successfully + if mask is not None: + normalized_mask = blur_mask(mask,resize_dim=resize_dim,kernel_size=kernel_size) + # Save the normalized mask image + cv2.imwrite(output_file_path, normalized_mask) + return f"Processed, normalized, and saved: {output_file_path}" + return f"Failed to load image: {file_path}" + + +def blur_mask(mask, resize_dim=(64, 64), kernel_size=(51, 51)): + """ + Read, resize, blur, normalize, and save an image. + + Parameters: + file_path (str): Path to the input image file. + resize_dim (tuple): Dimensions to resize the images to. + kernel_size (tuple): Size of the kernel to use for Gaussian blur. + """ + # Check if the image is loaded successfully + normalized_mask = None + if mask is not None: + # Resize the mask image + resized_mask = cv2.resize(mask, resize_dim) + # Apply Gaussian blur to the resized mask image + blurred_mask = cv2.GaussianBlur(resized_mask, kernel_size, 0) + # Normalize the blurred image + normalized_mask = cv2.normalize( + blurred_mask, None, 0, 255, cv2.NORM_MINMAX) + # Save the normalized mask image + return normalized_mask + +def get_background_mask(file_path, output_file_path): + """ + Read an image, invert its values, and save the result. + + Parameters: + file_path (str): Path to the input image file. + output_dir (str): Path to the output directory to save the inverted image. + """ + # Read the image + image = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE) + + if image is None: + print(f"Failed to load image: {file_path}") + return + + # Invert the image + inverted_image = 1.0 - ( + image / 255.0 + ) # Assuming the image values are in [0, 255] range + # Convert back to uint8 + inverted_image = (inverted_image * 255).astype(np.uint8) + + # Save the inverted image + cv2.imwrite(output_file_path, inverted_image) + print(f"Processed and saved: {output_file_path}") + + +def get_sep_face_mask(file_path1, file_path2, output_file_path): + """ + Read two images, subtract the second one from the first, and save the result. + + Parameters: + output_dir (str): Path to the output directory to save the subtracted image. + """ + + # Read the images + mask1 = cv2.imread(file_path1, cv2.IMREAD_GRAYSCALE) + mask2 = cv2.imread(file_path2, cv2.IMREAD_GRAYSCALE) + + if mask1 is None or mask2 is None: + print(f"Failed to load images: {file_path1}") + return + + # Ensure the images are the same size + if mask1.shape != mask2.shape: + print( + f"Image shapes do not match for {file_path1}: {mask1.shape} vs {mask2.shape}" + ) + return + + # Subtract the second mask from the first + result_mask = cv2.subtract(mask1, mask2) + + # Save the result mask image + cv2.imwrite(output_file_path, result_mask) + print(f"Processed and saved: {output_file_path}") + +def resample_audio(input_audio_file: str, output_audio_file: str, sample_rate: int): + p = subprocess.Popen([ + "ffmpeg", "-y", "-v", "error", "-i", input_audio_file, "-ar", str(sample_rate), output_audio_file + ]) + ret = p.wait() + assert ret == 0, "Resample audio failed!" + return output_audio_file + +def get_face_region(image_path: str, detector): + try: + image = cv2.imread(image_path) + if image is None: + print(f"Failed to open image: {image_path}. Skipping...") + return None, None + + mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image) + detection_result = detector.detect(mp_image) + + # Adjust mask creation for the three-channel image + mask = np.zeros_like(image, dtype=np.uint8) + + for detection in detection_result.detections: + bbox = detection.bounding_box + start_point = (int(bbox.origin_x), int(bbox.origin_y)) + end_point = (int(bbox.origin_x + bbox.width), + int(bbox.origin_y + bbox.height)) + cv2.rectangle(mask, start_point, end_point, + (255, 255, 255), thickness=-1) + + save_path = image_path.replace("images", "face_masks") + os.makedirs(os.path.dirname(save_path), exist_ok=True) + cv2.imwrite(save_path, mask) + # print(f"Processed and saved {save_path}") + return image_path, mask + except Exception as e: + print(f"Error processing image {image_path}: {e}") + return None, None + + +def save_checkpoint(model: torch.nn.Module, save_dir: str, prefix: str, ckpt_num: int, total_limit: int = -1) -> None: + """ + Save the model's state_dict to a checkpoint file. + + If `total_limit` is provided, this function will remove the oldest checkpoints + until the total number of checkpoints is less than the specified limit. + + Args: + model (nn.Module): The model whose state_dict is to be saved. + save_dir (str): The directory where the checkpoint will be saved. + prefix (str): The prefix for the checkpoint file name. + ckpt_num (int): The checkpoint number to be saved. + total_limit (int, optional): The maximum number of checkpoints to keep. + Defaults to None, in which case no checkpoints will be removed. + + Raises: + FileNotFoundError: If the save directory does not exist. + ValueError: If the checkpoint number is negative. + OSError: If there is an error saving the checkpoint. + """ + + if not osp.exists(save_dir): + raise FileNotFoundError( + f"The save directory {save_dir} does not exist.") + + if ckpt_num < 0: + raise ValueError(f"Checkpoint number {ckpt_num} must be non-negative.") + + save_path = osp.join(save_dir, f"{prefix}-{ckpt_num}.pth") + + if total_limit > 0: + checkpoints = os.listdir(save_dir) + checkpoints = [d for d in checkpoints if d.startswith(prefix)] + checkpoints = sorted( + checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0]) + ) + + if len(checkpoints) >= total_limit: + num_to_remove = len(checkpoints) - total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + print( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + print( + f"Removing checkpoints: {', '.join(removing_checkpoints)}" + ) + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint_path = osp.join( + save_dir, removing_checkpoint) + try: + os.remove(removing_checkpoint_path) + except OSError as e: + print( + f"Error removing checkpoint {removing_checkpoint_path}: {e}") + + state_dict = model.state_dict() + try: + torch.save(state_dict, save_path) + print(f"Checkpoint saved at {save_path}") + except OSError as e: + raise OSError(f"Error saving checkpoint at {save_path}: {e}") from e + + +def init_output_dir(dir_list: List[str]): + """ + Initialize the output directories. + + This function creates the directories specified in the `dir_list`. If a directory already exists, it does nothing. + + Args: + dir_list (List[str]): List of directory paths to create. + """ + for path in dir_list: + os.makedirs(path, exist_ok=True) + + +def load_checkpoint(cfg, save_dir, accelerator): + """ + Load the most recent checkpoint from the specified directory. + + This function loads the latest checkpoint from the `save_dir` if the `resume_from_checkpoint` parameter is set to "latest". + If a specific checkpoint is provided in `resume_from_checkpoint`, it loads that checkpoint. If no checkpoint is found, + it starts training from scratch. + + Args: + cfg: The configuration object containing training parameters. + save_dir (str): The directory where checkpoints are saved. + accelerator: The accelerator object for distributed training. + + Returns: + int: The global step at which to resume training. + """ + if cfg.resume_from_checkpoint != "latest": + resume_dir = cfg.resume_from_checkpoint + else: + resume_dir = save_dir + # Get the most recent checkpoint + dirs = os.listdir(resume_dir) + + dirs = [d for d in dirs if d.startswith("checkpoint")] + if len(dirs) > 0: + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] + accelerator.load_state(os.path.join(resume_dir, path)) + accelerator.print(f"Resuming from checkpoint {path}") + global_step = int(path.split("-")[1]) + else: + accelerator.print( + f"Could not find checkpoint under {resume_dir}, start training from scratch") + global_step = 0 + + return global_step + + +def compute_snr(noise_scheduler, timesteps): + """ + Computes SNR as per + https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/ + 521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/ + # 521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ + timesteps + ].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( + device=timesteps.device + )[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + +def extract_audio_from_videos(video_path: Path, audio_output_path: Path) -> Path: + """ + Extract audio from a video file and save it as a WAV file. + + This function uses ffmpeg to extract the audio stream from a given video file and saves it as a WAV file + in the specified output directory. + + Args: + video_path (Path): The path to the input video file. + output_dir (Path): The directory where the extracted audio file will be saved. + + Returns: + Path: The path to the extracted audio file. + + Raises: + subprocess.CalledProcessError: If the ffmpeg command fails to execute. + """ + ffmpeg_command = [ + 'ffmpeg', '-y', + '-i', str(video_path), + '-vn', '-acodec', + "pcm_s16le", '-ar', '16000', '-ac', '2', + str(audio_output_path) + ] + + try: + print(f"Running command: {' '.join(ffmpeg_command)}") + subprocess.run(ffmpeg_command, check=True) + except subprocess.CalledProcessError as e: + print(f"Error extracting audio from video: {e}") + raise + + return audio_output_path + + +def convert_video_to_images(video_path: Path, output_dir: Path) -> Path: + """ + Convert a video file into a sequence of images. + + This function uses ffmpeg to convert each frame of the given video file into an image. The images are saved + in a directory named after the video file stem under the specified output directory. + + Args: + video_path (Path): The path to the input video file. + output_dir (Path): The directory where the extracted images will be saved. + + Returns: + Path: The path to the directory containing the extracted images. + + Raises: + subprocess.CalledProcessError: If the ffmpeg command fails to execute. + """ + ffmpeg_command = [ + 'ffmpeg', + '-i', str(video_path), + '-vf', 'fps=25', + str(output_dir / '%04d.png') + ] + + try: + print(f"Running command: {' '.join(ffmpeg_command)}") + subprocess.run(ffmpeg_command, check=True) + except subprocess.CalledProcessError as e: + print(f"Error converting video to images: {e}") + raise + + return output_dir + + +def get_union_mask(masks): + """ + Compute the union of a list of masks. + + This function takes a list of masks and computes their union by taking the maximum value at each pixel location. + Additionally, it finds the bounding box of the non-zero regions in the mask and sets the bounding box area to white. + + Args: + masks (list of np.ndarray): List of masks to be combined. + + Returns: + np.ndarray: The union of the input masks. + """ + union_mask = None + for mask in masks: + if union_mask is None: + union_mask = mask + else: + union_mask = np.maximum(union_mask, mask) + + if union_mask is not None: + # Find the bounding box of the non-zero regions in the mask + rows = np.any(union_mask, axis=1) + cols = np.any(union_mask, axis=0) + try: + ymin, ymax = np.where(rows)[0][[0, -1]] + xmin, xmax = np.where(cols)[0][[0, -1]] + except Exception as e: + print(str(e)) + return 0.0 + + # Set bounding box area to white + union_mask[ymin: ymax + 1, xmin: xmax + 1] = np.max(union_mask) + + return union_mask + + +def move_final_checkpoint(save_dir, module_dir, prefix): + """ + Move the final checkpoint file to the save directory. + + This function identifies the latest checkpoint file based on the given prefix and moves it to the specified save directory. + + Args: + save_dir (str): The directory where the final checkpoint file should be saved. + module_dir (str): The directory containing the checkpoint files. + prefix (str): The prefix used to identify checkpoint files. + + Raises: + ValueError: If no checkpoint files are found with the specified prefix. + """ + checkpoints = os.listdir(module_dir) + checkpoints = [d for d in checkpoints if d.startswith(prefix)] + checkpoints = sorted( + checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0]) + ) + shutil.copy2(os.path.join( + module_dir, checkpoints[-1]), os.path.join(save_dir, prefix + '.pth')) diff --git a/Hallo2/hallo2/hallo2_README.md b/Hallo2/hallo2/hallo2_README.md new file mode 100644 index 00000000..4322bce3 --- /dev/null +++ b/Hallo2/hallo2/hallo2_README.md @@ -0,0 +1,433 @@ +

Hallo2: Long-Duration and High-Resolution Audio-driven Portrait Image Animation

+ +
+ Jiahao Cui1*  + Hui Li1*  + Yao Yao3  + Hao Zhu3  + Hanlin Shang1  + Kaihui Cheng1  + Hang Zhou2  +
+
+ Siyu Zhu1✉️  + Jingdong Wang2  +
+ +
+ 1Fudan University  2Baidu Inc  3Nanjing University +
+ +
+
+ + + + + + +
+
+ +## 📸 Showcase + + + + + + + + + + + + + + + + + + +
Tailor Swift Speech @ NYU (4K, 23 minutes)Johan Rockstrom Speech @ TED (4K, 18 minutes)
Churchill's Iron Curtain Speech (4K, 4 minutes)An LLM Course from Stanford (4K, up to 1 hour)
+ +Visit our [project page](https://fudan-generative-vision.github.io/hallo2/#/) to view more cases. + +## 📰 News + +- **`2024/10/16`**: ✨✨✨ Source code and pretrained weights released. +- **`2024/10/10`**: 🎉🎉🎉 Paper submitted on [Arxiv](https://arxiv.org/abs/2410.07718). + +## 📅️ Roadmap + +| Status | Milestone | ETA | +| :----: | :------------------------------------------------------------------------------------------- | :--------: | +| ✅ | **[Paper submitted on Arixiv](https://arxiv.org/abs/2410.07718)** | 2024-10-10 | +| ✅ | **[Source code meet everyone on GitHub](https://github.com/fudan-generative-vision/hallo2)** | 2024-10-16 | +| 🚀 | **[Accelerate performance on inference]()** | TBD | + +## 🔧️ Framework + +![framework](assets/framework_2.jpg) + +## ⚙️ Installation + +- System requirement: Ubuntu 20.04/Ubuntu 22.04, Cuda 11.8 +- Tested GPUs: A100 + +Download the codes: + +```bash + git clone https://github.com/fudan-generative-vision/hallo2 + cd hallo2 +``` + +Create conda environment: + +```bash + conda create -n hallo python=3.10 + conda activate hallo +``` + +Install packages with `pip` + +```bash + pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118 + pip install -r requirements.txt +``` + +Besides, ffmpeg is also needed: + +```bash + apt-get install ffmpeg +``` + +### 📥 Download Pretrained Models + +You can easily get all pretrained models required by inference from our [HuggingFace repo](https://huggingface.co/fudan-generative-ai/hallo2). + +Using `huggingface-cli` to download the models: + +```shell +cd $ProjectRootDir +pip install huggingface-cli +huggingface-cli download fudan-generative-ai/hallo --local-dir ./pretrained_models +``` + +Or you can download them separately from their source repo: + +- [hallo](https://huggingface.co/fudan-generative-ai/hallo2/tree/main/hallo2): Our checkpoints consist of denoising UNet, face locator, image & audio proj. +- [audio_separator](https://huggingface.co/huangjackson/Kim_Vocal_2): Kim*Vocal_2 MDX-Net vocal removal model. (\_Thanks to [KimberleyJensen](https://github.com/KimberleyJensen)*) +- [insightface](https://github.com/deepinsight/insightface/tree/master/python-package#model-zoo): 2D and 3D Face Analysis placed into `pretrained_models/face_analysis/models/`. (_Thanks to deepinsight_) +- [face landmarker](https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task): Face detection & mesh model from [mediapipe](https://ai.google.dev/edge/mediapipe/solutions/vision/face_landmarker#models) placed into `pretrained_models/face_analysis/models`. +- [motion module](https://github.com/guoyww/AnimateDiff/blob/main/README.md#202309-animatediff-v2): motion module from [AnimateDiff](https://github.com/guoyww/AnimateDiff). (_Thanks to [guoyww](https://github.com/guoyww)_). +- [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse): Weights are intended to be used with the diffusers library. (_Thanks to [stablilityai](https://huggingface.co/stabilityai)_) +- [StableDiffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5): Initialized and fine-tuned from Stable-Diffusion-v1-2. (_Thanks to [runwayml](https://huggingface.co/runwayml)_) +- [wav2vec](https://huggingface.co/facebook/wav2vec2-base-960h): wav audio to vector model from [Facebook](https://huggingface.co/facebook/wav2vec2-base-960h). +- [facelib](https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0): pretrained face parse models +- [realesrgan](https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth): background upsample model +- [CodeFormer](https://github.com/sczhou/CodeFormer/releases/download/v0.1.0): pretrained [Codeformer](https://github.com/sczhou/CodeFormer) model, it's optional to download it, only if you want to train our video super-resolution model from scratch + +Finally, these pretrained models should be organized as follows: + +```text +./pretrained_models/ +|-- audio_separator/ +| |-- download_checks.json +| |-- mdx_model_data.json +| |-- vr_model_data.json +| `-- Kim_Vocal_2.onnx +|-- CodeFormer/ +| |-- codeformer.pth +| `-- vqgan_code1024.pth +|-- face_analysis/ +| `-- models/ +| |-- face_landmarker_v2_with_blendshapes.task # face landmarker model from mediapipe +| |-- 1k3d68.onnx +| |-- 2d106det.onnx +| |-- genderage.onnx +| |-- glintr100.onnx +| `-- scrfd_10g_bnkps.onnx +|-- facelib +| |-- detection_mobilenet0.25_Final.pth +| |-- detection_Resnet50_Final.pth +| |-- parsing_parsenet.pth +| |-- yolov5l-face.pth +| `-- yolov5n-face.pth +|-- hallo2 +| |-- net_g.pth +| `-- net.pth +|-- motion_module/ +| `-- mm_sd_v15_v2.ckpt +|-- realesrgan +| `-- RealESRGAN_x2plus.pth +|-- sd-vae-ft-mse/ +| |-- config.json +| `-- diffusion_pytorch_model.safetensors +|-- stable-diffusion-v1-5/ +| `-- unet/ +| |-- config.json +| `-- diffusion_pytorch_model.safetensors +`-- wav2vec/ + `-- wav2vec2-base-960h/ + |-- config.json + |-- feature_extractor_config.json + |-- model.safetensors + |-- preprocessor_config.json + |-- special_tokens_map.json + |-- tokenizer_config.json + `-- vocab.json +``` + +### 🛠️ Prepare Inference Data + +Hallo has a few simple requirements for input data: + +For the source image: + +1. It should be cropped into squares. +2. The face should be the main focus, making up 50%-70% of the image. +3. The face should be facing forward, with a rotation angle of less than 30° (no side profiles). + +For the driving audio: + +1. It must be in WAV format. +2. It must be in English since our training datasets are only in this language. +3. Ensure the vocals are clear; background music is acceptable. + +We have provided [some samples](examples/) for your reference. + +### 🎮 Run Inference + +#### Long-Duration animation + +Simply to run the `scripts/inference_long.py` and change `source_image`, `driving_audio` and `save_path` in the config file: + +```bash +python scripts/inference_long.py --config ./configs/inference/long.yaml +``` + +Animation results will be saved at `save_path`. You can find more examples for inference at [examples folder](https://github.com/fudan-generative-vision/hallo2/tree/main/examples). + +For more options: + +```shell +usage: inference_long.py [-h] [-c CONFIG] [--source_image SOURCE_IMAGE] [--driving_audio DRIVING_AUDIO] [--pose_weight POSE_WEIGHT] + [--face_weight FACE_WEIGHT] [--lip_weight LIP_WEIGHT] [--face_expand_ratio FACE_EXPAND_RATIO] + +options: + -h, --help show this help message and exit + -c CONFIG, --config CONFIG + --source_image SOURCE_IMAGE + source image + --driving_audio DRIVING_AUDIO + driving audio + --pose_weight POSE_WEIGHT + weight of pose + --face_weight FACE_WEIGHT + weight of face + --lip_weight LIP_WEIGHT + weight of lip + --face_expand_ratio FACE_EXPAND_RATIO + face region +``` + +#### High-Resolution animation + +Simply to run the `scripts/video_sr.py` and pass `input_video` and `output_path`: + +```bash +python scripts/video_sr.py --input_path [input_video] --output_path [output_dir] --bg_upsampler realesrgan --face_upsample -w 1 -s 4 +``` + +Animation results will be saved at `output_dir`. + +For more options: + +```shell +usage: video_sr.py [-h] [-i INPUT_PATH] [-o OUTPUT_PATH] [-w FIDELITY_WEIGHT] [-s UPSCALE] [--has_aligned] [--only_center_face] [--draw_box] + [--detection_model DETECTION_MODEL] [--bg_upsampler BG_UPSAMPLER] [--face_upsample] [--bg_tile BG_TILE] [--suffix SUFFIX] + +options: + -h, --help show this help message and exit + -i INPUT_PATH, --input_path INPUT_PATH + Input video + -o OUTPUT_PATH, --output_path OUTPUT_PATH + Output folder. + -w FIDELITY_WEIGHT, --fidelity_weight FIDELITY_WEIGHT + Balance the quality and fidelity. Default: 0.5 + -s UPSCALE, --upscale UPSCALE + The final upsampling scale of the image. Default: 2 + --has_aligned Input are cropped and aligned faces. Default: False + --only_center_face Only restore the center face. Default: False + --draw_box Draw the bounding box for the detected faces. Default: False + --detection_model DETECTION_MODEL + Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n. Default: retinaface_resnet50 + --bg_upsampler BG_UPSAMPLER + Background upsampler. Optional: realesrgan + --face_upsample Face upsampler after enhancement. Default: False + --bg_tile BG_TILE Tile size for background sampler. Default: 400 + --suffix SUFFIX Suffix of the restored faces. Default: None +``` + +> NOTICE: The High-Resolution animation feature is a modified version of [CodeFormer](https://github.com/sczhou/CodeFormer). When using or redistributing this feature, please comply with the [S-Lab License 1.0](https://github.com/sczhou/CodeFormer?tab=License-1-ov-file). We kindly request that you respect the terms of this license in any usage or redistribution of this component. + +## Training + +### Long-Duration animation + +#### prepare data for training + +The training data, which utilizes some talking-face videos similar to the source images used for inference, also needs to meet the following requirements: + +1. It should be cropped into squares. +2. The face should be the main focus, making up 50%-70% of the image. +3. The face should be facing forward, with a rotation angle of less than 30° (no side profiles). + +Organize your raw videos into the following directory structure: + +```text +dataset_name/ +|-- videos/ +| |-- 0001.mp4 +| |-- 0002.mp4 +| |-- 0003.mp4 +| `-- 0004.mp4 +``` + +You can use any `dataset_name`, but ensure the `videos` directory is named as shown above. + +Next, process the videos with the following commands: + +```bash +python -m scripts.data_preprocess --input_dir dataset_name/videos --step 1 +python -m scripts.data_preprocess --input_dir dataset_name/videos --step 2 +``` + +**Note:** Execute steps 1 and 2 sequentially as they perform different tasks. Step 1 converts videos into frames, extracts audio from each video, and generates the necessary masks. Step 2 generates face embeddings using InsightFace and audio embeddings using Wav2Vec, and requires a GPU. For parallel processing, use the `-p` and `-r` arguments. The `-p` argument specifies the total number of instances to launch, dividing the data into `p` parts. The `-r` argument specifies which part the current process should handle. You need to manually launch multiple instances with different values for `-r`. + +Generate the metadata JSON files with the following commands: + +```bash +python scripts/extract_meta_info_stage1.py -r path/to/dataset -n dataset_name +python scripts/extract_meta_info_stage2.py -r path/to/dataset -n dataset_name +``` + +Replace `path/to/dataset` with the path to the parent directory of `videos`, such as `dataset_name` in the example above. This will generate `dataset_name_stage1.json` and `dataset_name_stage2.json` in the `./data` directory. + +#### Training + +Update the data meta path settings in the configuration YAML files, `configs/train/stage1.yaml` and `configs/train/stage2_long.yaml`: + +```yaml +#stage1.yaml +data: + meta_paths: + - ./data/dataset_name_stage1.json + +#stage2.yaml +data: + meta_paths: + - ./data/dataset_name_stage2.json +``` + +Start training with the following command: + +```shell +accelerate launch -m \ + --config_file accelerate_config.yaml \ + --machine_rank 0 \ + --main_process_ip 0.0.0.0 \ + --main_process_port 20055 \ + --num_machines 1 \ + --num_processes 8 \ + scripts.train_stage1 --config ./configs/train/stage1.yaml +``` + +##### Accelerate Usage Explanation + +The `accelerate launch` command is used to start the training process with distributed settings. + +```shell +accelerate launch [arguments] {training_script} --{training_script-argument-1} --{training_script-argument-2} ... +``` + +**Arguments for Accelerate:** + +- `-m, --module`: Interpret the launch script as a Python module. +- `--config_file`: Configuration file for Hugging Face Accelerate. +- `--machine_rank`: Rank of the current machine in a multi-node setup. +- `--main_process_ip`: IP address of the master node. +- `--main_process_port`: Port of the master node. +- `--num_machines`: Total number of nodes participating in the training. +- `--num_processes`: Total number of processes for training, matching the total number of GPUs across all machines. + +**Arguments for Training:** + +- `{training_script}`: The training script, such as `scripts.train_stage1` or `scripts.train_stage2`. +- `--{training_script-argument-1}`: Arguments specific to the training script. Our training scripts accept one argument, `--config`, to specify the training configuration file. + +For multi-node training, you need to manually run the command with different `machine_rank` on each node separately. + +For more settings, refer to the [Accelerate documentation](https://huggingface.co/docs/accelerate/en/index). + +### High-Resolution animation + +#### Training + +##### prepare data for training + +We use the VFHQ dataset for training, you can download from its [homepage](https://liangbinxie.github.io/projects/vfhq/). Then updata `dataroot_gt` in `./configs/train/video_sr.yaml`. + +#### training + +Start training with the following command: + +```shell +python -m torch.distributed.launch --nproc_per_node=8 --master_port=4322 \ +basicsr/train.py -opt ./configs/train/video_sr.yaml \ +--launcher pytorch +``` + +## 📝 Citation + +If you find our work useful for your research, please consider citing the paper: + +``` +@misc{cui2024hallo2, + title={Hallo2: Long-Duration and High-Resolution Audio-driven Portrait Image Animation}, + author={Jiahao Cui and Hui Li and Yao Yao and Hao Zhu and Hanlin Shang and Kaihui Cheng and Hang Zhou and Siyu Zhu and️ Jingdong Wang}, + year={2024}, + eprint={2410.07718}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` + +## 🌟 Opportunities Available + +Multiple research positions are open at the **Generative Vision Lab, Fudan University**! Include: + +- Research assistant +- Postdoctoral researcher +- PhD candidate +- Master students + +Interested individuals are encouraged to contact us at [siyuzhu@fudan.edu.cn](mailto://siyuzhu@fudan.edu.cn) for further information. + +## ⚠️ Social Risks and Mitigations + +The development of portrait image animation technologies driven by audio inputs poses social risks, such as the ethical implications of creating realistic portraits that could be misused for deepfakes. To mitigate these risks, it is crucial to establish ethical guidelines and responsible use practices. Privacy and consent concerns also arise from using individuals' images and voices. Addressing these involves transparent data usage policies, informed consent, and safeguarding privacy rights. By addressing these risks and implementing mitigations, the research aims to ensure the responsible and ethical development of this technology. + +## 🤗 Acknowledgements + +We would like to thank the contributors to the [magic-animate](https://github.com/magic-research/magic-animate), [AnimateDiff](https://github.com/guoyww/AnimateDiff), [ultimatevocalremovergui](https://github.com/Anjok07/ultimatevocalremovergui), [AniPortrait](https://github.com/Zejun-Yang/AniPortrait) and [Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone) repositories, for their open research and exploration. + +If we missed any open-source projects or related articles, we would like to complement the acknowledgement of this specific work immediately. + +## 👏 Community Contributors + +Thank you to all the contributors who have helped to make this project better! + + + + diff --git a/Hallo2/hallo2/requirements.txt b/Hallo2/hallo2/requirements.txt new file mode 100644 index 00000000..60a7d720 --- /dev/null +++ b/Hallo2/hallo2/requirements.txt @@ -0,0 +1,34 @@ +accelerate==0.28.0 +audio-separator==0.17.2 +av==12.1.0 +bitsandbytes==0.43.1 +decord==0.6.0 +diffusers==0.27.2 +einops==0.8.0 +ffmpeg-python==0.2.0 +icecream==2.1.3 +insightface==0.7.3 +librosa==0.10.2.post1 +lpips==0.1.4 +mediapipe[vision]==0.10.14 +mlflow==2.13.1 +moviepy==1.0.3 +numpy==1.26.4 +omegaconf==2.3.0 +onnx2torch==1.5.14 +onnx==1.16.1 +onnxruntime-gpu==1.18.0 +opencv-contrib-python==4.9.0.80 +opencv-python-headless==4.9.0.80 +opencv-python==4.9.0.80 +pillow==10.3.0 +setuptools==70.0.0 +tqdm==4.66.4 +transformers==4.39.2 +xformers==0.0.25.post1 +isort==5.13.2 +pylint==3.2.2 +pre-commit==3.7.1 +gradio==4.36.1 +lpips +ffmpeg-python==0.2.0 \ No newline at end of file diff --git a/Hallo2/hallo2/scripts/data_preprocess.py b/Hallo2/hallo2/scripts/data_preprocess.py new file mode 100644 index 00000000..92efc2fc --- /dev/null +++ b/Hallo2/hallo2/scripts/data_preprocess.py @@ -0,0 +1,191 @@ +# pylint: disable=W1203,W0718 +""" +This module is used to process videos to prepare data for training. It utilizes various libraries and models +to perform tasks such as video frame extraction, audio extraction, face mask generation, and face embedding extraction. +The script takes in command-line arguments to specify the input and output directories, GPU status, level of parallelism, +and rank for distributed processing. + +Usage: + python -m scripts.data_preprocess --input_dir /path/to/video_dir --dataset_name dataset_name --gpu_status --parallelism 4 --rank 0 + +Example: + python -m scripts.data_preprocess -i data/videos -o data/output -g -p 4 -r 0 +""" +import argparse +import logging +import os +from pathlib import Path +from typing import List + +import cv2 +import torch +from tqdm import tqdm + +from hallo.datasets.audio_processor import AudioProcessor +from hallo.datasets.image_processor import ImageProcessorForDataProcessing +from hallo.utils.util import convert_video_to_images, extract_audio_from_videos + +# Configure logging +logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s') + + +def setup_directories(video_path: Path) -> dict: + """ + Setup directories for storing processed files. + + Args: + video_path (Path): Path to the video file. + + Returns: + dict: A dictionary containing paths for various directories. + """ + base_dir = video_path.parent.parent + dirs = { + "face_mask": base_dir / "face_mask", + "sep_pose_mask": base_dir / "sep_pose_mask", + "sep_face_mask": base_dir / "sep_face_mask", + "sep_lip_mask": base_dir / "sep_lip_mask", + "face_emb": base_dir / "face_emb", + "audio_emb": base_dir / "audio_emb" + } + + for path in dirs.values(): + path.mkdir(parents=True, exist_ok=True) + + return dirs + + +def process_single_video(video_path: Path, + output_dir: Path, + image_processor: ImageProcessorForDataProcessing, + audio_processor: AudioProcessor, + step: int) -> None: + """ + Process a single video file. + + Args: + video_path (Path): Path to the video file. + output_dir (Path): Directory to save the output. + image_processor (ImageProcessorForDataProcessing): Image processor object. + audio_processor (AudioProcessor): Audio processor object. + gpu_status (bool): Whether to use GPU for processing. + """ + assert video_path.exists(), f"Video path {video_path} does not exist" + dirs = setup_directories(video_path) + logging.info(f"Processing video: {video_path}") + + try: + if step == 1: + images_output_dir = output_dir / 'images' / video_path.stem + images_output_dir.mkdir(parents=True, exist_ok=True) + images_output_dir = convert_video_to_images( + video_path, images_output_dir) + logging.info(f"Images saved to: {images_output_dir}") + + audio_output_dir = output_dir / 'audios' + audio_output_dir.mkdir(parents=True, exist_ok=True) + audio_output_path = audio_output_dir / f'{video_path.stem}.wav' + audio_output_path = extract_audio_from_videos( + video_path, audio_output_path) + logging.info(f"Audio extracted to: {audio_output_path}") + + face_mask, _, sep_pose_mask, sep_face_mask, sep_lip_mask = image_processor.preprocess( + images_output_dir) + cv2.imwrite( + str(dirs["face_mask"] / f"{video_path.stem}.png"), face_mask) + cv2.imwrite(str(dirs["sep_pose_mask"] / + f"{video_path.stem}.png"), sep_pose_mask) + cv2.imwrite(str(dirs["sep_face_mask"] / + f"{video_path.stem}.png"), sep_face_mask) + cv2.imwrite(str(dirs["sep_lip_mask"] / + f"{video_path.stem}.png"), sep_lip_mask) + else: + images_dir = output_dir / "images" / video_path.stem + audio_path = output_dir / "audios" / f"{video_path.stem}.wav" + _, face_emb, _, _, _ = image_processor.preprocess(images_dir) + torch.save(face_emb, str( + dirs["face_emb"] / f"{video_path.stem}.pt")) + audio_emb, _ = audio_processor.preprocess(audio_path) + torch.save(audio_emb, str( + dirs["audio_emb"] / f"{video_path.stem}.pt")) + except Exception as e: + logging.error(f"Failed to process video {video_path}: {e}") + + +def process_all_videos(input_video_list: List[Path], output_dir: Path, step: int) -> None: + """ + Process all videos in the input list. + + Args: + input_video_list (List[Path]): List of video paths to process. + output_dir (Path): Directory to save the output. + gpu_status (bool): Whether to use GPU for processing. + """ + face_analysis_model_path = "pretrained_models/face_analysis" + landmark_model_path = "pretrained_models/face_analysis/models/face_landmarker_v2_with_blendshapes.task" + audio_separator_model_file = "pretrained_models/audio_separator/Kim_Vocal_2.onnx" + wav2vec_model_path = 'pretrained_models/wav2vec/wav2vec2-base-960h' + + audio_processor = AudioProcessor( + 16000, + 25, + wav2vec_model_path, + False, + os.path.dirname(audio_separator_model_file), + os.path.basename(audio_separator_model_file), + os.path.join(output_dir, "vocals"), + ) if step==2 else None + + image_processor = ImageProcessorForDataProcessing( + face_analysis_model_path, landmark_model_path, step) + + for video_path in tqdm(input_video_list, desc="Processing videos"): + process_single_video(video_path, output_dir, + image_processor, audio_processor, step) + + +def get_video_paths(source_dir: Path, parallelism: int, rank: int) -> List[Path]: + """ + Get paths of videos to process, partitioned for parallel processing. + + Args: + source_dir (Path): Source directory containing videos. + parallelism (int): Level of parallelism. + rank (int): Rank for distributed processing. + + Returns: + List[Path]: List of video paths to process. + """ + video_paths = [item for item in sorted( + source_dir.iterdir()) if item.is_file() and item.suffix == '.mp4'] + return [video_paths[i] for i in range(len(video_paths)) if i % parallelism == rank] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Process videos to prepare data for training. Run this script twice with different GPU status parameters." + ) + parser.add_argument("-i", "--input_dir", type=Path, + required=True, help="Directory containing videos") + parser.add_argument("-o", "--output_dir", type=Path, + help="Directory to save results, default is parent dir of input dir") + parser.add_argument("-s", "--step", type=int, default=1, + help="Specify data processing step 1 or 2, you should run 1 and 2 sequently") + parser.add_argument("-p", "--parallelism", default=1, + type=int, help="Level of parallelism") + parser.add_argument("-r", "--rank", default=0, type=int, + help="Rank for distributed processing") + + args = parser.parse_args() + + if args.output_dir is None: + args.output_dir = args.input_dir.parent + + video_path_list = get_video_paths( + args.input_dir, args.parallelism, args.rank) + + if not video_path_list: + logging.warning("No videos to process.") + else: + process_all_videos(video_path_list, args.output_dir, args.step) diff --git a/Hallo2/hallo2/scripts/extract_meta_info_stage1.py b/Hallo2/hallo2/scripts/extract_meta_info_stage1.py new file mode 100644 index 00000000..936cb06c --- /dev/null +++ b/Hallo2/hallo2/scripts/extract_meta_info_stage1.py @@ -0,0 +1,106 @@ +# pylint: disable=R0801 +""" +This module is used to extract meta information from video directories. + +It takes in two command-line arguments: `root_path` and `dataset_name`. The `root_path` +specifies the path to the video directory, while the `dataset_name` specifies the name +of the dataset. The module then collects all the video folder paths, and for each video +folder, it checks if a mask path and a face embedding path exist. If they do, it appends +a dictionary containing the image path, mask path, and face embedding path to a list. + +Finally, the module writes the list of dictionaries to a JSON file with the filename +constructed using the `dataset_name`. + +Usage: + python tools/extract_meta_info_stage1.py --root_path /path/to/video_dir --dataset_name hdtf + +""" + +import argparse +import json +import os +from pathlib import Path + +import torch + + +def collect_video_folder_paths(root_path: Path) -> list: + """ + Collect all video folder paths from the root path. + + Args: + root_path (Path): The root directory containing video folders. + + Returns: + list: List of video folder paths. + """ + return [frames_dir.resolve() for frames_dir in root_path.iterdir() if frames_dir.is_dir()] + + +def construct_meta_info(frames_dir_path: Path) -> dict: + """ + Construct meta information for a given frames directory. + + Args: + frames_dir_path (Path): The path to the frames directory. + + Returns: + dict: A dictionary containing the meta information for the frames directory, or None if the required files do not exist. + """ + mask_path = str(frames_dir_path).replace("images", "face_mask") + ".png" + face_emb_path = str(frames_dir_path).replace("images", "face_emb") + ".pt" + + if not os.path.exists(mask_path): + print(f"Mask path not found: {mask_path}") + return None + + if torch.load(face_emb_path) is None: + print(f"Face emb is None: {face_emb_path}") + return None + + return { + "image_path": str(frames_dir_path), + "mask_path": mask_path, + "face_emb": face_emb_path, + } + + +def main(): + """ + Main function to extract meta info for training. + """ + parser = argparse.ArgumentParser() + parser.add_argument("-r", "--root_path", type=str, + required=True, help="Root path of the video directories") + parser.add_argument("-n", "--dataset_name", type=str, + required=True, help="Name of the dataset") + parser.add_argument("--meta_info_name", type=str, + help="Name of the meta information file") + + args = parser.parse_args() + + if args.meta_info_name is None: + args.meta_info_name = args.dataset_name + + image_dir = Path(args.root_path) / "images" + output_dir = Path("./data") + output_dir.mkdir(exist_ok=True) + + # Collect all video folder paths + frames_dir_paths = collect_video_folder_paths(image_dir) + + meta_infos = [] + for frames_dir_path in frames_dir_paths: + meta_info = construct_meta_info(frames_dir_path) + if meta_info: + meta_infos.append(meta_info) + + output_file = output_dir / f"{args.meta_info_name}_stage1.json" + with output_file.open("w", encoding="utf-8") as f: + json.dump(meta_infos, f, indent=4) + + print(f"Final data count: {len(meta_infos)}") + + +if __name__ == "__main__": + main() diff --git a/Hallo2/hallo2/scripts/extract_meta_info_stage2.py b/Hallo2/hallo2/scripts/extract_meta_info_stage2.py new file mode 100644 index 00000000..e2d9301c --- /dev/null +++ b/Hallo2/hallo2/scripts/extract_meta_info_stage2.py @@ -0,0 +1,192 @@ +# pylint: disable=R0801 +""" +This module is used to extract meta information from video files and store them in a JSON file. + +The script takes in command line arguments to specify the root path of the video files, +the dataset name, and the name of the meta information file. It then generates a list of +dictionaries containing the meta information for each video file and writes it to a JSON +file with the specified name. + +The meta information includes the path to the video file, the mask path, the face mask +path, the face mask union path, the face mask gaussian path, the lip mask path, the lip +mask union path, the lip mask gaussian path, the separate mask border, the separate mask +face, the separate mask lip, the face embedding path, the audio path, the vocals embedding +base last path, the vocals embedding base all path, the vocals embedding base average +path, the vocals embedding large last path, the vocals embedding large all path, and the +vocals embedding large average path. + +The script checks if the mask path exists before adding the information to the list. + +Usage: + python tools/extract_meta_info_stage2.py --root_path --dataset_name --meta_info_name + +Example: + python tools/extract_meta_info_stage2.py --root_path data/videos_25fps --dataset_name my_dataset --meta_info_name my_meta_info +""" + +import argparse +import json +import os +from pathlib import Path + +import torch +from decord import VideoReader, cpu +from tqdm import tqdm + + +def get_video_paths(root_path: Path, extensions: list) -> list: + """ + Get a list of video paths from the root path with the specified extensions. + + Args: + root_path (Path): The root directory containing video files. + extensions (list): List of file extensions to include. + + Returns: + list: List of video file paths. + """ + return [str(path.resolve()) for path in root_path.iterdir() if path.suffix in extensions] + + +def file_exists(file_path: str) -> bool: + """ + Check if a file exists. + + Args: + file_path (str): The path to the file. + + Returns: + bool: True if the file exists, False otherwise. + """ + return os.path.exists(file_path) + + +def construct_paths(video_path: str, base_dir: str, new_dir: str, new_ext: str) -> str: + """ + Construct a new path by replacing the base directory and extension in the original path. + + Args: + video_path (str): The original video path. + base_dir (str): The base directory to be replaced. + new_dir (str): The new directory to replace the base directory. + new_ext (str): The new file extension. + + Returns: + str: The constructed path. + """ + return str(video_path).replace(base_dir, new_dir).replace(".mp4", new_ext) + + +def extract_meta_info(video_path: str) -> dict: + """ + Extract meta information for a given video file. + + Args: + video_path (str): The path to the video file. + + Returns: + dict: A dictionary containing the meta information for the video. + """ + mask_path = construct_paths( + video_path, "videos", "face_mask", ".png") + sep_mask_border = construct_paths( + video_path, "videos", "sep_pose_mask", ".png") + sep_mask_face = construct_paths( + video_path, "videos", "sep_face_mask", ".png") + sep_mask_lip = construct_paths( + video_path, "videos", "sep_lip_mask", ".png") + face_emb_path = construct_paths( + video_path, "videos", "face_emb", ".pt") + audio_path = construct_paths(video_path, "videos", "audios", ".wav") + vocal_emb_base_all = construct_paths( + video_path, "videos", "audio_emb", ".pt") + + assert_flag = True + + if not file_exists(mask_path): + print(f"Mask path not found: {mask_path}") + assert_flag = False + if not file_exists(sep_mask_border): + print(f"Separate mask border not found: {sep_mask_border}") + assert_flag = False + if not file_exists(sep_mask_face): + print(f"Separate mask face not found: {sep_mask_face}") + assert_flag = False + if not file_exists(sep_mask_lip): + print(f"Separate mask lip not found: {sep_mask_lip}") + assert_flag = False + if not file_exists(face_emb_path): + print(f"Face embedding path not found: {face_emb_path}") + assert_flag = False + if not file_exists(audio_path): + print(f"Audio path not found: {audio_path}") + assert_flag = False + if not file_exists(vocal_emb_base_all): + print(f"Vocal embedding base all not found: {vocal_emb_base_all}") + assert_flag = False + + video_frames = VideoReader(video_path, ctx=cpu(0)) + audio_emb = torch.load(vocal_emb_base_all) + if abs(len(video_frames) - audio_emb.shape[0]) > 3: + print(f"Frame count mismatch for video: {video_path}") + assert_flag = False + + face_emb = torch.load(face_emb_path) + if face_emb is None: + print(f"Face embedding is None for video: {video_path}") + assert_flag = False + + del video_frames, audio_emb + + if assert_flag: + return { + "video_path": str(video_path), + "mask_path": mask_path, + "sep_mask_border": sep_mask_border, + "sep_mask_face": sep_mask_face, + "sep_mask_lip": sep_mask_lip, + "face_emb_path": face_emb_path, + "audio_path": audio_path, + "vocals_emb_base_all": vocal_emb_base_all, + } + return None + + +def main(): + """ + Main function to extract meta info for training. + """ + parser = argparse.ArgumentParser() + parser.add_argument("-r", "--root_path", type=str, + required=True, help="Root path of the video files") + parser.add_argument("-n", "--dataset_name", type=str, + required=True, help="Name of the dataset") + parser.add_argument("--meta_info_name", type=str, + help="Name of the meta information file") + + args = parser.parse_args() + + if args.meta_info_name is None: + args.meta_info_name = args.dataset_name + + video_dir = Path(args.root_path) / "videos" + video_paths = get_video_paths(video_dir, [".mp4"]) + + meta_infos = [] + + for video_path in tqdm(video_paths, desc="Extracting meta info"): + meta_info = extract_meta_info(video_path) + if meta_info: + meta_infos.append(meta_info) + + print(f"Final data count: {len(meta_infos)}") + + output_file = Path(f"./data/{args.meta_info_name}_stage2.json") + output_file.parent.mkdir(parents=True, exist_ok=True) + + with output_file.open("w", encoding="utf-8") as f: + json.dump(meta_infos, f, indent=4) + + +if __name__ == "__main__": + main() diff --git a/Hallo2/hallo2/scripts/inference_long.py b/Hallo2/hallo2/scripts/inference_long.py new file mode 100644 index 00000000..af93733b --- /dev/null +++ b/Hallo2/hallo2/scripts/inference_long.py @@ -0,0 +1,509 @@ +# pylint: disable=E1101 +# scripts/inference.py + +""" +This script contains the main inference pipeline for processing audio and image inputs to generate a video output. + +The script imports necessary packages and classes, defines a neural network model, +and contains functions for processing audio embeddings and performing inference. + +The main inference process is outlined in the following steps: +1. Initialize the configuration. +2. Set up runtime variables. +3. Prepare the input data for inference (source image, face mask, and face embeddings). +4. Process the audio embeddings. +5. Build and freeze the model and scheduler. +6. Run the inference loop and save the result. + +Usage: +This script can be run from the command line with the following arguments: +- audio_path: Path to the audio file. +- image_path: Path to the source image. +- face_mask_path: Path to the face mask image. +- face_emb_path: Path to the face embeddings file. +- output_path: Path to save the output video. + +Example: +python scripts/inference.py --audio_path audio.wav --image_path image.jpg + --face_mask_path face_mask.png --face_emb_path face_emb.pt --output_path output.mp4 +""" + +import argparse +import os +import sys + +import torch +from diffusers import AutoencoderKL, DDIMScheduler +from omegaconf import OmegaConf +from torch import nn +from pathlib import Path +import numpy as np +import torchvision.transforms as transforms +from PIL import Image +from pydub import AudioSegment + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from hallo.animate.face_animate import FaceAnimatePipeline +from hallo.datasets.audio_processor import AudioProcessor +from hallo.datasets.image_processor import ImageProcessor +from hallo.models.audio_proj import AudioProjModel +from hallo.models.face_locator import FaceLocator +from hallo.models.image_proj import ImageProjModel +from hallo.models.unet_2d_condition import UNet2DConditionModel +from hallo.models.unet_3d import UNet3DConditionModel +from hallo.utils.config import filter_non_none +from hallo.utils.util import tensor_to_video_batch, merge_videos + +from icecream import ic + +class Net(nn.Module): + """ + The Net class combines all the necessary modules for the inference process. + + Args: + reference_unet (UNet2DConditionModel): The UNet2DConditionModel used as a reference for inference. + denoising_unet (UNet3DConditionModel): The UNet3DConditionModel used for denoising the input audio. + face_locator (FaceLocator): The FaceLocator model used to locate the face in the input image. + imageproj (nn.Module): The ImageProjector model used to project the source image onto the face. + audioproj (nn.Module): The AudioProjector model used to project the audio embeddings onto the face. + """ + def __init__( + self, + reference_unet: UNet2DConditionModel, + denoising_unet: UNet3DConditionModel, + face_locator: FaceLocator, + imageproj, + audioproj, + ): + super().__init__() + self.reference_unet = reference_unet + self.denoising_unet = denoising_unet + self.face_locator = face_locator + self.imageproj = imageproj + self.audioproj = audioproj + + def forward(self,): + """ + empty function to override abstract function of nn Module + """ + + def get_modules(self): + """ + Simple method to avoid too-few-public-methods pylint error + """ + return { + "reference_unet": self.reference_unet, + "denoising_unet": self.denoising_unet, + "face_locator": self.face_locator, + "imageproj": self.imageproj, + "audioproj": self.audioproj, + } + + +def process_audio_emb(audio_emb): + """ + Process the audio embedding to concatenate with other tensors. + + Parameters: + audio_emb (torch.Tensor): The audio embedding tensor to process. + + Returns: + concatenated_tensors (List[torch.Tensor]): The concatenated tensor list. + """ + concatenated_tensors = [] + + for i in range(audio_emb.shape[0]): + vectors_to_concat = [ + audio_emb[max(min(i + j, audio_emb.shape[0]-1), 0)]for j in range(-2, 3)] + concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0)) + + audio_emb = torch.stack(concatenated_tensors, dim=0) + + return audio_emb + +def save_image_batch(image_tensor, save_path): + image_tensor = (image_tensor + 1) / 2 + + os.makedirs(save_path, exist_ok=True) + + for i in range(image_tensor.shape[0]): + img_tensor = image_tensor[i] + + img_array = img_tensor.permute(1, 2, 0).cpu().numpy() + + img_array = (img_array * 255).astype(np.uint8) + + image = Image.fromarray(img_array) + image.save(os.path.join(save_path, f'motion_frame_{i}.png')) + + +def cut_audio(audio_path, save_dir, length=60): + audio = AudioSegment.from_wav(audio_path) + + segment_length = length * 1000 # pydub使用毫秒 + + num_segments = len(audio) // segment_length + (1 if len(audio) % segment_length != 0 else 0) + + os.makedirs(save_dir, exist_ok=True) + + audio_list = [] + + for i in range(num_segments): + start_time = i * segment_length + end_time = min((i + 1) * segment_length, len(audio)) + segment = audio[start_time:end_time] + + path = f"{save_dir}/segment_{i+1}.wav" + audio_list.append(path) + segment.export(path, format="wav") + + return audio_list + + +def inference_process(args: argparse.Namespace): + """ + Perform inference processing. + + Args: + args (argparse.Namespace): Command-line arguments. + + This function initializes the configuration for the inference process. It sets up the necessary + modules and variables to prepare for the upcoming inference steps. + """ + # 1. init config + cli_args = filter_non_none(vars(args)) + config = OmegaConf.load(args.config) + config = OmegaConf.merge(config, cli_args) + source_image_path = config.source_image + driving_audio_path = config.driving_audio + + save_path = os.path.join(config.save_path, Path(source_image_path).stem) + save_seg_path = os.path.join(save_path, "seg_video") + print("save path: ", save_path) + + if not os.path.exists(save_path): + os.makedirs(save_path) + if not os.path.exists(save_seg_path): + os.makedirs(save_seg_path) + + motion_scale = [config.pose_weight, config.face_weight, config.lip_weight] + + # 2. runtime variables + device = torch.device( + "cuda") if torch.cuda.is_available() else torch.device("cpu") + if config.weight_dtype == "fp16": + weight_dtype = torch.float16 + elif config.weight_dtype == "bf16": + weight_dtype = torch.bfloat16 + elif config.weight_dtype == "fp32": + weight_dtype = torch.float32 + else: + weight_dtype = torch.float32 + + # 3. prepare inference data + # 3.1 prepare source image, face mask, face embeddings + img_size = (config.data.source_image.width, + config.data.source_image.height) + clip_length = config.data.n_sample_frames + face_analysis_model_path = config.face_analysis.model_path + with ImageProcessor(img_size, face_analysis_model_path) as image_processor: + source_image_pixels, \ + source_image_face_region, \ + source_image_face_emb, \ + source_image_full_mask, \ + source_image_face_mask, \ + source_image_lip_mask = image_processor.preprocess( + source_image_path, save_path, config.face_expand_ratio) + + # 3.2 prepare audio embeddings + sample_rate = config.data.driving_audio.sample_rate + assert sample_rate == 16000, "audio sample rate must be 16000" + fps = config.data.export_video.fps + wav2vec_model_path = config.wav2vec.model_path + wav2vec_only_last_features = config.wav2vec.features == "last" + audio_separator_model_file = config.audio_separator.model_path + + + if config.use_cut: + audio_list = cut_audio(driving_audio_path, os.path.join( + save_path, f"seg-long-{Path(driving_audio_path).stem}")) + + audio_emb_list = [] + l = 0 + + audio_processor = AudioProcessor( + sample_rate, + fps, + wav2vec_model_path, + wav2vec_only_last_features, + os.path.dirname(audio_separator_model_file), + os.path.basename(audio_separator_model_file), + os.path.join(save_path, "audio_preprocess") + ) + + for idx, audio_path in enumerate(audio_list): + padding = (idx+1) == len(audio_list) + emb, length = audio_processor.preprocess(audio_path, clip_length, + padding=padding, processed_length=l) + audio_emb_list.append(emb) + l += length + + audio_emb = torch.cat(audio_emb_list) + audio_length = l + + else: + with AudioProcessor( + sample_rate, + fps, + wav2vec_model_path, + wav2vec_only_last_features, + os.path.dirname(audio_separator_model_file), + os.path.basename(audio_separator_model_file), + os.path.join(save_path, "audio_preprocess") + ) as audio_processor: + audio_emb, audio_length = audio_processor.preprocess(driving_audio_path, clip_length) + + # 4. build modules + sched_kwargs = OmegaConf.to_container(config.noise_scheduler_kwargs) + if config.enable_zero_snr: + sched_kwargs.update( + rescale_betas_zero_snr=True, + timestep_spacing="trailing", + prediction_type="v_prediction", + ) + val_noise_scheduler = DDIMScheduler(**sched_kwargs) + sched_kwargs.update({"beta_schedule": "scaled_linear"}) + + vae = AutoencoderKL.from_pretrained(config.vae.model_path) + reference_unet = UNet2DConditionModel.from_pretrained( + config.base_model_path, subfolder="unet") + denoising_unet = UNet3DConditionModel.from_pretrained_2d( + config.base_model_path, + config.motion_module_path, + subfolder="unet", + unet_additional_kwargs=OmegaConf.to_container( + config.unet_additional_kwargs), + use_landmark=False, + ) + # denoising_unet.set_attn_processor() + + face_locator = FaceLocator(conditioning_embedding_channels=320) + image_proj = ImageProjModel( + cross_attention_dim=denoising_unet.config.cross_attention_dim, + clip_embeddings_dim=512, + clip_extra_context_tokens=4, + ) + + audio_proj = AudioProjModel( + seq_len=5, + blocks=12, # use 12 layers' hidden states of wav2vec + channels=768, # audio embedding channel + intermediate_dim=512, + output_dim=768, + context_tokens=32, + ).to(device=device, dtype=weight_dtype) + + audio_ckpt_dir = config.audio_ckpt_dir + + + # Freeze + vae.requires_grad_(False) + image_proj.requires_grad_(False) + reference_unet.requires_grad_(False) + denoising_unet.requires_grad_(False) + face_locator.requires_grad_(False) + audio_proj.requires_grad_(False) + + reference_unet.enable_gradient_checkpointing() + denoising_unet.enable_gradient_checkpointing() + + net = Net( + reference_unet, + denoising_unet, + face_locator, + image_proj, + audio_proj, + ) + + m,u = net.load_state_dict( + torch.load( + os.path.join(audio_ckpt_dir, f"net.pth"), + map_location="cpu", + ), + ) + assert len(m) == 0 and len(u) == 0, "Fail to load correct checkpoint." + print("loaded weight from ", os.path.join(audio_ckpt_dir, "net.pth")) + + # 5. inference + pipeline = FaceAnimatePipeline( + vae=vae, + reference_unet=net.reference_unet, + denoising_unet=net.denoising_unet, + face_locator=net.face_locator, + scheduler=val_noise_scheduler, + image_proj=net.imageproj, + ) + pipeline.to(device=device, dtype=weight_dtype) + + audio_emb = process_audio_emb(audio_emb) + + source_image_pixels = source_image_pixels.unsqueeze(0) + source_image_face_region = source_image_face_region.unsqueeze(0) + source_image_face_emb = source_image_face_emb.reshape(1, -1) + source_image_face_emb = torch.tensor(source_image_face_emb) + + source_image_full_mask = [ + (mask.repeat(clip_length, 1)) + for mask in source_image_full_mask + ] + source_image_face_mask = [ + (mask.repeat(clip_length, 1)) + for mask in source_image_face_mask + ] + source_image_lip_mask = [ + (mask.repeat(clip_length, 1)) + for mask in source_image_lip_mask + ] + + + times = audio_emb.shape[0] // clip_length + + tensor_result = [] + + generator = torch.manual_seed(42) + + ic(audio_emb.shape) + ic(audio_length) + batch_size = 60 + start = 0 + + for t in range(times): + print(f"[{t+1}/{times}]") + + if len(tensor_result) == 0: + # The first iteration + motion_zeros = source_image_pixels.repeat( + config.data.n_motion_frames, 1, 1, 1) + motion_zeros = motion_zeros.to( + dtype=source_image_pixels.dtype, device=source_image_pixels.device) + pixel_values_ref_img = torch.cat( + [source_image_pixels, motion_zeros], dim=0) # concat the ref image and the first motion frames + else: + motion_frames = tensor_result[-1][0] + motion_frames = motion_frames.permute(1, 0, 2, 3) + motion_frames = motion_frames[0-config.data.n_motion_frames:] + motion_frames = motion_frames * 2.0 - 1.0 + motion_frames = motion_frames.to( + dtype=source_image_pixels.dtype, device=source_image_pixels.device) + pixel_values_ref_img = torch.cat( + [source_image_pixels, motion_frames], dim=0) # concat the ref image and the motion frames + + pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0) + + pixel_motion_values = pixel_values_ref_img[:, 1:] + + if config.use_mask: + b, f, c, h, w = pixel_motion_values.shape + rand_mask = torch.rand(h, w) + mask = rand_mask > config.mask_rate + mask = mask.unsqueeze(0).unsqueeze(0).unsqueeze(0) + mask = mask.expand(b, f, c, h, w) + + face_mask = source_image_face_region.repeat(f, 1, 1, 1).unsqueeze(0) + assert face_mask.shape == mask.shape + mask = mask | face_mask.bool() + + pixel_motion_values = pixel_motion_values * mask + pixel_values_ref_img[:, 1:] = pixel_motion_values + + + assert pixel_motion_values.shape[0] == 1 + + audio_tensor = audio_emb[ + t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0]) + ] + audio_tensor = audio_tensor.unsqueeze(0) + audio_tensor = audio_tensor.to( + device=net.audioproj.device, dtype=net.audioproj.dtype) + audio_tensor = net.audioproj(audio_tensor) + + pipeline_output = pipeline( + ref_image=pixel_values_ref_img, + audio_tensor=audio_tensor, + face_emb=source_image_face_emb, + face_mask=source_image_face_region, + pixel_values_full_mask=source_image_full_mask, + pixel_values_face_mask=source_image_face_mask, + pixel_values_lip_mask=source_image_lip_mask, + width=img_size[0], + height=img_size[1], + video_length=clip_length, + num_inference_steps=config.inference_steps, + guidance_scale=config.cfg_scale, + generator=generator, + motion_scale=motion_scale, + ) + + ic(pipeline_output.videos.shape) + tensor_result.append(pipeline_output.videos) + + if (t+1) % batch_size == 0 or (t+1)==times: + last_motion_frame = [tensor_result[-1]] + ic(len(tensor_result)) + + if start!=0: + tensor_result = torch.cat(tensor_result[1:], dim=2) + else: + tensor_result = torch.cat(tensor_result, dim=2) + + tensor_result = tensor_result.squeeze(0) + f = tensor_result.shape[1] + length = min(f, audio_length) + tensor_result = tensor_result[:, :length] + + ic(tensor_result.shape) + ic(start) + ic(audio_length) + + name = Path(save_path).name + output_file = os.path.join(save_seg_path, f"{name}-{t+1:06}.mp4") + + tensor_to_video_batch(tensor_result, output_file, start, driving_audio_path) + del tensor_result + + tensor_result = last_motion_frame + audio_length -= length + start += length + + return save_seg_path + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-c", "--config", default="configs/inference/long.yaml") + parser.add_argument("--source_image", type=str, required=False, + help="source image") + parser.add_argument("--driving_audio", type=str, required=False, + help="driving audio") + parser.add_argument( + "--pose_weight", type=float, help="weight of pose", required=False) + parser.add_argument( + "--face_weight", type=float, help="weight of face", required=False) + parser.add_argument( + "--lip_weight", type=float, help="weight of lip", required=False) + parser.add_argument( + "--face_expand_ratio", type=float, help="face region", required=False) + parser.add_argument( + "--audio_ckpt_dir", "--checkpoint", type=str, help="specific checkpoint dir", required=False) + + + command_line_args = parser.parse_args() + + + + save_path = inference_process(command_line_args) + merge_videos(save_path, os.path.join(Path(save_path).parent, "merge_video.mp4")) diff --git a/Hallo2/hallo2/scripts/train_stage1.py b/Hallo2/hallo2/scripts/train_stage1.py new file mode 100644 index 00000000..e9e7e847 --- /dev/null +++ b/Hallo2/hallo2/scripts/train_stage1.py @@ -0,0 +1,793 @@ +# pylint: disable=E1101,C0415,W0718,R0801 +# scripts/train_stage1.py +""" +This is the main training script for stage 1 of the project. +It imports necessary packages, defines necessary classes and functions, and trains the model using the provided configuration. + +The script includes the following classes and functions: + +1. Net: A PyTorch model that takes noisy latents, timesteps, reference image latents, face embeddings, + and face masks as input and returns the denoised latents. +3. log_validation: A function that logs the validation information using the given VAE, image encoder, + network, scheduler, accelerator, width, height, and configuration. +4. train_stage1_process: A function that processes the training stage 1 using the given configuration. + +The script also includes the necessary imports and a brief description of the purpose of the file. +""" + +import argparse +import copy +import logging +import math +import os +import random +import warnings +from datetime import datetime + +import cv2 +import diffusers +import mlflow +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs +from diffusers import AutoencoderKL, DDIMScheduler +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version +from diffusers.utils.import_utils import is_xformers_available +from insightface.app import FaceAnalysis +from omegaconf import OmegaConf +from PIL import Image +from torch import nn +from tqdm.auto import tqdm + +from hallo.animate.face_animate_static import StaticPipeline +from hallo.datasets.mask_image import FaceMaskDataset +from hallo.models.face_locator import FaceLocator +from hallo.models.image_proj import ImageProjModel +from hallo.models.mutual_self_attention import ReferenceAttentionControl +from hallo.models.unet_2d_condition import UNet2DConditionModel +from hallo.models.unet_3d import UNet3DConditionModel +from hallo.utils.util import (compute_snr, delete_additional_ckpt, + import_filename, init_output_dir, + load_checkpoint, move_final_checkpoint, + save_checkpoint, seed_everything) + +warnings.filterwarnings("ignore") + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.10.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + + +class Net(nn.Module): + """ + The Net class defines a neural network model that combines a reference UNet2DConditionModel, + a denoising UNet3DConditionModel, a face locator, and other components to animate a face in a static image. + + Args: + reference_unet (UNet2DConditionModel): The reference UNet2DConditionModel used for face animation. + denoising_unet (UNet3DConditionModel): The denoising UNet3DConditionModel used for face animation. + face_locator (FaceLocator): The face locator model used for face animation. + reference_control_writer: The reference control writer component. + reference_control_reader: The reference control reader component. + imageproj: The image projection model. + + Forward method: + noisy_latents (torch.Tensor): The noisy latents tensor. + timesteps (torch.Tensor): The timesteps tensor. + ref_image_latents (torch.Tensor): The reference image latents tensor. + face_emb (torch.Tensor): The face embeddings tensor. + face_mask (torch.Tensor): The face mask tensor. + uncond_fwd (bool): A flag indicating whether to perform unconditional forward pass. + + Returns: + torch.Tensor: The output tensor of the neural network model. + """ + + def __init__( + self, + reference_unet: UNet2DConditionModel, + denoising_unet: UNet3DConditionModel, + face_locator: FaceLocator, + reference_control_writer: ReferenceAttentionControl, + reference_control_reader: ReferenceAttentionControl, + imageproj: ImageProjModel, + ): + super().__init__() + self.reference_unet = reference_unet + self.denoising_unet = denoising_unet + self.face_locator = face_locator + self.reference_control_writer = reference_control_writer + self.reference_control_reader = reference_control_reader + self.imageproj = imageproj + + def forward( + self, + noisy_latents, + timesteps, + ref_image_latents, + face_emb, + face_mask, + uncond_fwd: bool = False, + ): + """ + Forward pass of the model. + Args: + self (Net): The model instance. + noisy_latents (torch.Tensor): Noisy latents. + timesteps (torch.Tensor): Timesteps. + ref_image_latents (torch.Tensor): Reference image latents. + face_emb (torch.Tensor): Face embedding. + face_mask (torch.Tensor): Face mask. + uncond_fwd (bool, optional): Unconditional forward pass. Defaults to False. + + Returns: + torch.Tensor: Model prediction. + """ + + face_emb = self.imageproj(face_emb) + face_mask = face_mask.to(device="cuda") + face_mask_feature = self.face_locator(face_mask) + + if not uncond_fwd: + ref_timesteps = torch.zeros_like(timesteps) + self.reference_unet( + ref_image_latents, + ref_timesteps, + encoder_hidden_states=face_emb, + return_dict=False, + ) + self.reference_control_reader.update(self.reference_control_writer) + model_pred = self.denoising_unet( + noisy_latents, + timesteps, + mask_cond_fea=face_mask_feature, + encoder_hidden_states=face_emb, + ).sample + + return model_pred + + +def get_noise_scheduler(cfg: argparse.Namespace): + """ + Create noise scheduler for training + + Args: + cfg (omegaconf.dictconfig.DictConfig): Configuration object. + + Returns: + train noise scheduler and val noise scheduler + """ + sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs) + if cfg.enable_zero_snr: + sched_kwargs.update( + rescale_betas_zero_snr=True, + timestep_spacing="trailing", + prediction_type="v_prediction", + ) + val_noise_scheduler = DDIMScheduler(**sched_kwargs) + sched_kwargs.update({"beta_schedule": "scaled_linear"}) + train_noise_scheduler = DDIMScheduler(**sched_kwargs) + + return train_noise_scheduler, val_noise_scheduler + + +def log_validation( + vae, + net, + scheduler, + accelerator, + width, + height, + imageproj, + cfg, + save_dir, + global_step, + face_analysis_model_path, +): + """ + Log validation generation image. + + Args: + vae (nn.Module): Variational Autoencoder model. + net (Net): Main model. + scheduler (diffusers.SchedulerMixin): Noise scheduler. + accelerator (accelerate.Accelerator): Accelerator for training. + width (int): Width of the input images. + height (int): Height of the input images. + imageproj (nn.Module): Image projection model. + cfg (omegaconf.dictconfig.DictConfig): Configuration object. + save_dir (str): directory path to save log result. + global_step (int): Global step number. + + Returns: + None + """ + logger.info("Running validation... ") + + ori_net = accelerator.unwrap_model(net) + ori_net = copy.deepcopy(ori_net) + reference_unet = ori_net.reference_unet + denoising_unet = ori_net.denoising_unet + face_locator = ori_net.face_locator + + generator = torch.manual_seed(42) + image_enc = FaceAnalysis( + name="", + root=face_analysis_model_path, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], + ) + image_enc.prepare(ctx_id=0, det_size=(640, 640)) + + pipe = StaticPipeline( + vae=vae, + reference_unet=reference_unet, + denoising_unet=denoising_unet, + face_locator=face_locator, + scheduler=scheduler, + imageproj=imageproj, + ) + + pil_images = [] + for ref_image_path, mask_image_path in zip(cfg.ref_image_paths, cfg.mask_image_paths): + # for mask_image_path in mask_image_paths: + mask_name = os.path.splitext( + os.path.basename(mask_image_path))[0] + ref_name = os.path.splitext( + os.path.basename(ref_image_path))[0] + ref_image_pil = Image.open(ref_image_path).convert("RGB") + mask_image_pil = Image.open(mask_image_path).convert("RGB") + + # Prepare face embeds + face_info = image_enc.get( + cv2.cvtColor(np.array(ref_image_pil), cv2.COLOR_RGB2BGR)) + face_info = sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * ( + x['bbox'][3] - x['bbox'][1]))[-1] # only use the maximum face + face_emb = torch.tensor(face_info['embedding']) + face_emb = face_emb.to( + imageproj.device, imageproj.dtype) + + image = pipe( + ref_image_pil, + mask_image_pil, + width, + height, + 20, + 3.5, + face_emb, + generator=generator, + ).images + image = image[0, :, 0].permute(1, 2, 0).cpu().numpy() # (3, 512, 512) + res_image_pil = Image.fromarray((image * 255).astype(np.uint8)) + # Save ref_image, src_image and the generated_image + w, h = res_image_pil.size + canvas = Image.new("RGB", (w * 3, h), "white") + ref_image_pil = ref_image_pil.resize((w, h)) + mask_image_pil = mask_image_pil.resize((w, h)) + canvas.paste(ref_image_pil, (0, 0)) + canvas.paste(mask_image_pil, (w, 0)) + canvas.paste(res_image_pil, (w * 2, 0)) + + out_file = os.path.join( + save_dir, f"{global_step:06d}-{ref_name}_{mask_name}.jpg" + ) + canvas.save(out_file) + + del pipe + del ori_net + torch.cuda.empty_cache() + + return pil_images + + +def train_stage1_process(cfg: argparse.Namespace) -> None: + """ + Trains the model using the given configuration (cfg). + + Args: + cfg (dict): The configuration dictionary containing the parameters for training. + + Notes: + - This function trains the model using the given configuration. + - It initializes the necessary components for training, such as the pipeline, optimizer, and scheduler. + - The training progress is logged and tracked using the accelerator. + - The trained model is saved after the training is completed. + """ + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps, + mixed_precision=cfg.solver.mixed_precision, + log_with="mlflow", + project_dir="./mlruns", + kwargs_handlers=[kwargs], + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if cfg.seed is not None: + seed_everything(cfg.seed) + + # create output dir for training + exp_name = cfg.exp_name + save_dir = f"{cfg.output_dir}/{exp_name}" + checkpoint_dir = os.path.join(save_dir, "checkpoints") + module_dir = os.path.join(save_dir, "modules") + validation_dir = os.path.join(save_dir, "validation") + + if accelerator.is_main_process: + init_output_dir([save_dir, checkpoint_dir, module_dir, validation_dir]) + + accelerator.wait_for_everyone() + + # create model + if cfg.weight_dtype == "fp16": + weight_dtype = torch.float16 + elif cfg.weight_dtype == "bf16": + weight_dtype = torch.bfloat16 + elif cfg.weight_dtype == "fp32": + weight_dtype = torch.float32 + else: + raise ValueError( + f"Do not support weight dtype: {cfg.weight_dtype} during training" + ) + + # create model + vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to( + "cuda", dtype=weight_dtype + ) + reference_unet = UNet2DConditionModel.from_pretrained( + cfg.base_model_path, + subfolder="unet", + ).to(device="cuda", dtype=weight_dtype) + denoising_unet = UNet3DConditionModel.from_pretrained_2d( + cfg.base_model_path, + "", + subfolder="unet", + unet_additional_kwargs={ + "use_motion_module": False, + "unet_use_temporal_attention": False, + }, + use_landmark=False + ).to(device="cuda", dtype=weight_dtype) + imageproj = ImageProjModel( + cross_attention_dim=denoising_unet.config.cross_attention_dim, + clip_embeddings_dim=512, + clip_extra_context_tokens=4, + ).to(device="cuda", dtype=weight_dtype) + + if cfg.face_locator_pretrained: + face_locator = FaceLocator( + conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256) + ).to(device="cuda", dtype=weight_dtype) + miss, _ = face_locator.load_state_dict( + cfg.face_state_dict_path, strict=False) + logger.info(f"Missing key for face locator: {len(miss)}") + else: + face_locator = FaceLocator( + conditioning_embedding_channels=320, + ).to(device="cuda", dtype=weight_dtype) + # Freeze + vae.requires_grad_(False) + denoising_unet.requires_grad_(True) + reference_unet.requires_grad_(True) + imageproj.requires_grad_(True) + face_locator.requires_grad_(True) + + reference_control_writer = ReferenceAttentionControl( + reference_unet, + do_classifier_free_guidance=False, + mode="write", + fusion_blocks="full", + ) + reference_control_reader = ReferenceAttentionControl( + denoising_unet, + do_classifier_free_guidance=False, + mode="read", + fusion_blocks="full", + ) + + net = Net( + reference_unet, + denoising_unet, + face_locator, + reference_control_writer, + reference_control_reader, + imageproj, + ).to(dtype=weight_dtype) + + # get noise scheduler + train_noise_scheduler, val_noise_scheduler = get_noise_scheduler(cfg) + + # init optimizer + if cfg.solver.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + reference_unet.enable_xformers_memory_efficient_attention() + denoising_unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError( + "xformers is not available. Make sure it is installed correctly" + ) + + if cfg.solver.gradient_checkpointing: + reference_unet.enable_gradient_checkpointing() + denoising_unet.enable_gradient_checkpointing() + + if cfg.solver.scale_lr: + learning_rate = ( + cfg.solver.learning_rate + * cfg.solver.gradient_accumulation_steps + * cfg.data.train_bs + * accelerator.num_processes + ) + else: + learning_rate = cfg.solver.learning_rate + + # Initialize the optimizer + if cfg.solver.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError as exc: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) from exc + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + trainable_params = list( + filter(lambda p: p.requires_grad, net.parameters())) + optimizer = optimizer_cls( + trainable_params, + lr=learning_rate, + betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2), + weight_decay=cfg.solver.adam_weight_decay, + eps=cfg.solver.adam_epsilon, + ) + + # init scheduler + lr_scheduler = get_scheduler( + cfg.solver.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.solver.lr_warmup_steps + * cfg.solver.gradient_accumulation_steps, + num_training_steps=cfg.solver.max_train_steps + * cfg.solver.gradient_accumulation_steps, + ) + + # get data loader + train_dataset = FaceMaskDataset( + img_size=(cfg.data.train_width, cfg.data.train_height), + data_meta_paths=cfg.data.meta_paths, + sample_margin=cfg.data.sample_margin, + ) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=4 + ) + + # Prepare everything with our `accelerator`. + ( + net, + optimizer, + train_dataloader, + lr_scheduler, + ) = accelerator.prepare( + net, + optimizer, + train_dataloader, + lr_scheduler, + ) + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / cfg.solver.gradient_accumulation_steps + ) + # Afterwards we recalculate our number of training epochs + num_train_epochs = math.ceil( + cfg.solver.max_train_steps / num_update_steps_per_epoch + ) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + run_time = datetime.now().strftime("%Y%m%d-%H%M") + accelerator.init_trackers( + cfg.exp_name, + init_kwargs={"mlflow": {"run_name": run_time}}, + ) + # dump config file + mlflow.log_dict(OmegaConf.to_container(cfg), "config.yaml") + + logger.info(f"save config to {save_dir}") + OmegaConf.save( + cfg, os.path.join(save_dir, "config.yaml") + ) + # Train! + total_batch_size = ( + cfg.data.train_bs + * accelerator.num_processes + * cfg.solver.gradient_accumulation_steps + ) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {cfg.data.train_bs}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) + logger.info( + f" Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}" + ) + logger.info(f" Total optimization steps = {cfg.solver.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # load checkpoint + # Potentially load in the weights and states from a previous save + if cfg.resume_from_checkpoint: + logger.info(f"Loading checkpoint from {checkpoint_dir}") + global_step = load_checkpoint(cfg, checkpoint_dir, accelerator) + first_epoch = global_step // num_update_steps_per_epoch + + # Only show the progress bar once on each machine. + progress_bar = tqdm( + range(global_step, cfg.solver.max_train_steps), + disable=not accelerator.is_main_process, + ) + progress_bar.set_description("Steps") + net.train() + for _ in range(first_epoch, num_train_epochs): + train_loss = 0.0 + for _, batch in enumerate(train_dataloader): + with accelerator.accumulate(net): + # Convert videos to latent space + pixel_values = batch["img"].to(weight_dtype) + with torch.no_grad(): + latents = vae.encode(pixel_values).latent_dist.sample() + latents = latents.unsqueeze(2) # (b, c, 1, h, w) + latents = latents * 0.18215 + + noise = torch.randn_like(latents) + if cfg.noise_offset > 0.0: + noise += cfg.noise_offset * torch.randn( + (noise.shape[0], noise.shape[1], 1, 1, 1), + device=noise.device, + ) + + bsz = latents.shape[0] + # Sample a random timestep for each video + timesteps = torch.randint( + 0, + train_noise_scheduler.num_train_timesteps, + (bsz,), + device=latents.device, + ) + timesteps = timesteps.long() + + face_mask_img = batch["tgt_mask"] + face_mask_img = face_mask_img.unsqueeze( + 2) + face_mask_img = face_mask_img.to(weight_dtype) + + uncond_fwd = random.random() < cfg.uncond_ratio + face_emb_list = [] + ref_image_list = [] + for _, (ref_img, face_emb) in enumerate( + zip(batch["ref_img"], batch["face_emb"]) + ): + if uncond_fwd: + face_emb_list.append(torch.zeros_like(face_emb)) + else: + face_emb_list.append(face_emb) + ref_image_list.append(ref_img) + + with torch.no_grad(): + ref_img = torch.stack(ref_image_list, dim=0).to( + dtype=vae.dtype, device=vae.device + ) + ref_image_latents = vae.encode( + ref_img + ).latent_dist.sample() + ref_image_latents = ref_image_latents * 0.18215 + + face_emb = torch.stack(face_emb_list, dim=0).to( + dtype=imageproj.dtype, device=imageproj.device + ) + + # add noise + noisy_latents = train_noise_scheduler.add_noise( + latents, noise, timesteps + ) + + # Get the target for loss depending on the prediction type + if train_noise_scheduler.prediction_type == "epsilon": + target = noise + elif train_noise_scheduler.prediction_type == "v_prediction": + target = train_noise_scheduler.get_velocity( + latents, noise, timesteps + ) + else: + raise ValueError( + f"Unknown prediction type {train_noise_scheduler.prediction_type}" + ) + model_pred = net( + noisy_latents, + timesteps, + ref_image_latents, + face_emb, + face_mask_img, + uncond_fwd, + ) + + if cfg.snr_gamma == 0: + loss = F.mse_loss( + model_pred.float(), target.float(), reduction="mean" + ) + else: + snr = compute_snr(train_noise_scheduler, timesteps) + if train_noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( + torch.stack( + [snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / snr + ) + loss = F.mse_loss( + model_pred.float(), target.float(), reduction="none" + ) + loss = ( + loss.mean(dim=list(range(1, len(loss.shape)))) + * mse_loss_weights + ) + loss = loss.mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather( + loss.repeat(cfg.data.train_bs)).mean() + train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_( + trainable_params, + cfg.solver.max_grad_norm, + ) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + if accelerator.sync_gradients: + reference_control_reader.clear() + reference_control_writer.clear() + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + if global_step % cfg.checkpointing_steps == 0 or global_step == cfg.solver.max_train_steps: + accelerator.wait_for_everyone() + save_path = os.path.join( + checkpoint_dir, f"checkpoint-{global_step}") + if accelerator.is_main_process: + delete_additional_ckpt(checkpoint_dir, 3) + accelerator.save_state(save_path) + accelerator.wait_for_everyone() + unwrap_net = accelerator.unwrap_model(net) + if accelerator.is_main_process: + save_checkpoint( + unwrap_net.reference_unet, + module_dir, + "reference_unet", + global_step, + total_limit=3, + ) + save_checkpoint( + unwrap_net.imageproj, + module_dir, + "imageproj", + global_step, + total_limit=3, + ) + save_checkpoint( + unwrap_net.denoising_unet, + module_dir, + "denoising_unet", + global_step, + total_limit=3, + ) + save_checkpoint( + unwrap_net.face_locator, + module_dir, + "face_locator", + global_step, + total_limit=3, + ) + + if global_step % cfg.val.validation_steps == 0 or global_step == 1: + if accelerator.is_main_process: + generator = torch.Generator(device=accelerator.device) + generator.manual_seed(cfg.seed) + log_validation( + vae=vae, + net=net, + scheduler=val_noise_scheduler, + accelerator=accelerator, + width=cfg.data.train_width, + height=cfg.data.train_height, + imageproj=imageproj, + cfg=cfg, + save_dir=validation_dir, + global_step=global_step, + face_analysis_model_path=cfg.face_analysis_model_path + ) + + logs = { + "step_loss": loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + } + progress_bar.set_postfix(**logs) + + if global_step >= cfg.solver.max_train_steps: + # process final module weight for stage2 + if accelerator.is_main_process: + move_final_checkpoint(save_dir, module_dir, "reference_unet") + move_final_checkpoint(save_dir, module_dir, "imageproj") + move_final_checkpoint(save_dir, module_dir, "denoising_unet") + move_final_checkpoint(save_dir, module_dir, "face_locator") + break + + accelerator.wait_for_everyone() + accelerator.end_training() + + +def load_config(config_path: str) -> dict: + """ + Loads the configuration file. + + Args: + config_path (str): Path to the configuration file. + + Returns: + dict: The configuration dictionary. + """ + + if config_path.endswith(".yaml"): + return OmegaConf.load(config_path) + if config_path.endswith(".py"): + return import_filename(config_path).cfg + raise ValueError("Unsupported format for config file") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, + default="./configs/train/stage1.yaml") + args = parser.parse_args() + + try: + config = load_config(args.config) + train_stage1_process(config) + except Exception as e: + logging.error("Failed to execute the training process: %s", e) diff --git a/Hallo2/hallo2/scripts/train_stage2_long.py b/Hallo2/hallo2/scripts/train_stage2_long.py new file mode 100644 index 00000000..961d32dd --- /dev/null +++ b/Hallo2/hallo2/scripts/train_stage2_long.py @@ -0,0 +1,1051 @@ +# pylint: disable=E1101,C0415,W0718,R0801 +# scripts/train_stage2.py +""" +This is the main training script for stage 2 of the project. +It imports necessary packages, defines necessary classes and functions, and trains the model using the provided configuration. + +The script includes the following classes and functions: + +1. Net: A PyTorch model that takes noisy latents, timesteps, reference image latents, face embeddings, + and face masks as input and returns the denoised latents. +2. get_attention_mask: A function that rearranges the mask tensors to the required format. +3. get_noise_scheduler: A function that creates and returns the noise schedulers for training and validation. +4. process_audio_emb: A function that processes the audio embeddings to concatenate with other tensors. +5. log_validation: A function that logs the validation information using the given VAE, image encoder, + network, scheduler, accelerator, width, height, and configuration. +6. train_stage2_process: A function that processes the training stage 2 using the given configuration. +7. load_config: A function that loads the configuration file from the given path. + +The script also includes the necessary imports and a brief description of the purpose of the file. +""" + +import argparse +import copy +import logging +import math +import os +import random +import time +import warnings +from datetime import datetime +from typing import List, Tuple + +import diffusers +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs +from diffusers import AutoencoderKL, DDIMScheduler +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version +from diffusers.utils.import_utils import is_xformers_available +from einops import rearrange, repeat +from omegaconf import OmegaConf +from torch import nn +from tqdm.auto import tqdm +from accelerate.utils import ProjectConfiguration +from torchvision import transforms +from PIL import Image +import numpy as np +from pathlib import Path + +from hallo.animate.face_animate import FaceAnimatePipeline +from hallo.datasets.audio_processor import AudioProcessor +from hallo.datasets.image_processor import ImageProcessor +from hallo.datasets.talk_video import TalkingVideoDataset +from hallo.models.audio_proj import AudioProjModel +from hallo.models.face_locator import FaceLocator +from hallo.models.image_proj import ImageProjModel +from hallo.models.mutual_self_attention import ReferenceAttentionControl +from hallo.models.unet_2d_condition import UNet2DConditionModel +from hallo.models.unet_3d import UNet3DConditionModel +from hallo.utils.util import (compute_snr, delete_additional_ckpt, + import_filename, init_output_dir, + load_checkpoint, save_checkpoint, + seed_everything, tensor_to_video) + +warnings.filterwarnings("ignore") + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.10.0.dev0") + +logger = get_logger(__name__, log_level="INFO") +from icecream import ic + +class Net(nn.Module): + """ + The Net class defines a neural network model that combines a reference UNet2DConditionModel, + a denoising UNet3DConditionModel, a face locator, and other components to animate a face in a static image. + + Args: + reference_unet (UNet2DConditionModel): The reference UNet2DConditionModel used for face animation. + denoising_unet (UNet3DConditionModel): The denoising UNet3DConditionModel used for face animation. + face_locator (FaceLocator): The face locator model used for face animation. + reference_control_writer: The reference control writer component. + reference_control_reader: The reference control reader component. + imageproj: The image projection model. + audioproj: The audio projection model. + + Forward method: + noisy_latents (torch.Tensor): The noisy latents tensor. + timesteps (torch.Tensor): The timesteps tensor. + ref_image_latents (torch.Tensor): The reference image latents tensor. + face_emb (torch.Tensor): The face embeddings tensor. + audio_emb (torch.Tensor): The audio embeddings tensor. + mask (torch.Tensor): Hard face mask for face locator. + full_mask (torch.Tensor): Pose Mask. + face_mask (torch.Tensor): Face Mask + lip_mask (torch.Tensor): Lip Mask + uncond_img_fwd (bool): A flag indicating whether to perform reference image unconditional forward pass. + uncond_audio_fwd (bool): A flag indicating whether to perform audio unconditional forward pass. + + Returns: + torch.Tensor: The output tensor of the neural network model. + """ + def __init__( + self, + reference_unet: UNet2DConditionModel, + denoising_unet: UNet3DConditionModel, + face_locator: FaceLocator, + reference_control_writer, + reference_control_reader, + imageproj, + audioproj, + ): + super().__init__() + self.reference_unet = reference_unet + self.denoising_unet = denoising_unet + self.face_locator = face_locator + self.reference_control_writer = reference_control_writer + self.reference_control_reader = reference_control_reader + self.imageproj = imageproj + self.audioproj = audioproj + + def forward( + self, + noisy_latents: torch.Tensor, + timesteps: torch.Tensor, + ref_image_latents: torch.Tensor, + face_emb: torch.Tensor, + audio_emb: torch.Tensor, + mask: torch.Tensor, + full_mask: torch.Tensor, + face_mask: torch.Tensor, + lip_mask: torch.Tensor, + uncond_img_fwd: bool = False, + uncond_audio_fwd: bool = False, + ): + """ + simple docstring to prevent pylint error + """ + face_emb = self.imageproj(face_emb) + mask = mask.to(device="cuda") + mask_feature = self.face_locator(mask) + audio_emb = audio_emb.to( + device=self.audioproj.device, dtype=self.audioproj.dtype) + audio_emb = self.audioproj(audio_emb) + + # condition forward + if not uncond_img_fwd: + ref_timesteps = torch.zeros_like(timesteps) + ref_timesteps = repeat( + ref_timesteps, + "b -> (repeat b)", + repeat=ref_image_latents.size(0) // ref_timesteps.size(0), + ) + self.reference_unet( + ref_image_latents, + ref_timesteps, + encoder_hidden_states=face_emb, + return_dict=False, + ) + self.reference_control_reader.update(self.reference_control_writer) + + if uncond_audio_fwd: + audio_emb = torch.zeros_like(audio_emb).to( + device=audio_emb.device, dtype=audio_emb.dtype + ) + + model_pred = self.denoising_unet( + noisy_latents, + timesteps, + mask_cond_fea=mask_feature, + encoder_hidden_states=face_emb, + audio_embedding=audio_emb, + full_mask=full_mask, + face_mask=face_mask, + lip_mask=lip_mask + ).sample + + return model_pred + + +def get_attention_mask(mask: torch.Tensor, weight_dtype: torch.dtype) -> torch.Tensor: + """ + Rearrange the mask tensors to the required format. + + Args: + mask (torch.Tensor): The input mask tensor. + weight_dtype (torch.dtype): The data type for the mask tensor. + + Returns: + torch.Tensor: The rearranged mask tensor. + """ + if isinstance(mask, List): + _mask = [] + for m in mask: + _mask.append( + rearrange(m, "b f 1 h w -> (b f) (h w)").to(weight_dtype)) + return _mask + mask = rearrange(mask, "b f 1 h w -> (b f) (h w)").to(weight_dtype) + return mask + + +def get_noise_scheduler(cfg: argparse.Namespace) -> Tuple[DDIMScheduler, DDIMScheduler]: + """ + Create noise scheduler for training. + + Args: + cfg (argparse.Namespace): Configuration object. + + Returns: + Tuple[DDIMScheduler, DDIMScheduler]: Train noise scheduler and validation noise scheduler. + """ + + sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs) + if cfg.enable_zero_snr: + sched_kwargs.update( + rescale_betas_zero_snr=True, + timestep_spacing="trailing", + prediction_type="v_prediction", + ) + val_noise_scheduler = DDIMScheduler(**sched_kwargs) + sched_kwargs.update({"beta_schedule": "scaled_linear"}) + train_noise_scheduler = DDIMScheduler(**sched_kwargs) + + return train_noise_scheduler, val_noise_scheduler + + +def process_audio_emb(audio_emb: torch.Tensor) -> torch.Tensor: + """ + Process the audio embedding to concatenate with other tensors. + + Parameters: + audio_emb (torch.Tensor): The audio embedding tensor to process. + + Returns: + concatenated_tensors (List[torch.Tensor]): The concatenated tensor list. + """ + concatenated_tensors = [] + + for i in range(audio_emb.shape[0]): + vectors_to_concat = [ + audio_emb[max(min(i + j, audio_emb.shape[0] - 1), 0)]for j in range(-2, 3)] + concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0)) + + audio_emb = torch.stack(concatenated_tensors, dim=0) + + return audio_emb + +def save_image_batch(image_tensor, save_path): + image_tensor = (image_tensor + 1) / 2 + + os.makedirs(save_path, exist_ok=True) + + for i in range(image_tensor.shape[0]): + img_tensor = image_tensor[i] + + img_array = img_tensor.permute(1, 2, 0).cpu().numpy() + + img_array = (img_array * 255).astype(np.uint8) + + image = Image.fromarray(img_array) + image.save(os.path.join(save_path, f'motion_frame_{i}.png')) + +def log_validation( + accelerator: Accelerator, + vae: AutoencoderKL, + net: Net, + scheduler: DDIMScheduler, + width: int, + height: int, + clip_length: int = 24, + generator: torch.Generator = None, + cfg: dict = None, + save_dir: str = None, + global_step: int = 0, + times: int = None, + face_analysis_model_path: str = "", +) -> None: + """ + Log validation video during the training process. + + Args: + accelerator (Accelerator): The accelerator for distributed training. + vae (AutoencoderKL): The autoencoder model. + net (Net): The main neural network model. + scheduler (DDIMScheduler): The scheduler for noise. + width (int): The width of the input images. + height (int): The height of the input images. + clip_length (int): The length of the video clips. Defaults to 24. + generator (torch.Generator): The random number generator. Defaults to None. + cfg (dict): The configuration dictionary. Defaults to None. + save_dir (str): The directory to save validation results. Defaults to None. + global_step (int): The current global step in training. Defaults to 0. + times (int): The number of inference times. Defaults to None. + face_analysis_model_path (str): The path to the face analysis model. Defaults to "". + + Returns: + torch.Tensor: The tensor result of the validation. + """ + ori_net = accelerator.unwrap_model(net) + reference_unet = ori_net.reference_unet + denoising_unet = ori_net.denoising_unet + face_locator = ori_net.face_locator + imageproj = ori_net.imageproj + audioproj = ori_net.audioproj + + generator = torch.manual_seed(42) + tmp_denoising_unet = copy.deepcopy(denoising_unet) + + pipeline = FaceAnimatePipeline( + vae=vae, + reference_unet=reference_unet, + denoising_unet=tmp_denoising_unet, + face_locator=face_locator, + image_proj=imageproj, + scheduler=scheduler, + ) + pipeline = pipeline.to("cuda") + + image_processor = ImageProcessor((width, height), face_analysis_model_path) + audio_processor = AudioProcessor( + cfg.data.sample_rate, + cfg.data.fps, + cfg.wav2vec_config.model_path, + cfg.wav2vec_config.features == "last", + os.path.dirname(cfg.audio_separator.model_path), + os.path.basename(cfg.audio_separator.model_path), + os.path.join(save_dir, '.cache', "audio_preprocess") + ) + + for idx, ref_img_path in enumerate(cfg.ref_img_path): + audio_path = cfg.audio_path[idx] + source_image_pixels, \ + source_image_face_region, \ + source_image_face_emb, \ + source_image_full_mask, \ + source_image_face_mask, \ + source_image_lip_mask = image_processor.preprocess( + ref_img_path, os.path.join(save_dir, '.cache'), cfg.face_expand_ratio) + audio_emb, audio_length = audio_processor.preprocess( + audio_path, clip_length) + + audio_emb = process_audio_emb(audio_emb) + + source_image_pixels = source_image_pixels.unsqueeze(0) + source_image_face_region = source_image_face_region.unsqueeze(0) + source_image_face_emb = source_image_face_emb.reshape(1, -1) + source_image_face_emb = torch.tensor(source_image_face_emb) + + source_image_full_mask = [ + (mask.repeat(clip_length, 1)) + for mask in source_image_full_mask + ] + source_image_face_mask = [ + (mask.repeat(clip_length, 1)) + for mask in source_image_face_mask + ] + source_image_lip_mask = [ + (mask.repeat(clip_length, 1)) + for mask in source_image_lip_mask + ] + + times = audio_emb.shape[0] // clip_length + tensor_result = [] + generator = torch.manual_seed(42) + + save_path = os.path.join(save_dir, f"{global_step}_{Path(ref_img_path).name}") + + for t in range(times): + print(f"[{t+1}/{times}]") + + if len(tensor_result) == 0: + # The first iteration + motion_zeros = source_image_pixels.repeat( + cfg.data.n_motion_frames, 1, 1, 1) + motion_zeros = motion_zeros.to( + dtype=source_image_pixels.dtype, device=source_image_pixels.device) + pixel_values_ref_img = torch.cat( + [source_image_pixels, motion_zeros], dim=0) # concat the ref image and the first motion frames + else: + motion_frames = tensor_result[-1][0] + motion_frames = motion_frames.permute(1, 0, 2, 3) + motion_frames = motion_frames[0 - cfg.data.n_motion_frames:] + motion_frames = motion_frames * 2.0 - 1.0 + motion_frames = motion_frames.to( + dtype=source_image_pixels.dtype, device=source_image_pixels.device) + pixel_values_ref_img = torch.cat( + [source_image_pixels, motion_frames], dim=0) # concat the ref image and the motion frames + + pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0) + + pixel_motion_values = pixel_values_ref_img[:, 1:] + + if cfg.use_mask: + + b, f, c, h, w = pixel_motion_values.shape + rand_mask = torch.rand(h, w) + mask = rand_mask > cfg.mask_rate + mask = mask.unsqueeze(0).unsqueeze(0).unsqueeze(0) + mask = mask.expand(b, f, c, h, w) + + face_mask = source_image_face_region.repeat(f, 1, 1, 1).unsqueeze(0) + assert face_mask.shape == mask.shape + mask = mask | face_mask.bool() + + pixel_motion_values = pixel_motion_values * mask + pixel_values_ref_img[:, 1:] = pixel_motion_values + + assert pixel_motion_values.shape[0] == 1 + + audio_tensor = audio_emb[ + t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0]) + ] + audio_tensor = audio_tensor.unsqueeze(0) + audio_tensor = audio_tensor.to( + device=audioproj.device, dtype=audioproj.dtype) + audio_tensor = audioproj(audio_tensor) + + pipeline_output = pipeline( + ref_image=pixel_values_ref_img, + audio_tensor=audio_tensor, + face_emb=source_image_face_emb, + face_mask=source_image_face_region, + pixel_values_full_mask=source_image_full_mask, + pixel_values_face_mask=source_image_face_mask, + pixel_values_lip_mask=source_image_lip_mask, + width=cfg.data.train_width, + height=cfg.data.train_height, + video_length=clip_length, + num_inference_steps=cfg.inference_steps, + guidance_scale=cfg.cfg_scale, + generator=generator + ) + + tensor_result.append(pipeline_output.videos) + + tensor_result = torch.cat(tensor_result, dim=2) + tensor_result = tensor_result.squeeze(0) + tensor_result = tensor_result[:, :audio_length] + audio_name = os.path.basename(audio_path).split('.')[0] + ref_name = os.path.basename(ref_img_path).split('.')[0] + output_file = os.path.join(save_path,f"{global_step}_{ref_name}_{audio_name}.mp4") + # save the result after all iteration + tensor_to_video(tensor_result, output_file, audio_path) + + + # clean up + del tmp_denoising_unet + del pipeline + del image_processor + del audio_processor + torch.cuda.empty_cache() + + return tensor_result + + + +def train_stage2_process(cfg: argparse.Namespace) -> None: + """ + Trains the model using the given configuration (cfg). + + Args: + cfg (dict): The configuration dictionary containing the parameters for training. + + Notes: + - This function trains the model using the given configuration. + - It initializes the necessary components for training, such as the pipeline, optimizer, and scheduler. + - The training progress is logged and tracked using the accelerator. + - The trained model is saved after the training is completed. + """ + config = ProjectConfiguration(project_dir=".", logging_dir="log") + accelerator = Accelerator( + gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps, + mixed_precision=cfg.solver.mixed_precision, + log_with="tensorboard", + project_config=config + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if cfg.seed is not None: + seed_everything(cfg.seed) + + # create output dir for training + exp_name = cfg.exp_name + save_dir = f"{cfg.output_dir}/{exp_name}" + checkpoint_dir = os.path.join(save_dir, "checkpoints") + module_dir = os.path.join(save_dir, "modules") + validation_dir = os.path.join(save_dir, "validation") + if accelerator.is_main_process: + init_output_dir([save_dir, checkpoint_dir, module_dir, validation_dir]) + + accelerator.wait_for_everyone() + + if cfg.weight_dtype == "fp16": + weight_dtype = torch.float16 + elif cfg.weight_dtype == "bf16": + weight_dtype = torch.bfloat16 + elif cfg.weight_dtype == "fp32": + weight_dtype = torch.float32 + else: + raise ValueError( + f"Do not support weight dtype: {cfg.weight_dtype} during training" + ) + + # Create Models + vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to( + "cuda", dtype=weight_dtype + ) + reference_unet = UNet2DConditionModel.from_pretrained( + cfg.base_model_path, + subfolder="unet", + ).to(device="cuda", dtype=weight_dtype) + denoising_unet = UNet3DConditionModel.from_pretrained_2d( + cfg.base_model_path, + cfg.mm_path, + subfolder="unet", + unet_additional_kwargs=OmegaConf.to_container( + cfg.unet_additional_kwargs), + use_landmark=False + ).to(device="cuda", dtype=weight_dtype) + imageproj = ImageProjModel( + cross_attention_dim=denoising_unet.config.cross_attention_dim, + clip_embeddings_dim=512, + clip_extra_context_tokens=4, + ).to(device="cuda", dtype=weight_dtype) + face_locator = FaceLocator( + conditioning_embedding_channels=320, + ).to(device="cuda", dtype=weight_dtype) + audioproj = AudioProjModel( + seq_len=5, + blocks=12, + channels=768, + intermediate_dim=512, + output_dim=768, + context_tokens=32, + ).to(device="cuda", dtype=weight_dtype) + + # Freeze + vae.requires_grad_(False) + imageproj.requires_grad_(False) + reference_unet.requires_grad_(False) + denoising_unet.requires_grad_(False) + face_locator.requires_grad_(False) + audioproj.requires_grad_(False) + + # Set motion module learnable + trainable_modules = cfg.trainable_para + for name, module in denoising_unet.named_modules(): + if any(trainable_mod in name for trainable_mod in trainable_modules): + for params in module.parameters(): + params.requires_grad_(True) + + reference_control_writer = ReferenceAttentionControl( + reference_unet, + do_classifier_free_guidance=False, + mode="write", + fusion_blocks="full", + ) + reference_control_reader = ReferenceAttentionControl( + denoising_unet, + do_classifier_free_guidance=False, + mode="read", + fusion_blocks="full", + ) + + net = Net( + reference_unet, + denoising_unet, + face_locator, + reference_control_writer, + reference_control_reader, + imageproj, + audioproj, + ).to(dtype=weight_dtype) + + m,u = net.load_state_dict( + torch.load( + os.path.join(cfg.audio_ckpt_dir, "net-3000.pth"), + map_location="cpu", + ), + strict=False + ) + + logger.info(f"missing key: {m}") + logger.info(f"unexcepted key: {u}") + + # get noise scheduler + train_noise_scheduler, val_noise_scheduler = get_noise_scheduler(cfg) + + if cfg.solver.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + reference_unet.enable_xformers_memory_efficient_attention() + denoising_unet.enable_xformers_memory_efficient_attention() + + else: + raise ValueError( + "xformers is not available. Make sure it is installed correctly" + ) + + if cfg.solver.gradient_checkpointing: + reference_unet.enable_gradient_checkpointing() + denoising_unet.enable_gradient_checkpointing() + + if cfg.solver.scale_lr: + learning_rate = ( + cfg.solver.learning_rate + * cfg.solver.gradient_accumulation_steps + * cfg.data.train_bs + * accelerator.num_processes + ) + else: + learning_rate = cfg.solver.learning_rate + + # Initialize the optimizer + if cfg.solver.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError as exc: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) from exc + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + trainable_params = list( + filter(lambda p: p.requires_grad, net.parameters())) + logger.info(f"Total trainable params {len(trainable_params)}") + optimizer = optimizer_cls( + trainable_params, + lr=learning_rate, + betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2), + weight_decay=cfg.solver.adam_weight_decay, + eps=cfg.solver.adam_epsilon, + ) + + # Scheduler + lr_scheduler = get_scheduler( + cfg.solver.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.solver.lr_warmup_steps + * cfg.solver.gradient_accumulation_steps, + num_training_steps=cfg.solver.max_train_steps + * cfg.solver.gradient_accumulation_steps, + ) + + # get data loader + train_dataset = TalkingVideoDataset( + img_size=(cfg.data.train_width, cfg.data.train_height), + sample_rate=cfg.data.sample_rate, + n_sample_frames=cfg.data.n_sample_frames, + n_motion_frames=cfg.data.n_motion_frames, + audio_margin=cfg.data.audio_margin, + data_meta_paths=cfg.data.train_meta_paths, + wav2vec_cfg=cfg.wav2vec_config, + ) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=16 + ) + + # Prepare everything with our `accelerator`. + ( + net, + optimizer, + train_dataloader, + lr_scheduler, + ) = accelerator.prepare( + net, + optimizer, + train_dataloader, + lr_scheduler, + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / cfg.solver.gradient_accumulation_steps + ) + # Afterwards we recalculate our number of training epochs + num_train_epochs = math.ceil( + cfg.solver.max_train_steps / num_update_steps_per_epoch + ) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + run_time = datetime.now().strftime("%Y%m%d-%H%M") + accelerator.init_trackers( + exp_name, + init_kwargs={"tensorboard": {"flush_secs": 60}}, + ) + + logger.info(f"save config to {save_dir}") + OmegaConf.save( + cfg, os.path.join(save_dir, "config.yaml") + ) + + # Train! + total_batch_size = ( + cfg.data.train_bs + * accelerator.num_processes + * cfg.solver.gradient_accumulation_steps + ) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {cfg.data.train_bs}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) + logger.info( + f" Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}" + ) + logger.info(f" Total optimization steps = {cfg.solver.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # # Potentially load in the weights and states from a previous save + if cfg.resume_from_checkpoint: + logger.info(f"Loading checkpoint from {checkpoint_dir}") + global_step = load_checkpoint(cfg, checkpoint_dir, accelerator) + first_epoch = global_step // num_update_steps_per_epoch + + # Only show the progress bar once on each machine. + progress_bar = tqdm( + range(global_step, cfg.solver.max_train_steps), + disable=not accelerator.is_local_main_process, + ) + progress_bar.set_description("Steps") + + for _ in range(first_epoch, num_train_epochs): + train_loss = 0.0 + t_data_start = time.time() + for idx, batch in enumerate(train_dataloader): + t_data = time.time() - t_data_start + with accelerator.accumulate(net): + # Convert videos to latent space + pixel_values_vid = batch["pixel_values_vid"].to(weight_dtype) + + pixel_values_face_mask = batch["pixel_values_face_mask"] + pixel_values_face_mask = get_attention_mask( + pixel_values_face_mask, weight_dtype + ) + pixel_values_lip_mask = batch["pixel_values_lip_mask"] + pixel_values_lip_mask = get_attention_mask( + pixel_values_lip_mask, weight_dtype + ) + pixel_values_full_mask = batch["pixel_values_full_mask"] + pixel_values_full_mask = get_attention_mask( + pixel_values_full_mask, weight_dtype + ) + + + with torch.no_grad(): + video_length = pixel_values_vid.shape[1] + pixel_values_vid = rearrange( + pixel_values_vid, "b f c h w -> (b f) c h w" + ) + latents = vae.encode(pixel_values_vid).latent_dist.sample() + latents = rearrange( + latents, "(b f) c h w -> b c f h w", f=video_length + ) + latents = latents * 0.18215 + + noise = torch.randn_like(latents) + if cfg.noise_offset > 0: + noise += cfg.noise_offset * torch.randn( + (latents.shape[0], latents.shape[1], 1, 1, 1), + device=latents.device, + ) + + bsz = latents.shape[0] + # Sample a random timestep for each video + timesteps = torch.randint( + 0, + train_noise_scheduler.num_train_timesteps, + (bsz,), + device=latents.device, + ) + timesteps = timesteps.long() + + motion_timesteps = torch.randint( + 0, + 50, + (bsz,), + device=latents.device, + ) + motion_timesteps = motion_timesteps.long() + + # mask for face locator + pixel_values_mask = ( + batch["pixel_values_mask"].unsqueeze( + 1).to(dtype=weight_dtype) + ) + pixel_values_mask = repeat( + pixel_values_mask, + "b f c h w -> b (repeat f) c h w", + repeat=video_length, + ) + pixel_values_mask = pixel_values_mask.transpose( + 1, 2) + + uncond_img_fwd = random.random() < cfg.uncond_img_ratio + uncond_audio_fwd = random.random() < cfg.uncond_audio_ratio + + start_frame = random.random() < cfg.start_ratio + pixel_values_ref_img = batch["pixel_values_ref_img"].to( + dtype=weight_dtype + ) + # initialize the motion frames as zero maps + if start_frame: + pixel_values_ref_img[:, 1:] = 0.0 + + # random mask + use_mask = random.random() < cfg.use_mask + + # assert use_mask + + with torch.no_grad(): + + motion_latents = pixel_values_ref_img[:, 1:] + motion_noise = torch.randn_like(motion_latents) + if cfg.noise_offset > 0: + motion_noise += cfg.noise_offset * torch.randn( + (motion_latents.shape[0], motion_latents.shape[1], 1, 1, 1), + device=latents.device, + ) + + # add motion noise + noisy_motion_latents = train_noise_scheduler.add_noise( + motion_latents, motion_noise, motion_timesteps + ) + pixel_values_ref_img[:, 1:] = noisy_motion_latents + + + if use_mask: + pixel_motion_values = pixel_values_ref_img[:, 1:] + + b, f, c, h, w = pixel_motion_values.shape + rand_mask = torch.rand(h, w).to(device=pixel_motion_values.device) + mask = rand_mask > cfg.mask_rate + mask = mask.unsqueeze(0).unsqueeze(0).unsqueeze(0) + mask = mask.expand(b, f, c, h, w) + + face_mask = pixel_values_mask.transpose(1, 2)[:,:f] + assert face_mask.shape == mask.shape + mask = mask | face_mask.bool() + + pixel_motion_values = pixel_motion_values * mask + pixel_values_ref_img[:, 1:] = pixel_motion_values + + + ref_img_and_motion = rearrange( + pixel_values_ref_img, "b f c h w -> (b f) c h w" + ) + + ref_image_latents = vae.encode( + ref_img_and_motion + ).latent_dist.sample() + ref_image_latents = ref_image_latents * 0.18215 + image_prompt_embeds = batch["face_emb"].to( + dtype=imageproj.dtype, device=imageproj.device + ) + + # add noise + noisy_latents = train_noise_scheduler.add_noise( + latents, noise, timesteps + ) + + # Get the target for loss depending on the prediction type + if train_noise_scheduler.prediction_type == "epsilon": + target = noise + elif train_noise_scheduler.prediction_type == "v_prediction": + target = train_noise_scheduler.get_velocity( + latents, noise, timesteps + ) + else: + raise ValueError( + f"Unknown prediction type {train_noise_scheduler.prediction_type}" + ) + + # ---- Forward!!! ----- + model_pred = net( + noisy_latents=noisy_latents, + timesteps=timesteps, + ref_image_latents=ref_image_latents, + face_emb=image_prompt_embeds, + mask=pixel_values_mask, + full_mask=pixel_values_full_mask, + face_mask=pixel_values_face_mask, + lip_mask=pixel_values_lip_mask, + audio_emb=batch["audio_tensor"].to( + dtype=weight_dtype), + uncond_img_fwd=uncond_img_fwd, + uncond_audio_fwd=uncond_audio_fwd + ) + + if cfg.snr_gamma == 0: + loss = F.mse_loss( + model_pred.float(), + target.float(), + reduction="mean", + ) + else: + snr = compute_snr(train_noise_scheduler, timesteps) + if train_noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( + torch.stack( + [snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / snr + ) + loss = F.mse_loss( + model_pred.float(), + target.float(), + reduction="mean", + ) + loss = ( + loss.mean(dim=list(range(1, len(loss.shape)))) + * mse_loss_weights + ).mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather( + loss.repeat(cfg.data.train_bs)).mean() + train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_( + trainable_params, + cfg.solver.max_grad_norm, + ) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + if accelerator.sync_gradients: + reference_control_reader.clear() + reference_control_writer.clear() + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % cfg.val.validation_steps == 0 or global_step==1: + if accelerator.is_main_process: + generator = torch.Generator(device=accelerator.device) + generator.manual_seed(cfg.seed) + + log_validation( + accelerator=accelerator, + vae=vae, + net=net, + scheduler=val_noise_scheduler, + width=cfg.data.train_width, + height=cfg.data.train_height, + clip_length=cfg.data.n_sample_frames, + cfg=cfg, + save_dir=validation_dir, + global_step=global_step, + times=cfg.single_inference_times if cfg.single_inference_times is not None else None, + face_analysis_model_path=cfg.face_analysis_model_path + ) + + logs = { + "step_loss": loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + "td": f"{t_data:.2f}s", + } + t_data_start = time.time() + progress_bar.set_postfix(**logs) + + if ( + global_step % cfg.checkpointing_steps == 0 + or global_step == cfg.solver.max_train_steps + ): + # save model + save_path = os.path.join( + checkpoint_dir, f"checkpoint-{global_step}") + if accelerator.is_main_process: + delete_additional_ckpt(checkpoint_dir, 100) + accelerator.wait_for_everyone() + accelerator.save_state(save_path) + + # save model weight + unwrap_net = accelerator.unwrap_model(net) + if accelerator.is_main_process: + save_checkpoint( + unwrap_net, + module_dir, + "net", + global_step, + total_limit=100, + ) + if global_step >= cfg.solver.max_train_steps: + break + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + accelerator.end_training() + + +def load_config(config_path: str) -> dict: + """ + Loads the configuration file. + + Args: + config_path (str): Path to the configuration file. + + Returns: + dict: The configuration dictionary. + """ + + if config_path.endswith(".yaml"): + return OmegaConf.load(config_path) + if config_path.endswith(".py"): + return import_filename(config_path).cfg + raise ValueError("Unsupported format for config file") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", type=str, default="./configs/train/stage2_long.yaml" + ) + args = parser.parse_args() + + config = load_config(args.config) + train_stage2_process(config) diff --git a/Hallo2/hallo2/scripts/video_sr.py b/Hallo2/hallo2/scripts/video_sr.py new file mode 100644 index 00000000..b3c79f8e --- /dev/null +++ b/Hallo2/hallo2/scripts/video_sr.py @@ -0,0 +1,311 @@ +""" +Modified from [CodeFormer](https://github.com/sczhou/CodeFormer). +When using or redistributing this feature, please comply with the [S-Lab License 1.0](https://github.com/sczhou/CodeFormer?tab=License-1-ov-file). +We kindly request that you respect the terms of this license in any usage or redistribution of this component. +""" + +import os +import cv2 +import argparse +import glob +import sys + +import torch +from torchvision.transforms.functional import normalize + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from basicsr.utils import imwrite, img2tensor, tensor2img +from basicsr.utils.download_util import load_file_from_url +from basicsr.utils.misc import gpu_is_available, get_device +from facelib.utils.face_restoration_helper import FaceRestoreHelper +from facelib.utils.misc import is_gray + +from basicsr.utils.registry import ARCH_REGISTRY + + +def set_realesrgan(): + from basicsr.archs.rrdbnet_arch import RRDBNet + from basicsr.utils.realesrgan_utils import RealESRGANer + + use_half = False + if torch.cuda.is_available(): # set False in CPU/MPS mode + no_half_gpu_list = ['1650', '1660'] # set False for GPUs that don't support f16 + if not True in [gpu in torch.cuda.get_device_name(0) for gpu in no_half_gpu_list]: + use_half = True + + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=2, + ) + upsampler = RealESRGANer( + scale=2, + model_path="./pretrained_models/realesrgan/RealESRGAN_x2plus.pth", + model=model, + tile=args.bg_tile, + tile_pad=40, + pre_pad=0, + half=use_half + ) + + if not gpu_is_available(): # CPU + import warnings + warnings.warn('Running on CPU now! Make sure your PyTorch version matches your CUDA.' + 'The unoptimized RealESRGAN is slow on CPU. ' + 'If you want to disable it, please remove `--bg_upsampler` and `--face_upsample` in command.', + category=RuntimeWarning) + return upsampler + + + + +if __name__ == '__main__': + # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = get_device() + parser = argparse.ArgumentParser() + + parser.add_argument('-i', '--input_path', type=str, help='Input video') + parser.add_argument('-o', '--output_path', type=str, default=None, + help='Output folder') + parser.add_argument('-w', '--fidelity_weight', type=float, default=0.5, + help='Balance the quality and fidelity. Default: 0.5') + parser.add_argument('-s', '--upscale', type=int, default=2, + help='The final upsampling scale of the image. Default: 2') + parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces. Default: False') + parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face. Default: False') + parser.add_argument('--draw_box', action='store_true', help='Draw the bounding box for the detected faces. Default: False') + # large det_model: 'YOLOv5l', 'retinaface_resnet50' + # small det_model: 'YOLOv5n', 'retinaface_mobile0.25' + parser.add_argument('--detection_model', type=str, default='retinaface_resnet50', + help='Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n. \ + Default: retinaface_resnet50') + parser.add_argument('--bg_upsampler', type=str, default='None', help='Background upsampler. Optional: realesrgan') + parser.add_argument('--face_upsample', action='store_true', help='Face upsampler after enhancement. Default: False') + parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler. Default: 400') + parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces. Default: None') + + args = parser.parse_args() + + # ------------------------ input & output ------------------------ + w = args.fidelity_weight + input_video = False + if args.input_path.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path + from basicsr.utils.video_util import VideoReader, VideoWriter + input_img_list = [] + vidreader = VideoReader(args.input_path) + image = vidreader.get_frame() + while image is not None: + input_img_list.append(image) + image = vidreader.get_frame() + audio = vidreader.get_audio() + fps = vidreader.get_fps() + video_name = os.path.basename(args.input_path)[:-4] + result_root = f'./hq_results/{video_name}_{w}_{args.upscale}' + input_video = True + vidreader.close() + else: + raise RuntimeError("input should be mp4 file") + + if not args.output_path is None: # set output path + result_root = args.output_path + + test_img_num = len(input_img_list) + if test_img_num == 0: + raise FileNotFoundError('No input image/video is found...\n' + '\tNote that --input_path for video should end with .mp4|.mov|.avi') + + # ------------------ set up background upsampler ------------------ + if args.bg_upsampler == 'realesrgan': + bg_upsampler = set_realesrgan() + else: + bg_upsampler = None + + # ------------------ set up face upsampler ------------------ + if args.face_upsample: + if bg_upsampler is not None: + face_upsampler = bg_upsampler + else: + face_upsampler = set_realesrgan() + else: + face_upsampler = None + + # ------------------ set up CodeFormer restorer ------------------- + net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, + connect_list=['32', '64', '128', '256']).to(device) + + ckpt_path = './pretrained_models/hallo2/net_g.pth' + + checkpoint = torch.load(ckpt_path)['params_ema'] + m, n = net.load_state_dict(checkpoint, strict=False) + print("missing key: ", m) + assert len(n)==0 + net.eval() + + # ------------------ set up FaceRestoreHelper ------------------- + # large det_model: 'YOLOv5l', 'retinaface_resnet50' + # small det_model: 'YOLOv5n', 'retinaface_mobile0.25' + if not args.has_aligned: + print(f'Face detection model: {args.detection_model}') + if bg_upsampler is not None: + print(f'Background upsampling: True, Face upsampling: {args.face_upsample}') + else: + print(f'Background upsampling: False, Face upsampling: {args.face_upsample}') + + face_helper = FaceRestoreHelper( + args.upscale, + face_size=512, + crop_ratio=(1, 1), + det_model = args.detection_model, + save_ext='png', + use_parse=True, + device=device) + + n = -1 + input_img_list = input_img_list[:n] + length = len(input_img_list) + + overlay = 4 + chunk = 16 + idx_list = [] + + i=0 + j=0 + while i < length and j < length: + j = min(i+chunk, length) + idx_list.append([i, j]) + i = j-overlay + + + id_list = [] + + # -------------------- start to processing --------------------- + for i, idx in enumerate(idx_list): + # clean all the intermediate results to process the next image + face_helper.clean_all() + + start = idx[0] + end = idx[1] + + img_list = input_img_list[start:end] + + for j, img_path in enumerate(img_list): + + if isinstance(img_path, str): + img_name = os.path.basename(img_path) + basename, ext = os.path.splitext(img_name) + print(f'[{j+1}/{chunk}] Processing: {img_name}') + img = cv2.imread(img_path, cv2.IMREAD_COLOR) + else: # for video processing + basename = str(i).zfill(4) + img_name = f'{video_name}_{basename}_{j}' if input_video else basename + print(f'[{j+1}/{chunk}] Processing: {img_name}') + img = img_path + + if args.has_aligned: + # the input faces are already cropped and aligned + img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR) + face_helper.is_gray = is_gray(img, threshold=10) + if face_helper.is_gray: + print('Grayscale input: True') + face_helper.cropped_faces = [img] + else: + face_helper.read_image(img) + # get face landmarks for each face + num_det_faces = face_helper.get_face_landmarks_5( + only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5) + print(f'\tdetect {num_det_faces} faces') + # align and warp each face + face_helper.align_warp_face() + + crop_image = [] + # face restoration for each cropped face + for idx, cropped_face in enumerate(face_helper.cropped_faces): + # prepare data + cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0) + + crop_image.append(cropped_face_t) + + assert len(crop_image)==len(img_list) + + crop_image = torch.cat(crop_image, dim=0).to(device) + crop_image = crop_image.unsqueeze(0) + + output, top_idx = net.inference(crop_image, w=w, adain=True) + assert output.shape==crop_image.shape + + for k in range(output.shape[1]): + face_output = output[:, k:k+1] + restored_face = tensor2img(face_output.squeeze_(1), rgb2bgr=True, min_max=(-1, 1)) + + restored_face = restored_face.astype('uint8') + cropped_face = face_helper.cropped_faces[k] + face_helper.add_restored_face(restored_face, cropped_face) + + bg_img_list = [] + # paste_back + if not args.has_aligned: + for img in img_list: + # upsample the background + if bg_upsampler is not None: + # Now only support RealESRGAN for upsampling background + bg_img = bg_upsampler.enhance(img, outscale=args.upscale)[0] + else: + bg_img = None + bg_img_list.append(bg_img) + + + face_helper.get_inverse_affine(None) + # paste each restored face to the input image + if args.face_upsample and face_upsampler is not None: + restored_img_list = face_helper.paste_faces_to_input_image(upsample_img_list=bg_img_list, draw_box=args.draw_box, face_upsampler=face_upsampler) + else: + restored_img_list = face_helper.paste_faces_to_input_image(upsample_img_list=bg_img_list, draw_box=args.draw_box) + + torch.cuda.empty_cache() + + if i!=0: + restored_img_list = restored_img_list[overlay:] + + + # save restored img + if not args.has_aligned and len(restored_img_list)!=0: + if args.suffix is not None: + basename = f'{video_name}_{args.suffix}_{i}' + for k, restored_img in enumerate(restored_img_list): + kk = str(k).zfill(3) + save_restore_path = os.path.join(result_root, 'final_results', f'{basename}_{kk}.png') + imwrite(restored_img, save_restore_path) + + # save enhanced video + if input_video: + print('Video Saving...') + # load images + video_frames = [] + img_list = sorted(glob.glob(os.path.join(result_root, 'final_results', '*.[jp][pn]g'))) + + assert len(img_list)==length, print(len(img_list), length) + + # write images to video + sample_img = cv2.imread(img_list[0]) + height, width = sample_img.shape[:2] + + if args.suffix is not None: + video_name = f'{video_name}_{args.suffix}.png' + save_restore_path = os.path.join(result_root, f'{video_name}.mp4') + + vidwriter = VideoWriter(save_restore_path, height, width, fps, audio) + + for img_path in img_list: + print(img_path) + img = cv2.imread(img_path) + vidwriter.write_frame(img) + + vidwriter.close() + + print(f'\nAll results are saved in {result_root}')