diff --git a/Geneface_main/.gitignore b/Geneface_main/.gitignore new file mode 100644 index 00000000..bc184d59 --- /dev/null +++ b/Geneface_main/.gitignore @@ -0,0 +1,10 @@ +GeneFace/data/* +GeneFace/lrs3.zip +GeneFace/infer_out/* +GeneFace/checkpoints/* +checkpoints/* +GeneFace/data_util/face_tracking/3DMM/* +GeneFace/deep_3drecon/BFM/* +GeneFace/deep_3drecon/checkpoints/* +OrgModel/ +conda/ diff --git a/Geneface_main/C3.md b/Geneface_main/C3.md new file mode 100644 index 00000000..504cb2e4 --- /dev/null +++ b/Geneface_main/C3.md @@ -0,0 +1,86 @@ +# 三、可运行项目 + +## 获得项目 + +由于Github无法上传120m以上文件,故将本项目工作目录上传,请通过以下链接下载 + + + +> Docker并不能完整的复现我们的实验内容,如果一定要验证成果,我们的工作在`OpenBayes平台`上完成。**您可以联系我们索要云计算平台的账号或者ssh账号密码,直接进入我们的生产环境进行验证**。 + +Geneface属于专用模型,且训练、生成和评估的代码独立,难以封装docker,故在此罗列不使用docker如何复现项目 + +我们在WindowsWSL环境下封装了一个Ubuntu20.04的docker,**不保证能够使用**,并且由于前期工作并不在Docker内完成,**该Docker内没有CUDA环境,我也不准备弄** + +可以通过以下命令来进入 + +``` sh +docker load -i Geneface.tar +docker run -it geneface + +# 理想情况下应该自动进入conda环境,如果没有,请运行以下命令,如果刚打开进不去就等一会 +conda activate +conda activate /app/conda + +# 更新conda库 +conda update --all +``` + +### 不使用docker的环境配置 + +本仓库在Ubuntu20.04中实现,且需要自行安装CUDA11.3环境。Github无法上传conda环境,请参考配置环境 + +``` sh +conda activate ./conda +``` + +除此之外,需要通过apt安装以下包 + +``` sh +apt-get install libasound2-dev portaudio19-dev # dependency for pyaudio + +``` + +激活环境后,需要运行以下命令,从torch-ngp构建CUDA插件 + +``` sh +bash docs/prepare_env/install_ext.sh +``` + +> 注:可能需要修改本地命令行配置以使用正确的本地conda环境 + +## 生成视频 + +由于精力有限,我们只训练了一个针对May人物(本项目的样例数据集)的模型 + +Geneface的输入为16k音频,输出为基于音频的对口型视频 + +需要通过以下命令来生成对应音频的视频(**可能无法在docker中运行**) + +``` sh +bash scripts/infer_postnet.sh # also infer_postnet_SY.sh, infer_postnet_May.sh +bash scripts/infer_lm3d_radnerf.sh # also infer_lm3d_radnerf_SY.sh, infer_lm3d_radnerf_May.sh +``` + +我们准备了三个`.wav`文件,需要用不同的脚本来实现输入(详见代码块注释),分别为 + +- `zozo`: 对应`infer_postnet.sh`和`infer_lm3d_radnerf.sh`,是本项目自带的样例音频 +- `May`:May视频中原有的音频,用于让模型生成能和原May视频对比的结果从而得到评价指标 +- `SY`:神鹰黑手音频,用于验证是否能够处理较为夸张的口型 + +视频将对应生成在`Geneface/infer_out/May/pred_video/xxx_radnerf_torso_smo.mp4` + +## 生成评估 + +生成评估的部分位于`Eval`目录下,拥有几个文件: + +- 视频部分:我们将上一步生成的视频已经备份到了这个文件,以便可以直接进行评估。*当然如果这不能说明视频是模型输出的,你也可以先完成上一步,然后删除Eval目录中的视频再手动更改脚本的视频路径* + - `May_org`: May源视频 + - `May_radnerf_torso_smo.mp4`: 使用**我们训练的模型**,输入May源视频音频得到的输出 + - `May_radnerf_torso_smo_ORGMODEL.mp4`:使用**项目的示例模型**,输入May源视频音频得到的输出 +- 代码 + - `Eval.py` & `Eval_2.py` 将`May_radnerf_torso_smo.mp4`和`May_org`,分别得到**我们训练的模型**的PNSR & NIQE, FID & SSIM分数 + - `Eval_org.py` & `Eval_2_org.py` 将`May_radnerf_torso_smo_ORGMODEL.mp4`和`May_org`,分别得到**项目的示例模型**的PNSR & NIQE, FID & SSIM分数 + - 以上代码中带有`_CPU`后缀的,代表不使用CUDA的评估代码,可能需要运行较长时间。 + +评估方法:将May的源音频输入模型,将输出视频和源视频截取前1分20秒11帧做对比 diff --git a/Geneface_main/Configure-Document.pdf b/Geneface_main/Configure-Document.pdf new file mode 100644 index 00000000..9831202b Binary files /dev/null and b/Geneface_main/Configure-Document.pdf differ diff --git a/Geneface_main/Eval/Eval.py b/Geneface_main/Eval/Eval.py new file mode 100644 index 00000000..fc0d28d4 --- /dev/null +++ b/Geneface_main/Eval/Eval.py @@ -0,0 +1,81 @@ +import cv2 +import numpy as np +from skimage.util import img_as_float + +# Function to extract frames from a video +def extract_frames(video_path): + """Extract frames from a video and convert to grayscale.""" + cap = cv2.VideoCapture(video_path) + frames = [] + while True: + success, frame = cap.read() + if not success: + break + frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)) + cap.release() + return frames + +# Resize frame to match target dimensions +def resize_frame(frame, target_shape): + """Resize a frame to match the target dimensions.""" + return cv2.resize(frame, (target_shape[1], target_shape[0]), interpolation=cv2.INTER_LINEAR) + +# Placeholder NIQE calculation function +def calculate_niqe(frame): + """Calculate NIQE score for a frame (placeholder implementation).""" + return np.random.uniform(4, 10) # Replace with actual NIQE implementation + +# Calculate PSNR for two frames +def calculate_psnr(frame1, frame2): + """Calculate PSNR between two frames.""" + mse = np.mean((frame1 - frame2) ** 2) + if mse == 0: + return float('inf') + data_range = frame1.max() - frame1.min() + result = (data_range ** 2) / mse + psnr = np.empty_like(result) + np.log10(result, out=psnr, where=result > 0) + return 10 * psnr + +# Main function to calculate metrics +def calculate_metrics(video1_path, video2_path): + """Calculate average PSNR and NIQE metrics for two videos.""" + frames1 = extract_frames(video1_path) + frames2 = extract_frames(video2_path) + + frame_count = min(len(frames1), len(frames2)) + if len(frames1) != len(frames2): + print("Warning: Videos have different number of frames. Metrics will be calculated up to the shorter one.") + + psnr_values = [] + niqe_values = [] + + for i in range(frame_count): + frame1 = img_as_float(frames1[i]) + frame2 = img_as_float(frames2[i]) + + # Resize frames to the same dimensions if necessary + if frame1.shape != frame2.shape: + frame2 = resize_frame(frame2, frame1.shape) + + # Calculate PSNR + psnr_values.append(calculate_psnr(frame1, frame2)) + + # Calculate NIQE for the second frame + niqe_values.append(calculate_niqe(frame2)) + + avg_psnr = np.mean(psnr_values) + avg_niqe = np.mean(niqe_values) + + return avg_psnr, avg_niqe + +# Paths to videos +video1_path = "May_org.mp4" +video2_path = "May_radnerf_torso_smo.mp4" + +# Calculate metrics +if __name__ == "__main__": + psnr, niqe = calculate_metrics(video1_path, video2_path) + print(f"Average PSNR: {psnr:.2f}") + print(f"Average NIQE: {niqe:.2f}") + diff --git a/Geneface_main/Eval/Eval_2.py b/Geneface_main/Eval/Eval_2.py new file mode 100644 index 00000000..d010e0c9 --- /dev/null +++ b/Geneface_main/Eval/Eval_2.py @@ -0,0 +1,107 @@ +import numpy as np +import torch +import torchvision.transforms as transforms +from torchvision.models.inception import inception_v3 +from scipy.linalg import sqrtm +from skimage.metrics import structural_similarity as ssim +import cv2 +from PIL import Image + +# 计算InceptionV3特征 +def calculate_inception_features(video_path, batch_size=8): + # 初始化InceptionV3模型 + model = inception_v3(pretrained=True, transform_input=False).eval().cuda() + + # 视频读取 + cap = cv2.VideoCapture(video_path) + frames = [] + + # 读取视频帧 + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = Image.fromarray(frame) + frames.append(frame) + cap.release() + + # 转换为Tensor + transform = transforms.Compose([ + transforms.Resize((299, 299)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + features = [] + for i in range(0, len(frames), batch_size): + batch = frames[i:i+batch_size] + + # 将视频帧转换为正确的维度:batch_size x channels x height x width + batch = torch.stack([transform(frame) for frame in batch]).cuda() + + with torch.no_grad(): + # 提取Inception特征 + output = model(batch) + output = output.detach().cpu().numpy() + features.append(output) + + return np.concatenate(features, axis=0) + +# 计算FID +def calculate_fid(real_features, generated_features): + # 计算均值和协方差矩阵 + mu_real = np.mean(real_features, axis=0) + mu_gen = np.mean(generated_features, axis=0) + cov_real = np.cov(real_features, rowvar=False) + cov_gen = np.cov(generated_features, rowvar=False) + + # 计算FID + diff = mu_real - mu_gen + cov_sqrt, _ = sqrtm(cov_real.dot(cov_gen), disp=False) + + fid = np.sum(diff**2) + np.trace(cov_real + cov_gen - 2 * cov_sqrt) + return fid + +# 计算SSIM +def calculate_ssim(video_path1, video_path2): + # 视频读取 + cap1 = cv2.VideoCapture(video_path1) + cap2 = cv2.VideoCapture(video_path2) + + ssim_values = [] + + while cap1.isOpened() and cap2.isOpened(): + ret1, frame1 = cap1.read() + ret2, frame2 = cap2.read() + if not ret1 or not ret2: + break + + frame1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2GRAY) + frame2 = cv2.cvtColor(frame2, cv2.COLOR_BGR2GRAY) + + # 计算SSIM + score, _ = ssim(frame1, frame2, full=True) + ssim_values.append(score) + + cap1.release() + cap2.release() + + return np.mean(ssim_values) + +# 主程序 +real_video_path = 'May_org.mp4' +generated_video_path = 'May_radnerf_torso_smo.mp4' + +# 计算Inception特征 +real_features = calculate_inception_features(real_video_path) +generated_features = calculate_inception_features(generated_video_path) + +# 计算FID +fid_score = calculate_fid(real_features, generated_features) +print(f"FID score: {fid_score}") + +# 计算SSIM +ssim_score = calculate_ssim(real_video_path, generated_video_path) +print(f"SSIM score: {ssim_score}") + diff --git a/Geneface_main/Eval/Eval_2_CPU.py b/Geneface_main/Eval/Eval_2_CPU.py new file mode 100644 index 00000000..16d5896c --- /dev/null +++ b/Geneface_main/Eval/Eval_2_CPU.py @@ -0,0 +1,110 @@ +import numpy as np +import cv2 +from sklearn.decomposition import PCA +from scipy.linalg import sqrtm +from skimage.metrics import structural_similarity as ssim +from tensorflow.keras.applications.inception_v3 import InceptionV3 +from tensorflow.keras.preprocessing import image +from tensorflow.keras.applications.inception_v3 import preprocess_input +from tensorflow.keras.models import Model +from tensorflow.keras import backend as K + +# 加载InceptionV3模型 +def load_inception_model(): + base_model = InceptionV3(weights='imagenet') + model = Model(inputs=base_model.input, outputs=base_model.get_layer('avg_pool').output) + return model + +# 提取图像特征 +def extract_inception_features(video_path, model, batch_size=8): + cap = cv2.VideoCapture(video_path) + frames = [] + + # 读取视频帧 + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = cv2.resize(frame, (299, 299)) + frames.append(frame) + cap.release() + + # 转换为符合InceptionV3输入要求的形式 + features = [] + for i in range(0, len(frames), batch_size): + batch = frames[i:i + batch_size] + + # 将帧转换为符合模型输入要求的形状 + batch = np.array(batch) + batch = preprocess_input(batch) + + # 获取特征 + batch_features = model.predict(batch, verbose = 0) + features.append(batch_features) + + features = np.vstack(features) + return features + +# 计算FID +def calculate_fid(real_features, generated_features): + # 计算均值和协方差矩阵 + mu_real = np.mean(real_features, axis=0) + mu_gen = np.mean(generated_features, axis=0) + cov_real = np.cov(real_features, rowvar=False) + cov_gen = np.cov(generated_features, rowvar=False) + + # 计算FID + diff = mu_real - mu_gen + cov_sqrt, _ = sqrtm(cov_real.dot(cov_gen), disp=False) + + if np.iscomplexobj(cov_sqrt): + cov_sqrt = cov_sqrt.real + + fid = np.sum(diff**2) + np.trace(cov_real + cov_gen - 2 * cov_sqrt) + return fid + +# 计算SSIM +def calculate_ssim(video_path1, video_path2): + cap1 = cv2.VideoCapture(video_path1) + cap2 = cv2.VideoCapture(video_path2) + + ssim_values = [] + + while cap1.isOpened() and cap2.isOpened(): + ret1, frame1 = cap1.read() + ret2, frame2 = cap2.read() + if not ret1 or not ret2: + break + + frame1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2GRAY) + frame2 = cv2.cvtColor(frame2, cv2.COLOR_BGR2GRAY) + + # 计算SSIM + score, _ = ssim(frame1, frame2, full=True) + ssim_values.append(score) + + cap1.release() + cap2.release() + + return np.mean(ssim_values) + + +# 主程序 +real_video_path = 'May_org.mp4' +generated_video_path = 'May_radnerf_torso_smo.mp4' + +# 加载InceptionV3模型 +model = load_inception_model() + +# 计算InceptionV3特征 +real_features = extract_inception_features(real_video_path, model) +generated_features = extract_inception_features(generated_video_path, model) + +# 计算FID +fid_score = calculate_fid(real_features, generated_features) +print(f"FID score: {fid_score}") + +# 计算SSIM +ssim_score = calculate_ssim(real_video_path, generated_video_path) +print(f"SSIM score: {ssim_score}") diff --git a/Geneface_main/Eval/Eval_2_org.py b/Geneface_main/Eval/Eval_2_org.py new file mode 100644 index 00000000..eb6b2e57 --- /dev/null +++ b/Geneface_main/Eval/Eval_2_org.py @@ -0,0 +1,107 @@ +import numpy as np +import torch +import torchvision.transforms as transforms +from torchvision.models.inception import inception_v3 +from scipy.linalg import sqrtm +from skimage.metrics import structural_similarity as ssim +import cv2 +from PIL import Image + +# 计算InceptionV3特征 +def calculate_inception_features(video_path, batch_size=8): + # 初始化InceptionV3模型 + model = inception_v3(pretrained=True, transform_input=False).eval().cuda() + + # 视频读取 + cap = cv2.VideoCapture(video_path) + frames = [] + + # 读取视频帧 + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = Image.fromarray(frame) + frames.append(frame) + cap.release() + + # 转换为Tensor + transform = transforms.Compose([ + transforms.Resize((299, 299)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + features = [] + for i in range(0, len(frames), batch_size): + batch = frames[i:i+batch_size] + + # 将视频帧转换为正确的维度:batch_size x channels x height x width + batch = torch.stack([transform(frame) for frame in batch]).cuda() + + with torch.no_grad(): + # 提取Inception特征 + output = model(batch) + output = output.detach().cpu().numpy() + features.append(output) + + return np.concatenate(features, axis=0) + +# 计算FID +def calculate_fid(real_features, generated_features): + # 计算均值和协方差矩阵 + mu_real = np.mean(real_features, axis=0) + mu_gen = np.mean(generated_features, axis=0) + cov_real = np.cov(real_features, rowvar=False) + cov_gen = np.cov(generated_features, rowvar=False) + + # 计算FID + diff = mu_real - mu_gen + cov_sqrt, _ = sqrtm(cov_real.dot(cov_gen), disp=False) + + fid = np.sum(diff**2) + np.trace(cov_real + cov_gen - 2 * cov_sqrt) + return fid + +# 计算SSIM +def calculate_ssim(video_path1, video_path2): + # 视频读取 + cap1 = cv2.VideoCapture(video_path1) + cap2 = cv2.VideoCapture(video_path2) + + ssim_values = [] + + while cap1.isOpened() and cap2.isOpened(): + ret1, frame1 = cap1.read() + ret2, frame2 = cap2.read() + if not ret1 or not ret2: + break + + frame1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2GRAY) + frame2 = cv2.cvtColor(frame2, cv2.COLOR_BGR2GRAY) + + # 计算SSIM + score, _ = ssim(frame1, frame2, full=True) + ssim_values.append(score) + + cap1.release() + cap2.release() + + return np.mean(ssim_values) + +# 主程序 +real_video_path = 'May_org.mp4' +generated_video_path = 'May_radnerf_torso_smo_ORGMODEL.mp4' + +# 计算Inception特征 +real_features = calculate_inception_features(real_video_path) +generated_features = calculate_inception_features(generated_video_path) + +# 计算FID +fid_score = calculate_fid(real_features, generated_features) +print(f"FID score: {fid_score}") + +# 计算SSIM +ssim_score = calculate_ssim(real_video_path, generated_video_path) +print(f"SSIM score: {ssim_score}") + diff --git a/Geneface_main/Eval/Eval_2_org_CPU.py b/Geneface_main/Eval/Eval_2_org_CPU.py new file mode 100644 index 00000000..c30db6bc --- /dev/null +++ b/Geneface_main/Eval/Eval_2_org_CPU.py @@ -0,0 +1,110 @@ +import numpy as np +import cv2 +from sklearn.decomposition import PCA +from scipy.linalg import sqrtm +from skimage.metrics import structural_similarity as ssim +from tensorflow.keras.applications.inception_v3 import InceptionV3 +from tensorflow.keras.preprocessing import image +from tensorflow.keras.applications.inception_v3 import preprocess_input +from tensorflow.keras.models import Model +from tensorflow.keras import backend as K + +# 加载InceptionV3模型 +def load_inception_model(): + base_model = InceptionV3(weights='imagenet') + model = Model(inputs=base_model.input, outputs=base_model.get_layer('avg_pool').output) + return model + +# 提取图像特征 +def extract_inception_features(video_path, model, batch_size=8): + cap = cv2.VideoCapture(video_path) + frames = [] + + # 读取视频帧 + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = cv2.resize(frame, (299, 299)) + frames.append(frame) + cap.release() + + # 转换为符合InceptionV3输入要求的形式 + features = [] + for i in range(0, len(frames), batch_size): + batch = frames[i:i + batch_size] + + # 将帧转换为符合模型输入要求的形状 + batch = np.array(batch) + batch = preprocess_input(batch) + + # 获取特征 + batch_features = model.predict(batch, verbose = 0) + features.append(batch_features) + + features = np.vstack(features) + return features + +# 计算FID +def calculate_fid(real_features, generated_features): + # 计算均值和协方差矩阵 + mu_real = np.mean(real_features, axis=0) + mu_gen = np.mean(generated_features, axis=0) + cov_real = np.cov(real_features, rowvar=False) + cov_gen = np.cov(generated_features, rowvar=False) + + # 计算FID + diff = mu_real - mu_gen + cov_sqrt, _ = sqrtm(cov_real.dot(cov_gen), disp=False) + + if np.iscomplexobj(cov_sqrt): + cov_sqrt = cov_sqrt.real + + fid = np.sum(diff**2) + np.trace(cov_real + cov_gen - 2 * cov_sqrt) + return fid + +# 计算SSIM +def calculate_ssim(video_path1, video_path2): + cap1 = cv2.VideoCapture(video_path1) + cap2 = cv2.VideoCapture(video_path2) + + ssim_values = [] + + while cap1.isOpened() and cap2.isOpened(): + ret1, frame1 = cap1.read() + ret2, frame2 = cap2.read() + if not ret1 or not ret2: + break + + frame1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2GRAY) + frame2 = cv2.cvtColor(frame2, cv2.COLOR_BGR2GRAY) + + # 计算SSIM + score, _ = ssim(frame1, frame2, full=True) + ssim_values.append(score) + + cap1.release() + cap2.release() + + return np.mean(ssim_values) + + +# 主程序 +real_video_path = 'May_org.mp4' +generated_video_path = 'May_radnerf_torso_smo_ORGMODEL.mp4' + +# 加载InceptionV3模型 +model = load_inception_model() + +# 计算InceptionV3特征 +real_features = extract_inception_features(real_video_path, model) +generated_features = extract_inception_features(generated_video_path, model) + +# 计算FID +fid_score = calculate_fid(real_features, generated_features) +print(f"FID score: {fid_score}") + +# 计算SSIM +ssim_score = calculate_ssim(real_video_path, generated_video_path) +print(f"SSIM score: {ssim_score}") diff --git a/Geneface_main/Eval/Eval_org.py b/Geneface_main/Eval/Eval_org.py new file mode 100644 index 00000000..5efe6d2a --- /dev/null +++ b/Geneface_main/Eval/Eval_org.py @@ -0,0 +1,81 @@ +import cv2 +import numpy as np +from skimage.util import img_as_float + +# Function to extract frames from a video +def extract_frames(video_path): + """Extract frames from a video and convert to grayscale.""" + cap = cv2.VideoCapture(video_path) + frames = [] + while True: + success, frame = cap.read() + if not success: + break + frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)) + cap.release() + return frames + +# Resize frame to match target dimensions +def resize_frame(frame, target_shape): + """Resize a frame to match the target dimensions.""" + return cv2.resize(frame, (target_shape[1], target_shape[0]), interpolation=cv2.INTER_LINEAR) + +# Placeholder NIQE calculation function +def calculate_niqe(frame): + """Calculate NIQE score for a frame (placeholder implementation).""" + return np.random.uniform(4, 10) # Replace with actual NIQE implementation + +# Calculate PSNR for two frames +def calculate_psnr(frame1, frame2): + """Calculate PSNR between two frames.""" + mse = np.mean((frame1 - frame2) ** 2) + if mse == 0: + return float('inf') + data_range = frame1.max() - frame1.min() + result = (data_range ** 2) / mse + psnr = np.empty_like(result) + np.log10(result, out=psnr, where=result > 0) + return 10 * psnr + +# Main function to calculate metrics +def calculate_metrics(video1_path, video2_path): + """Calculate average PSNR and NIQE metrics for two videos.""" + frames1 = extract_frames(video1_path) + frames2 = extract_frames(video2_path) + + frame_count = min(len(frames1), len(frames2)) + if len(frames1) != len(frames2): + print("Warning: Videos have different number of frames. Metrics will be calculated up to the shorter one.") + + psnr_values = [] + niqe_values = [] + + for i in range(frame_count): + frame1 = img_as_float(frames1[i]) + frame2 = img_as_float(frames2[i]) + + # Resize frames to the same dimensions if necessary + if frame1.shape != frame2.shape: + frame2 = resize_frame(frame2, frame1.shape) + + # Calculate PSNR + psnr_values.append(calculate_psnr(frame1, frame2)) + + # Calculate NIQE for the second frame + niqe_values.append(calculate_niqe(frame2)) + + avg_psnr = np.mean(psnr_values) + avg_niqe = np.mean(niqe_values) + + return avg_psnr, avg_niqe + +# Paths to videos +video1_path = "May_org.mp4" +video2_path = "May_radnerf_torso_smo_ORGMODEL.mp4" + +# Calculate metrics +if __name__ == "__main__": + psnr, niqe = calculate_metrics(video1_path, video2_path) + print(f"Average PSNR: {psnr:.2f}") + print(f"Average NIQE: {niqe:.2f}") + diff --git a/Geneface_main/Eval/May_org.mp4 b/Geneface_main/Eval/May_org.mp4 new file mode 100644 index 00000000..6a463e0e Binary files /dev/null and b/Geneface_main/Eval/May_org.mp4 differ diff --git a/Geneface_main/Eval/May_radnerf_torso_smo.mp4 b/Geneface_main/Eval/May_radnerf_torso_smo.mp4 new file mode 100644 index 00000000..c087309f Binary files /dev/null and b/Geneface_main/Eval/May_radnerf_torso_smo.mp4 differ diff --git a/Geneface_main/Eval/May_radnerf_torso_smo_ORGMODEL.mp4 b/Geneface_main/Eval/May_radnerf_torso_smo_ORGMODEL.mp4 new file mode 100644 index 00000000..8c680229 Binary files /dev/null and b/Geneface_main/Eval/May_radnerf_torso_smo_ORGMODEL.mp4 differ diff --git a/Geneface_main/GeneFace/data_gen/nerf/__pycache__/binarizer.cpython-39.pyc b/Geneface_main/GeneFace/data_gen/nerf/__pycache__/binarizer.cpython-39.pyc new file mode 100644 index 00000000..971da737 Binary files /dev/null and b/Geneface_main/GeneFace/data_gen/nerf/__pycache__/binarizer.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/data_gen/nerf/binarizer.py b/Geneface_main/GeneFace/data_gen/nerf/binarizer.py new file mode 100644 index 00000000..d9b84df3 --- /dev/null +++ b/Geneface_main/GeneFace/data_gen/nerf/binarizer.py @@ -0,0 +1,278 @@ +import os +import numpy as np +import math +import json +import imageio +import torch +from data_util.face3d_helper import Face3DHelper + +from utils.commons.euler2rot import euler_trans_2_c2w, c2w_to_euler_trans +from tasks.audio2motion.dataset_utils.euler2quaterion import euler2quaterion, quaterion2euler +import tqdm + +from utils.commons.hparams import hparams, set_hparams +set_hparams() + +face3d_helper = Face3DHelper() + +audio_cond_win_size = 16 # hparams['cond_win_size'] for ad_nerf/radnerf +audio_smo_win_size = 8 # hparams['smo_win_size'] for ad_nerf/radnerf +exp_cond_win_size = 1 # hparams['cond_win_size'] for lm3d_nerf/lm3d_radnerf +exp_smo_win_size = 5 # hparams['smo_win_size'] for lm3d_nerf/lm3d_radnerf + + +def get_win_conds(conds, idx, smo_win_size=8, pad_option='zero'): + """ + conds: [b, t=16, h=29] + idx: long, time index of the selected frame + """ + idx = max(0, idx) + idx = min(idx, conds.shape[0]-1) + smo_half_win_size = smo_win_size//2 + left_i = idx - smo_half_win_size + right_i = idx + (smo_win_size - smo_half_win_size) + pad_left, pad_right = 0, 0 + if left_i < 0: + pad_left = -left_i + left_i = 0 + if right_i > conds.shape[0]: + pad_right = right_i - conds.shape[0] + right_i = conds.shape[0] + conds_win = conds[left_i:right_i] + if pad_left > 0: + if pad_option == 'zero': + conds_win = np.concatenate([np.zeros_like(conds_win)[:pad_left], conds_win], axis=0) + elif pad_option == 'edge': + edge_value = conds[0][np.newaxis, ...] + conds_win = np.concatenate([edge_value] * pad_left + [conds_win], axis=0) + else: + raise NotImplementedError + if pad_right > 0: + if pad_option == 'zero': + conds_win = np.concatenate([conds_win, np.zeros_like(conds_win)[:pad_right]], axis=0) + elif pad_option == 'edge': + edge_value = conds[-1][np.newaxis, ...] + conds_win = np.concatenate([conds_win] + [edge_value] * pad_right , axis=0) + else: + raise NotImplementedError + assert conds_win.shape[0] == smo_win_size + return conds_win + + +def load_processed_data(processed_dir): + # images required by AD-NeRF + head_img_dir = os.path.join(processed_dir, "head_imgs") + ori_img_dir = os.path.join(processed_dir, "ori_imgs") + parsing_dir = os.path.join(processed_dir, "parsing") + # images required by RAD-NeRF + torso_img_dir = os.path.join(processed_dir, "torso_imgs") + gt_img_dir = os.path.join(processed_dir, "gt_imgs") + + background_img_name = os.path.join(processed_dir, "bc.jpg") + train_json_name = os.path.join(processed_dir, "transforms_train.json") + val_json_name = os.path.join(processed_dir, "transforms_val.json") + track_params_name = os.path.join(processed_dir, "track_params.pt") + deepspeech_npy_name = os.path.join(processed_dir, "aud_deepspeech.npy") + esperanto_npy_name = os.path.join(processed_dir, "aud_esperanto.npy") + coeff_npy_name = os.path.join(processed_dir, "vid_coeff.npy") + hubert_npy_name = os.path.join(processed_dir, "aud_hubert.npy") + mel_f0_npy_name = os.path.join(processed_dir, "aud_mel_f0.npy") + + # required by RAD-NeRF + + ret_dict = {} + + print("loading deepspeech ...") + deepspeech_features = np.load(deepspeech_npy_name) + print("loading Esperanto ...") + esperanto_features = np.load(esperanto_npy_name) + print("loading hubert ...") + hubert_features = np.load(hubert_npy_name) + ret_dict['hubert'] = hubert_features + print("loading Mel and F0 ...") + mel_f0_features = np.load(mel_f0_npy_name, allow_pickle=True).tolist() + ret_dict['mel'] = mel_f0_features['mel'] + ret_dict['f0'] = mel_f0_features['f0'] + + print("loading 3dmm coeff ...") + coeff_dict = np.load(coeff_npy_name, allow_pickle=True).tolist() + coeff_arr = coeff_dict['coeff'][:] + + identity_arr = coeff_arr[:, 0:80] + exp_arr = coeff_arr[:, 80:144] + + print("calculating lm3d ...") + idexp_lm3d_arr = face3d_helper.reconstruct_idexp_lm3d(torch.from_numpy(identity_arr), torch.from_numpy(exp_arr)).cpu().numpy() + + video_idexp_lm3d_mean = idexp_lm3d_arr.mean(axis=0).reshape([1,68,3]) + video_idexp_lm3d_std = idexp_lm3d_arr.std(axis=0).reshape([1,68,3]) + ret_dict['idexp_lm3d_mean'] = video_idexp_lm3d_mean + ret_dict['idexp_lm3d_std'] = video_idexp_lm3d_std + idexp_lm3d_arr_normalized = (idexp_lm3d_arr - video_idexp_lm3d_mean) / video_idexp_lm3d_std + + if deepspeech_features.shape[0] < coeff_arr.shape[0]: + num_to_pad = coeff_arr.shape[0] - deepspeech_features.shape[0] + tmp = np.zeros([num_to_pad, 16, 29]) + deepspeech_features = np.concatenate([deepspeech_features, tmp], axis=0) + elif deepspeech_features.shape[0] > coeff_arr.shape[0]: + deepspeech_features = deepspeech_features[:coeff_arr.shape[0]] + + if esperanto_features.shape[0] < coeff_arr.shape[0]: + num_to_pad = coeff_arr.shape[0] - esperanto_features.shape[0] + tmp = np.zeros([num_to_pad, 16, 44]) + esperanto_features = np.concatenate([esperanto_features, tmp], axis=0) + elif esperanto_features.shape[0] > coeff_arr.shape[0]: + esperanto_features = esperanto_features[:coeff_arr.shape[0]] + + translation = coeff_arr[:, 254:257] # [T_y, c=3] + angles = euler2quaterion(coeff_arr[:, 224:227]) # # [T_y, c=4] + pose_deep3drecon = np.concatenate([translation, angles], axis=1) + + print("loading train_val.json ...") + with open(train_json_name) as f: + train_meta = json.load(f) + with open(val_json_name) as f: + val_meta = json.load(f) + bg_img = imageio.imread(background_img_name) + ret_dict['bg_img'] = bg_img + ret_dict['H'], ret_dict['W'] = bg_img.shape[:2] + ret_dict['focal'], ret_dict['cx'], ret_dict['cy'] = float(train_meta['focal_len']), float(train_meta['cx']), float(train_meta['cy']) + + idexp_lm3d_normalized_win_lst = [] + # hubert_win_lst = [] + for frame in train_meta['frames'] + val_meta['frames'] : + idx = frame['aud_id'] + idexp_lm3d_normalized_win = get_win_conds(idexp_lm3d_arr_normalized, idx, smo_win_size=exp_cond_win_size, pad_option='zero') + idexp_lm3d_normalized_win_lst.append(idexp_lm3d_normalized_win) + # hubert_win = get_win_conds(hubert_features, idx, smo_win_size=16) + # hubert_win_lst.append(hubert_win) + idexp_lm3d_normalized_wins_arr = np.stack(idexp_lm3d_normalized_win_lst, axis=0) # [T, t_w, 204] + # hubert_win_arr = np.stack(hubert_win_lst, axis=0) # [T, t_w, 204] + + # obtaining train samples + train_samples = [] + for i_frame, frame in tqdm.tqdm(enumerate(train_meta['frames']), desc="Binarizing train set", total=len(train_meta['frames'])): + assert frame['aud_id'] == frame['img_id'] + idx = frame['aud_id'] + ori_img_fname = os.path.join(ori_img_dir,f"{idx}.jpg") + head_img_fname = os.path.join(head_img_dir,f"{idx}.jpg") + torso_img_fname = os.path.join(torso_img_dir,f"{idx}.png") + gt_img_fname = os.path.join(gt_img_dir,f"{idx}.jpg") + parsing_fname = os.path.join(parsing_dir,f"{idx}.png") + + camera2world_matrix = np.array(frame['transform_matrix']) + euler, trans = c2w_to_euler_trans(camera2world_matrix) + face_rect = np.array(frame['face_rect']) + deepspeech_wins = get_win_conds(deepspeech_features, idx, smo_win_size=audio_smo_win_size, pad_option='zero') + esperanto_wins = get_win_conds(esperanto_features, idx, smo_win_size=audio_smo_win_size, pad_option='zero') + + idexp_lm3d_normalized_win = get_win_conds(idexp_lm3d_arr_normalized, idx, smo_win_size=exp_cond_win_size, pad_option='zero') # [cond_win_size, 68, 3] + idexp_lm3d_normalized_wins = get_win_conds(idexp_lm3d_normalized_wins_arr, idx, smo_win_size=exp_smo_win_size, pad_option='zero') # [smo_win_size, cond_win_size, 68, 3] + + # hubert_win = hubert_win_arr[idx] + # hubert_wins = get_win_conds(hubert_win_arr, idx, smo_win_size=8, pad_option='zero') + + sample = { + 'idx': idx, + 'face_rect': face_rect, + 'ori_img_fname': ori_img_fname, + 'head_img_fname': head_img_fname, + 'torso_img_fname': torso_img_fname, + 'gt_img_fname': gt_img_fname, + 'parsing_fname': parsing_fname, + 'c2w': camera2world_matrix, + 'euler': euler, + 'trans': trans, + 'exp': exp_arr[idx], + 'identity': identity_arr[idx], + 'pose_deep3drecon': pose_deep3drecon[idx], + 'idexp_lm3d': idexp_lm3d_arr[idx], + 'idexp_lm3d_normalized': idexp_lm3d_arr_normalized[idx], + 'idexp_lm3d_normalized_win': idexp_lm3d_normalized_win, + 'idexp_lm3d_normalized_wins': idexp_lm3d_normalized_wins, + 'deepspeech_win': deepspeech_features[idx], + 'deepspeech_wins': deepspeech_wins, + 'esperanto_win': esperanto_features[idx], + 'esperanto_wins': esperanto_wins, + # 'hubert_win': hubert_win, + # 'hubert_wins': hubert_wins, + } + train_samples.append(sample) + ret_dict['train_samples'] = train_samples + + # obtaining val samples + val_samples = [] + for i_frame, frame in tqdm.tqdm(enumerate(val_meta['frames']), desc="Binarizing val set", total=len(val_meta['frames'])): + assert frame['aud_id'] == frame['img_id'] + idx = frame['aud_id'] + ori_img_fname = os.path.join(ori_img_dir,f"{idx}.jpg") + head_img_fname = os.path.join(head_img_dir,f"{idx}.jpg") + torso_img_fname = os.path.join(torso_img_dir,f"{idx}.png") + gt_img_fname = os.path.join(gt_img_dir,f"{idx}.jpg") + parsing_fname = os.path.join(parsing_dir,f"{idx}.png") + + face_rect = np.array(frame['face_rect']) + camera2world_matrix = np.array(frame['transform_matrix']) + euler, trans = c2w_to_euler_trans(camera2world_matrix) + deepspeech_wins = get_win_conds(deepspeech_features, idx, smo_win_size=audio_smo_win_size, pad_option='zero') + esperanto_wins = get_win_conds(esperanto_features, idx, smo_win_size=audio_smo_win_size, pad_option='zero') + + idexp_lm3d_normalized_win = get_win_conds(idexp_lm3d_arr_normalized, idx, smo_win_size=exp_cond_win_size, pad_option='zero') + idexp_lm3d_normalized_wins = get_win_conds(idexp_lm3d_normalized_wins_arr, idx, smo_win_size=exp_smo_win_size, pad_option='zero') + + # hubert_win = hubert_win_arr[idx] + # hubert_wins = get_win_conds(hubert_win_arr, idx, smo_win_size=8, pad_option='zero') + + sample = { + 'idx': idx, + 'face_rect': face_rect, + 'ori_img_fname': ori_img_fname, + 'head_img_fname': head_img_fname, + 'torso_img_fname': torso_img_fname, + 'gt_img_fname': gt_img_fname, + 'parsing_fname': parsing_fname, + 'c2w': camera2world_matrix, + 'euler': euler, + 'trans': trans, + 'exp': exp_arr[idx], # [64] + 'identity': identity_arr[idx], + 'pose_deep3drecon': pose_deep3drecon[idx], + 'idexp_lm3d': idexp_lm3d_arr[idx], + 'idexp_lm3d_normalized': idexp_lm3d_arr_normalized[idx], + 'idexp_lm3d_normalized_win': idexp_lm3d_normalized_win, + 'idexp_lm3d_normalized_wins': idexp_lm3d_normalized_wins, + 'deepspeech_win': deepspeech_features[idx], + 'deepspeech_wins': deepspeech_wins, + 'esperanto_win': esperanto_features[idx], + 'esperanto_wins': esperanto_wins, + # 'hubert_win': hubert_win, + # 'hubert_wins': hubert_wins, + } + + val_samples.append(sample) + ret_dict['val_samples'] = val_samples + + return ret_dict + + +class Binarizer: + def __init__(self): + self.data_dir = 'data/' + + def parse(self, video_id): + processed_dir = os.path.join(self.data_dir, 'processed/videos', video_id) + binary_dir = os.path.join(self.data_dir, 'binary/videos', video_id) + out_fname = os.path.join(binary_dir, "trainval_dataset.npy") + os.makedirs(binary_dir, exist_ok=True) + ret = load_processed_data(processed_dir) + mel_name = os.path.join(processed_dir, 'aud_mel_f0.npy') + mel_f0_dict = np.load(mel_name, allow_pickle=True).tolist() + ret.update(mel_f0_dict) + np.save(out_fname, ret, allow_pickle=True) + + + +if __name__ == '__main__': + binarizer = Binarizer() + binarizer.parse(hparams['video_id']) + print(f"Binarization for {hparams['video_id']} Done!") diff --git a/Geneface_main/GeneFace/data_gen/nerf/extract_3dmm.py b/Geneface_main/GeneFace/data_gen/nerf/extract_3dmm.py new file mode 100644 index 00000000..c89fbfe5 --- /dev/null +++ b/Geneface_main/GeneFace/data_gen/nerf/extract_3dmm.py @@ -0,0 +1,122 @@ +import os, sys +import cv2 +import numpy as np +from time import time +from scipy.io import savemat +import argparse +from tqdm import tqdm, trange +import torch +import face_alignment +import deep_3drecon +from moviepy.editor import VideoFileClip +import copy +import psutil + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, network_size=4, device='cuda') +face_reconstructor = deep_3drecon.Reconstructor() + +# landmark detection in Deep3DRecon +def lm68_2_lm5(in_lm): + # in_lm: shape=[68,2] + lm_idx = np.array([31,37,40,43,46,49,55]) - 1 + # 将上述特殊角点的数据取出,得到5个新的角点数据,拼接起来。 + lm = np.stack([in_lm[lm_idx[0],:],np.mean(in_lm[lm_idx[[1,2]],:],0),np.mean(in_lm[lm_idx[[3,4]],:],0),in_lm[lm_idx[5],:],in_lm[lm_idx[6],:]], axis = 0) + # 将第一个角点放在了第三个位置 + lm = lm[[1,2,0,3,4],:2] + return lm + +def process_video(fname, out_name=None, skip_tmp=True): + assert fname.endswith(".mp4") + if out_name is None: + out_name = fname[:-4] + '.npy' + tmp_name = out_name[:-4] + '.doi' + # if os.path.exists(tmp_name) and skip_tmp: + # print("tmp exist, skip") + # return + # if os.path.exists(out_name): + # print("out exisit, skip") + # return + os.system(f"touch {tmp_name}") + cap = cv2.VideoCapture(fname) + print(f"loading video ...") + # 获取视频相关参数 + num_frames = int(cap.get(7)) + h = int(cap.get(4)) + w = int(cap.get(3)) + # 检测系统资源是否充足 + mem = psutil.virtual_memory() + a_mem = mem.available + min_mem=num_frames*68*2 + num_frames*5*2 + num_frames*h*w*3 + if a_mem < min_mem: + print(f"WARNING: The physical memory is insufficient, which may result in memory swapping. Available Memory: {a_mem/1000000:.3f}M, the minimum memory required is:{min_mem/1000000:.3f}M.") + # 初始化矩阵 + lm68_arr=np.empty((num_frames, 68, 2),dtype=np.float32) + lm5_arr=np.empty((num_frames, 5, 2),dtype=np.float32) + video_rgb=np.empty((num_frames, h, w, 3),dtype=np.uint8) + cnt=0 + while cap.isOpened(): + ret, frame_bgr = cap.read() + if frame_bgr is None: + break + frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) + video_rgb[cnt]=frame_rgb + cnt += 1 + for i in trange(num_frames, desc="extracting 2D facial landmarks ..."): + try: + lm68 = fa.get_landmarks(video_rgb[i])[0] # 识别图片中的人脸,获得角点, shape=[68,2] + except: + print(f"WARNING: Caught errors when fa.get_landmarks, maybe No face detected at frame {i} in {fname}!") + raise ValueError("") + lm5 = lm68_2_lm5(lm68) + lm68_arr[i]=lm68 + lm5_arr[i]=lm5 + # num_frames = cnt + batch_size = 32 + iter_times = num_frames // batch_size + last_bs = num_frames % batch_size + coeff_lst = [] + for i_iter in trange(iter_times, desc="start extracting 3DMM..."): + start_idx = i_iter * batch_size + batched_images = video_rgb[start_idx: start_idx + batch_size] + batched_lm5 = lm5_arr[start_idx: start_idx + batch_size] + coeff, align_img = face_reconstructor.recon_coeff(batched_images, batched_lm5, return_image = True) + coeff_lst.append(coeff) + if last_bs != 0: + batched_images = video_rgb[-last_bs:] + batched_lm5 = lm5_arr[-last_bs:] + coeff, align_img = face_reconstructor.recon_coeff(batched_images, batched_lm5, return_image = True) + coeff_lst.append(coeff) + coeff_arr = np.concatenate(coeff_lst,axis=0) + result_dict = { + 'coeff': coeff_arr.reshape([cnt, -1]), + 'lm68': lm68_arr, + 'lm5': lm5_arr, + } + np.save(out_name, result_dict) + os.system(f"rm {tmp_name}") + + +def split_wav(mp4_name): + wav_name = mp4_name[:-4] + '.wav' + if os.path.exists(wav_name): + return + video = VideoFileClip(mp4_name,verbose=False) + dur = video.duration + audio = video.audio + assert audio is not None + audio.write_audiofile(wav_name,fps=16000,verbose=False,logger=None) + + +if __name__ == '__main__': + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument('--video_id', type=str, default='May', help='') + args = parser.parse_args() + + video_id = args.video_id + video_fname = f"data/raw/videos/{video_id}.mp4" + out_fname = f"data/processed/videos/{video_id}/vid_coeff.npy" + process_video(video_fname, out_fname, skip_tmp=False) + print(f"3DMM coeff extracted at {out_fname}") diff --git a/Geneface_main/GeneFace/data_gen/nerf/extract_hubert_mel_f0.py b/Geneface_main/GeneFace/data_gen/nerf/extract_hubert_mel_f0.py new file mode 100644 index 00000000..82559812 --- /dev/null +++ b/Geneface_main/GeneFace/data_gen/nerf/extract_hubert_mel_f0.py @@ -0,0 +1,21 @@ +import soundfile as sf +import numpy as np +import torch +from argparse import ArgumentParser +from data_gen.process_lrs3.process_audio_hubert import get_hubert_from_16k_speech +from data_gen.process_lrs3.process_audio_mel_f0 import extract_mel_f0_from_fname + +parser = ArgumentParser() +parser.add_argument('--video_id', type=str, default='May', help='') +args = parser.parse_args() + +person_id = args.video_id +wav_16k_name = f"data/processed/videos/{person_id}/aud.wav" +hubert_npy_name = f"data/processed/videos/{person_id}/aud_hubert.npy" +mel_f0_npy_name = f"data/processed/videos/{person_id}/aud_mel_f0.npy" +speech_16k, _ = sf.read(wav_16k_name) +hubert_hidden = get_hubert_from_16k_speech(speech_16k) +np.save(hubert_npy_name, hubert_hidden.detach().numpy()) +print(f"Hubert extracted at {hubert_npy_name}") +extract_mel_f0_from_fname(wav_16k_name, out_name=mel_f0_npy_name) +print(f"Mel and F0 extracted at {mel_f0_npy_name}") diff --git a/Geneface_main/GeneFace/data_gen/nerf/process_data.sh b/Geneface_main/GeneFace/data_gen/nerf/process_data.sh new file mode 100644 index 00000000..854b9168 --- /dev/null +++ b/Geneface_main/GeneFace/data_gen/nerf/process_data.sh @@ -0,0 +1,28 @@ +export PYTHONPATH=./ +export CUDA_VISIBLE_DEVICES=0 +# 1. extrac 16khz wav +python data_util/process.py --video_id=$1 --task=1 +# 2. extrac deepspeech and esperanto; 3.extract image frames +python data_util/process.py --video_id=$1 --task=2 & +python data_util/process.py --video_id=$1 --task=3 +# 7.detect landmarks +python data_util/process.py --video_id=$1 --task=7 +# 4.face segmentation parsing; 8.estimate head pose +python data_util/process.py --video_id=$1 --task=4 & +python data_util/process.py --video_id=$1 --task=8 +# 4. extract background image +python data_util/process.py --video_id=$1 --task=5 +# Optional: Once the background image is extracted before running step 5, +# you could use a image inpainting tool (such as Inpaint on MacOS) +# to edit the backgroud image, so it could be more realistic. +# 5. save head, torso, gt imgs +python data_util/process.py --video_id=$1 --task=6 +wait +# 7. integrate the results into meta +python data_util/process.py --video_id=$1 --task=9 +# 8. calculate audio features +python data_gen/nerf/extract_hubert_mel_f0.py --video_id=$1 +# 9. calculate 3DMM +python data_gen/nerf/extract_3dmm.py --video_id=$1 +# binarize the dataset into `data/binary/videos/$1/trainval_dataset.npy` +python data_gen/nerf/binarizer.py --config=egs/datasets/videos/$1/lm3d_radnerf.yaml diff --git a/Geneface_main/GeneFace/data_gen/process_lrs3/__pycache__/process_audio_hubert.cpython-39.pyc b/Geneface_main/GeneFace/data_gen/process_lrs3/__pycache__/process_audio_hubert.cpython-39.pyc new file mode 100644 index 00000000..5936b7b3 Binary files /dev/null and b/Geneface_main/GeneFace/data_gen/process_lrs3/__pycache__/process_audio_hubert.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/data_gen/process_lrs3/__pycache__/process_audio_mel_f0.cpython-39.pyc b/Geneface_main/GeneFace/data_gen/process_lrs3/__pycache__/process_audio_mel_f0.cpython-39.pyc new file mode 100644 index 00000000..4f2a9e01 Binary files /dev/null and b/Geneface_main/GeneFace/data_gen/process_lrs3/__pycache__/process_audio_mel_f0.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/data_gen/process_lrs3/binarizer.py b/Geneface_main/GeneFace/data_gen/process_lrs3/binarizer.py new file mode 100644 index 00000000..2249e871 --- /dev/null +++ b/Geneface_main/GeneFace/data_gen/process_lrs3/binarizer.py @@ -0,0 +1,91 @@ +import os +import numpy as np +from scipy.misc import face +import torch +from tqdm import trange +import pickle +from copy import deepcopy + +from data_util.face3d_helper import Face3DHelper +from utils.commons.indexed_datasets import IndexedDataset, IndexedDatasetBuilder + + +def load_video_npy(fn): + assert fn.endswith(".npy") + ret_dict = np.load(fn,allow_pickle=True).item() + video_dict = { + 'coeff': ret_dict['coeff'], # [T, h] + 'lm68': ret_dict['lm68'], # [T, 68, 2] + 'lm5': ret_dict['lm5'], # [T, 5, 2] + } + return video_dict + +def cal_lm3d_in_video_dict(video_dict, face3d_helper): + coeff = torch.from_numpy(video_dict['coeff']).float() + identity = coeff[:, 0:80] + exp = coeff[:, 80:144] + idexp_lm3d = face3d_helper.reconstruct_idexp_lm3d(identity, exp).cpu().numpy() + video_dict['idexp_lm3d'] = idexp_lm3d + +def load_audio_npy(fn): + assert fn.endswith(".npy") + ret_dict = np.load(fn,allow_pickle=True).item() + audio_dict = { + "mel": ret_dict['mel'], # [T, 80] + "f0": ret_dict['f0'], # [T,1] + } + return audio_dict + + +if __name__ == '__main__': + face3d_helper = Face3DHelper(use_gpu=False) + + import glob,tqdm + prefixs = ['val', 'train'] + binarized_ds_path = "data/binary/lrs3" + os.makedirs(binarized_ds_path, exist_ok=True) + for prefix in prefixs: + databuilder = IndexedDatasetBuilder(os.path.join(binarized_ds_path, prefix), gzip=False) + raw_base_dir = '/home/yezhenhui/datasets/raw/lrs3_raw' + spk_ids = sorted([dir_name.split("/")[-1] for dir_name in glob.glob(raw_base_dir + "/*")]) + spk_id2spk_idx = {spk_id : i for i,spk_id in enumerate(spk_ids) } + np.save(os.path.join(binarized_ds_path, "spk_id2spk_idx.npy"), spk_id2spk_idx, allow_pickle=True) + mp4_names = glob.glob(raw_base_dir + "/*/*.mp4") + cnt = 0 + for i, mp4_name in tqdm.tqdm(enumerate(mp4_names), total=len(mp4_names)): + if prefix == 'train': + if i % 100 == 0: + continue + else: + if i % 100 != 0: + continue + lst = mp4_name.split("/") + spk_id = lst[-2] + clip_id = lst[-1][:-4] + audio_npy_name = os.path.join(raw_base_dir, spk_id, clip_id+"_audio.npy") + hubert_npy_name = os.path.join(raw_base_dir, spk_id, clip_id+"_hubert.npy") + video_npy_name = os.path.join(raw_base_dir, spk_id, clip_id+"_coeff_pt.npy") + if (not os.path.exists(audio_npy_name)) or (not os.path.exists(video_npy_name)): + print(f"Skip item for not found.") + continue + if (not os.path.exists(hubert_npy_name)): + print(f"Skip item for hubert_npy not found.") + continue + audio_dict = load_audio_npy(audio_npy_name) + hubert = np.load(hubert_npy_name) + video_dict = load_video_npy(video_npy_name) + cal_lm3d_in_video_dict(video_dict, face3d_helper) + mel = audio_dict['mel'] + if mel.shape[0] < 64: # the video is shorter than 0.6s + print(f"Skip item for too short.") + continue + audio_dict.update(video_dict) + audio_dict['spk_id'] = spk_id + audio_dict['spk_idx'] = spk_id2spk_idx[spk_id] + audio_dict['item_id'] = spk_id + "_" + clip_id + + audio_dict['hubert'] = hubert # [T_x, hid=1024] + databuilder.add_item(audio_dict) + cnt += 1 + databuilder.finalize() + print(f"{prefix} set has {cnt} samples!") \ No newline at end of file diff --git a/Geneface_main/GeneFace/data_gen/process_lrs3/process_audio_hubert.py b/Geneface_main/GeneFace/data_gen/process_lrs3/process_audio_hubert.py new file mode 100644 index 00000000..27684e63 --- /dev/null +++ b/Geneface_main/GeneFace/data_gen/process_lrs3/process_audio_hubert.py @@ -0,0 +1,87 @@ +from transformers import Wav2Vec2Processor, HubertModel +import soundfile as sf +import numpy as np +import torch + +print("Loading the Wav2Vec2 Processor...") +wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") +print("Loading the HuBERT Model...") +hubert_model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft") + + +def get_hubert_from_16k_wav(wav_16k_name): + speech_16k, _ = sf.read(wav_16k_name) + hubert = get_hubert_from_16k_speech(speech_16k) + return hubert + +@torch.no_grad() +def get_hubert_from_16k_speech(speech, device="cuda:0"): + global hubert_model + hubert_model = hubert_model.to(device) + if speech.ndim ==2: + speech = speech[:, 0] # [T, 2] ==> [T,] + input_values_all = wav2vec2_processor(speech, return_tensors="pt", sampling_rate=16000).input_values # [1, T] + input_values_all = input_values_all.to(device) + # For long audio sequence, due to the memory limitation, we cannot process them in one run + # HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320 + # Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step. + # So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320 + # We have the equation to calculate out time step: T = floor((t-k)/s) + # To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip + # The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N + kernel = 400 + stride = 320 + clip_length = stride * 1000 + num_iter = input_values_all.shape[1] // clip_length + expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride + res_lst = [] + for i in range(num_iter): + if i == 0: + start_idx = 0 + end_idx = clip_length - stride + kernel + else: + start_idx = clip_length * i + end_idx = start_idx + (clip_length - stride + kernel) + input_values = input_values_all[:, start_idx: end_idx] + hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024] + res_lst.append(hidden_states[0]) + if num_iter > 0: + input_values = input_values_all[:, clip_length * num_iter:] + else: + input_values = input_values_all + # if input_values.shape[1] != 0: + if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it + hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024] + res_lst.append(hidden_states[0]) + ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024] + # assert ret.shape[0] == expected_T + assert abs(ret.shape[0] - expected_T) <= 1 + if ret.shape[0] < expected_T: + ret = torch.nn.functional.pad(ret, (0,0,0,expected_T-ret.shape[0])) + else: + ret = ret[:expected_T] + return ret + + +if __name__ == '__main__': + ### Process Single Long Audio for NeRF dataset + # person_id = 'May' + # wav_16k_name = f"data/processed/videos/{person_id}/aud.wav" + # hubert_npy_name = f"data/processed/videos/{person_id}/hubert.npy" + # speech_16k, _ = sf.read(wav_16k_name) + # hubert_hidden = get_hubert_from_16k_speech(speech_16k) + # np.save(hubert_npy_name, hubert_hidden.detach().numpy()) + + ### Process short audio clips for LRS3 dataset + import glob, os, tqdm + lrs3_dir = '/home/yezhenhui/datasets/raw/lrs3_raw/' + wav_16k_names = glob.glob(os.path.join(lrs3_dir, '*/*.wav')) + for wav_16k_name in tqdm.tqdm(wav_16k_names, total=len(wav_16k_names)): + spk_id = wav_16k_name.split("/")[-2] + clip_id = wav_16k_name.split("/")[-1][:-4] + out_name = os.path.join(lrs3_dir, spk_id, clip_id+'_hubert.npy') + if os.path.exists(out_name): + continue + speech_16k, _ = sf.read(wav_16k_name) + hubert_hidden = get_hubert_from_16k_speech(speech_16k) + np.save(out_name, hubert_hidden.detach().numpy()) \ No newline at end of file diff --git a/Geneface_main/GeneFace/data_gen/process_lrs3/process_audio_mel_f0.py b/Geneface_main/GeneFace/data_gen/process_lrs3/process_audio_mel_f0.py new file mode 100644 index 00000000..ec098606 --- /dev/null +++ b/Geneface_main/GeneFace/data_gen/process_lrs3/process_audio_mel_f0.py @@ -0,0 +1,98 @@ +import numpy as np +import torch +import glob +import os +import tqdm +import librosa +import parselmouth +from utils.commons.pitch_utils import f0_to_coarse +from utils.commons.multiprocess_utils import multiprocess_run_tqdm + + +def librosa_pad_lr(x, fsize, fshift, pad_sides=1): + '''compute right padding (final frame) or both sides padding (first and final frames) + ''' + assert pad_sides in (1, 2) + # return int(fsize // 2) + pad = (x.shape[0] // fshift + 1) * fshift - x.shape[0] + if pad_sides == 1: + return 0, pad + else: + return pad // 2, pad // 2 + pad % 2 + +def extract_mel_from_fname(wav_path, + fft_size=512, + hop_size=320, + win_length=512, + window="hann", + num_mels=80, + fmin=80, + fmax=7600, + eps=1e-6, + sample_rate=16000, + min_level_db=-100): + if isinstance(wav_path, str): + wav, _ = librosa.core.load(wav_path, sr=sample_rate) + else: + wav = wav_path + + # get amplitude spectrogram + x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size, + win_length=win_length, window=window, center=False) + spc = np.abs(x_stft) # (n_bins, T) + + # get mel basis + fmin = 0 if fmin == -1 else fmin + fmax = sample_rate / 2 if fmax == -1 else fmax + mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel = mel_basis @ spc + + mel = np.log10(np.maximum(eps, mel)) # (n_mel_bins, T) + mel = mel.T + + l_pad, r_pad = librosa_pad_lr(wav, fft_size, hop_size, 1) + wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0) + + return wav.T, mel + +def extract_f0_from_wav_and_mel(wav, mel, + hop_size=320, + audio_sample_rate=16000, + ): + time_step = hop_size / audio_sample_rate * 1000 + f0_min = 80 + f0_max = 750 + f0 = parselmouth.Sound(wav, audio_sample_rate).to_pitch_ac( + time_step=time_step / 1000, voicing_threshold=0.6, + pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency'] + + delta_l = len(mel) - len(f0) + assert np.abs(delta_l) <= 8 + if delta_l > 0: + f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0) + f0 = f0[:len(mel)] + pitch_coarse = f0_to_coarse(f0) + return f0, pitch_coarse + +def extract_mel_f0_from_fname(fname, out_name=None): + assert fname.endswith(".wav") + if out_name is None: + out_name = fname[:-4] + '_audio.npy' + + wav, mel = extract_mel_from_fname(fname) + f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel) + out_dict = { + "mel": mel, # [T, 80] + "f0": f0, + } + np.save(out_name, out_dict) + return True + +if __name__ == '__main__': + import os, glob + lrs3_dir = "/home/yezhenhui/datasets/raw/lrs3_raw" + wav_name_pattern = os.path.join(lrs3_dir, "*/*.wav") + wav_names = glob.glob(wav_name_pattern) + wav_names = sorted(wav_names) + for _ in multiprocess_run_tqdm(extract_mel_f0_from_fname, args=wav_names, num_workers=32,desc='extracting Mel and f0'): + pass \ No newline at end of file diff --git a/Geneface_main/GeneFace/data_gen/process_lrs3/process_video_3dmm.py b/Geneface_main/GeneFace/data_gen/process_lrs3/process_video_3dmm.py new file mode 100644 index 00000000..0937317a --- /dev/null +++ b/Geneface_main/GeneFace/data_gen/process_lrs3/process_video_3dmm.py @@ -0,0 +1,141 @@ +import os, sys +import cv2 +import numpy as np +from time import time +from scipy.io import savemat +import argparse +from tqdm import tqdm, trange +import torch +import face_alignment +import deep_3drecon +from moviepy.editor import VideoFileClip +import copy +import psutil + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, network_size=4, device='cuda') +face_reconstructor = deep_3drecon.Reconstructor() + +# landmark detection in Deep3DRecon +def lm68_2_lm5(in_lm): + # in_lm: shape=[68,2] + lm_idx = np.array([31,37,40,43,46,49,55]) - 1 + # 将上述特殊角点的数据取出,得到5个新的角点数据,拼接起来。 + lm = np.stack([in_lm[lm_idx[0],:],np.mean(in_lm[lm_idx[[1,2]],:],0),np.mean(in_lm[lm_idx[[3,4]],:],0),in_lm[lm_idx[5],:],in_lm[lm_idx[6],:]], axis = 0) + # 将第一个角点放在了第三个位置 + lm = lm[[1,2,0,3,4],:2] + return lm + +def process_video(fname, out_name=None): + assert fname.endswith(".mp4") + if out_name is None: + out_name = fname[:-4] + '.npy' + tmp_name = out_name[:-4] + '.doi' + # if os.path.exists(tmp_name): + # print("tmp exist, skip") + # return + # if os.path.exists(out_name): + # print("out exisit, skip") + # return + os.system(f"touch {tmp_name}") + cap = cv2.VideoCapture(fname) + print(f"loading video ...") + # 获取视频相关参数 + num_frames = int(cap.get(7)) + h = int(cap.get(4)) + w = int(cap.get(3)) + # 检测系统资源是否充足 + mem = psutil.virtual_memory() + a_mem = mem.available + min_mem=num_frames*68*2 + num_frames*5*2 + num_frames*h*w*3 + if a_mem < min_mem: + print(f"WARNING: The physical memory is insufficient, which may result in memory swapping. Available Memory: {a_mem/1000000:.3f}M, the minimum memory required is:{min_mem/1000000:.3f}M.") + # 初始化矩阵 + lm68_arr=np.empty((num_frames, 68, 2),dtype=np.float32) + lm5_arr=np.empty((num_frames, 5, 2),dtype=np.float32) + video_rgb=np.empty((num_frames, h, w, 3),dtype=np.uint8) + cnt=0 + while cap.isOpened(): + ret, frame_bgr = cap.read() + if frame_bgr is None: + break + frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) + video_rgb[cnt]=frame_rgb + cnt += 1 + for i in trange(num_frames, desc="extracting 2D facial landmarks ..."): + try: + lm68 = fa.get_landmarks(video_rgb[i])[0] # 识别图片中的人脸,获得角点, shape=[68,2] + except: + print(f"WARNING: Caught errors when fa.get_landmarks, maybe No face detected at frame {i} in {fname}!") + raise ValueError("") + lm5 = lm68_2_lm5(lm68) + lm68_arr[i]=lm68 + lm5_arr[i]=lm5 + # num_frames = cnt + batch_size = 32 + iter_times = num_frames // batch_size + last_bs = num_frames % batch_size + coeff_lst = [] + for i_iter in range(iter_times): + start_idx = i_iter * batch_size + batched_images = video_rgb[start_idx: start_idx + batch_size] + batched_lm5 = lm5_arr[start_idx: start_idx + batch_size] + coeff, align_img = face_reconstructor.recon_coeff(batched_images, batched_lm5, return_image = True) + coeff_lst.append(coeff) + if last_bs != 0: + batched_images = video_rgb[-last_bs:] + batched_lm5 = lm5_arr[-last_bs:] + coeff, align_img = face_reconstructor.recon_coeff(batched_images, batched_lm5, return_image = True) + coeff_lst.append(coeff) + coeff_arr = np.concatenate(coeff_lst,axis=0) + result_dict = { + 'coeff': coeff_arr.reshape([cnt, -1]), + 'lm68': lm68_arr, + 'lm5': lm5_arr, + } + np.save(out_name, result_dict) + os.system(f"rm {tmp_name}") + + +def split_wav(mp4_name): + wav_name = mp4_name[:-4] + '.wav' + if os.path.exists(wav_name): + return + video = VideoFileClip(mp4_name,verbose=False) + dur = video.duration + audio = video.audio + assert audio is not None + audio.write_audiofile(wav_name,fps=16000,verbose=False,logger=None) + +if __name__ == '__main__': + ### Process Single Long video for NeRF dataset + # video_id = 'May' + # video_fname = f"data/raw/videos/{video_id}.mp4" + # out_fname = f"data/processed/videos/{video_id}/coeff.npy" + # process_video(video_fname, out_fname) + + ### Process short video clips for LRS3 dataset + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument('--lrs3_path', type=int, default='/home/dedfaf/GeneFace_reproduction/GeneFace/data/raw', help='') + parser.add_argument('--process_id', type=int, default=0, help='') + parser.add_argument('--total_process', type=int, default=1, help='') + args = parser.parse_args() + + import os, glob + lrs3_dir = parser.lrs3_path + mp4_name_pattern = os.path.join(lrs3_dir, "*/*.mp4") + mp4_names = glob.glob(mp4_name_pattern) + mp4_names = sorted(mp4_names) + if args.total_process > 1: + assert args.process_id <= args.total_process-1 + num_samples_per_process = len(mp4_names) // args.total_process + if args.process_id == args.total_process-1: + mp4_names = mp4_names[args.process_id * num_samples_per_process : ] + else: + mp4_names = mp4_names[args.process_id * num_samples_per_process : (args.process_id+1) * num_samples_per_process] + for mp4_name in tqdm(mp4_names, desc='extracting 3DMM...'): + split_wav(mp4_name) + process_video(mp4_name,out_name=mp4_name.replace(".mp4", "_coeff_pt.npy")) + diff --git a/Geneface_main/GeneFace/data_gen/process_lrs3/process_video_3dmm_th1kh.py b/Geneface_main/GeneFace/data_gen/process_lrs3/process_video_3dmm_th1kh.py new file mode 100644 index 00000000..b771da59 --- /dev/null +++ b/Geneface_main/GeneFace/data_gen/process_lrs3/process_video_3dmm_th1kh.py @@ -0,0 +1,209 @@ +import os, sys +import cv2 +import numpy as np +from time import time +from scipy.io import savemat +import argparse +from tqdm import tqdm, trange +import torch +import face_alignment +import deep_3drecon +from moviepy.editor import VideoFileClip +import copy +from utils.commons.multiprocess_utils import multiprocess_run_tqdm, multiprocess_run +from utils.commons.meters import Timer +from decord import VideoReader +from decord import cpu, gpu +from utils.commons.face_alignment_utils import mediapipe_lm478_to_face_alignment_lm68 +import mediapipe + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +# fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, network_size=4, device='cuda') +mp_face_mesh = mediapipe.solutions.face_mesh +face_reconstructor = deep_3drecon.Reconstructor() + + +def chunk(iterable, chunk_size): + final_ret = [] + cnt = 0 + ret = [] + for record in iterable: + if cnt == 0: + ret = [] + ret.append(record) + cnt += 1 + if len(ret) == chunk_size: + final_ret.append(ret) + ret = [] + if len(final_ret[-1]) != chunk_size: + final_ret.append(ret) + return final_ret + +# landmark detection in Deep3DRecon +def lm68_2_lm5(in_lm): + assert in_lm.ndim == 2 + # in_lm: shape=[68,2] + lm_idx = np.array([31,37,40,43,46,49,55]) - 1 + # 将上述特殊角点的数据取出,得到5个新的角点数据,拼接起来。 + lm = np.stack([in_lm[lm_idx[0],:],np.mean(in_lm[lm_idx[[1,2]],:],0),np.mean(in_lm[lm_idx[[3,4]],:],0),in_lm[lm_idx[5],:],in_lm[lm_idx[6],:]], axis = 0) + # 将第一个角点放在了第三个位置 + lm = lm[[1,2,0,3,4],:2] + return lm + +def extract_frames_job(fname): + out_name=fname.replace(".mp4", "_coeff_pt.npy").replace("datasets/raw/cropped_clips", "datasets/processed/coeff") + if os.path.exists(out_name): + return None + video_reader = VideoReader(fname, ctx=cpu(0)) + frame_rgb_lst = video_reader.get_batch(list(range(0,len(video_reader)))).asnumpy() + return frame_rgb_lst + +def extract_lms_mediapipe_job(frames): + if frames is None: + return None + with mp_face_mesh.FaceMesh( + static_image_mode=False, + max_num_faces=1, + refine_landmarks=True, + min_detection_confidence=0.5) as face_mesh: + ldms_normed = [] + frame_i = 0 + frame_ids = [] + for i in range(len(frames)): + # Convert the BGR image to RGB before processing. + ret = face_mesh.process(frames[i]) + # Print and draw face mesh landmarks on the image. + if not ret.multi_face_landmarks: + print(f"Skip Item: Caught errors when mediapipe get face_mesh, maybe No face detected in some frames!") + return None + else: + myFaceLandmarks = [] + lms = ret.multi_face_landmarks[0] + for lm in lms.landmark: + myFaceLandmarks.append([lm.x, lm.y, lm.z]) + ldms_normed.append(myFaceLandmarks) + frame_ids.append(frame_i) + frame_i += 1 + bs, H, W, _ = frames.shape + ldms478 = np.array(ldms_normed) + lm68 = mediapipe_lm478_to_face_alignment_lm68(ldms478, H, W, return_2d=True) + lm5_lst = [lm68_2_lm5(lm68[i]) for i in range(lm68.shape[0])] + lm5 = np.stack(lm5_lst) + return ldms478, lm68, lm5 + +def process_video_batch(fname_lst, out_name_lst=None): + frames_lst = [] + with Timer("load_frames", True): + for (i, res) in multiprocess_run_tqdm(extract_frames_job, fname_lst, num_workers=2, desc="decord is loading frames in the batch videos..."): + frames_lst.append(res) + + lm478s_lst = [] + lm68s_lst = [] + lm5s_lst = [] + with Timer("mediapipe_faceAlign", True): + for (i, res) in multiprocess_run_tqdm(extract_lms_mediapipe_job, frames_lst, num_workers=2, desc="mediapipe is predicting face mesh in batch videos..."): + if res is None: + res = (None, None, None) + lm478s, lm68s, lm5s = res + lm478s_lst.append(lm478s) + lm68s_lst.append(lm68s) + lm5s_lst.append(lm5s) + + processed_cnt_in_this_batch = 0 + with Timer("deep_3drecon_pytorch", True): + for i, fname in tqdm(enumerate(fname_lst), total=len(fname_lst), desc="extracting 3DMM in the batch videos..."): + video_rgb = frames_lst[i] # [t, 224,224, 3] + lm478_arr = lm478s_lst[i] + lm68_arr = lm68s_lst[i] + lm5_arr = lm5s_lst[i] + if lm5_arr is None: + continue + num_frames = len(video_rgb) + batch_size = 32 + iter_times = num_frames // batch_size + last_bs = num_frames % batch_size + + coeff_lst = [] + for i_iter in range(iter_times): + start_idx = i_iter * batch_size + batched_images = video_rgb[start_idx: start_idx + batch_size] + batched_lm5 = lm5_arr[start_idx: start_idx + batch_size] + coeff, align_img = face_reconstructor.recon_coeff(batched_images, batched_lm5, return_image = True) + coeff_lst.append(coeff) + if last_bs != 0: + batched_images = video_rgb[-last_bs:] + batched_lm5 = lm5_arr[-last_bs:] + coeff, align_img = face_reconstructor.recon_coeff(batched_images, batched_lm5, return_image = True) + coeff_lst.append(coeff) + coeff_arr = np.concatenate(coeff_lst,axis=0) + result_dict = { + 'coeff': coeff_arr.reshape([num_frames, -1]).astype(np.float32), + 'lm478': lm478_arr.reshape([num_frames, 478, 3]).astype(np.float32), + 'lm68': lm68_arr.reshape([num_frames, 68, 2]).astype(np.int16), + 'lm5': lm5_arr.reshape([num_frames, 5, 2]).astype(np.int16), + } + np.save(out_name_lst[i], result_dict) + processed_cnt_in_this_batch +=1 + + print(f"In this batch {processed_cnt_in_this_batch} files are processed") + + + +def split_wav(mp4_name): + wav_name = mp4_name[:-4] + '.wav' + if os.path.exists(wav_name): + return + video = VideoFileClip(mp4_name,verbose=False) + dur = video.duration + audio = video.audio + assert audio is not None + audio.write_audiofile(wav_name,fps=16000,verbose=False,logger=None) + +if __name__ == '__main__': + ### Process Single Long video for NeRF dataset + # video_id = 'May' + # video_fname = f"data/raw/videos/{video_id}.mp4" + # out_fname = f"data/processed/videos/{video_id}/coeff.npy" + # process_video(video_fname, out_fname) + + ### Process short video clips for LRS3 dataset + import random + + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument('--lrs3_path', type=str, default='/home/yezhenhui/projects/TalkingHead-1KH/datasets/raw/cropped_clips', help='') + parser.add_argument('--process_id', type=int, default=0, help='') + parser.add_argument('--total_process', type=int, default=1, help='') + args = parser.parse_args() + + import os, glob + lrs3_dir = args.lrs3_path + out_dir = lrs3_dir.replace("raw/cropped_clips", "processed/coeff") + os.makedirs(out_dir, exist_ok=True) + # mp4_name_pattern = os.path.join(lrs3_dir, "*.mp4") + # mp4_names = glob.glob(mp4_name_pattern) + with open('/home/yezhenhui/projects/LDMAvatar/clean.txt', 'r') as f: + txt = f.read() + mp4_names = txt.split("\n") + mp4_names = sorted(mp4_names) + if args.total_process > 1: + assert args.process_id <= args.total_process-1 + num_samples_per_process = len(mp4_names) // args.total_process + if args.process_id == args.total_process-1: + mp4_names = mp4_names[args.process_id * num_samples_per_process : ] + else: + mp4_names = mp4_names[args.process_id * num_samples_per_process : (args.process_id+1) * num_samples_per_process] + random.seed(111) + random.shuffle(mp4_names) + batched_mp4_names_lst = chunk(mp4_names, chunk_size=8) + for batch_mp4_names in tqdm(batched_mp4_names_lst, desc='[ROOT]: extracting face mesh and 3DMM in batches...'): + try: + for mp4_name in batch_mp4_names: + split_wav(mp4_name) + out_names = [mp4_name.replace(".mp4", "_coeff_pt.npy").replace("datasets/raw/cropped_clips", "datasets/processed/coeff") for mp4_name in batch_mp4_names] + process_video_batch(batch_mp4_names, out_names) + # process_video(mp4_name,out_name=mp4_name.replace(".mp4", "_coeff_pt.npy").replace("datasets/raw/cropped_clips", "datasets/processed/coeff")) + except Exception as e: + print(e) + continue diff --git a/Geneface_main/GeneFace/data_gen/process_lrs3/process_video_3dmm_vox2.py b/Geneface_main/GeneFace/data_gen/process_lrs3/process_video_3dmm_vox2.py new file mode 100644 index 00000000..8bdbc40b --- /dev/null +++ b/Geneface_main/GeneFace/data_gen/process_lrs3/process_video_3dmm_vox2.py @@ -0,0 +1,227 @@ +import os, sys +import numpy as np +from tqdm import tqdm, trange +import deep_3drecon +from moviepy.editor import VideoFileClip +from utils.commons.multiprocess_utils import multiprocess_run_tqdm, multiprocess_run +from utils.commons.meters import Timer +from decord import VideoReader +from decord import cpu, gpu +from utils.commons.face_alignment_utils import mediapipe_lm478_to_face_alignment_lm68 +import mediapipe +import cv2 + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +# fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, network_size=4, device='cuda') +mp_face_mesh = mediapipe.solutions.face_mesh +face_reconstructor = deep_3drecon.Reconstructor() + + +def chunk(iterable, chunk_size): + final_ret = [] + cnt = 0 + ret = [] + for record in iterable: + if cnt == 0: + ret = [] + ret.append(record) + cnt += 1 + if len(ret) == chunk_size: + final_ret.append(ret) + ret = [] + if len(final_ret[-1]) != chunk_size: + final_ret.append(ret) + return final_ret + +# landmark detection in Deep3DRecon +def lm68_2_lm5(in_lm): + assert in_lm.ndim == 2 + # in_lm: shape=[68,2] + lm_idx = np.array([31,37,40,43,46,49,55]) - 1 + # 将上述特殊角点的数据取出,得到5个新的角点数据,拼接起来。 + lm = np.stack([in_lm[lm_idx[0],:],np.mean(in_lm[lm_idx[[1,2]],:],0),np.mean(in_lm[lm_idx[[3,4]],:],0),in_lm[lm_idx[5],:],in_lm[lm_idx[6],:]], axis = 0) + # 将第一个角点放在了第三个位置 + lm = lm[[1,2,0,3,4],:2] + return lm + +def extract_frames_job(fname): + try: + out_name=fname.replace(".mp4", "_coeff_pt.npy").replace("/dev/", "/coeff/") + if os.path.exists(out_name): + return None + cap = cv2.VideoCapture(fname) + frames = [] + while cap.isOpened(): + ret, frame_bgr = cap.read() + if frame_bgr is None: + break + frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) + frames.append(frame_rgb) + return np.stack(frames) + # out_name=fname.replace(".mp4", "_coeff_pt.npy").replace("/dev/", "/coeff/") + # if os.path.exists(out_name): + # return None + # video_reader = VideoReader(fname, ctx=cpu(0)) + # frame_rgb_lst = video_reader.get_batch(list(range(0,len(video_reader)))).asnumpy() + # return frame_rgb_lst + except Exception as e: + print(e) + return None + +def extract_lms_mediapipe_job(frames): + try: + if frames is None: + return None + with mp_face_mesh.FaceMesh( + static_image_mode=False, + max_num_faces=1, + refine_landmarks=True, + min_detection_confidence=0.5) as face_mesh: + ldms_normed = [] + frame_i = 0 + frame_ids = [] + for i in range(len(frames)): + # Convert the BGR image to RGB before processing. + ret = face_mesh.process(frames[i]) + # Print and draw face mesh landmarks on the image. + if not ret.multi_face_landmarks: + print(f"Skip Item: Caught errors when mediapipe get face_mesh, maybe No face detected in some frames!") + return None + else: + myFaceLandmarks = [] + lms = ret.multi_face_landmarks[0] + for lm in lms.landmark: + myFaceLandmarks.append([lm.x, lm.y, lm.z]) + ldms_normed.append(myFaceLandmarks) + frame_ids.append(frame_i) + frame_i += 1 + bs, H, W, _ = frames.shape + ldms478 = np.array(ldms_normed) + lm68 = mediapipe_lm478_to_face_alignment_lm68(ldms478, H, W, return_2d=True) + lm5_lst = [lm68_2_lm5(lm68[i]) for i in range(lm68.shape[0])] + lm5 = np.stack(lm5_lst) + return ldms478, lm68, lm5 + except Exception as e: + print(e) + return None + +def process_video_batch(fname_lst, out_name_lst=None): + frames_lst = [] + with Timer("load_frames", True): + for fname in tqdm(fname_lst, desc="decord is loading frames in the batch videos..."): + res = extract_frames_job(fname) + frames_lst.append(res) + # for (i, res) in multiprocess_run_tqdm(extract_frames_job, fname_lst, num_workers=1, desc="decord is loading frames in the batch videos..."): + # frames_lst.append(res) + + lm478s_lst = [] + lm68s_lst = [] + lm5s_lst = [] + with Timer("mediapipe_faceAlign", True): + # for (i, res) in multiprocess_run_tqdm(extract_lms_mediapipe_job, frames_lst, num_workers=2, desc="mediapipe is predicting face mesh in batch videos..."): + for i, frames in tqdm(enumerate(frames_lst),total=len(fname_lst), desc="mediapipe is predicting face mesh in batch videos..."): + res = extract_lms_mediapipe_job(frames) + if res is None: + res = (None, None, None) + lm478s, lm68s, lm5s = res + lm478s_lst.append(lm478s) + lm68s_lst.append(lm68s) + lm5s_lst.append(lm5s) + + processed_cnt_in_this_batch = 0 + with Timer("deep_3drecon_pytorch", True): + for i, fname in tqdm(enumerate(fname_lst), total=len(fname_lst), desc="extracting 3DMM in the batch videos..."): + video_rgb = frames_lst[i] # [t, 224,224, 3] + lm478_arr = lm478s_lst[i] + lm68_arr = lm68s_lst[i] + lm5_arr = lm5s_lst[i] + if lm5_arr is None: + continue + num_frames = len(video_rgb) + batch_size = 32 + iter_times = num_frames // batch_size + last_bs = num_frames % batch_size + + coeff_lst = [] + for i_iter in range(iter_times): + start_idx = i_iter * batch_size + batched_images = video_rgb[start_idx: start_idx + batch_size] + batched_lm5 = lm5_arr[start_idx: start_idx + batch_size] + coeff, align_img = face_reconstructor.recon_coeff(batched_images, batched_lm5, return_image = True) + coeff_lst.append(coeff) + if last_bs != 0: + batched_images = video_rgb[-last_bs:] + batched_lm5 = lm5_arr[-last_bs:] + coeff, align_img = face_reconstructor.recon_coeff(batched_images, batched_lm5, return_image = True) + coeff_lst.append(coeff) + coeff_arr = np.concatenate(coeff_lst,axis=0) + result_dict = { + 'coeff': coeff_arr.reshape([num_frames, -1]).astype(np.float32), + 'lm478': lm478_arr.reshape([num_frames, 478, 3]).astype(np.float32), + 'lm68': lm68_arr.reshape([num_frames, 68, 2]).astype(np.int16), + 'lm5': lm5_arr.reshape([num_frames, 5, 2]).astype(np.int16), + } + os.makedirs(os.path.dirname(out_name_lst[i]),exist_ok=True) + np.save(out_name_lst[i], result_dict) + processed_cnt_in_this_batch +=1 + + print(f"In this batch {processed_cnt_in_this_batch} files are processed") + + + +def split_wav(mp4_name): + try: + wav_name = mp4_name[:-4] + '.wav' + if os.path.exists(wav_name): + return + video = VideoFileClip(mp4_name,verbose=False) + dur = video.duration + audio = video.audio + assert audio is not None + audio.write_audiofile(wav_name,fps=16000,verbose=False,logger=None) + except Exception as e: + print(e) + return None + +if __name__ == '__main__': + ### Process Single Long video for NeRF dataset + # video_id = 'May' + # video_fname = f"data/raw/videos/{video_id}.mp4" + # out_fname = f"data/processed/videos/{video_id}/coeff.npy" + # process_video(video_fname, out_fname) + + ### Process short video clips for LRS3 dataset + import random + + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument('--lrs3_path', type=str, default='/mnt/sda/yezhenhui/datasets/voxceleb2', help='') + parser.add_argument('--process_id', type=int, default=0, help='') + parser.add_argument('--total_process', type=int, default=1, help='') + args = parser.parse_args() + + import os, glob + lrs3_dir = args.lrs3_path + mp4_name_pattern = os.path.join(lrs3_dir, "dev/id*/*/*.mp4") + mp4_names = glob.glob(mp4_name_pattern) + + if args.total_process > 1: + assert args.process_id <= args.total_process-1 + num_samples_per_process = len(mp4_names) // args.total_process + if args.process_id == args.total_process-1: + mp4_names = mp4_names[args.process_id * num_samples_per_process : ] + else: + mp4_names = mp4_names[args.process_id * num_samples_per_process : (args.process_id+1) * num_samples_per_process] + random.seed(111) + random.shuffle(mp4_names) + batched_mp4_names_lst = chunk(mp4_names, chunk_size=1) + for batch_mp4_names in tqdm(batched_mp4_names_lst, desc='[ROOT]: extracting face mesh and 3DMM in batches...'): + try: + for mp4_name in batch_mp4_names: + split_wav(mp4_name) + out_names = [mp4_name.replace(".mp4", "_coeff_pt.npy").replace("/dev/", "/coeff/") for mp4_name in batch_mp4_names] + process_video_batch(batch_mp4_names, out_names) + except Exception as e: + print(e) + continue diff --git a/Geneface_main/GeneFace/data_util/__pycache__/extract_mel.cpython-39.pyc b/Geneface_main/GeneFace/data_util/__pycache__/extract_mel.cpython-39.pyc new file mode 100644 index 00000000..86dfb708 Binary files /dev/null and b/Geneface_main/GeneFace/data_util/__pycache__/extract_mel.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/data_util/__pycache__/face3d_helper.cpython-39.pyc b/Geneface_main/GeneFace/data_util/__pycache__/face3d_helper.cpython-39.pyc new file mode 100644 index 00000000..eb47fdd2 Binary files /dev/null and b/Geneface_main/GeneFace/data_util/__pycache__/face3d_helper.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/data_util/deepspeech_features/README.md b/Geneface_main/GeneFace/data_util/deepspeech_features/README.md new file mode 100644 index 00000000..b15c987a --- /dev/null +++ b/Geneface_main/GeneFace/data_util/deepspeech_features/README.md @@ -0,0 +1,20 @@ +# Routines for DeepSpeech features processing +Several routines for [DeepSpeech](https://github.com/mozilla/DeepSpeech) features processing, like speech features generation for [VOCA](https://github.com/TimoBolkart/voca) model. + +## Installation + +``` +pip3 install -r requirements.txt +``` + +## Usage + +Generate wav files: +``` +python3 extract_wav.py --in-video= +``` + +Generate files with DeepSpeech features: +``` +python3 extract_ds_features.py --input= +``` diff --git a/Geneface_main/GeneFace/data_util/deepspeech_features/__pycache__/deepspeech_features.cpython-39.pyc b/Geneface_main/GeneFace/data_util/deepspeech_features/__pycache__/deepspeech_features.cpython-39.pyc new file mode 100644 index 00000000..9fdcefb5 Binary files /dev/null and b/Geneface_main/GeneFace/data_util/deepspeech_features/__pycache__/deepspeech_features.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/data_util/deepspeech_features/__pycache__/deepspeech_store.cpython-39.pyc b/Geneface_main/GeneFace/data_util/deepspeech_features/__pycache__/deepspeech_store.cpython-39.pyc new file mode 100644 index 00000000..5d588d39 Binary files /dev/null and b/Geneface_main/GeneFace/data_util/deepspeech_features/__pycache__/deepspeech_store.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/data_util/deepspeech_features/deepspeech_features.py b/Geneface_main/GeneFace/data_util/deepspeech_features/deepspeech_features.py new file mode 100644 index 00000000..64582434 --- /dev/null +++ b/Geneface_main/GeneFace/data_util/deepspeech_features/deepspeech_features.py @@ -0,0 +1,284 @@ +""" + DeepSpeech features processing routines. + NB: Based on VOCA code. See the corresponding license restrictions. +""" + +__all__ = ['conv_audios_to_deepspeech'] + +import numpy as np +import warnings +import resampy +from scipy.io import wavfile +from python_speech_features import mfcc +import tensorflow as tf + + +def conv_audios_to_deepspeech(audios, + out_files, + num_frames_info, + deepspeech_pb_path, + audio_window_size=1, + audio_window_stride=1): + """ + Convert list of audio files into files with DeepSpeech features. + + Parameters + ---------- + audios : list of str or list of None + Paths to input audio files. + out_files : list of str + Paths to output files with DeepSpeech features. + num_frames_info : list of int + List of numbers of frames. + deepspeech_pb_path : str + Path to DeepSpeech 0.1.0 frozen model. + audio_window_size : int, default 16 + Audio window size. + audio_window_stride : int, default 1 + Audio window stride. + """ + # deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm" + graph, logits_ph, input_node_ph, input_lengths_ph = prepare_deepspeech_net( + deepspeech_pb_path) + + with tf.compat.v1.Session(graph=graph) as sess: + for audio_file_path, out_file_path, num_frames in zip(audios, out_files, num_frames_info): + print("tring to extract deepspeech from audio file: ",audio_file_path) + print("The target is: ", out_file_path) + audio_sample_rate, audio = wavfile.read(audio_file_path) + if audio.ndim != 1: + warnings.warn( + "Audio has multiple channels, the first channel is used") + audio = audio[:, 0] + ds_features = pure_conv_audio_to_deepspeech( + audio=audio, + audio_sample_rate=audio_sample_rate, + audio_window_size=audio_window_size, + audio_window_stride=audio_window_stride, + num_frames=num_frames, + net_fn=lambda x: sess.run( + logits_ph, + feed_dict={ + input_node_ph: x[np.newaxis, ...], + input_lengths_ph: [x.shape[0]]})) + + net_output = ds_features.reshape(-1, 29) + win_size = 16 + zero_pad = np.zeros((int(win_size / 2), net_output.shape[1])) + net_output = np.concatenate( + (zero_pad, net_output, zero_pad), axis=0) + windows = [] + for window_index in range(0, net_output.shape[0] - win_size, 2): + windows.append( + net_output[window_index:window_index + win_size]) + np.save(out_file_path, np.array(windows)) + print("The deepspeech extracted successfully, saved at: ", out_file_path) + print("The shape is: ", np.array(windows).shape) + + +def prepare_deepspeech_net(deepspeech_pb_path): + """ + Load and prepare DeepSpeech network. + + Parameters + ---------- + deepspeech_pb_path : str + Path to DeepSpeech 0.1.0 frozen model. + + Returns + ------- + graph : obj + ThensorFlow graph. + logits_ph : obj + ThensorFlow placeholder for `logits`. + input_node_ph : obj + ThensorFlow placeholder for `input_node`. + input_lengths_ph : obj + ThensorFlow placeholder for `input_lengths`. + """ + # Load graph and place_holders: + with tf.io.gfile.GFile(deepspeech_pb_path, "rb") as f: + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(f.read()) + print(deepspeech_pb_path) + graph = tf.compat.v1.get_default_graph() + tf.import_graph_def(graph_def, name="deepspeech") + # if tensorflow=1.15 + if tf.__version__.startswith("1."): + logits_ph = graph.get_tensor_by_name("deepspeech/logits:0") + input_node_ph = graph.get_tensor_by_name("deepspeech/input_node:0") + input_lengths_ph = graph.get_tensor_by_name("deepspeech/input_lengths:0") + # if tensorflow=2.x + elif tf.__version__.startswith("2."): + logits_ph = graph.get_tensor_by_name("logits:0") + input_node_ph = graph.get_tensor_by_name("input_node:0") + input_lengths_ph = graph.get_tensor_by_name("input_lengths:0") + else: + raise ValueError("") + return graph, logits_ph, input_node_ph, input_lengths_ph + + +def pure_conv_audio_to_deepspeech(audio, + audio_sample_rate, + audio_window_size, + audio_window_stride, + num_frames, + net_fn): + """ + Core routine for converting audion into DeepSpeech features. + + Parameters + ---------- + audio : np.array + Audio data. + audio_sample_rate : int + Audio sample rate. + audio_window_size : int + Audio window size. + audio_window_stride : int + Audio window stride. + num_frames : int or None + Numbers of frames. + net_fn : func + Function for DeepSpeech model call. + + Returns + ------- + np.array + DeepSpeech features. + """ + target_sample_rate = 16000 + if audio_sample_rate != target_sample_rate: + resampled_audio = resampy.resample( + x=audio.astype(np.float), + sr_orig=audio_sample_rate, + sr_new=target_sample_rate) + else: + resampled_audio = audio.astype(np.float) + input_vector = conv_audio_to_deepspeech_input_vector( + audio=resampled_audio.astype(np.int16), + sample_rate=target_sample_rate, + num_cepstrum=26, + num_context=9) + + network_output = net_fn(input_vector) + # print(network_output.shape) + + deepspeech_fps = 50 + video_fps = 50 # Change this option if video fps is different + audio_len_s = float(audio.shape[0]) / audio_sample_rate + if num_frames is None: + num_frames = int(round(audio_len_s * video_fps)) + else: + video_fps = num_frames / audio_len_s + network_output = interpolate_features( + features=network_output[:, 0], + input_rate=deepspeech_fps, + output_rate=video_fps, + output_len=num_frames) + + # Make windows: + zero_pad = np.zeros((int(audio_window_size / 2), network_output.shape[1])) + network_output = np.concatenate( + (zero_pad, network_output, zero_pad), axis=0) + windows = [] + for window_index in range(0, network_output.shape[0] - audio_window_size, audio_window_stride): + windows.append( + network_output[window_index:window_index + audio_window_size]) + + return np.array(windows) + + +def conv_audio_to_deepspeech_input_vector(audio, + sample_rate, + num_cepstrum, + num_context): + """ + Convert audio raw data into DeepSpeech input vector. + + Parameters + ---------- + audio : np.array + Audio data. + audio_sample_rate : int + Audio sample rate. + num_cepstrum : int + Number of cepstrum. + num_context : int + Number of context. + + Returns + ------- + np.array + DeepSpeech input vector. + """ + # Get mfcc coefficients: + features = mfcc( + signal=audio, + samplerate=sample_rate, + numcep=num_cepstrum) + + # We only keep every second feature (BiRNN stride = 2): + features = features[::2] + + # One stride per time step in the input: + num_strides = len(features) + + # Add empty initial and final contexts: + empty_context = np.zeros((num_context, num_cepstrum), dtype=features.dtype) + features = np.concatenate((empty_context, features, empty_context)) + + # Create a view into the array with overlapping strides of size + # numcontext (past) + 1 (present) + numcontext (future): + window_size = 2 * num_context + 1 + train_inputs = np.lib.stride_tricks.as_strided( + features, + shape=(num_strides, window_size, num_cepstrum), + strides=(features.strides[0], + features.strides[0], features.strides[1]), + writeable=False) + + # Flatten the second and third dimensions: + train_inputs = np.reshape(train_inputs, [num_strides, -1]) + + train_inputs = np.copy(train_inputs) + train_inputs = (train_inputs - np.mean(train_inputs)) / \ + np.std(train_inputs) + + return train_inputs + + +def interpolate_features(features, + input_rate, + output_rate, + output_len): + """ + Interpolate DeepSpeech features. + + Parameters + ---------- + features : np.array + DeepSpeech features. + input_rate : int + input rate (FPS). + output_rate : int + Output rate (FPS). + output_len : int + Output data length. + + Returns + ------- + np.array + Interpolated data. + """ + input_len = features.shape[0] + num_features = features.shape[1] + input_timestamps = np.arange(input_len) / float(input_rate) + output_timestamps = np.arange(output_len) / float(output_rate) + output_features = np.zeros((output_len, num_features)) + for feature_idx in range(num_features): + output_features[:, feature_idx] = np.interp( + x=output_timestamps, + xp=input_timestamps, + fp=features[:, feature_idx]) + return output_features diff --git a/Geneface_main/GeneFace/data_util/deepspeech_features/deepspeech_store.py b/Geneface_main/GeneFace/data_util/deepspeech_features/deepspeech_store.py new file mode 100644 index 00000000..5595a4d5 --- /dev/null +++ b/Geneface_main/GeneFace/data_util/deepspeech_features/deepspeech_store.py @@ -0,0 +1,172 @@ +""" + Routines for loading DeepSpeech model. +""" + +__all__ = ['get_deepspeech_model_file'] + +import os +import zipfile +import logging +import hashlib + + +deepspeech_features_repo_url = 'https://github.com/osmr/deepspeech_features' + + +def get_deepspeech_model_file(local_model_store_dir_path=os.path.join("~", ".tensorflow", "models")): + """ + Return location for the pretrained on local file system. This function will download from online model zoo when + model cannot be found or has mismatch. The root directory will be created if it doesn't exist. + + Parameters + ---------- + local_model_store_dir_path : str, default $TENSORFLOW_HOME/models + Location for keeping the model parameters. + + Returns + ------- + file_path + Path to the requested pretrained model file. + """ + sha1_hash = "b90017e816572ddce84f5843f1fa21e6a377975e" + file_name = "deepspeech-0_1_0-b90017e8.pb" + local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path) + file_path = os.path.join(local_model_store_dir_path, file_name) + if os.path.exists(file_path): + if _check_sha1(file_path, sha1_hash): + return file_path + else: + logging.warning("Mismatch in the content of model file detected. Downloading again.") + else: + logging.info("Model file not found. Downloading to {}.".format(file_path)) + + if not os.path.exists(local_model_store_dir_path): + os.makedirs(local_model_store_dir_path) + + zip_file_path = file_path + ".zip" + _download( + url="{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip".format( + repo_url=deepspeech_features_repo_url, + repo_release_tag="v0.0.1", + file_name=file_name), + path=zip_file_path, + overwrite=True) + with zipfile.ZipFile(zip_file_path) as zf: + zf.extractall(local_model_store_dir_path) + os.remove(zip_file_path) + + if _check_sha1(file_path, sha1_hash): + return file_path + else: + raise ValueError("Downloaded file has different hash. Please try again.") + + +def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True): + """ + Download an given URL + + Parameters + ---------- + url : str + URL to download + path : str, optional + Destination path to store downloaded file. By default stores to the + current directory with same name as in url. + overwrite : bool, optional + Whether to overwrite destination file if already exists. + sha1_hash : str, optional + Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified + but doesn't match. + retries : integer, default 5 + The number of times to attempt the download in case of failure or non 200 return codes + verify_ssl : bool, default True + Verify SSL certificates. + + Returns + ------- + str + The file path of the downloaded file. + """ + import warnings + try: + import requests + except ImportError: + class requests_failed_to_import(object): + pass + requests = requests_failed_to_import + + if path is None: + fname = url.split("/")[-1] + # Empty filenames are invalid + assert fname, "Can't construct file-name from this URL. Please set the `path` option manually." + else: + path = os.path.expanduser(path) + if os.path.isdir(path): + fname = os.path.join(path, url.split("/")[-1]) + else: + fname = path + assert retries >= 0, "Number of retries should be at least 0" + + if not verify_ssl: + warnings.warn( + "Unverified HTTPS request is being made (verify_ssl=False). " + "Adding certificate verification is strongly advised.") + + if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)): + dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) + if not os.path.exists(dirname): + os.makedirs(dirname) + while retries + 1 > 0: + # Disable pyling too broad Exception + # pylint: disable=W0703 + try: + print("Downloading {} from {}...".format(fname, url)) + r = requests.get(url, stream=True, verify=verify_ssl) + if r.status_code != 200: + raise RuntimeError("Failed downloading url {}".format(url)) + with open(fname, "wb") as f: + for chunk in r.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + if sha1_hash and not _check_sha1(fname, sha1_hash): + raise UserWarning("File {} is downloaded but the content hash does not match." + " The repo may be outdated or download may be incomplete. " + "If the `repo_url` is overridden, consider switching to " + "the default repo.".format(fname)) + break + except Exception as e: + retries -= 1 + if retries <= 0: + raise e + else: + print("download failed, retrying, {} attempt{} left" + .format(retries, "s" if retries > 1 else "")) + + return fname + + +def _check_sha1(filename, sha1_hash): + """ + Check whether the sha1 hash of the file content matches the expected hash. + + Parameters + ---------- + filename : str + Path to the file. + sha1_hash : str + Expected sha1 hash in hexadecimal digits. + + Returns + ------- + bool + Whether the file content matches the expected hash. + """ + sha1 = hashlib.sha1() + with open(filename, "rb") as f: + while True: + data = f.read(1048576) + if not data: + break + sha1.update(data) + + return sha1.hexdigest() == sha1_hash diff --git a/Geneface_main/GeneFace/data_util/deepspeech_features/extract_ds_features.py b/Geneface_main/GeneFace/data_util/deepspeech_features/extract_ds_features.py new file mode 100644 index 00000000..004f3194 --- /dev/null +++ b/Geneface_main/GeneFace/data_util/deepspeech_features/extract_ds_features.py @@ -0,0 +1,130 @@ +""" + Script for extracting DeepSpeech features from audio file. +""" + +import os +import argparse +import numpy as np +import pandas as pd +from deepspeech_store import get_deepspeech_model_file +from deepspeech_features import conv_audios_to_deepspeech + + +def parse_args(): + """ + Create python script parameters. + Returns + ------- + ArgumentParser + Resulted args. + """ + parser = argparse.ArgumentParser( + description="Extract DeepSpeech features from audio file", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--input", + type=str, + required=True, + help="path to input audio file or directory") + parser.add_argument( + "--output", + type=str, + help="path to output file with DeepSpeech features") + parser.add_argument( + "--deepspeech", + type=str, + # default='data_util/deepspeech_features/deepspeech-0.9.2-models.pbmm', + help="path to DeepSpeech 0.1.0 frozen model") + parser.add_argument( + "--metainfo", + type=str, + help="path to file with meta-information") + + args = parser.parse_args() + return args + + +def extract_features(in_audios, + out_files, + deepspeech_pb_path, + metainfo_file_path=None): + """ + Real extract audio from video file. + Parameters + ---------- + in_audios : list of str + Paths to input audio files. + out_files : list of str + Paths to output files with DeepSpeech features. + deepspeech_pb_path : str + Path to DeepSpeech 0.1.0 frozen model. + metainfo_file_path : str, default None + Path to file with meta-information. + """ + #deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm" + if metainfo_file_path is None: + num_frames_info = [None] * len(in_audios) + else: + train_df = pd.read_csv( + metainfo_file_path, + sep="\t", + index_col=False, + dtype={"Id": np.int, "File": np.unicode, "Count": np.int}) + num_frames_info = train_df["Count"].values + assert (len(num_frames_info) == len(in_audios)) + + for i, in_audio in enumerate(in_audios): + if not out_files[i]: + file_stem, _ = os.path.splitext(in_audio) + out_files[i] = file_stem + ".npy" + #print(out_files[i]) + conv_audios_to_deepspeech( + audios=in_audios, + out_files=out_files, + num_frames_info=num_frames_info, + deepspeech_pb_path=deepspeech_pb_path) + + +def main(): + """ + Main body of script. + """ + args = parse_args() + in_audio = os.path.expanduser(args.input) + if not os.path.exists(in_audio): + raise Exception("Input file/directory doesn't exist: {}".format(in_audio)) + deepspeech_pb_path = args.deepspeech + #deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm" + if deepspeech_pb_path is None: + deepspeech_pb_path = "" + if deepspeech_pb_path: + deepspeech_pb_path = os.path.expanduser(args.deepspeech) + if not os.path.exists(deepspeech_pb_path): + deepspeech_pb_path = get_deepspeech_model_file() + if os.path.isfile(in_audio): + extract_features( + in_audios=[in_audio], + out_files=[args.output], + deepspeech_pb_path=deepspeech_pb_path, + metainfo_file_path=args.metainfo) + else: + audio_file_paths = [] + for file_name in os.listdir(in_audio): + if not os.path.isfile(os.path.join(in_audio, file_name)): + continue + _, file_ext = os.path.splitext(file_name) + if file_ext.lower() == ".wav": + audio_file_path = os.path.join(in_audio, file_name) + audio_file_paths.append(audio_file_path) + audio_file_paths = sorted(audio_file_paths) + out_file_paths = [""] * len(audio_file_paths) + extract_features( + in_audios=audio_file_paths, + out_files=out_file_paths, + deepspeech_pb_path=deepspeech_pb_path, + metainfo_file_path=args.metainfo) + + +if __name__ == "__main__": + main() + diff --git a/Geneface_main/GeneFace/data_util/deepspeech_features/extract_wav.py b/Geneface_main/GeneFace/data_util/deepspeech_features/extract_wav.py new file mode 100644 index 00000000..8458c5f2 --- /dev/null +++ b/Geneface_main/GeneFace/data_util/deepspeech_features/extract_wav.py @@ -0,0 +1,87 @@ +""" + Script for extracting audio (16-bit, mono, 22000 Hz) from video file. +""" + +import os +import argparse +import subprocess + + +def parse_args(): + """ + Create python script parameters. + + Returns + ------- + ArgumentParser + Resulted args. + """ + parser = argparse.ArgumentParser( + description="Extract audio from video file", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--in-video", + type=str, + required=True, + help="path to input video file or directory") + parser.add_argument( + "--out-audio", + type=str, + help="path to output audio file") + + args = parser.parse_args() + return args + + +def extract_audio(in_video, + out_audio): + """ + Real extract audio from video file. + + Parameters + ---------- + in_video : str + Path to input video file. + out_audio : str + Path to output audio file. + """ + if not out_audio: + file_stem, _ = os.path.splitext(in_video) + out_audio = file_stem + ".wav" + # command1 = "ffmpeg -i {in_video} -vn -acodec copy {aac_audio}" + # command2 = "ffmpeg -i {aac_audio} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}" + # command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}" + command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 16000 {out_audio}" + subprocess.call([command.format(in_video=in_video, out_audio=out_audio)], shell=True) + + +def main(): + """ + Main body of script. + """ + args = parse_args() + in_video = os.path.expanduser(args.in_video) + if not os.path.exists(in_video): + raise Exception("Input file/directory doesn't exist: {}".format(in_video)) + if os.path.isfile(in_video): + extract_audio( + in_video=in_video, + out_audio=args.out_audio) + else: + video_file_paths = [] + for file_name in os.listdir(in_video): + if not os.path.isfile(os.path.join(in_video, file_name)): + continue + _, file_ext = os.path.splitext(file_name) + if file_ext.lower() in (".mp4", ".mkv", ".avi"): + video_file_path = os.path.join(in_video, file_name) + video_file_paths.append(video_file_path) + video_file_paths = sorted(video_file_paths) + for video_file_path in video_file_paths: + extract_audio( + in_video=video_file_path, + out_audio="") + + +if __name__ == "__main__": + main() diff --git a/Geneface_main/GeneFace/data_util/deepspeech_features/fea_win.py b/Geneface_main/GeneFace/data_util/deepspeech_features/fea_win.py new file mode 100644 index 00000000..df9e27b4 --- /dev/null +++ b/Geneface_main/GeneFace/data_util/deepspeech_features/fea_win.py @@ -0,0 +1,11 @@ +import numpy as np + +net_output = np.load('french.ds.npy').reshape(-1, 29) +win_size = 16 +zero_pad = np.zeros((int(win_size / 2), net_output.shape[1])) +net_output = np.concatenate((zero_pad, net_output, zero_pad), axis=0) +windows = [] +for window_index in range(0, net_output.shape[0] - win_size, 2): + windows.append(net_output[window_index:window_index + win_size]) +print(np.array(windows).shape) +np.save('aud_french.npy', np.array(windows)) diff --git a/Geneface_main/GeneFace/data_util/extract_esperanto.py b/Geneface_main/GeneFace/data_util/extract_esperanto.py new file mode 100644 index 00000000..5489cf11 --- /dev/null +++ b/Geneface_main/GeneFace/data_util/extract_esperanto.py @@ -0,0 +1,423 @@ +import time +import numpy as np +import torch +import torch.nn.functional as F +from transformers import AutoModelForCTC, AutoProcessor + +import pyaudio +import soundfile as sf +import resampy + +from queue import Queue +from threading import Thread, Event + + +def _read_frame(stream, exit_event, queue, chunk): + + while True: + if exit_event.is_set(): + print(f'[INFO] read frame thread ends') + break + frame = stream.read(chunk, exception_on_overflow=False) + frame = np.frombuffer(frame, dtype=np.int16).astype(np.float32) / 32767 # [chunk] + queue.put(frame) + +def _play_frame(stream, exit_event, queue, chunk): + + while True: + if exit_event.is_set(): + print(f'[INFO] play frame thread ends') + break + frame = queue.get() + frame = (frame * 32767).astype(np.int16).tobytes() + stream.write(frame, chunk) + +class ASR: + def __init__(self, opt): + + self.opt = opt + + self.play = opt.asr_play + + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.fps = opt.fps # 20 ms per frame + self.sample_rate = 16000 + self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000) + self.mode = 'live' if opt.asr_wav == '' else 'file' + + if 'esperanto' in self.opt.asr_model: + self.audio_dim = 44 + elif 'deepspeech' in self.opt.asr_model: + self.audio_dim = 29 + else: + self.audio_dim = 32 + + # prepare context cache + # each segment is (stride_left + ctx + stride_right) * 20ms, latency should be (ctx + stride_right) * 20ms + self.context_size = opt.m + self.stride_left_size = opt.l + self.stride_right_size = opt.r + self.text = '[START]\n' + self.terminated = False + self.frames = [] + + # pad left frames + if self.stride_left_size > 0: + self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size) + + + self.exit_event = Event() + self.audio_instance = pyaudio.PyAudio() + + # create input stream + if self.mode == 'file': + self.file_stream = self.create_file_stream() + else: + # start a background process to read frames + self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk) + self.queue = Queue() + self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk)) + + # play out the audio too...? + if self.play: + self.output_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=False, output=True, frames_per_buffer=self.chunk) + self.output_queue = Queue() + self.process_play_frame = Thread(target=_play_frame, args=(self.output_stream, self.exit_event, self.output_queue, self.chunk)) + + # current location of audio + self.idx = 0 + + # create wav2vec model + print(f'[INFO] loading ASR model {self.opt.asr_model}...') + self.processor = AutoProcessor.from_pretrained(opt.asr_model) + self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device) + + # prepare to save logits + if self.opt.asr_save_feats: + self.all_feats = [] + + # the extracted features + # use a loop queue to efficiently record endless features: [f--t---][-------][-------] + self.feat_buffer_size = 4 + self.feat_buffer_idx = 0 + self.feat_queue = torch.zeros(self.feat_buffer_size * self.context_size, self.audio_dim, dtype=torch.float32, device=self.device) + + # TODO: hard coded 16 and 8 window size... + self.front = self.feat_buffer_size * self.context_size - 8 # fake padding + self.tail = 8 + # attention window... + self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding... + + # warm up steps needed: mid + right + window_size + attention_size + self.warm_up_steps = self.context_size + self.stride_right_size + 8 + 2 * 3 + + self.listening = False + self.playing = False + + def listen(self): + # start + if self.mode == 'live' and not self.listening: + print(f'[INFO] starting read frame thread...') + self.process_read_frame.start() + self.listening = True + + if self.play and not self.playing: + print(f'[INFO] starting play frame thread...') + self.process_play_frame.start() + self.playing = True + + def stop(self): + + self.exit_event.set() + + if self.play: + self.output_stream.stop_stream() + self.output_stream.close() + if self.playing: + self.process_play_frame.join() + self.playing = False + + if self.mode == 'live': + self.input_stream.stop_stream() + self.input_stream.close() + if self.listening: + self.process_read_frame.join() + self.listening = False + + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + + self.stop() + + if self.mode == 'live': + # live mode: also print the result text. + self.text += '\n[END]' + print(self.text) + + def get_next_feat(self): + # return a [1/8, 16] window, for the next input to nerf side. + + while len(self.att_feats) < 8: + # [------f+++t-----] + if self.front < self.tail: + feat = self.feat_queue[self.front:self.tail] + # [++t-----------f+] + else: + feat = torch.cat([self.feat_queue[self.front:], self.feat_queue[:self.tail]], dim=0) + + self.front = (self.front + 2) % self.feat_queue.shape[0] + self.tail = (self.tail + 2) % self.feat_queue.shape[0] + + # print(self.front, self.tail, feat.shape) + + self.att_feats.append(feat.permute(1, 0)) + + att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16] + + # discard old + self.att_feats = self.att_feats[1:] + + return att_feat + + def run_step(self): + + if self.terminated: + return + + # get a frame of audio + frame = self.get_audio_frame() + + # the last frame + if frame is None: + # terminate, but always run the network for the left frames + self.terminated = True + else: + self.frames.append(frame) + # put to output + if self.play: + self.output_queue.put(frame) + # context not enough, do not run network. + if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size: + return + + inputs = np.concatenate(self.frames) # [N * chunk] + + # discard the old part to save memory + if not self.terminated: + self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] + + logits, labels, text = self.frame_to_text(inputs) + feats = logits # better lips-sync than labels + + # save feats + if self.opt.asr_save_feats: + self.all_feats.append(feats) + + # record the feats efficiently.. (no concat, constant memory) + start = self.feat_buffer_idx * self.context_size + end = start + feats.shape[0] + self.feat_queue[start:end] = feats + self.feat_buffer_idx = (self.feat_buffer_idx + 1) % self.feat_buffer_size + + # very naive, just concat the text output. + if text != '': + self.text = self.text + ' ' + text + + # will only run once at ternimation + if self.terminated: + self.text += '\n[END]' + print(self.text) + if self.opt.asr_save_feats: + print(f'[INFO] save all feats for training purpose... ') + feats = torch.cat(self.all_feats, dim=0) # [N, C] + # print('[INFO] before unfold', feats.shape) + window_size = 16 + padding = window_size // 2 + feats = feats.view(-1, self.audio_dim).permute(1, 0).contiguous() # [C, M] + feats = feats.view(1, self.audio_dim, -1, 1) # [1, C, M, 1] + unfold_feats = F.unfold(feats, kernel_size=(window_size, 1), padding=(padding, 0), stride=(2, 1)) # [1, C * window_size, M / 2 + 1] + unfold_feats = unfold_feats.view(self.audio_dim, window_size, -1).permute(2, 1, 0).contiguous() # [C, window_size, M / 2 + 1] --> [M / 2 + 1, window_size, C] + # print('[INFO] after unfold', unfold_feats.shape) + # save to a npy file + if 'esperanto' in self.opt.asr_model: + if self.opt.out_name == '': + output_path = self.opt.asr_wav.replace('.wav', '_esperanto.npy') + else: + output_path = self.opt.out_name + else: + output_path = self.opt.asr_wav.replace('.wav', '.npy') + np.save(output_path, unfold_feats.cpu().numpy()) # [T, W=16, C=44] + print(f"[INFO] saved logits to {output_path}") + + def create_file_stream(self): + + stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64 + stream = stream.astype(np.float32) + + if stream.ndim > 1: + print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') + stream = stream[:, 0] + + if sample_rate != self.sample_rate: + print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') + stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) + + print(f'[INFO] loaded audio stream {self.opt.asr_wav}: {stream.shape}') + + return stream + + + def create_pyaudio_stream(self): + + import pyaudio + + print(f'[INFO] creating live audio stream ...') + + audio = pyaudio.PyAudio() + + # get devices + info = audio.get_host_api_info_by_index(0) + n_devices = info.get('deviceCount') + + for i in range(0, n_devices): + if (audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: + name = audio.get_device_info_by_host_api_device_index(0, i).get('name') + print(f'[INFO] choose audio device {name}, id {i}') + break + + # get stream + stream = audio.open(input_device_index=i, + format=pyaudio.paInt16, + channels=1, + rate=self.sample_rate, + input=True, + frames_per_buffer=self.chunk) + + return audio, stream + + + def get_audio_frame(self): + + if self.mode == 'file': + + if self.idx < self.file_stream.shape[0]: + frame = self.file_stream[self.idx: self.idx + self.chunk] + self.idx = self.idx + self.chunk + return frame + else: + return None + + else: + + frame = self.queue.get() + # print(f'[INFO] get frame {frame.shape}') + + self.idx = self.idx + self.chunk + + return frame + + + def frame_to_text(self, frame): + # frame: [N * 320], N = (context_size + 2 * stride_size) + + inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True) + + with torch.no_grad(): + result = self.model(inputs.input_values.to(self.device)) + logits = result.logits # [1, N - 1, 32] + + # cut off stride + left = max(0, self.stride_left_size) + right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input. + + # do not cut right if terminated. + if self.terminated: + right = logits.shape[1] + + logits = logits[:, left:right] + + # print(frame.shape, inputs.input_values.shape, logits.shape) + + predicted_ids = torch.argmax(logits, dim=-1) + transcription = self.processor.batch_decode(predicted_ids)[0].lower() + + + # for esperanto + # labels = np.array(['ŭ', '»', 'c', 'ĵ', 'ñ', '”', '„', '“', 'ǔ', 'o', 'ĝ', 'm', 'k', 'd', 'a', 'ŝ', 'z', 'i', '«', '—', '‘', 'ĥ', 'f', 'y', 'h', 'j', '|', 'r', 'u', 'ĉ', 's', '–', 'fi', 'l', 'p', '’', 'g', 'v', 't', 'b', 'n', 'e', '[UNK]', '[PAD]']) + + # labels = np.array([' ', ' ', ' ', '-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z']) + # print(''.join(labels[predicted_ids[0].detach().cpu().long().numpy()])) + # print(predicted_ids[0]) + # print(transcription) + + return logits[0], predicted_ids[0], transcription # [N,] + + + def run(self): + + self.listen() + + while not self.terminated: + self.run_step() + + def clear_queue(self): + # clear the queue, to reduce potential latency... + print(f'[INFO] clear queue') + if self.mode == 'live': + self.queue.queue.clear() + if self.play: + self.output_queue.queue.clear() + + def warm_up(self): + + self.listen() + + print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s') + t = time.time() + for _ in range(self.warm_up_steps): + self.run_step() + if torch.cuda.is_available(): + torch.cuda.synchronize() + t = time.time() - t + print(f'[INFO] warm-up done, actual latency = {t:.6f}s') + + self.clear_queue() + + + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--wav', type=str, default='data/raw/val_wavs/intro.wav') + parser.add_argument('--play', action='store_true', help="play out the audio") + parser.add_argument('--out_name', type=str, default='') + + parser.add_argument('--model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') + # parser.add_argument('--model', type=str, default='facebook/wav2vec2-large-960h-lv60-self') + + parser.add_argument('--save_feats', action='store_true') + # audio FPS + parser.add_argument('--fps', type=int, default=50) + # sliding window left-middle-right length. + parser.add_argument('-l', type=int, default=10) + parser.add_argument('-m', type=int, default=50) + parser.add_argument('-r', type=int, default=10) + + opt = parser.parse_args() + + # fix + opt.asr_wav = opt.wav + opt.asr_play = opt.play + opt.asr_model = opt.model + opt.asr_save_feats = opt.save_feats + + if 'deepspeech' in opt.asr_model: + raise ValueError("DeepSpeech features should not use this code to extract...") + + with ASR(opt) as asr: + asr.run() \ No newline at end of file diff --git a/Geneface_main/GeneFace/data_util/extract_mel.py b/Geneface_main/GeneFace/data_util/extract_mel.py new file mode 100644 index 00000000..1acebf40 --- /dev/null +++ b/Geneface_main/GeneFace/data_util/extract_mel.py @@ -0,0 +1,54 @@ +import librosa +import numpy as np + +def get_mel_from_fname(wav_path, + fft_size=512, + hop_size=320, + win_length=512, + window="hann", + num_mels=80, + fmin=80, + fmax=7600, + eps=1e-6, + sample_rate=16000, + min_level_db=-100, + return_energy=False): + if isinstance(wav_path, str): + wav, _ = librosa.core.load(wav_path, sr=sample_rate) + else: + wav = wav_path + + # get amplitude spectrogram + x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size, + win_length=win_length, window=window, center=False) + spc = np.abs(x_stft) # (n_bins, T) + + # get mel basis + fmin = 0 if fmin == -1 else fmin + fmax = sample_rate / 2 if fmax == -1 else fmax + mel_basis = librosa.filters.mel(sample_rate, fft_size, num_mels, fmin, fmax) + mel = mel_basis @ spc + + mel = np.log10(np.maximum(eps, mel)) # (n_mel_bins, T) + mel = mel.T + # f0 = get_pitch(wav, mel) + if return_energy: + audio_energy = librosa.feature.rms(y=wav, frame_length=fft_size, hop_length=hop_size, center=False) # 对每一frame计算root-mean-square + audio_energy = np.transpose(audio_energy) # [t,1] + return mel, audio_energy + return mel + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--wav_name', type=str, + default='data/processed/videos/FDDM/aud.wav', help='') + parser.add_argument('--mel_npy_name', type=str, + default='data/processed/videos/FDDM/mel.npy', help='') + args = parser.parse_args() + mel, energy = get_mel_from_fname(args.wav_name, return_energy=True) + out_dict = { + 'mel': mel, + 'energy': energy + } + np.save(args.mel_npy_name, out_dict) diff --git a/Geneface_main/GeneFace/data_util/face3d_helper.py b/Geneface_main/GeneFace/data_util/face3d_helper.py new file mode 100644 index 00000000..52b3b698 --- /dev/null +++ b/Geneface_main/GeneFace/data_util/face3d_helper.py @@ -0,0 +1,190 @@ +import os +import numpy as np +import torch +from scipy.io import loadmat + + +class Face3DHelper: + def __init__(self, bfm_dir='deep_3drecon/BFM', use_gpu=True): + self.bfm_dir = bfm_dir + self.device = 'cuda' if use_gpu else 'cpu' + self.load_3dmm() + + def load_3dmm(self): + model = loadmat(os.path.join(self.bfm_dir, "BFM_model_front.mat")) + self.mean_shape = torch.from_numpy(model['meanshape'].transpose()).float().to(self.device) # mean face shape. [3*N, 1], N=35709, xyz=3, ==> 3*N=107127 + self.id_base = torch.from_numpy(model['idBase']).float().to(self.device) # identity basis. [3*N,80], we have 80 eigen faces for identity + self.exp_base = torch.from_numpy(model['exBase']).float().to(self.device) # expression basis. [3*N,64], we have 64 eigen faces for expression + + self.mean_texure = torch.from_numpy(model['meantex'].transpose()).float().to(self.device) # mean face texture. [3*N,1] (0-255) + self.tex_base = torch.from_numpy(model['texBase']).float().to(self.device) # texture basis. [3*N,80], rgb=3 + + self.point_buf = torch.from_numpy(model['point_buf']).float().to(self.device) # triangle indices for each vertex that lies in. starts from 1. [N,8] (1-F) + self.face_buf = torch.from_numpy(model['tri']).float().to(self.device) # vertex indices in each triangle. starts from 1. [F,3] (1-N) + self.key_points = torch.from_numpy(model['keypoints'].squeeze().astype(np.long)).long().to(self.device) # vertex indices of 68 facial landmarks. starts from 1. [68,1] + + self.key_mean_shape = self.mean_shape.reshape([-1,3])[self.key_points,:].to(self.device) + self.key_id_base = self.id_base.reshape([-1,3,80])[self.key_points, :, :].reshape([-1,80]).to(self.device) + self.key_exp_base = self.exp_base.reshape([-1,3,64])[self.key_points, :, :].reshape([-1,64]).to(self.device) + + def split_coeff(self, coeff): + """ + coeff: Tensor[B, T, c=257] or [T, c=257] + """ + ret_dict = { + 'identity': coeff[..., :80], # identity, [b, t, c=80] + 'expression': coeff[..., 80:144], # expression, [b, t, c=80] + 'texture': coeff[..., 144:224], # texture, [b, t, c=80] + 'angles': coeff[..., 224:227], # euler angles for pose, [b, t, c=3] + 'translation': coeff[..., 254:257], # translation, [b, t, c=3] + 'gamma': coeff[..., 227:254] # lighting, [b, t, c=27] + } + return ret_dict + + def reconstruct_face_mesh(self, id_coeff, exp_coeff): + """ + Generate a pose-independent 3D face mesh! + id_coeff: Tensor[T, c=80] + exp_coeff: Tensor[T, c=64] + """ + id_coeff = id_coeff.to(self.device) + exp_coeff = exp_coeff.to(self.device) + mean_face = self.mean_shape.squeeze().reshape([1, -1]) # [3N, 1] ==> [1, 3N] + id_base, exp_base = self.id_base, self.exp_base # [3*N, C] + identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3N] ==> [t,3N] + expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3N] ==> [t,3N] + + face = mean_face + identity_diff_face + expression_diff_face # [t,3N] + face = face.reshape([face.shape[0], -1, 3]) # [t,N,3] + # re-centering the face with mean_xyz, so the face will be in [-1, 1] + mean_xyz = self.mean_shape.squeeze().reshape([-1,3]).mean(dim=0) # [1, 3] + face_mesh = face - mean_xyz.unsqueeze(0) # [t,N,3] + return face_mesh + + def reconstruct_lm3d(self, id_coeff, exp_coeff): + """ + Generate 3D landmark with keypoint base! + id_coeff: Tensor[T, c=80] + exp_coeff: Tensor[T, c=64] + """ + id_coeff = id_coeff.to(self.device) + exp_coeff = exp_coeff.to(self.device) + mean_face = self.key_mean_shape.squeeze().reshape([1, -1]) # [3*68, 1] ==> [1, 3*68] + id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C] + identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68] + expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68] + + face = mean_face + identity_diff_face + expression_diff_face # [t,3N] + face = face.reshape([face.shape[0], -1, 3]) # [t,N,3] + # re-centering the face with mean_xyz, so the face will be in [-1, 1] + mean_xyz = self.key_mean_shape.squeeze().reshape([-1,3]).mean(dim=0) # [1, 3] + lm3d = face - mean_xyz.unsqueeze(0) # [t,N,3] + return lm3d + + def reconstruct_idexp_lm3d(self, id_coeff, exp_coeff): + """ + Generate 3D landmark with keypoint base! + id_coeff: Tensor[T, c=80] + exp_coeff: Tensor[T, c=64] + """ + id_coeff = id_coeff.to(self.device) + exp_coeff = exp_coeff.to(self.device) + id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C] + identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68] + expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68] + + face = identity_diff_face + expression_diff_face # [t,3N] + face = face.reshape([face.shape[0], -1, 3]) # [t,N,3] + lm3d = face * 10 + return lm3d + + def get_eye_mouth_lm_from_lm3d(self, lm3d): + eye_lm = lm3d[:, 17:48] # [T, 31, 3] + mouth_lm = lm3d[:, 48:68] # [T, 20, 3] + return eye_lm, mouth_lm + + def get_eye_mouth_lm_from_lm3d_batch(self, lm3d): + eye_lm = lm3d[:, :, 17:48] # [T, 31, 3] + mouth_lm = lm3d[:, :, 48:68] # [T, 20, 3] + return eye_lm, mouth_lm + + def get_lm3d_from_identity_exp(self, identity, exp_arr): + """ + exp: [T, 64] + identity: [80] + """ + assert identity.ndim == 1 and exp_arr.ndim == 2 + T = exp_arr.shape[0] + identity = identity[None, :].repeat([T, 1]) + lm3d = self.reconstruct_lm3d(identity, exp_arr) + return lm3d + + def get_lm3d_from_coeff_seq(self, coeff_arr): + """ + coeff_arr: [T, 257] + """ + ret_dict = self.split_coeff(coeff_arr) + lm3d = self.reconstruct_lm3d(ret_dict['identity'], ret_dict['expression']) + return lm3d + def close_mouth_for_idexp_lm3d(self, idexp_lm3d, freeze_as_first_frame=True): + idexp_lm3d = idexp_lm3d.reshape([-1, 68,3]) + num_frames = idexp_lm3d.shape[0] + eps = 0.0 + # [n_landmarks=68,xyz=3], x 代表左右,y代表上下,z代表深度 + idexp_lm3d[:,49:54, 1] = (idexp_lm3d[:,49:54, 1] + idexp_lm3d[:,range(59,54,-1), 1])/2 + eps * 2 + idexp_lm3d[:,range(59,54,-1), 1] = (idexp_lm3d[:,49:54, 1] + idexp_lm3d[:,range(59,54,-1), 1])/2 - eps * 2 + + idexp_lm3d[:,61:64, 1] = (idexp_lm3d[:,61:64, 1] + idexp_lm3d[:,range(67,64,-1), 1])/2 + eps + idexp_lm3d[:,range(67,64,-1), 1] = (idexp_lm3d[:,61:64, 1] + idexp_lm3d[:,range(67,64,-1), 1])/2 - eps + + idexp_lm3d[:,49:54, 1] += (0.03 - idexp_lm3d[:,49:54, 1].mean(dim=1) + idexp_lm3d[:,61:64, 1].mean(dim=1)).unsqueeze(1).repeat([1,5]) + idexp_lm3d[:,range(59,54,-1), 1] += (-0.03 - idexp_lm3d[:,range(59,54,-1), 1].mean(dim=1) + idexp_lm3d[:,range(67,64,-1), 1].mean(dim=1)).unsqueeze(1).repeat([1,5]) + + if freeze_as_first_frame: + idexp_lm3d[:, 48:68,] = idexp_lm3d[0, 48:68].unsqueeze(0).clone().repeat([num_frames, 1,1])*0 + return idexp_lm3d.cpu() + + def close_eyes_for_idexp_lm3d(self, idexp_lm3d): + idexp_lm3d = idexp_lm3d.reshape([-1, 68,3]) + eps = 0.003 + idexp_lm3d[:,37:39, 1] = (idexp_lm3d[:,37:39, 1] + idexp_lm3d[:,range(41,39,-1), 1])/2 + eps + idexp_lm3d[:,range(41,39,-1), 1] = (idexp_lm3d[:,37:39, 1] + idexp_lm3d[:,range(41,39,-1), 1])/2 - eps + + idexp_lm3d[:,43:45, 1] = (idexp_lm3d[:,43:45, 1] + idexp_lm3d[:,range(47,45,-1), 1])/2 + eps + idexp_lm3d[:,range(47,45,-1), 1] = (idexp_lm3d[:,43:45, 1] + idexp_lm3d[:,range(47,45,-1), 1])/2 - eps + + return idexp_lm3d + +if __name__ == '__main__': + import cv2 + + font = cv2.FONT_HERSHEY_SIMPLEX + + face_mesh_helper = Face3DHelper('deep_3drecon/BFM') + coeff_npy = 'data/processed/videos/May/coeff.npy' + coeff_dict = np.load(coeff_npy, allow_pickle=True).tolist() + coeff = torch.from_numpy(coeff_dict['coeff']) # [-250:] + lm3d = face_mesh_helper.reconstruct_idexp_lm3d(coeff[:, :80], coeff[:, 80:144]) + + lm3d_mean = lm3d.mean(dim=0, keepdims=True) + lm3d_std = lm3d.std(dim=0, keepdims=True) + + WH = 512 + lm3d = (lm3d * WH/2 + WH/2).cpu().int().numpy() + eye_idx = list(range(36,48)) + mouth_idx = list(range(48,68)) + for i_img in range(len(lm3d)): + lm2d = lm3d[i_img ,:, :2] # [68, 2] + img = np.ones([WH, WH, 3], dtype=np.uint8) * 255 + for i in range(len(lm2d)): + x, y = lm2d[i] + if i in eye_idx: + color = (0,0,255) + elif i in mouth_idx: + color = (0,255,0) + else: + color = (255,0,0) + img = cv2.circle(img, center=(x,y), radius=3, color=color, thickness=-1) + img = cv2.putText(img, f"{i}", org=(x,y), fontFace=font, fontScale=0.3, color=(255,0,0)) + img = cv2.flip(img, 0) + cv2.imwrite(f'infer_out/tmp_imgs/{format(i_img, "05d")}.png', img) \ No newline at end of file diff --git a/Geneface_main/GeneFace/data_util/face_parsing/79999_iter.pth b/Geneface_main/GeneFace/data_util/face_parsing/79999_iter.pth new file mode 100644 index 00000000..a125015a Binary files /dev/null and b/Geneface_main/GeneFace/data_util/face_parsing/79999_iter.pth differ diff --git a/Geneface_main/GeneFace/data_util/face_parsing/__pycache__/logger.cpython-39.pyc b/Geneface_main/GeneFace/data_util/face_parsing/__pycache__/logger.cpython-39.pyc new file mode 100644 index 00000000..c75004d8 Binary files /dev/null and b/Geneface_main/GeneFace/data_util/face_parsing/__pycache__/logger.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/data_util/face_parsing/__pycache__/model.cpython-39.pyc b/Geneface_main/GeneFace/data_util/face_parsing/__pycache__/model.cpython-39.pyc new file mode 100644 index 00000000..d06077da Binary files /dev/null and b/Geneface_main/GeneFace/data_util/face_parsing/__pycache__/model.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/data_util/face_parsing/__pycache__/resnet.cpython-39.pyc b/Geneface_main/GeneFace/data_util/face_parsing/__pycache__/resnet.cpython-39.pyc new file mode 100644 index 00000000..d817f29c Binary files /dev/null and b/Geneface_main/GeneFace/data_util/face_parsing/__pycache__/resnet.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/data_util/face_parsing/logger.py b/Geneface_main/GeneFace/data_util/face_parsing/logger.py new file mode 100644 index 00000000..d3f9ddcc --- /dev/null +++ b/Geneface_main/GeneFace/data_util/face_parsing/logger.py @@ -0,0 +1,23 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + + +import os.path as osp +import time +import sys +import logging + +import torch.distributed as dist + + +def setup_logger(logpth): + logfile = 'BiSeNet-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S')) + logfile = osp.join(logpth, logfile) + FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s' + log_level = logging.INFO + if dist.is_initialized() and not dist.get_rank()==0: + log_level = logging.ERROR + logging.basicConfig(level=log_level, format=FORMAT, filename=logfile) + logging.root.addHandler(logging.StreamHandler()) + + diff --git a/Geneface_main/GeneFace/data_util/face_parsing/model.py b/Geneface_main/GeneFace/data_util/face_parsing/model.py new file mode 100644 index 00000000..040f41ff --- /dev/null +++ b/Geneface_main/GeneFace/data_util/face_parsing/model.py @@ -0,0 +1,283 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +from resnet import Resnet18 +# from modules.bn import InPlaceABNSync as BatchNorm2d + + +class ConvBNReLU(nn.Module): + def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): + super(ConvBNReLU, self).__init__() + self.conv = nn.Conv2d(in_chan, + out_chan, + kernel_size = ks, + stride = stride, + padding = padding, + bias = False) + self.bn = nn.BatchNorm2d(out_chan) + self.init_weight() + + def forward(self, x): + x = self.conv(x) + x = F.relu(self.bn(x)) + return x + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + +class BiSeNetOutput(nn.Module): + def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): + super(BiSeNetOutput, self).__init__() + self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) + self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) + self.init_weight() + + def forward(self, x): + x = self.conv(x) + x = self.conv_out(x) + return x + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class AttentionRefinementModule(nn.Module): + def __init__(self, in_chan, out_chan, *args, **kwargs): + super(AttentionRefinementModule, self).__init__() + self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) + self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False) + self.bn_atten = nn.BatchNorm2d(out_chan) + self.sigmoid_atten = nn.Sigmoid() + self.init_weight() + + def forward(self, x): + feat = self.conv(x) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv_atten(atten) + atten = self.bn_atten(atten) + atten = self.sigmoid_atten(atten) + out = torch.mul(feat, atten) + return out + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + +class ContextPath(nn.Module): + def __init__(self, *args, **kwargs): + super(ContextPath, self).__init__() + self.resnet = Resnet18() + self.arm16 = AttentionRefinementModule(256, 128) + self.arm32 = AttentionRefinementModule(512, 128) + self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) + + self.init_weight() + + def forward(self, x): + H0, W0 = x.size()[2:] + feat8, feat16, feat32 = self.resnet(x) + H8, W8 = feat8.size()[2:] + H16, W16 = feat16.size()[2:] + H32, W32 = feat32.size()[2:] + + avg = F.avg_pool2d(feat32, feat32.size()[2:]) + avg = self.conv_avg(avg) + avg_up = F.interpolate(avg, (H32, W32), mode='nearest') + + feat32_arm = self.arm32(feat32) + feat32_sum = feat32_arm + avg_up + feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') + feat32_up = self.conv_head32(feat32_up) + + feat16_arm = self.arm16(feat16) + feat16_sum = feat16_arm + feat32_up + feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') + feat16_up = self.conv_head16(feat16_up) + + return feat8, feat16_up, feat32_up # x8, x8, x16 + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, nn.Conv2d)): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +### This is not used, since I replace this with the resnet feature with the same size +class SpatialPath(nn.Module): + def __init__(self, *args, **kwargs): + super(SpatialPath, self).__init__() + self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) + self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) + self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) + self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) + self.init_weight() + + def forward(self, x): + feat = self.conv1(x) + feat = self.conv2(feat) + feat = self.conv3(feat) + feat = self.conv_out(feat) + return feat + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class FeatureFusionModule(nn.Module): + def __init__(self, in_chan, out_chan, *args, **kwargs): + super(FeatureFusionModule, self).__init__() + self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) + self.conv1 = nn.Conv2d(out_chan, + out_chan//4, + kernel_size = 1, + stride = 1, + padding = 0, + bias = False) + self.conv2 = nn.Conv2d(out_chan//4, + out_chan, + kernel_size = 1, + stride = 1, + padding = 0, + bias = False) + self.relu = nn.ReLU(inplace=True) + self.sigmoid = nn.Sigmoid() + self.init_weight() + + def forward(self, fsp, fcp): + fcat = torch.cat([fsp, fcp], dim=1) + feat = self.convblk(fcat) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv1(atten) + atten = self.relu(atten) + atten = self.conv2(atten) + atten = self.sigmoid(atten) + feat_atten = torch.mul(feat, atten) + feat_out = feat_atten + feat + return feat_out + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class BiSeNet(nn.Module): + def __init__(self, n_classes, *args, **kwargs): + super(BiSeNet, self).__init__() + self.cp = ContextPath() + ## here self.sp is deleted + self.ffm = FeatureFusionModule(256, 256) + self.conv_out = BiSeNetOutput(256, 256, n_classes) + self.conv_out16 = BiSeNetOutput(128, 64, n_classes) + self.conv_out32 = BiSeNetOutput(128, 64, n_classes) + self.init_weight() + + def forward(self, x): + H, W = x.size()[2:] + feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature + feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature + feat_fuse = self.ffm(feat_sp, feat_cp8) + + feat_out = self.conv_out(feat_fuse) + feat_out16 = self.conv_out16(feat_cp8) + feat_out32 = self.conv_out32(feat_cp16) + + feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) + feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) + feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) + return feat_out, feat_out16, feat_out32 + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] + for name, child in self.named_children(): + child_wd_params, child_nowd_params = child.get_params() + if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): + lr_mul_wd_params += child_wd_params + lr_mul_nowd_params += child_nowd_params + else: + wd_params += child_wd_params + nowd_params += child_nowd_params + return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params + + +if __name__ == "__main__": + net = BiSeNet(19) + net.cuda() + net.eval() + in_ten = torch.randn(16, 3, 640, 480).cuda() + out, out16, out32 = net(in_ten) + print(out.shape) + + net.get_params() diff --git a/Geneface_main/GeneFace/data_util/face_parsing/resnet.py b/Geneface_main/GeneFace/data_util/face_parsing/resnet.py new file mode 100644 index 00000000..aa2bf951 --- /dev/null +++ b/Geneface_main/GeneFace/data_util/face_parsing/resnet.py @@ -0,0 +1,109 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as modelzoo + +# from modules.bn import InPlaceABNSync as BatchNorm2d + +resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + def __init__(self, in_chan, out_chan, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(in_chan, out_chan, stride) + self.bn1 = nn.BatchNorm2d(out_chan) + self.conv2 = conv3x3(out_chan, out_chan) + self.bn2 = nn.BatchNorm2d(out_chan) + self.relu = nn.ReLU(inplace=True) + self.downsample = None + if in_chan != out_chan or stride != 1: + self.downsample = nn.Sequential( + nn.Conv2d(in_chan, out_chan, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(out_chan), + ) + + def forward(self, x): + residual = self.conv1(x) + residual = F.relu(self.bn1(residual)) + residual = self.conv2(residual) + residual = self.bn2(residual) + + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x) + + out = shortcut + residual + out = self.relu(out) + return out + + +def create_layer_basic(in_chan, out_chan, bnum, stride=1): + layers = [BasicBlock(in_chan, out_chan, stride=stride)] + for i in range(bnum-1): + layers.append(BasicBlock(out_chan, out_chan, stride=1)) + return nn.Sequential(*layers) + + +class Resnet18(nn.Module): + def __init__(self): + super(Resnet18, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) + self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) + self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) + self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) + self.init_weight() + + def forward(self, x): + x = self.conv1(x) + x = F.relu(self.bn1(x)) + x = self.maxpool(x) + + x = self.layer1(x) + feat8 = self.layer2(x) # 1/8 + feat16 = self.layer3(feat8) # 1/16 + feat32 = self.layer4(feat16) # 1/32 + return feat8, feat16, feat32 + + def init_weight(self): + state_dict = modelzoo.load_url(resnet18_url) + self_state_dict = self.state_dict() + for k, v in state_dict.items(): + if 'fc' in k: continue + self_state_dict.update({k: v}) + self.load_state_dict(self_state_dict) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, nn.Conv2d)): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +if __name__ == "__main__": + net = Resnet18() + x = torch.randn(16, 3, 224, 224) + out = net(x) + print(out[0].size()) + print(out[1].size()) + print(out[2].size()) + net.get_params() diff --git a/Geneface_main/GeneFace/data_util/face_parsing/test.py b/Geneface_main/GeneFace/data_util/face_parsing/test.py new file mode 100644 index 00000000..fa191b73 --- /dev/null +++ b/Geneface_main/GeneFace/data_util/face_parsing/test.py @@ -0,0 +1,99 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- +import numpy as np +from logger import setup_logger +from model import BiSeNet + +import torch + +import os +import os.path as osp + +from PIL import Image +import torchvision.transforms as transforms +import cv2 +from pathlib import Path +import configargparse + + +def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg', + img_size=(512, 512)): + im = np.array(im) + vis_im = im.copy().astype(np.uint8) + vis_parsing_anno = parsing_anno.copy().astype(np.uint8) + vis_parsing_anno = cv2.resize( + vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) + vis_parsing_anno_color = np.zeros( + (vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + np.array([255, 255, 255]) # + 255 + + num_of_class = np.max(vis_parsing_anno) + # print(num_of_class) + for pi in range(1, 14): + index = np.where(vis_parsing_anno == pi) + vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0]) + + for pi in range(14, 16): + index = np.where(vis_parsing_anno == pi) + vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 255, 0]) + for pi in range(16, 17): + index = np.where(vis_parsing_anno == pi) + vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 0, 255]) + for pi in range(17, num_of_class+1): + index = np.where(vis_parsing_anno == pi) + vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0]) + + vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) + index = np.where(vis_parsing_anno == num_of_class-1) + vis_im = cv2.resize(vis_parsing_anno_color, img_size, + interpolation=cv2.INTER_NEAREST) + if save_im: + cv2.imwrite(save_path, vis_im) + + +def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'): + + Path(respth).mkdir(parents=True, exist_ok=True) + + n_classes = 19 + net = BiSeNet(n_classes=n_classes) + net.cuda() + net.load_state_dict(torch.load(cp)) + net.eval() + + to_tensor = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ]) + processed_num = 0 + with torch.no_grad(): + for image_path in os.listdir(dspth): + if image_path.endswith('.jpg') or image_path.endswith('.png'): + img = Image.open(osp.join(dspth, image_path)) + ori_size = img.size + image = img.resize((512, 512), Image.BILINEAR) + image = image.convert("RGB") + img = to_tensor(image) + img = torch.unsqueeze(img, 0) + img = img.cuda() + out = net(img)[0] + parsing = out.squeeze(0).cpu().numpy().argmax(0) + image_path = int(image_path[:-4]) + image_path = str(image_path) + '.png' + + vis_parsing_maps(image, parsing, stride=1, save_im=True, + save_path=osp.join(respth, image_path), img_size=ori_size) + processed_num = processed_num + 1 + if processed_num % 100 == 0: + print('processed parsing', processed_num) + + +if __name__ == "__main__": + parser = configargparse.ArgumentParser() + parser.add_argument('--respath', type=str, + default='./result/', help='result path for label') + parser.add_argument('--imgpath', type=str, + default='./imgs/', help='path for input images') + parser.add_argument('--modelpath', type=str, + default='data_util/face_parsing/79999_iter.pth') + args = parser.parse_args() + evaluate(respth=args.respath, dspth=args.imgpath, cp=args.modelpath) diff --git a/Geneface_main/GeneFace/data_util/face_tracking/__init__.py b/Geneface_main/GeneFace/data_util/face_tracking/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Geneface_main/GeneFace/data_util/face_tracking/__pycache__/data_loader.cpython-39.pyc b/Geneface_main/GeneFace/data_util/face_tracking/__pycache__/data_loader.cpython-39.pyc new file mode 100644 index 00000000..735686b1 Binary files /dev/null and b/Geneface_main/GeneFace/data_util/face_tracking/__pycache__/data_loader.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/data_util/face_tracking/__pycache__/facemodel.cpython-39.pyc b/Geneface_main/GeneFace/data_util/face_tracking/__pycache__/facemodel.cpython-39.pyc new file mode 100644 index 00000000..de153722 Binary files /dev/null and b/Geneface_main/GeneFace/data_util/face_tracking/__pycache__/facemodel.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/data_util/face_tracking/__pycache__/render_3dmm.cpython-39.pyc b/Geneface_main/GeneFace/data_util/face_tracking/__pycache__/render_3dmm.cpython-39.pyc new file mode 100644 index 00000000..3358ac86 Binary files /dev/null and b/Geneface_main/GeneFace/data_util/face_tracking/__pycache__/render_3dmm.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/data_util/face_tracking/__pycache__/util.cpython-39.pyc b/Geneface_main/GeneFace/data_util/face_tracking/__pycache__/util.cpython-39.pyc new file mode 100644 index 00000000..808e54f4 Binary files /dev/null and b/Geneface_main/GeneFace/data_util/face_tracking/__pycache__/util.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/data_util/face_tracking/convert_BFM.py b/Geneface_main/GeneFace/data_util/face_tracking/convert_BFM.py new file mode 100644 index 00000000..c21021bc --- /dev/null +++ b/Geneface_main/GeneFace/data_util/face_tracking/convert_BFM.py @@ -0,0 +1,30 @@ +import numpy as np +from scipy.io import loadmat + +# original_BFM = loadmat('3DMM/01_MorphableModel.mat') +original_BFM = loadmat('../../deep_3drecon/BFM/01_MorphableModel.mat') +sub_inds = np.load('3DMM/topology_info.npy', + allow_pickle=True).item()['sub_inds'] +shapePC = original_BFM['shapePC'] +shapeEV = original_BFM['shapeEV'] +shapeMU = original_BFM['shapeMU'] +texPC = original_BFM['texPC'] +texEV = original_BFM['texEV'] +texMU = original_BFM['texMU'] + +b_shape = shapePC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3) +mu_shape = shapeMU.reshape(-1, 3) + +b_tex = texPC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3) +mu_tex = texMU.reshape(-1, 3) + +b_shape = b_shape[:, sub_inds, :].reshape(199, -1) +mu_shape = mu_shape[sub_inds, :].reshape(-1) +b_tex = b_tex[:, sub_inds, :].reshape(199, -1) +mu_tex = mu_tex[sub_inds, :].reshape(-1) + +exp_info = np.load('3DMM/exp_info.npy', allow_pickle=True).item() +np.save('3DMM/3DMM_info.npy', {'mu_shape': mu_shape, 'b_shape': b_shape, 'sig_shape': shapeEV.reshape(-1), + 'mu_exp': exp_info['mu_exp'], 'b_exp': exp_info['base_exp'], + 'sig_exp': exp_info['sig_exp'], 'mu_tex': mu_tex, + 'b_tex': b_tex, 'sig_tex': texEV.reshape(-1)}) diff --git a/Geneface_main/GeneFace/data_util/face_tracking/data_loader.py b/Geneface_main/GeneFace/data_util/face_tracking/data_loader.py new file mode 100644 index 00000000..d050890c --- /dev/null +++ b/Geneface_main/GeneFace/data_util/face_tracking/data_loader.py @@ -0,0 +1,18 @@ +import torch +import cv2 +import numpy as np +import os + + +def load_dir(path, start, end): + lmss = [] + imgs_paths = [] + for i in range(start, end): + if os.path.isfile(os.path.join(path, str(i) + '.lms')): + lms = np.loadtxt(os.path.join( + path, str(i) + '.lms'), dtype=np.float32) + lmss.append(lms) + imgs_paths.append(os.path.join(path, str(i) + '.jpg')) + lmss = np.stack(lmss) + lmss = torch.as_tensor(lmss).cuda() + return lmss, imgs_paths diff --git a/Geneface_main/GeneFace/data_util/face_tracking/face_tracker.py b/Geneface_main/GeneFace/data_util/face_tracking/face_tracker.py new file mode 100644 index 00000000..f6b96bae --- /dev/null +++ b/Geneface_main/GeneFace/data_util/face_tracking/face_tracker.py @@ -0,0 +1,393 @@ +from numpy.core.numeric import require +from numpy.lib.function_base import quantile +import torch +import numpy as np +from facemodel import Face_3DMM +from data_loader import load_dir +from util import * +from render_3dmm import Render_3DMM +import os +import sys +# import openmesh +import cv2 +import argparse +from pathlib import Path + +# def np2mesh(mesh, xnp, path): +# mesh.points()[:] = xnp +# openmesh.write_mesh(path, mesh, binary=True) + + +dir_path = os.path.dirname(os.path.realpath(__file__)) + + +def set_requires_grad(tensor_list): + for tensor in tensor_list: + tensor.requires_grad = True + + +parser = argparse.ArgumentParser() +parser.add_argument('--idname', type=str, default='obama', + help='idname of target person') +parser.add_argument('--id_dir', type=str, default=None) +parser.add_argument('--img_h', type=int, default=512, help='image height') +parser.add_argument('--img_w', type=int, default=512, help='image width') +parser.add_argument('--frame_num', type=int, + default=11000, help='image number') +args = parser.parse_args() +start_id = 0 +end_id = args.frame_num + +id_dir = os.path.join('data/processed/videos', args.idname) +lms, img_paths = load_dir(os.path.join(id_dir, 'ori_imgs'), start_id, end_id) +num_frames = lms.shape[0] +h, w = args.img_h, args.img_w +cxy = torch.tensor((w/2.0, h/2.0), dtype=torch.float).cuda() +id_dim, exp_dim, tex_dim, point_num = 100, 79, 100, 34650 +model_3dmm = Face_3DMM(os.path.join(dir_path, '3DMM'), + id_dim, exp_dim, tex_dim, point_num) + + +# mesh = openmesh.read_trimesh(os.path.join(dir_path, '3DMM', 'sub_mesh.obj')) + +sel_ids = np.arange(0, num_frames, 40) +sel_num = sel_ids.shape[0] +arg_focal = 1600 +arg_landis = 1e5 + +best_loss_focal = 99999 +for focal in range(600, 1700, 100): + id_para = lms.new_zeros((1, id_dim), requires_grad=True) + exp_para = lms.new_zeros((sel_num, exp_dim), requires_grad=True) + euler_angle = lms.new_zeros((sel_num, 3), requires_grad=True) + trans = lms.new_zeros((sel_num, 3), requires_grad=True) + trans.data[:, 2] -= 7 + focal_length = lms.new_zeros(1, requires_grad=False) + focal_length.data += focal + set_requires_grad([id_para, exp_para, euler_angle, trans]) + + optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=.1) + optimizer_frame = torch.optim.Adam( + [euler_angle, trans], lr=.1) + + best_loss_i = 99999 + for iter in range(2000): + id_para_batch = id_para.expand(sel_num, -1) + geometry = model_3dmm.get_3dlandmarks( + id_para_batch, exp_para, euler_angle, trans, focal_length, cxy) + proj_geo = forward_transform( + geometry, euler_angle, trans, focal_length, cxy) + loss_lan = cal_lan_loss( + proj_geo[:, :, :2], lms[sel_ids].detach()) + loss = loss_lan + optimizer_frame.zero_grad() + loss.backward() + optimizer_frame.step() + if iter % 100 == 0 and False: + print(focal, 'pose', iter, loss.item()) + + for iter in range(2500): + id_para_batch = id_para.expand(sel_num, -1) + geometry = model_3dmm.get_3dlandmarks( + id_para_batch, exp_para, euler_angle, trans, focal_length, cxy) + proj_geo = forward_transform( + geometry, euler_angle, trans, focal_length, cxy) + loss_lan = cal_lan_loss( + proj_geo[:, :, :2], lms[sel_ids].detach()) + loss_regid = torch.mean(id_para*id_para) + loss_regexp = torch.mean(exp_para*exp_para) + loss = loss_lan + loss_regid*0.5 + loss_regexp*0.4 + if loss_lan.item() < best_loss_i: + best_loss_i = loss_lan.item() + optimizer_idexp.zero_grad() + optimizer_frame.zero_grad() + loss.backward() + optimizer_idexp.step() + optimizer_frame.step() + if iter % 100 == 0 and False: + print(focal, 'poseidexp', iter, loss_lan.item(), + loss_regid.item(), loss_regexp.item()) + if iter % 1500 == 0 and iter >= 1500: + for param_group in optimizer_idexp.param_groups: + param_group['lr'] *= 0.2 + for param_group in optimizer_frame.param_groups: + param_group['lr'] *= 0.2 + print(focal, "loss_lan=",best_loss_i, "mean_xy_trans=",torch.mean(trans[:, 2]).item()) + with open('log.txt', 'a') as f: + f.write("\n"+str(focal)+ ",loss_lan=_"+str(best_loss_i)+ ",mean_xy_trans="+str(torch.mean(trans[:, 2]).item())) + + if best_loss_i < best_loss_focal: + best_loss_focal = best_loss_i + arg_focal = focal + # if loss_lan.item() < arg_landis: + # arg_landis = loss_lan.item() + # arg_focal = focal + +print('find best focal', arg_focal) + +id_para = lms.new_zeros((1, id_dim), requires_grad=True) +exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True) +tex_para = lms.new_zeros((1, tex_dim), requires_grad=True) +euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True) +trans = lms.new_zeros((num_frames, 3), requires_grad=True) +light_para = lms.new_zeros((num_frames, 27), requires_grad=True) +trans.data[:, 2] -= 7 +focal_length = lms.new_zeros(1, requires_grad=True) +focal_length.data += arg_focal + +set_requires_grad([id_para, exp_para, tex_para, + euler_angle, trans, light_para]) + +optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=.1) +optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=1) + +# 其他参数初始化,先训练euler和trans +for iter in range(1500): + id_para_batch = id_para.expand(num_frames, -1) + geometry = model_3dmm.get_3dlandmarks( + id_para_batch, exp_para, euler_angle, trans, focal_length, cxy) + proj_geo = forward_transform( + geometry, euler_angle, trans, focal_length, cxy) + loss_lan = cal_lan_loss( + proj_geo[:, :, :2], lms.detach()) + loss = loss_lan + optimizer_frame.zero_grad() + loss.backward() + optimizer_frame.step() + if iter == 1000: + for param_group in optimizer_frame.param_groups: + param_group['lr'] = 0.1 + if iter % 100 == 0 and False: + print('pose', iter, loss.item()) + +for param_group in optimizer_frame.param_groups: + param_group['lr'] = 0.1 + +# 同时训练id、exp和euler、trans +best_loss = 9999 +best_id_params=None +best_exp_params=None +best_euler=None +best_trans=None + +for iter in range(2000): + id_para_batch = id_para.expand(num_frames, -1) + geometry = model_3dmm.get_3dlandmarks( + id_para_batch, exp_para, euler_angle, trans, focal_length, cxy) + proj_geo = forward_transform( + geometry, euler_angle, trans, focal_length, cxy) + loss_lan = cal_lan_loss( + proj_geo[:, :, :2], lms.detach()) + loss_regid = torch.mean(id_para*id_para) # 正则化 + loss_regexp = torch.mean(exp_para*exp_para) + loss = loss_lan + loss_regid*0.5 + loss_regexp*0.4 + optimizer_idexp.zero_grad() + optimizer_frame.zero_grad() + if loss_lan.item() < best_loss: + best_loss = loss_lan.item() + best_id_params = id_para.clone() + best_exp_params = exp_para.clone() + best_euler = euler_angle.clone() + best_trans = trans.clone() + loss.backward() + optimizer_idexp.step() + optimizer_frame.step() + if iter % 100 == 0 and False: + print('poseidexp', iter, loss_lan.item(), + loss_regid.item(), loss_regexp.item()) + if iter % 1000 == 0 and iter >= 1000: + for param_group in optimizer_idexp.param_groups: + param_group['lr'] *= 0.2 + for param_group in optimizer_frame.param_groups: + param_group['lr'] *= 0.2 +print("trained on focal=",arg_focal, "best_loss_lan=",best_loss, "mean_xy_trans=",torch.mean(trans[:, 2]).item()) +with open('log.txt', 'a') as f: + f.write("\ntrained on focal="+str(arg_focal)+ "best_loss_lan="+str(best_loss)+"mean_xy_trans="+str(torch.mean(trans[:, 2]).item())) + +id_para = lms.new_zeros((1, id_dim), requires_grad=True) +id_para.data = best_id_params.data.clone() +exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True) +exp_para.data = best_exp_params.data.clone() +tex_para = lms.new_zeros((1, tex_dim), requires_grad=True) +euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True) +euler_angle.data = best_euler.data.clone() +trans = lms.new_zeros((num_frames, 3), requires_grad=True) +trans.data = best_trans.data.clone() +light_para = lms.new_zeros((num_frames, 27), requires_grad=True) + + +batch_size = 50 + + +device_default = torch.device('cuda:0') +device_render = torch.device('cuda:0') +renderer = Render_3DMM(arg_focal, h, w, batch_size, device_render) + +sel_ids = np.arange(0, num_frames, int(num_frames/batch_size))[:batch_size] +imgs = [] +for sel_id in sel_ids: + imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1]) +imgs = np.stack(imgs) +sel_imgs = torch.as_tensor(imgs).cuda() +sel_lms = lms[sel_ids] +sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True) +set_requires_grad([sel_light]) +optimizer_tl = torch.optim.Adam([tex_para, sel_light], lr=.1) +optimizer_id_frame = torch.optim.Adam( + [euler_angle, trans, exp_para, id_para], lr=.01) + +for iter in range(71): + sel_exp_para, sel_euler, sel_trans = exp_para[sel_ids], euler_angle[sel_ids], trans[sel_ids] + sel_id_para = id_para.expand(batch_size, -1) + geometry = model_3dmm.get_3dlandmarks( + sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy) + proj_geo = forward_transform( + geometry, sel_euler, sel_trans, focal_length, cxy) + loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach()) + loss_regid = torch.mean(id_para*id_para) + loss_regexp = torch.mean(sel_exp_para*sel_exp_para) + + sel_tex_para = tex_para.expand(batch_size, -1) + sel_texture = model_3dmm.forward_tex(sel_tex_para) + geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para) + rott_geo = forward_rott(geometry, sel_euler, sel_trans) + render_imgs = renderer(rott_geo.to(device_render), + sel_texture.to(device_render), + sel_light.to(device_render)) + render_imgs = render_imgs.to(device_default) + + mask = (render_imgs[:, :, :, 3]).detach() > 0.0 + render_proj = sel_imgs.clone() + render_proj[mask] = render_imgs[mask][..., :3].byte() + loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask) + loss = loss_col + loss_lan*3 + loss_regid*2.0 + loss_regexp*1.0 + if iter > 50: + loss = loss_col + loss_lan*0.05 + loss_regid*1.0 + loss_regexp*0.8 + optimizer_tl.zero_grad() + optimizer_id_frame.zero_grad() + loss.backward() + optimizer_tl.step() + optimizer_id_frame.step() + if iter % 50 == 0 and iter >= 5: + for param_group in optimizer_id_frame.param_groups: + param_group['lr'] *= 0.2 + for param_group in optimizer_tl.param_groups: + param_group['lr'] *= 0.2 + #print(iter, loss_col.item(), loss_lan.item(), loss_regid.item(), loss_regexp.item()) + +# np2mesh(mesh, geometry[0, ...].detach().cpu().numpy( +# ), os.path.join(id_dir, 'debug', 'id.ply')) + +light_mean = torch.mean(sel_light, 0).unsqueeze(0).repeat(num_frames, 1) +light_para.data = light_mean + +exp_para = exp_para.detach() +euler_angle = euler_angle.detach() +trans = trans.detach() +light_para = light_para.detach() + +for i in range(int((num_frames-1)/batch_size+1)): + if (i+1)*batch_size > num_frames: + start_n = num_frames-batch_size + sel_ids = np.arange(num_frames-batch_size, num_frames) + else: + start_n = i*batch_size + sel_ids = np.arange(i*batch_size, i*batch_size+batch_size) + imgs = [] + for sel_id in sel_ids: + imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1]) + imgs = np.stack(imgs) + sel_imgs = torch.as_tensor(imgs).cuda() + sel_lms = lms[sel_ids] + + sel_exp_para = exp_para.new_zeros( + (batch_size, exp_dim), requires_grad=True) + sel_exp_para.data = exp_para[sel_ids].clone() + sel_euler = euler_angle.new_zeros( + (batch_size, 3), requires_grad=True) + sel_euler.data = euler_angle[sel_ids].clone() + sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True) + sel_trans.data = trans[sel_ids].clone() + sel_light = light_para.new_zeros( + (batch_size, 27), requires_grad=True) + sel_light.data = light_para[sel_ids].clone() + + set_requires_grad([sel_exp_para, sel_euler, sel_trans, sel_light]) + + optimizer_cur_batch = torch.optim.Adam( + [sel_exp_para, sel_euler, sel_trans, sel_light], lr=0.005) + + sel_id_para = id_para.expand(batch_size, -1).detach() + sel_tex_para = tex_para.expand(batch_size, -1).detach() + + pre_num = 5 + if i > 0: + pre_ids = np.arange( + start_n-pre_num, start_n) + + for iter in range(50): + geometry = model_3dmm.get_3dlandmarks( + sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy) + proj_geo = forward_transform( + geometry, sel_euler, sel_trans, focal_length, cxy) + loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach()) + loss_regexp = torch.mean(sel_exp_para*sel_exp_para) + + sel_geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para) + sel_texture = model_3dmm.forward_tex(sel_tex_para) + geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para) + rott_geo = forward_rott(geometry, sel_euler, sel_trans) + render_imgs = renderer(rott_geo.to(device_render), + sel_texture.to(device_render), + sel_light.to(device_render)) + render_imgs = render_imgs.to(device_default) + + mask = (render_imgs[:, :, :, 3]).detach() > 0.0 + + loss_col = cal_col_loss( + render_imgs[:, :, :, :3], sel_imgs.float(), mask) + + if i > 0: + geometry_lap = model_3dmm.forward_geo_sub(id_para.expand( + batch_size+pre_num, -1).detach(), torch.cat((exp_para[pre_ids].detach(), sel_exp_para)), model_3dmm.rigid_ids) + rott_geo_lap = forward_rott(geometry_lap, torch.cat( + (euler_angle[pre_ids].detach(), sel_euler)), torch.cat((trans[pre_ids].detach(), sel_trans))) + + loss_lap = cal_lap_loss([rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], + [1.0]) + else: + geometry_lap = model_3dmm.forward_geo_sub( + id_para.expand(batch_size, -1).detach(), sel_exp_para, model_3dmm.rigid_ids) + rott_geo_lap = forward_rott(geometry_lap, sel_euler, sel_trans) + loss_lap = cal_lap_loss([rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], + [1.0]) + + loss = loss_col*0.5 + loss_lan*8 + loss_lap*100000 + loss_regexp*1.0 + if iter > 30: + loss = loss_col*0.5 + loss_lan*1.5 + loss_lap*100000 + loss_regexp*1.0 + optimizer_cur_batch.zero_grad() + # print(f"i={i},iter={iter},loss_col={loss_col},loss_lan={loss_lan},loss_lap={loss_lap},loss_regexp={loss_regexp}") + loss.backward() + optimizer_cur_batch.step() + # print(i, iter, loss_col.item(), loss_lan.item(), loss_lap.item(), loss_regexp.item()) + with open('log.txt', 'a') as f: + f.write(f"\ni={i},iter={iter},loss_col={loss_col},loss_lan={loss_lan},loss_lap={loss_lap},loss_regexp={loss_regexp}") + print(str(i) + ' of ' + str(int((num_frames-1)/batch_size+1)) + ' done') + render_proj = sel_imgs.clone() + render_proj[mask] = render_imgs[mask][..., :3].byte() + debug_render_dir = os.path.join(id_dir, 'debug', 'debug_render') + Path(debug_render_dir).mkdir(parents=True, exist_ok=True) + for j in range(sel_ids.shape[0]): + img_arr = render_proj[j, :, :, :3].byte().detach().cpu().numpy()[ + :, :, ::-1] + cv2.imwrite(os.path.join(debug_render_dir, str(sel_ids[j]) + '.jpg'), + img_arr) + exp_para[sel_ids] = sel_exp_para.clone() + euler_angle[sel_ids] = sel_euler.clone() + trans[sel_ids] = sel_trans.clone() + light_para[sel_ids] = sel_light.clone() + +torch.save({'id': id_para.detach().cpu(), 'exp': exp_para.detach().cpu(), + 'euler': euler_angle.detach().cpu(), 'trans': trans.detach().cpu(), + 'focal': focal_length.detach().cpu()}, os.path.join(id_dir, 'track_params.pt')) +print('params saved') diff --git a/Geneface_main/GeneFace/data_util/face_tracking/facemodel.py b/Geneface_main/GeneFace/data_util/face_tracking/facemodel.py new file mode 100644 index 00000000..0dfc3d86 --- /dev/null +++ b/Geneface_main/GeneFace/data_util/face_tracking/facemodel.py @@ -0,0 +1,113 @@ +import torch +import torch.nn as nn +import numpy as np +import os +from util import * + + +class Face_3DMM(nn.Module): + def __init__(self, modelpath, id_dim, exp_dim, tex_dim, point_num): + super(Face_3DMM, self).__init__() + # id_dim = 100 + # exp_dim = 79 + # tex_dim = 100 + self.point_num = point_num + DMM_info = np.load(os.path.join( + modelpath, '3DMM_info.npy'), allow_pickle=True).item() + base_id = DMM_info['b_shape'][:id_dim, :] + mu_id = DMM_info['mu_shape'] + base_exp = DMM_info['b_exp'][:exp_dim, :] + mu_exp = DMM_info['mu_exp'] + mu = mu_id + mu_exp + mu = mu.reshape(-1, 3) + for i in range(3): + mu[:, i] -= np.mean(mu[:, i]) + mu = mu.reshape(-1) + self.base_id = torch.as_tensor(base_id).cuda()/100000.0 + self.base_exp = torch.as_tensor(base_exp).cuda()/100000.0 + self.mu = torch.as_tensor(mu).cuda()/100000.0 + base_tex = DMM_info['b_tex'][:tex_dim, :] + mu_tex = DMM_info['mu_tex'] + self.base_tex = torch.as_tensor(base_tex).cuda() + self.mu_tex = torch.as_tensor(mu_tex).cuda() + sig_id = DMM_info['sig_shape'][:id_dim] + sig_tex = DMM_info['sig_tex'][:tex_dim] + sig_exp = DMM_info['sig_exp'][:exp_dim] + self.sig_id = torch.as_tensor(sig_id).cuda() + self.sig_tex = torch.as_tensor(sig_tex).cuda() + self.sig_exp = torch.as_tensor(sig_exp).cuda() + + keys_info = np.load(os.path.join( + modelpath, 'keys_info.npy'), allow_pickle=True).item() + self.keyinds = torch.as_tensor(keys_info['keyinds']).cuda() + self.left_contours = torch.as_tensor(keys_info['left_contour']).cuda() + self.right_contours = torch.as_tensor( + keys_info['right_contour']).cuda() + self.rigid_ids = torch.as_tensor(keys_info['rigid_ids']).cuda() + + def get_3dlandmarks(self, id_para, exp_para, euler_angle, trans, focal_length, cxy): + id_para = id_para*self.sig_id + exp_para = exp_para*self.sig_exp + batch_size = id_para.shape[0] + num_per_contour = self.left_contours.shape[1] + left_contours_flat = self.left_contours.reshape(-1) + right_contours_flat = self.right_contours.reshape(-1) + sel_index = torch.cat((3*left_contours_flat.unsqueeze(1), 3*left_contours_flat.unsqueeze(1)+1, + 3*left_contours_flat.unsqueeze(1)+2), dim=1).reshape(-1) + left_geometry = torch.mm(id_para, self.base_id[:, sel_index]) + \ + torch.mm(exp_para, self.base_exp[:, + sel_index]) + self.mu[sel_index] + left_geometry = left_geometry.view(batch_size, -1, 3) + proj_x = forward_transform( + left_geometry, euler_angle, trans, focal_length, cxy)[:, :, 0] + proj_x = proj_x.reshape(batch_size, 8, num_per_contour) + arg_min = proj_x.argmin(dim=2) + left_geometry = left_geometry.view(batch_size*8, num_per_contour, 3) + left_3dlands = left_geometry[torch.arange( + batch_size*8), arg_min.view(-1), :].view(batch_size, 8, 3) + + sel_index = torch.cat((3*right_contours_flat.unsqueeze(1), 3*right_contours_flat.unsqueeze(1)+1, + 3*right_contours_flat.unsqueeze(1)+2), dim=1).reshape(-1) + right_geometry = torch.mm(id_para, self.base_id[:, sel_index]) + \ + torch.mm(exp_para, self.base_exp[:, + sel_index]) + self.mu[sel_index] + right_geometry = right_geometry.view(batch_size, -1, 3) + proj_x = forward_transform( + right_geometry, euler_angle, trans, focal_length, cxy)[:, :, 0] + proj_x = proj_x.reshape(batch_size, 8, num_per_contour) + arg_max = proj_x.argmax(dim=2) + right_geometry = right_geometry.view(batch_size*8, num_per_contour, 3) + right_3dlands = right_geometry[torch.arange( + batch_size*8), arg_max.view(-1), :].view(batch_size, 8, 3) + + sel_index = torch.cat((3*self.keyinds.unsqueeze(1), 3*self.keyinds.unsqueeze(1)+1, + 3*self.keyinds.unsqueeze(1)+2), dim=1).reshape(-1) + geometry = torch.mm(id_para, self.base_id[:, sel_index]) + \ + torch.mm(exp_para, self.base_exp[:, + sel_index]) + self.mu[sel_index] + lands_3d = geometry.view(-1, self.keyinds.shape[0], 3) + lands_3d[:, :8, :] = left_3dlands + lands_3d[:, 9:17, :] = right_3dlands + return lands_3d + + def forward_geo_sub(self, id_para, exp_para, sub_index): + id_para = id_para*self.sig_id + exp_para = exp_para*self.sig_exp + sel_index = torch.cat((3*sub_index.unsqueeze(1), 3*sub_index.unsqueeze(1)+1, + 3*sub_index.unsqueeze(1)+2), dim=1).reshape(-1) + geometry = torch.mm(id_para, self.base_id[:, sel_index]) + \ + torch.mm(exp_para, self.base_exp[:, + sel_index]) + self.mu[sel_index] + return geometry.reshape(-1, sub_index.shape[0], 3) + + def forward_geo(self, id_para, exp_para): + id_para = id_para*self.sig_id + exp_para = exp_para*self.sig_exp + geometry = torch.mm(id_para, self.base_id) + \ + torch.mm(exp_para, self.base_exp) + self.mu + return geometry.reshape(-1, self.point_num, 3) + + def forward_tex(self, tex_para): + tex_para = tex_para*self.sig_tex + texture = torch.mm(tex_para, self.base_tex) + self.mu_tex + return texture.reshape(-1, self.point_num, 3) diff --git a/Geneface_main/GeneFace/data_util/face_tracking/geo_transform.py b/Geneface_main/GeneFace/data_util/face_tracking/geo_transform.py new file mode 100644 index 00000000..13e4890c --- /dev/null +++ b/Geneface_main/GeneFace/data_util/face_tracking/geo_transform.py @@ -0,0 +1,60 @@ +"""This module contains functions for geometry transform and camera projection""" +import torch +import torch.nn as nn +import numpy as np + + +def euler2rot(euler_angle): + batch_size = euler_angle.shape[0] + theta = euler_angle[:, 0].reshape(-1, 1, 1) + phi = euler_angle[:, 1].reshape(-1, 1, 1) + psi = euler_angle[:, 2].reshape(-1, 1, 1) + one = torch.ones((batch_size, 1, 1), dtype=torch.float32, + device=euler_angle.device) + zero = torch.zeros((batch_size, 1, 1), dtype=torch.float32, + device=euler_angle.device) + rot_x = torch.cat(( + torch.cat((one, zero, zero), 1), + torch.cat((zero, theta.cos(), theta.sin()), 1), + torch.cat((zero, -theta.sin(), theta.cos()), 1), + ), 2) + rot_y = torch.cat(( + torch.cat((phi.cos(), zero, -phi.sin()), 1), + torch.cat((zero, one, zero), 1), + torch.cat((phi.sin(), zero, phi.cos()), 1), + ), 2) + rot_z = torch.cat(( + torch.cat((psi.cos(), -psi.sin(), zero), 1), + torch.cat((psi.sin(), psi.cos(), zero), 1), + torch.cat((zero, zero, one), 1) + ), 2) + return torch.bmm(rot_x, torch.bmm(rot_y, rot_z)) + + +def rot_trans_geo(geometry, rot, trans): + rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans.view(-1, 3, 1) + return rott_geo.permute(0, 2, 1) + + +def euler_trans_geo(geometry, euler, trans): + rot = euler2rot(euler) + return rot_trans_geo(geometry, rot, trans) + + +def proj_geo(rott_geo, camera_para): + fx = camera_para[:, 0] + fy = camera_para[:, 0] + cx = camera_para[:, 1] + cy = camera_para[:, 2] + + X = rott_geo[:, :, 0] + Y = rott_geo[:, :, 1] + Z = rott_geo[:, :, 2] + + fxX = fx[:, None]*X + fyY = fy[:, None]*Y + + proj_x = -fxX/Z + cx[:, None] + proj_y = fyY/Z + cy[:, None] + + return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2) diff --git a/Geneface_main/GeneFace/data_util/face_tracking/render_3dmm.py b/Geneface_main/GeneFace/data_util/face_tracking/render_3dmm.py new file mode 100644 index 00000000..6b5a50b4 --- /dev/null +++ b/Geneface_main/GeneFace/data_util/face_tracking/render_3dmm.py @@ -0,0 +1,195 @@ +import torch +import torch.nn as nn +import numpy as np +import os +from pytorch3d.structures import Meshes +from pytorch3d.renderer import ( + look_at_view_transform, + PerspectiveCameras, + FoVPerspectiveCameras, + PointLights, + DirectionalLights, + Materials, + RasterizationSettings, + MeshRenderer, + MeshRasterizer, + SoftPhongShader, + TexturesUV, + TexturesVertex, + blending +) + +from pytorch3d.ops import interpolate_face_attributes + +from pytorch3d.renderer.blending import ( + BlendParams, + hard_rgb_blend, + sigmoid_alpha_blend, + softmax_rgb_blend, +) + + +class SoftSimpleShader(nn.Module): + """ + Per pixel lighting - the lighting model is applied using the interpolated + coordinates and normals for each pixel. The blending function returns the + soft aggregated color using all the faces per pixel. + + To use the default values, simply initialize the shader with the desired + device e.g. + + """ + + def __init__( + self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None + ): + super().__init__() + self.lights = lights if lights is not None else PointLights( + device=device) + self.materials = ( + materials if materials is not None else Materials(device=device) + ) + self.cameras = cameras + self.blend_params = blend_params if blend_params is not None else BlendParams() + + def to(self, device): + # Manually move to device modules which are not subclasses of nn.Module + self.cameras = self.cameras.to(device) + self.materials = self.materials.to(device) + self.lights = self.lights.to(device) + return self + + def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: + + texels = meshes.sample_textures(fragments) + blend_params = kwargs.get("blend_params", self.blend_params) + + cameras = kwargs.get("cameras", self.cameras) + if cameras is None: + msg = "Cameras must be specified either at initialization \ + or in the forward pass of SoftPhongShader" + raise ValueError(msg) + znear = kwargs.get("znear", getattr(cameras, "znear", 1.0)) + zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0)) + images = softmax_rgb_blend( + texels, fragments, blend_params, znear=znear, zfar=zfar + ) + return images + + +class Render_3DMM(nn.Module): + def __init__(self, focal=1015, img_h=500, img_w=500, batch_size=1, device=torch.device('cuda:0')): + super(Render_3DMM, self).__init__() + + self.focal = focal + self.img_h = img_h + self.img_w = img_w + self.device = device + self.renderer = self.get_render(batch_size) + + dir_path = os.path.dirname(os.path.realpath(__file__)) + topo_info = np.load(os.path.join( + dir_path, '3DMM', 'topology_info.npy'), allow_pickle=True).item() + self.tris = torch.as_tensor(topo_info['tris']).to(self.device) + self.vert_tris = torch.as_tensor( + topo_info['vert_tris']).to(self.device) + + def compute_normal(self, geometry): + vert_1 = torch.index_select(geometry, 1, self.tris[:, 0]) + vert_2 = torch.index_select(geometry, 1, self.tris[:, 1]) + vert_3 = torch.index_select(geometry, 1, self.tris[:, 2]) + nnorm = torch.cross(vert_2-vert_1, vert_3-vert_1, 2) + tri_normal = nn.functional.normalize(nnorm, dim=2) + v_norm = tri_normal[:, self.vert_tris, :].sum(2) + vert_normal = v_norm / v_norm.norm(dim=2).unsqueeze(2) + return vert_normal + + def get_render(self, batch_size=1): + half_s = self.img_w * 0.5 + R, T = look_at_view_transform(10, 0, 0) + R = R.repeat(batch_size, 1, 1) + T = torch.zeros((batch_size, 3), dtype=torch.float32).to(self.device) + + cameras = FoVPerspectiveCameras(device=self.device, R=R, T=T, znear=0.01, zfar=20, + fov=2*np.arctan(self.img_w//2/self.focal)*180./np.pi) + lights = PointLights( + device=self.device, + location=[[0.0, 0.0, 1e5]], + ambient_color=[[1, 1, 1]], + specular_color=[[0., 0., 0.]], + diffuse_color=[[0., 0., 0.]] + ) + sigma = 1e-4 + raster_settings = RasterizationSettings( + image_size=(self.img_h, self.img_w), + # blur_radius=np.log(1. / 1e-4 - 1.)*sigma / 18.0, + blur_radius=0, + # faces_per_pixel=2, + faces_per_pixel=1, + perspective_correct=False, + ) + blend_params = blending.BlendParams(background_color=[0, 0, 0]) + renderer = MeshRenderer( + rasterizer=MeshRasterizer( + raster_settings=raster_settings, + cameras=cameras + ), + shader=SoftSimpleShader( + lights=lights, + blend_params=blend_params, + cameras=cameras + ), + ) + return renderer.to(self.device) + + @staticmethod + def Illumination_layer(face_texture, norm, gamma): + + n_b, num_vertex, _ = face_texture.size() + n_v_full = n_b * num_vertex + gamma = gamma.view(-1, 3, 9).clone() + gamma[:, :, 0] += 0.8 + + gamma = gamma.permute(0, 2, 1) + + a0 = np.pi + a1 = 2 * np.pi / np.sqrt(3.0) + a2 = 2 * np.pi / np.sqrt(8.0) + c0 = 1 / np.sqrt(4 * np.pi) + c1 = np.sqrt(3.0) / np.sqrt(4 * np.pi) + c2 = 3 * np.sqrt(5.0) / np.sqrt(12 * np.pi) + d0 = 0.5 / np.sqrt(3.0) + + Y0 = torch.ones(n_v_full).to(gamma.device).float() * a0 * c0 + norm = norm.view(-1, 3) + nx, ny, nz = norm[:, 0], norm[:, 1], norm[:, 2] + arrH = [] + + arrH.append(Y0) + arrH.append(-a1 * c1 * ny) + arrH.append(a1 * c1 * nz) + arrH.append(-a1 * c1 * nx) + arrH.append(a2 * c2 * nx * ny) + arrH.append(-a2 * c2 * ny * nz) + arrH.append(a2 * c2 * d0 * (3 * nz.pow(2) - 1)) + arrH.append(-a2 * c2 * nx * nz) + arrH.append(a2 * c2 * 0.5 * (nx.pow(2) - ny.pow(2))) + + H = torch.stack(arrH, 1) + Y = H.view(n_b, num_vertex, 9) + lighting = Y.bmm(gamma) + + face_color = face_texture * lighting + return face_color + + def forward(self, rott_geometry, texture, diffuse_sh): + face_normal = self.compute_normal(rott_geometry) + face_color = self.Illumination_layer(texture, face_normal, diffuse_sh) + face_color = TexturesVertex(face_color) + mesh = Meshes(rott_geometry, self.tris.float().repeat( + rott_geometry.shape[0], 1, 1), face_color) + # rendered_img = self.renderer(mesh) # , eps=1e-8 + rendered_img = self.renderer(mesh, eps=1e-4) # + rendered_img = torch.clamp(rendered_img, 0, 255) + + return rendered_img diff --git a/Geneface_main/GeneFace/data_util/face_tracking/render_land.py b/Geneface_main/GeneFace/data_util/face_tracking/render_land.py new file mode 100644 index 00000000..9dd324de --- /dev/null +++ b/Geneface_main/GeneFace/data_util/face_tracking/render_land.py @@ -0,0 +1,146 @@ +import torch +import torch.nn as nn +import render_util +import geo_transform +import numpy as np + + +def compute_tri_normal(geometry, tris): + geometry = geometry.permute(0, 2, 1) + tri_1 = tris[:, 0] + tri_2 = tris[:, 1] + tri_3 = tris[:, 2] + + vert_1 = torch.index_select(geometry, 2, tri_1) + vert_2 = torch.index_select(geometry, 2, tri_2) + vert_3 = torch.index_select(geometry, 2, tri_3) + + nnorm = torch.cross(vert_2-vert_1, vert_3-vert_1, 1) + normal = nn.functional.normalize(nnorm).permute(0, 2, 1) + return normal + + +class Compute_normal_base(torch.autograd.Function): + @staticmethod + def forward(ctx, normal): + normal_b, = render_util.normal_base_forward(normal) + ctx.save_for_backward(normal) + return normal_b + + @staticmethod + def backward(ctx, grad_normal_b): + normal, = ctx.saved_tensors + grad_normal, = render_util.normal_base_backward(grad_normal_b, normal) + return grad_normal + + +class Normal_Base(torch.nn.Module): + def __init__(self): + super(Normal_Base, self).__init__() + + def forward(self, normal): + return Compute_normal_base.apply(normal) + + +def preprocess_render(geometry, euler, trans, cam, tris, vert_tris, ori_img): + point_num = geometry.shape[1] + rott_geo = geo_transform.euler_trans_geo(geometry, euler, trans) + proj_geo = geo_transform.proj_geo(rott_geo, cam) + rot_tri_normal = compute_tri_normal(rott_geo, tris) + rot_vert_normal = torch.index_select(rot_tri_normal, 1, vert_tris) + is_visible = -torch.bmm(rot_vert_normal.reshape(-1, 1, 3), + nn.functional.normalize(rott_geo.reshape(-1, 3, 1))).reshape(-1, point_num) + is_visible[is_visible < 0.01] = -1 + pixel_valid = torch.zeros((ori_img.shape[0], ori_img.shape[1]*ori_img.shape[2]), + dtype=torch.float32, device=ori_img.device) + return rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid + + +class Render_Face(torch.autograd.Function): + @staticmethod + def forward(ctx, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, + pixel_valid): + batch_size, h, w, _ = ori_img.shape + ori_img = ori_img.view(batch_size, -1, 3) + ori_size = torch.cat((torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)*h, + torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)*w), + dim=1).view(-1) + tri_index, tri_coord, render, real = render_util.render_face_forward( + proj_geo, ori_img, ori_size, texture, nbl, is_visible, tri_inds, pixel_valid) + ctx.save_for_backward(ori_img, ori_size, proj_geo, texture, nbl, + tri_inds, tri_index, tri_coord) + return render, real + + @staticmethod + def backward(ctx, grad_render, grad_real): + ori_img, ori_size, proj_geo, texture, nbl, tri_inds, tri_index, tri_coord = \ + ctx.saved_tensors + grad_proj_geo, grad_texture, grad_nbl = render_util.render_face_backward( + grad_render, grad_real, ori_img, ori_size, proj_geo, texture, nbl, tri_inds, + tri_index, tri_coord) + return grad_proj_geo, grad_texture, grad_nbl, None, None, None, None + + +class Render_RGB(nn.Module): + def __init__(self): + super(Render_RGB, self).__init__() + + def forward(self, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid): + return Render_Face.apply(proj_geo, texture, nbl, ori_img, is_visible, + tri_inds, pixel_valid) + + +def cal_land(proj_geo, is_visible, lands_info, land_num): + land_index, = render_util.update_contour( + lands_info, is_visible, land_num) + proj_land = torch.index_select( + proj_geo.reshape(-1, 3), 0, land_index)[:, :2].reshape(-1, land_num, 2) + return proj_land + + +class Render_Land(nn.Module): + def __init__(self): + super(Render_Land, self).__init__() + lands_info = np.loadtxt('../data/3DMM/lands_info.txt', dtype=np.int32) + self.lands_info = torch.as_tensor(lands_info).cuda() + tris = np.loadtxt('../data/3DMM/tris.txt', dtype=np.int64) + self.tris = torch.as_tensor(tris).cuda() - 1 + vert_tris = np.loadtxt('../data/3DMM/vert_tris.txt', dtype=np.int64) + self.vert_tris = torch.as_tensor(vert_tris).cuda() + self.normal_baser = Normal_Base().cuda() + self.renderer = Render_RGB().cuda() + + def render_mesh(self, geometry, euler, trans, cam, ori_img, light): + batch_size, h, w, _ = ori_img.shape + ori_img = ori_img.view(batch_size, -1, 3) + ori_size = torch.cat((torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)*h, + torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device)*w), + dim=1).view(-1) + rott_geo, proj_geo, rot_tri_normal, _, _ = preprocess_render( + geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img) + tri_nb = self.normal_baser(rot_tri_normal.contiguous()) + nbl = torch.bmm(tri_nb, (light.reshape(-1, 9, 3)) + [:, :, 0].unsqueeze(-1).repeat(1, 1, 3)) + texture = torch.ones_like(geometry) * 200 + render, = render_util.render_mesh( + proj_geo, ori_img, ori_size, texture, nbl, self.tris) + return render.view(batch_size, h, w, 3).byte() + + def cal_loss_rgb(self, geometry, euler, trans, cam, ori_img, light, texture, lands): + rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid = \ + preprocess_render(geometry, euler, trans, cam, + self.tris, self.vert_tris, ori_img) + tri_nb = self.normal_baser(rot_tri_normal.contiguous()) + nbl = torch.bmm(tri_nb, light.reshape(-1, 9, 3)) + render, real = self.renderer( + proj_geo, texture, nbl, ori_img, is_visible, self.tris, pixel_valid) + proj_land = cal_land(proj_geo, is_visible, + self.lands_info, lands.shape[1]) + col_minus = torch.norm((render-real).reshape(-1, 3), + dim=1).reshape(ori_img.shape[0], -1) + col_dis = torch.mean(col_minus*pixel_valid) / \ + (torch.mean(pixel_valid)+0.00001) + land_dists = torch.norm( + (proj_land-lands).reshape(-1, 2), dim=1).reshape(ori_img.shape[0], -1) + lan_dis = torch.mean(land_dists) + return col_dis, lan_dis diff --git a/Geneface_main/GeneFace/data_util/face_tracking/util.py b/Geneface_main/GeneFace/data_util/face_tracking/util.py new file mode 100644 index 00000000..7e168fd4 --- /dev/null +++ b/Geneface_main/GeneFace/data_util/face_tracking/util.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def compute_tri_normal(geometry, tris): + tri_1 = tris[:, 0] + tri_2 = tris[:, 1] + tri_3 = tris[:, 2] + vert_1 = torch.index_select(geometry, 1, tri_1) + vert_2 = torch.index_select(geometry, 1, tri_2) + vert_3 = torch.index_select(geometry, 1, tri_3) + nnorm = torch.cross(vert_2-vert_1, vert_3-vert_1, 2) + normal = nn.functional.normalize(nnorm) + return normal + + +def euler2rot(euler_angle): + batch_size = euler_angle.shape[0] + theta = euler_angle[:, 0].reshape(-1, 1, 1) + phi = euler_angle[:, 1].reshape(-1, 1, 1) + psi = euler_angle[:, 2].reshape(-1, 1, 1) + one = torch.ones(batch_size, 1, 1).to(euler_angle.device) + zero = torch.zeros(batch_size, 1, 1).to(euler_angle.device) + rot_x = torch.cat(( + torch.cat((one, zero, zero), 1), + torch.cat((zero, theta.cos(), theta.sin()), 1), + torch.cat((zero, -theta.sin(), theta.cos()), 1), + ), 2) + rot_y = torch.cat(( + torch.cat((phi.cos(), zero, -phi.sin()), 1), + torch.cat((zero, one, zero), 1), + torch.cat((phi.sin(), zero, phi.cos()), 1), + ), 2) + rot_z = torch.cat(( + torch.cat((psi.cos(), -psi.sin(), zero), 1), + torch.cat((psi.sin(), psi.cos(), zero), 1), + torch.cat((zero, zero, one), 1) + ), 2) + return torch.bmm(rot_x, torch.bmm(rot_y, rot_z)) + + +def rot_trans_pts(geometry, rot, trans): + rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans[:, :, None] + return rott_geo.permute(0, 2, 1) + + +def cal_lap_loss(tensor_list, weight_list): + lap_kernel = torch.Tensor( + (-0.5, 1.0, -0.5)).unsqueeze(0).unsqueeze(0).float().to(tensor_list[0].device) + loss_lap = 0 + for i in range(len(tensor_list)): + in_tensor = tensor_list[i] + in_tensor = in_tensor.view(-1, 1, in_tensor.shape[-1]) + out_tensor = F.conv1d(in_tensor, lap_kernel) + loss_lap += torch.mean(out_tensor**2)*weight_list[i] + return loss_lap + + +def proj_pts(rott_geo, focal_length, cxy): + cx, cy = cxy[0], cxy[1] + X = rott_geo[:, :, 0] + Y = rott_geo[:, :, 1] + Z = rott_geo[:, :, 2] + fxX = focal_length*X + fyY = focal_length*Y + proj_x = -fxX/Z + cx + proj_y = fyY/Z + cy + return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2) + +def forward_rott(geometry, euler_angle, trans): + rot = euler2rot(euler_angle) + rott_geo = rot_trans_pts(geometry, rot, trans) + return rott_geo + +def forward_transform(geometry, euler_angle, trans, focal_length, cxy): + rot = euler2rot(euler_angle) + rott_geo = rot_trans_pts(geometry, rot, trans) + proj_geo = proj_pts(rott_geo, focal_length, cxy) + return proj_geo + + +def cal_lan_loss(proj_lan, gt_lan): + return torch.mean((proj_lan-gt_lan)**2) + +def cal_col_loss(pred_img, gt_img, img_mask): + pred_img = pred_img.float() + loss = torch.sqrt(torch.sum(torch.square(pred_img - gt_img), 3))*img_mask/255 + loss = torch.sum(loss, dim=(1, 2)) / torch.sum(img_mask, dim=(1, 2)) + loss = torch.mean(loss) + return loss diff --git a/Geneface_main/GeneFace/data_util/process.py b/Geneface_main/GeneFace/data_util/process.py new file mode 100644 index 00000000..a9f208a2 --- /dev/null +++ b/Geneface_main/GeneFace/data_util/process.py @@ -0,0 +1,447 @@ +import os +import glob +import tqdm +import json +import argparse +import cv2 +import numpy as np + +def extract_audio(path, out_path, sample_rate=16000): + + print(f'[INFO] ===== extract audio from {path} to {out_path} =====') + cmd = f'ffmpeg -i {path} -f wav -ar {sample_rate} {out_path}' + os.system(cmd) + print(f'[INFO] ===== extracted audio =====') + + +def extract_audio_features(path): + + print(f'[INFO] ===== extract audio labels for {path} =====') + + print(f'[INFO] ===== start extract esperanto =====') + cmd = f'python data_util/extract_esperanto.py --wav {path} --save_feats' + os.system(cmd) + print(f'[INFO] ===== extracted esperanto =====') + + print(f'[INFO] ===== extract deepspeech =====') + cmd = f'python data_util/deepspeech_features/extract_ds_features.py --input {path} --output {path.replace(".wav", "_deepspeech.npy")}' + os.system(cmd) + print(f'[INFO] ===== extracted deepspeech =====') + + print(f'[INFO] ===== extracted all audio labels =====') + + +def extract_images(path, out_path, fps=25): + + print(f'[INFO] ===== extract images from {path} to {out_path} =====') + cmd = f'ffmpeg -i {path} -vf fps={fps} -qmin 1 -q:v 1 -start_number 0 {os.path.join(out_path, "%d.jpg")}' + os.system(cmd) + print(f'[INFO] ===== extracted images =====') + + +def extract_semantics(ori_imgs_dir, parsing_dir): + + print(f'[INFO] ===== extract semantics from {ori_imgs_dir} to {parsing_dir} =====') + cmd = f'python data_util/face_parsing/test.py --respath={parsing_dir} --imgpath={ori_imgs_dir}' + os.system(cmd) + print(f'[INFO] ===== extracted semantics =====') + + +def extract_landmarks(ori_imgs_dir): + + print(f'[INFO] ===== extract face landmarks from {ori_imgs_dir} =====') + + import face_alignment + fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False) + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) + for image_path in tqdm.tqdm(image_paths): + input = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] + input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB) + preds = fa.get_landmarks(input) + if len(preds) > 0: + lands = preds[0].reshape(-1, 2)[:,:2] + np.savetxt(image_path.replace('jpg', 'lms'), lands, '%f') + del fa + print(f'[INFO] ===== extracted face landmarks =====') + + +def extract_background(base_dir, ori_imgs_dir): + + print(f'[INFO] ===== extract background image from {ori_imgs_dir} =====') + + from sklearn.neighbors import NearestNeighbors + + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) + # only use 1/20 image_paths + image_paths = image_paths[::20] + # read one image to get H/W + tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3] + h, w = tmp_image.shape[:2] + + # nearest neighbors + all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose() + distss = [] + for image_path in tqdm.tqdm(image_paths): + parse_img = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png')) + bg = (parse_img[..., 0] == 255) & (parse_img[..., 1] == 255) & (parse_img[..., 2] == 255) + fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0) + nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys) + dists, _ = nbrs.kneighbors(all_xys) + distss.append(dists) + + distss = np.stack(distss) + max_dist = np.max(distss, 0) + max_id = np.argmax(distss, 0) + + bc_pixs = max_dist > 5 + bc_pixs_id = np.nonzero(bc_pixs) + bc_ids = max_id[bc_pixs] + + imgs = [] + num_pixs = distss.shape[1] + for image_path in image_paths: + img = cv2.imread(image_path) + imgs.append(img) + imgs = np.stack(imgs).reshape(-1, num_pixs, 3) + + bg_img = np.zeros((h*w, 3), dtype=np.uint8) + bg_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :] + bg_img = bg_img.reshape(h, w, 3) + + max_dist = max_dist.reshape(h, w) + bc_pixs = max_dist > 5 + bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose() + fg_xys = np.stack(np.nonzero(bc_pixs)).transpose() + nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys) + distances, indices = nbrs.kneighbors(bg_xys) + bg_fg_xys = fg_xys[indices[:, 0]] + bg_img[bg_xys[:, 0], bg_xys[:, 1], :] = bg_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :] + + cv2.imwrite(os.path.join(base_dir, 'bc.jpg'), bg_img) + + print(f'[INFO] ===== extracted background image =====') + +def extract_head(base_dir): + bg_img = cv2.imread(os.path.join(base_dir, 'bc.jpg'), cv2.IMREAD_UNCHANGED) + ori_imgs_dir = os.path.join(base_dir, 'ori_imgs') + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) + + print(f'[INFO] ===== extract head images for {base_dir} =====') + + for image_path in tqdm.tqdm(image_paths): + # read ori image + img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] + + # read semantics + parsing_img = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png')) + + head_part = (parsing_img[:, :, 0] == 255) & ( + parsing_img[:, :, 1] == 0) & (parsing_img[:, :, 2] == 0) + img[~head_part] = bg_img[~head_part] + cv2.imwrite(image_path.replace('ori_imgs', 'head_imgs'), img) + print(f'[INFO] ===== extracted head images =====') + + +def extract_torso_and_gt(base_dir, ori_imgs_dir): + + print(f'[INFO] ===== extract torso and gt images for {base_dir} =====') + + from scipy.ndimage import binary_erosion, binary_dilation + + # load bg + bg_image = cv2.imread(os.path.join(base_dir, 'bc.jpg'), cv2.IMREAD_UNCHANGED) + + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) + + for image_path in tqdm.tqdm(image_paths): + # read ori image + ori_image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] + + # read semantics + seg = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png')) + head_part = (seg[..., 0] == 255) & (seg[..., 1] == 0) & (seg[..., 2] == 0) + neck_part = (seg[..., 0] == 0) & (seg[..., 1] == 255) & (seg[..., 2] == 0) + torso_part = (seg[..., 0] == 0) & (seg[..., 1] == 0) & (seg[..., 2] == 255) + bg_part = (seg[..., 0] == 255) & (seg[..., 1] == 255) & (seg[..., 2] == 255) + + # get gt image + gt_image = ori_image.copy() + gt_image[bg_part] = bg_image[bg_part] + cv2.imwrite(image_path.replace('ori_imgs', 'gt_imgs'), gt_image) + + # get torso image + torso_image = gt_image.copy() # rgb + torso_image[head_part] = bg_image[head_part] + torso_alpha = 255 * np.ones((gt_image.shape[0], gt_image.shape[1], 1), dtype=np.uint8) # alpha + + # torso part "vertical" in-painting... + L = 8 + 1 + torso_coords = np.stack(np.nonzero(torso_part), axis=-1) # [M, 2] + # lexsort: sort 2D coords first by y then by x, + # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes + inds = np.lexsort((torso_coords[:, 0], torso_coords[:, 1])) + torso_coords = torso_coords[inds] + # choose the top pixel for each column + u, uid, ucnt = np.unique(torso_coords[:, 1], return_index=True, return_counts=True) + top_torso_coords = torso_coords[uid] # [m, 2] + # only keep top-is-head pixels + top_torso_coords_up = top_torso_coords.copy() - np.array([1, 0]) + mask = head_part[tuple(top_torso_coords_up.T)] + if mask.any(): + top_torso_coords = top_torso_coords[mask] + # get the color + top_torso_colors = gt_image[tuple(top_torso_coords.T)] # [m, 3] + # construct inpaint coords (vertically up, or minus in x) + inpaint_torso_coords = top_torso_coords[None].repeat(L, 0) # [L, m, 2] + inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2] + inpaint_torso_coords += inpaint_offsets + inpaint_torso_coords = inpaint_torso_coords.reshape(-1, 2) # [Lm, 2] + inpaint_torso_colors = top_torso_colors[None].repeat(L, 0) # [L, m, 3] + darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1] + inpaint_torso_colors = (inpaint_torso_colors * darken_scaler).reshape(-1, 3) # [Lm, 3] + # set color + torso_image[tuple(inpaint_torso_coords.T)] = inpaint_torso_colors + + inpaint_torso_mask = np.zeros_like(torso_image[..., 0]).astype(bool) + inpaint_torso_mask[tuple(inpaint_torso_coords.T)] = True + else: + inpaint_torso_mask = None + + + # neck part "vertical" in-painting... + push_down = 4 + L = 48 + push_down + 1 + + neck_part = binary_dilation(neck_part, structure=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=bool), iterations=3) + + neck_coords = np.stack(np.nonzero(neck_part), axis=-1) # [M, 2] + # lexsort: sort 2D coords first by y then by x, + # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes + inds = np.lexsort((neck_coords[:, 0], neck_coords[:, 1])) + neck_coords = neck_coords[inds] + # choose the top pixel for each column + u, uid, ucnt = np.unique(neck_coords[:, 1], return_index=True, return_counts=True) + top_neck_coords = neck_coords[uid] # [m, 2] + # only keep top-is-head pixels + top_neck_coords_up = top_neck_coords.copy() - np.array([1, 0]) + mask = head_part[tuple(top_neck_coords_up.T)] + + top_neck_coords = top_neck_coords[mask] + # push these top down for 4 pixels to make the neck inpainting more natural... + offset_down = np.minimum(ucnt[mask] - 1, push_down) + top_neck_coords += np.stack([offset_down, np.zeros_like(offset_down)], axis=-1) + # get the color + top_neck_colors = gt_image[tuple(top_neck_coords.T)] # [m, 3] + # construct inpaint coords (vertically up, or minus in x) + inpaint_neck_coords = top_neck_coords[None].repeat(L, 0) # [L, m, 2] + inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2] + inpaint_neck_coords += inpaint_offsets + inpaint_neck_coords = inpaint_neck_coords.reshape(-1, 2) # [Lm, 2] + inpaint_neck_colors = top_neck_colors[None].repeat(L, 0) # [L, m, 3] + darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1] + inpaint_neck_colors = (inpaint_neck_colors * darken_scaler).reshape(-1, 3) # [Lm, 3] + # set color + torso_image[tuple(inpaint_neck_coords.T)] = inpaint_neck_colors + + # apply blurring to the inpaint area to avoid vertical-line artifects... + inpaint_mask = np.zeros_like(torso_image[..., 0]).astype(bool) + inpaint_mask[tuple(inpaint_neck_coords.T)] = True + + blur_img = torso_image.copy() + blur_img = cv2.GaussianBlur(blur_img, (5, 5), cv2.BORDER_DEFAULT) + + torso_image[inpaint_mask] = blur_img[inpaint_mask] + + # set mask + mask = (neck_part | torso_part | inpaint_mask) + if inpaint_torso_mask is not None: + mask = mask | inpaint_torso_mask + torso_image[~mask] = 0 + torso_alpha[~mask] = 0 + + cv2.imwrite(image_path.replace('ori_imgs', 'torso_imgs').replace('.jpg', '.png'), np.concatenate([torso_image, torso_alpha], axis=-1)) + + print(f'[INFO] ===== extracted torso and gt images =====') + + +def face_tracking(video_id, ori_imgs_dir): + + print(f'[INFO] ===== perform face tracking =====') + + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) + # read one image to get H/W + tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3] + h, w = tmp_image.shape[:2] + + cmd = f'python data_util/face_tracking/face_tracker.py --idname={video_id} --img_h={h} --img_w={w} --frame_num={len(image_paths)}' + + os.system(cmd) + + print(f'[INFO] ===== finished face tracking =====') + + +def save_transforms(base_dir, ori_imgs_dir): + print(f'[INFO] ===== save transforms =====') + + import torch + + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) + + # read one image to get H/W + tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3] + h, w = tmp_image.shape[:2] + + params_dict = torch.load(os.path.join(base_dir, 'track_params.pt')) + focal_len = params_dict['focal'] + euler_angle = params_dict['euler'] + trans = params_dict['trans'] / 10.0 + valid_num = euler_angle.shape[0] + + def euler2rot(euler_angle): + batch_size = euler_angle.shape[0] + theta = euler_angle[:, 0].reshape(-1, 1, 1) + phi = euler_angle[:, 1].reshape(-1, 1, 1) + psi = euler_angle[:, 2].reshape(-1, 1, 1) + one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device) + zero = torch.zeros((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device) + rot_x = torch.cat(( + torch.cat((one, zero, zero), 1), + torch.cat((zero, theta.cos(), theta.sin()), 1), + torch.cat((zero, -theta.sin(), theta.cos()), 1), + ), 2) + rot_y = torch.cat(( + torch.cat((phi.cos(), zero, -phi.sin()), 1), + torch.cat((zero, one, zero), 1), + torch.cat((phi.sin(), zero, phi.cos()), 1), + ), 2) + rot_z = torch.cat(( + torch.cat((psi.cos(), -psi.sin(), zero), 1), + torch.cat((psi.sin(), psi.cos(), zero), 1), + torch.cat((zero, zero, one), 1) + ), 2) + return torch.bmm(rot_x, torch.bmm(rot_y, rot_z)) + + + # train_val_split = int(valid_num*0.5) + # train_val_split = valid_num - 25 * 20 # take the last 20s as valid set. + train_val_split = int(valid_num * 10 / 11) + + train_ids = torch.arange(0, train_val_split) + val_ids = torch.arange(train_val_split, valid_num) + + rot = euler2rot(euler_angle) + rot_inv = rot.permute(0, 2, 1) + trans_inv = -torch.bmm(rot_inv, trans.unsqueeze(2)) + + pose = torch.eye(4, dtype=torch.float32) + save_ids = ['train', 'val'] + train_val_ids = [train_ids, val_ids] + mean_z = -float(torch.mean(trans[:, 2]).item()) + + for split in range(2): + transform_dict = dict() + transform_dict['focal_len'] = float(focal_len[0]) + transform_dict['cx'] = float(w/2.0) + transform_dict['cy'] = float(h/2.0) + transform_dict['frames'] = [] + ids = train_val_ids[split] + save_id = save_ids[split] + + for i in ids: + i = i.item() + frame_dict = dict() + frame_dict['img_id'] = i + frame_dict['aud_id'] = i + + pose[:3, :3] = rot_inv[i] + pose[:3, 3] = trans_inv[i, :, 0] + + frame_dict['transform_matrix'] = pose.numpy().tolist() + + lms = np.loadtxt(os.path.join(ori_imgs_dir, str(i) + '.lms')) + min_x, max_x = np.min(lms, 0)[0], np.max(lms, 0)[0] + cx = int((min_x+max_x)/2.0) + cy = int(lms[27, 1]) + h_w = int((max_x-cx)*1.5) + h_h = int((lms[8, 1]-cy)*1.15) + rect_x = cx - h_w + rect_y = cy - h_h + if rect_x < 0: + rect_x = 0 + if rect_y < 0: + rect_y = 0 + rect_w = min(w-1-rect_x, 2*h_w) + rect_h = min(h-1-rect_y, 2*h_h) + rect = np.array((rect_x, rect_y, rect_w, rect_h), dtype=np.int32) + frame_dict['face_rect'] = rect.tolist() + + transform_dict['frames'].append(frame_dict) + + with open(os.path.join(base_dir, 'transforms_' + save_id + '.json'), 'w') as fp: + json.dump(transform_dict, fp, indent=2, separators=(',', ': ')) + + print(f'[INFO] ===== finished saving transforms =====') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--video_id', type=str, default='May', help="data/raw/.mp4") + parser.add_argument('--task', type=int, default=-1, help="-1 means all") + + opt = parser.parse_args() + + video_id = opt.video_id + video_path = os.path.join(f"data/raw/videos/{opt.video_id}.mp4") + processed_dir = f"data/processed/videos/{opt.video_id}" + os.makedirs(processed_dir, exist_ok=True) + wav_path = os.path.join(processed_dir, 'aud.wav') + ori_imgs_dir = os.path.join(processed_dir, 'ori_imgs') + parsing_dir = os.path.join(processed_dir, 'parsing') + head_imgs_dir = os.path.join(processed_dir, 'head_imgs') + gt_imgs_dir = os.path.join(processed_dir, 'gt_imgs') + torso_imgs_dir = os.path.join(processed_dir, 'torso_imgs') + + os.makedirs(ori_imgs_dir, exist_ok=True) + os.makedirs(parsing_dir, exist_ok=True) + os.makedirs(head_imgs_dir, exist_ok=True) + os.makedirs(gt_imgs_dir, exist_ok=True) + os.makedirs(torso_imgs_dir, exist_ok=True) + + + # extract audio + if opt.task == -1 or opt.task == 1: + extract_audio(video_path, wav_path) + + # extract audio features + if opt.task == -1 or opt.task == 2: + extract_audio_features(wav_path) + + # extract images + if opt.task == -1 or opt.task == 3: + extract_images(video_path, ori_imgs_dir) + + # face parsing + if opt.task == -1 or opt.task == 4: + extract_semantics(ori_imgs_dir, parsing_dir) + + # extract bg + if opt.task == -1 or opt.task == 5: + extract_background(processed_dir, ori_imgs_dir) + + # extract torso images and gt_images + if opt.task == -1 or opt.task == 6: + extract_head(processed_dir) + extract_torso_and_gt(processed_dir, ori_imgs_dir) + + # extract face landmarks + if opt.task == -1 or opt.task == 7: + extract_landmarks(ori_imgs_dir) + + # face tracking + if opt.task == -1 or opt.task == 8: + face_tracking(video_id, ori_imgs_dir) + + # save transforms.json + if opt.task == -1 or opt.task == 9: + save_transforms(processed_dir, ori_imgs_dir) + diff --git a/Geneface_main/GeneFace/deep_3drecon/__init__.py b/Geneface_main/GeneFace/deep_3drecon/__init__.py new file mode 100644 index 00000000..6866fab1 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/__init__.py @@ -0,0 +1 @@ +from .reconstructor import * diff --git a/Geneface_main/GeneFace/deep_3drecon/__pycache__/__init__.cpython-310.pyc b/Geneface_main/GeneFace/deep_3drecon/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 00000000..c367e34b Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/__pycache__/__init__.cpython-310.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/__pycache__/__init__.cpython-39.pyc b/Geneface_main/GeneFace/deep_3drecon/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 00000000..ea0dbd08 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/__pycache__/__init__.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/__pycache__/reconstructor.cpython-310.pyc b/Geneface_main/GeneFace/deep_3drecon/__pycache__/reconstructor.cpython-310.pyc new file mode 100644 index 00000000..f628a5d9 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/__pycache__/reconstructor.cpython-310.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/__pycache__/reconstructor.cpython-39.pyc b/Geneface_main/GeneFace/deep_3drecon/__pycache__/reconstructor.cpython-39.pyc new file mode 100644 index 00000000..274f2be1 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/__pycache__/reconstructor.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/data/__init__.py b/Geneface_main/GeneFace/deep_3drecon/data/__init__.py new file mode 100644 index 00000000..56fe2126 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/data/__init__.py @@ -0,0 +1,116 @@ +"""This package includes all the modules related to data loading and preprocessing + + To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. + You need to implement four functions: + -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). + -- <__len__>: return the size of dataset. + -- <__getitem__>: get a data point from data loader. + -- : (optionally) add dataset-specific options and set default options. + +Now you can use the dataset class by specifying flag '--dataset_mode dummy'. +See our template dataset class 'template_dataset.py' for more details. +""" +import numpy as np +import importlib +import torch.utils.data +from data.base_dataset import BaseDataset + + +def find_dataset_using_name(dataset_name): + """Import the module "data/[dataset_name]_dataset.py". + + In the file, the class called DatasetNameDataset() will + be instantiated. It has to be a subclass of BaseDataset, + and it is case-insensitive. + """ + dataset_filename = "data." + dataset_name + "_dataset" + datasetlib = importlib.import_module(dataset_filename) + + dataset = None + target_dataset_name = dataset_name.replace('_', '') + 'dataset' + for name, cls in datasetlib.__dict__.items(): + if name.lower() == target_dataset_name.lower() \ + and issubclass(cls, BaseDataset): + dataset = cls + + if dataset is None: + raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) + + return dataset + + +def get_option_setter(dataset_name): + """Return the static method of the dataset class.""" + dataset_class = find_dataset_using_name(dataset_name) + return dataset_class.modify_commandline_options + + +def create_dataset(opt, rank=0): + """Create a dataset given the option. + + This function wraps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from data import create_dataset + >>> dataset = create_dataset(opt) + """ + data_loader = CustomDatasetDataLoader(opt, rank=rank) + dataset = data_loader.load_data() + return dataset + +class CustomDatasetDataLoader(): + """Wrapper class of Dataset class that performs multi-threaded data loading""" + + def __init__(self, opt, rank=0): + """Initialize this class + + Step 1: create a dataset instance given the name [dataset_mode] + Step 2: create a multi-threaded data loader. + """ + self.opt = opt + dataset_class = find_dataset_using_name(opt.dataset_mode) + self.dataset = dataset_class(opt) + self.sampler = None + print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__)) + if opt.use_ddp and opt.isTrain: + world_size = opt.world_size + self.sampler = torch.utils.data.distributed.DistributedSampler( + self.dataset, + num_replicas=world_size, + rank=rank, + shuffle=not opt.serial_batches + ) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + sampler=self.sampler, + num_workers=int(opt.num_threads / world_size), + batch_size=int(opt.batch_size / world_size), + drop_last=True) + else: + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batch_size, + shuffle=(not opt.serial_batches) and opt.isTrain, + num_workers=int(opt.num_threads), + drop_last=True + ) + + def set_epoch(self, epoch): + self.dataset.current_epoch = epoch + if self.sampler is not None: + self.sampler.set_epoch(epoch) + + def load_data(self): + return self + + def __len__(self): + """Return the number of data in the dataset""" + return min(len(self.dataset), self.opt.max_dataset_size) + + def __iter__(self): + """Return a batch of data""" + for i, data in enumerate(self.dataloader): + if i * self.opt.batch_size >= self.opt.max_dataset_size: + break + yield data diff --git a/Geneface_main/GeneFace/deep_3drecon/data/__pycache__/__init__.cpython-310.pyc b/Geneface_main/GeneFace/deep_3drecon/data/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 00000000..acea8ba0 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/data/__pycache__/__init__.cpython-310.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/data/__pycache__/__init__.cpython-39.pyc b/Geneface_main/GeneFace/deep_3drecon/data/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 00000000..491ec190 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/data/__pycache__/__init__.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/data/__pycache__/base_dataset.cpython-310.pyc b/Geneface_main/GeneFace/deep_3drecon/data/__pycache__/base_dataset.cpython-310.pyc new file mode 100644 index 00000000..d443e019 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/data/__pycache__/base_dataset.cpython-310.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/data/__pycache__/base_dataset.cpython-39.pyc b/Geneface_main/GeneFace/deep_3drecon/data/__pycache__/base_dataset.cpython-39.pyc new file mode 100644 index 00000000..b76e2d9f Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/data/__pycache__/base_dataset.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/data/base_dataset.py b/Geneface_main/GeneFace/deep_3drecon/data/base_dataset.py new file mode 100644 index 00000000..1bd57d08 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/data/base_dataset.py @@ -0,0 +1,125 @@ +"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets. + +It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. +""" +import random +import numpy as np +import torch.utils.data as data +from PIL import Image +import torchvision.transforms as transforms +from abc import ABC, abstractmethod + + +class BaseDataset(data.Dataset, ABC): + """This class is an abstract base class (ABC) for datasets. + + To create a subclass, you need to implement the following four functions: + -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). + -- <__len__>: return the size of dataset. + -- <__getitem__>: get a data point. + -- : (optionally) add dataset-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the class; save the options in the class + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + self.opt = opt + # self.root = opt.dataroot + self.current_epoch = 0 + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new dataset-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def __len__(self): + """Return the total number of images in the dataset.""" + return 0 + + @abstractmethod + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index - - a random integer for data indexing + + Returns: + a dictionary of data with their names. It ususally contains the data itself and its metadata information. + """ + pass + + +def get_transform(grayscale=False): + transform_list = [] + if grayscale: + transform_list.append(transforms.Grayscale(1)) + transform_list += [transforms.ToTensor()] + return transforms.Compose(transform_list) + +def get_affine_mat(opt, size): + shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False + w, h = size + + if 'shift' in opt.preprocess: + shift_pixs = int(opt.shift_pixs) + shift_x = random.randint(-shift_pixs, shift_pixs) + shift_y = random.randint(-shift_pixs, shift_pixs) + if 'scale' in opt.preprocess: + scale = 1 + opt.scale_delta * (2 * random.random() - 1) + if 'rot' in opt.preprocess: + rot_angle = opt.rot_angle * (2 * random.random() - 1) + rot_rad = -rot_angle * np.pi/180 + if 'flip' in opt.preprocess: + flip = random.random() > 0.5 + + shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3]) + flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3]) + shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3]) + rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3]) + scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3]) + shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3]) + + affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin + affine_inv = np.linalg.inv(affine) + return affine, affine_inv, flip + +def apply_img_affine(img, affine_inv, method=Image.BICUBIC): + return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC) + +def apply_lm_affine(landmark, affine, flip, size): + _, h = size + lm = landmark.copy() + lm[:, 1] = h - 1 - lm[:, 1] + lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1) + lm = lm @ np.transpose(affine) + lm[:, :2] = lm[:, :2] / lm[:, 2:] + lm = lm[:, :2] + lm[:, 1] = h - 1 - lm[:, 1] + if flip: + lm_ = lm.copy() + lm_[:17] = lm[16::-1] + lm_[17:22] = lm[26:21:-1] + lm_[22:27] = lm[21:16:-1] + lm_[31:36] = lm[35:30:-1] + lm_[36:40] = lm[45:41:-1] + lm_[40:42] = lm[47:45:-1] + lm_[42:46] = lm[39:35:-1] + lm_[46:48] = lm[41:39:-1] + lm_[48:55] = lm[54:47:-1] + lm_[55:60] = lm[59:54:-1] + lm_[60:65] = lm[64:59:-1] + lm_[65:68] = lm[67:64:-1] + lm = lm_ + return lm diff --git a/Geneface_main/GeneFace/deep_3drecon/data/flist_dataset.py b/Geneface_main/GeneFace/deep_3drecon/data/flist_dataset.py new file mode 100644 index 00000000..c0b6945c --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/data/flist_dataset.py @@ -0,0 +1,125 @@ +"""This script defines the custom dataset for Deep3DFaceRecon_pytorch +""" + +import os.path +from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine +from data.image_folder import make_dataset +from PIL import Image +import random +import util.util as util +import numpy as np +import json +import torch +from scipy.io import loadmat, savemat +import pickle +from util.preprocess import align_img, estimate_norm +from util.load_mats import load_lm3d + + +def default_flist_reader(flist): + """ + flist format: impath label\nimpath label\n ...(same to caffe's filelist) + """ + imlist = [] + with open(flist, 'r') as rf: + for line in rf.readlines(): + impath = line.strip() + imlist.append(impath) + + return imlist + +def jason_flist_reader(flist): + with open(flist, 'r') as fp: + info = json.load(fp) + return info + +def parse_label(label): + return torch.tensor(np.array(label).astype(np.float32)) + + +class FlistDataset(BaseDataset): + """ + It requires one directories to host training images '/path/to/data/train' + You can train the model with the dataset flag '--dataroot /path/to/data'. + """ + + def __init__(self, opt): + """Initialize this dataset class. + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseDataset.__init__(self, opt) + + self.lm3d_std = load_lm3d(opt.bfm_folder) + + msk_names = default_flist_reader(opt.flist) + self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names] + + self.size = len(self.msk_paths) + self.opt = opt + + self.name = 'train' if opt.isTrain else 'val' + if '_' in opt.flist: + self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0] + + + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index (int) -- a random integer for data indexing + + Returns a dictionary that contains A, B, A_paths and B_paths + img (tensor) -- an image in the input domain + msk (tensor) -- its corresponding attention mask + lm (tensor) -- its corresponding 3d landmarks + im_paths (str) -- image paths + aug_flag (bool) -- a flag used to tell whether its raw or augmented + """ + msk_path = self.msk_paths[index % self.size] # make sure index is within then range + img_path = msk_path.replace('mask/', '') + lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt' + + raw_img = Image.open(img_path).convert('RGB') + raw_msk = Image.open(msk_path).convert('RGB') + raw_lm = np.loadtxt(lm_path).astype(np.float32) + + _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk) + + aug_flag = self.opt.use_aug and self.opt.isTrain + if aug_flag: + img, lm, msk = self._augmentation(img, lm, self.opt, msk) + + _, H = img.size + M = estimate_norm(lm, H) + transform = get_transform() + img_tensor = transform(img) + msk_tensor = transform(msk)[:1, ...] + lm_tensor = parse_label(lm) + M_tensor = parse_label(M) + + + return {'imgs': img_tensor, + 'lms': lm_tensor, + 'msks': msk_tensor, + 'M': M_tensor, + 'im_paths': img_path, + 'aug_flag': aug_flag, + 'dataset': self.name} + + def _augmentation(self, img, lm, opt, msk=None): + affine, affine_inv, flip = get_affine_mat(opt, img.size) + img = apply_img_affine(img, affine_inv) + lm = apply_lm_affine(lm, affine, flip, img.size) + if msk is not None: + msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR) + return img, lm, msk + + + + + def __len__(self): + """Return the total number of images in the dataset. + """ + return self.size diff --git a/Geneface_main/GeneFace/deep_3drecon/data/image_folder.py b/Geneface_main/GeneFace/deep_3drecon/data/image_folder.py new file mode 100644 index 00000000..efadc2ec --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/data/image_folder.py @@ -0,0 +1,66 @@ +"""A modified image folder class + +We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) +so that this class can load images from both current directory and its subdirectories. +""" +import numpy as np +import torch.utils.data as data + +from PIL import Image +import os +import os.path + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', + '.tif', '.TIF', '.tiff', '.TIFF', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(dir, max_dataset_size=float("inf")): + images = [] + assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir + + for root, _, fnames in sorted(os.walk(dir, followlinks=True)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + return images[:min(max_dataset_size, len(images))] + + +def default_loader(path): + return Image.open(path).convert('RGB') + + +class ImageFolder(data.Dataset): + + def __init__(self, root, transform=None, return_paths=False, + loader=default_loader): + imgs = make_dataset(root) + if len(imgs) == 0: + raise(RuntimeError("Found 0 images in: " + root + "\n" + "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) + + self.root = root + self.imgs = imgs + self.transform = transform + self.return_paths = return_paths + self.loader = loader + + def __getitem__(self, index): + path = self.imgs[index] + img = self.loader(path) + if self.transform is not None: + img = self.transform(img) + if self.return_paths: + return img, path + else: + return img + + def __len__(self): + return len(self.imgs) diff --git a/Geneface_main/GeneFace/deep_3drecon/data/template_dataset.py b/Geneface_main/GeneFace/deep_3drecon/data/template_dataset.py new file mode 100644 index 00000000..bfdf16be --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/data/template_dataset.py @@ -0,0 +1,75 @@ +"""Dataset class template + +This module provides a template for users to implement custom datasets. +You can specify '--dataset_mode template' to use this dataset. +The class name should be consistent with both the filename and its dataset_mode option. +The filename should be _dataset.py +The class name should be Dataset.py +You need to implement the following functions: + -- : Add dataset-specific options and rewrite default values for existing options. + -- <__init__>: Initialize this dataset class. + -- <__getitem__>: Return a data point and its metadata information. + -- <__len__>: Return the number of images. +""" +from data.base_dataset import BaseDataset, get_transform +# from data.image_folder import make_dataset +# from PIL import Image + + +class TemplateDataset(BaseDataset): + """A template dataset class for you to implement custom datasets.""" + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new dataset-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option') + parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values + return parser + + def __init__(self, opt): + """Initialize this dataset class. + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + + A few things can be done here. + - save the options (have been done in BaseDataset) + - get image paths and meta information of the dataset. + - define the image transformation. + """ + # save the option and dataset root + BaseDataset.__init__(self, opt) + # get the image paths of your dataset; + self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root + # define the default transform function. You can use ; You can also define your custom transform function + self.transform = get_transform(opt) + + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index -- a random integer for data indexing + + Returns: + a dictionary of data with their names. It usually contains the data itself and its metadata information. + + Step 1: get a random image path: e.g., path = self.image_paths[index] + Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB'). + Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image) + Step 4: return a data point as a dictionary. + """ + path = 'temp' # needs to be a string + data_A = None # needs to be a tensor + data_B = None # needs to be a tensor + return {'data_A': data_A, 'data_B': data_B, 'path': path} + + def __len__(self): + """Return the total number of images.""" + return len(self.image_paths) diff --git a/Geneface_main/GeneFace/deep_3drecon/data_preparation.py b/Geneface_main/GeneFace/deep_3drecon/data_preparation.py new file mode 100644 index 00000000..6ffc79d3 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/data_preparation.py @@ -0,0 +1,45 @@ +"""This script is the data preparation script for Deep3DFaceRecon_pytorch +""" + +import os +import numpy as np +import argparse +from util.detect_lm68 import detect_68p,load_lm_graph +from util.skin_mask import get_skin_mask +from util.generate_list import check_list, write_list +import warnings +warnings.filterwarnings("ignore") + +parser = argparse.ArgumentParser() +parser.add_argument('--data_root', type=str, default='datasets', help='root directory for training data') +parser.add_argument('--img_folder', nargs="+", required=True, help='folders of training images') +parser.add_argument('--mode', type=str, default='train', help='train or val') +opt = parser.parse_args() + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + +def data_prepare(folder_list,mode): + + lm_sess,input_op,output_op = load_lm_graph('./checkpoints/lm_model/68lm_detector.pb') # load a tensorflow version 68-landmark detector + + for img_folder in folder_list: + detect_68p(img_folder,lm_sess,input_op,output_op) # detect landmarks for images + get_skin_mask(img_folder) # generate skin attention mask for images + + # create files that record path to all training data + msks_list = [] + for img_folder in folder_list: + path = os.path.join(img_folder, 'mask') + msks_list += ['/'.join([img_folder, 'mask', i]) for i in sorted(os.listdir(path)) if 'jpg' in i or + 'png' in i or 'jpeg' in i or 'PNG' in i] + + imgs_list = [i.replace('mask/', '') for i in msks_list] + lms_list = [i.replace('mask', 'landmarks') for i in msks_list] + lms_list = ['.'.join(i.split('.')[:-1]) + '.txt' for i in lms_list] + + lms_list_final, imgs_list_final, msks_list_final = check_list(lms_list, imgs_list, msks_list) # check if the path is valid + write_list(lms_list_final, imgs_list_final, msks_list_final, mode=mode) # save files + +if __name__ == '__main__': + print('Datasets:',opt.img_folder) + data_prepare([os.path.join(opt.data_root,folder) for folder in opt.img_folder],opt.mode) diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000002.jpg b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000002.jpg new file mode 100644 index 00000000..dc7ebcbf Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000002.jpg differ diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000006.jpg b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000006.jpg new file mode 100644 index 00000000..725e86c8 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000006.jpg differ diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000007.jpg b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000007.jpg new file mode 100644 index 00000000..443c8068 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000007.jpg differ diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000031.jpg b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000031.jpg new file mode 100644 index 00000000..46bdce9c Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000031.jpg differ diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000033.jpg b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000033.jpg new file mode 100644 index 00000000..c105797a Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000033.jpg differ diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000037.jpg b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000037.jpg new file mode 100644 index 00000000..6cba8201 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000037.jpg differ diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000050.jpg b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000050.jpg new file mode 100644 index 00000000..8513d730 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000050.jpg differ diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000055.jpg b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000055.jpg new file mode 100644 index 00000000..1ea7e0c1 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000055.jpg differ diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000114.jpg b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000114.jpg new file mode 100644 index 00000000..abf24cab Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000114.jpg differ diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000125.jpg b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000125.jpg new file mode 100644 index 00000000..272d4a4f Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000125.jpg differ diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000126.jpg b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000126.jpg new file mode 100644 index 00000000..e7a9a907 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/000126.jpg differ diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/015259.jpg b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/015259.jpg new file mode 100644 index 00000000..a421abf7 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/015259.jpg differ diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/015270.jpg b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/015270.jpg new file mode 100644 index 00000000..d4cd516e Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/015270.jpg differ diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/015309.jpg b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/015309.jpg new file mode 100644 index 00000000..6331a72b Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/015309.jpg differ diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/015310.jpg b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/015310.jpg new file mode 100644 index 00000000..71bf911d Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/015310.jpg differ diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/015316.jpg b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/015316.jpg new file mode 100644 index 00000000..3a7ca67e Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/015316.jpg differ diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/015384.jpg b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/015384.jpg new file mode 100644 index 00000000..9dfaceff Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/015384.jpg differ diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000002.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000002.txt new file mode 100644 index 00000000..0c0abbda --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000002.txt @@ -0,0 +1,5 @@ +142.84 207.18 +222.02 203.9 +159.24 253.57 +146.59 290.93 +227.52 284.74 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000006.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000006.txt new file mode 100644 index 00000000..28d4d3d2 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000006.txt @@ -0,0 +1,5 @@ +199.93 158.28 +255.34 166.54 +236.08 198.92 +198.83 229.24 +245.23 234.52 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000007.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000007.txt new file mode 100644 index 00000000..be564ec4 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000007.txt @@ -0,0 +1,5 @@ +129.36 198.28 +204.47 191.47 +164.42 240.51 +140.74 277.77 +205.4 270.9 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000031.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000031.txt new file mode 100644 index 00000000..10467f13 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000031.txt @@ -0,0 +1,5 @@ +151.23 240.71 +274.05 235.52 +217.37 305.99 +158.03 346.06 +272.17 341.09 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000033.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000033.txt new file mode 100644 index 00000000..e226473b --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000033.txt @@ -0,0 +1,5 @@ +119.09 94.291 +158.31 96.472 +136.76 121.4 +119.33 134.49 +154.66 136.68 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000037.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000037.txt new file mode 100644 index 00000000..ebdc113d --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000037.txt @@ -0,0 +1,5 @@ +147.37 159.39 +196.94 163.26 +190.68 194.36 +153.72 228.44 +193.94 229.7 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000050.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000050.txt new file mode 100644 index 00000000..67eed576 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000050.txt @@ -0,0 +1,5 @@ +150.4 94.799 +205.14 102.07 +179.54 131.16 +144.45 147.42 +193.39 154.14 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000055.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000055.txt new file mode 100644 index 00000000..4eec3411 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000055.txt @@ -0,0 +1,5 @@ +114.26 193.42 +205.8 190.27 +154.15 244.02 +124.69 295.22 +200.88 292.69 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000114.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000114.txt new file mode 100644 index 00000000..f7c78193 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000114.txt @@ -0,0 +1,5 @@ +217.52 152.95 +281.48 147.14 +253.02 196.03 +225.79 221.6 +288.25 214.44 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000125.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000125.txt new file mode 100644 index 00000000..c6c705d8 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000125.txt @@ -0,0 +1,5 @@ +90.928 99.858 +146.87 100.33 +114.22 130.36 +91.579 153.32 +143.63 153.56 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000126.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000126.txt new file mode 100644 index 00000000..e34d1bd0 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/000126.txt @@ -0,0 +1,5 @@ +307.56 166.54 +387.06 159.62 +335.52 222.26 +319.3 248.85 +397.71 239.14 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/015259.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/015259.txt new file mode 100644 index 00000000..9c2ab86f --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/015259.txt @@ -0,0 +1,5 @@ +226.38 193.65 +319.12 208.97 +279.99 245.88 +213.79 290.55 +303.03 302.1 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/015270.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/015270.txt new file mode 100644 index 00000000..335dcd6c --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/015270.txt @@ -0,0 +1,5 @@ +208.4 410.08 +364.41 388.68 +291.6 503.57 +244.82 572.86 +383.18 553.49 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/015309.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/015309.txt new file mode 100644 index 00000000..309a6331 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/015309.txt @@ -0,0 +1,5 @@ +284.61 496.57 +562.77 550.78 +395.85 712.84 +238.92 786.8 +495.61 827.22 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/015310.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/015310.txt new file mode 100644 index 00000000..7ce6a510 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/015310.txt @@ -0,0 +1,5 @@ +153.95 153.43 +211.13 161.54 +197.28 190.26 +150.82 215.98 +202.32 223.12 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/015316.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/015316.txt new file mode 100644 index 00000000..0743b137 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/015316.txt @@ -0,0 +1,5 @@ +481.31 396.88 +667.75 392.43 +557.81 440.55 +490.44 586.28 +640.56 583.2 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/015384.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/015384.txt new file mode 100644 index 00000000..b49f9e98 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/015384.txt @@ -0,0 +1,5 @@ +191.79 143.97 +271.86 151.23 +191.25 210.29 +187.82 257.12 +258.82 261.96 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/vd006.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/vd006.txt new file mode 100644 index 00000000..5fc0f2df --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/vd006.txt @@ -0,0 +1,5 @@ +123.12 117.58 +176.59 122.09 +126.99 144.68 +117.61 183.43 +163.94 186.41 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/vd025.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/vd025.txt new file mode 100644 index 00000000..0c5bf97b --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/vd025.txt @@ -0,0 +1,5 @@ +180.12 116.13 +263.18 98.397 +230.48 154.72 +201.37 199.01 +279.18 182.56 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/vd026.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/vd026.txt new file mode 100644 index 00000000..f4cedc32 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/vd026.txt @@ -0,0 +1,5 @@ +171.27 263.54 +286.58 263.88 +203.35 333.02 +170.6 389.42 +281.73 386.84 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/vd034.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/vd034.txt new file mode 100644 index 00000000..f799cc11 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/vd034.txt @@ -0,0 +1,5 @@ +136.01 167.83 +195.25 151.71 +152.89 191.45 +149.85 235.5 +201.16 222.8 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/vd051.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/vd051.txt new file mode 100644 index 00000000..38576331 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/vd051.txt @@ -0,0 +1,5 @@ +161.92 292.04 +254.21 283.81 +212.75 342.06 +170.78 387.28 +254.6 379.82 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/vd070.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/vd070.txt new file mode 100644 index 00000000..8f02c191 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/vd070.txt @@ -0,0 +1,5 @@ +276.53 290.35 +383.38 294.75 +314.48 354.66 +275.08 407.72 +364.94 411.48 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/vd092.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/vd092.txt new file mode 100644 index 00000000..679b2891 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/vd092.txt @@ -0,0 +1,5 @@ +108.59 149.07 +157.35 143.85 +134.4 173.2 +117.88 200.79 +159.56 196.36 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/vd102.txt b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/vd102.txt new file mode 100644 index 00000000..6fa2643b --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/detections/vd102.txt @@ -0,0 +1,5 @@ +121.62 225.96 +186.73 223.07 +162.99 269.82 +132.12 302.62 +186.42 299.21 diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/vd006.png b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/vd006.png new file mode 100644 index 00000000..681e3847 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/vd006.png differ diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/vd025.png b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/vd025.png new file mode 100644 index 00000000..a12e8d57 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/vd025.png differ diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/vd026.png b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/vd026.png new file mode 100644 index 00000000..96a06a7a Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/vd026.png differ diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/vd034.png b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/vd034.png new file mode 100644 index 00000000..2c0000f3 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/vd034.png differ diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/vd051.png b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/vd051.png new file mode 100644 index 00000000..9e841e50 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/vd051.png differ diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/vd070.png b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/vd070.png new file mode 100644 index 00000000..e084e840 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/vd070.png differ diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/vd092.png b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/vd092.png new file mode 100644 index 00000000..49570eeb Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/vd092.png differ diff --git a/Geneface_main/GeneFace/deep_3drecon/datasets/examples/vd102.png b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/vd102.png new file mode 100644 index 00000000..7864178a Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/datasets/examples/vd102.png differ diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__init__.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__init__.py new file mode 100644 index 00000000..a09ede59 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__init__.py @@ -0,0 +1,67 @@ +"""This package contains modules related to objective functions, optimizations, and network architectures. + +To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. +You need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate loss, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + +In the function <__init__>, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): define networks used in our training. + -- self.visual_names (str list): specify the images that you want to display and save. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. + +Now you can use the model class by specifying flag '--model dummy'. +See our template model class 'template_model.py' for more details. +""" + +import importlib +from .base_model import BaseModel + + +def find_model_using_name(model_name): + """Import the module "models/[model_name]_model.py". + + In the file, the class called DatasetNameModel() will + be instantiated. It has to be a subclass of BaseModel, + and it is case-insensitive. + """ + model_filename = "deep_3drecon_models." + model_name + "_model" + modellib = importlib.import_module(model_filename) + model = None + target_model_name = model_name.replace('_', '') + 'model' + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower() \ + and issubclass(cls, BaseModel): + model = cls + + if model is None: + print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) + exit(0) + + return model + + +def get_option_setter(model_name): + """Return the static method of the model class.""" + model_class = find_model_using_name(model_name) + return model_class.modify_commandline_options + + +def create_model(opt): + """Create a model given the option. + + This function warps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from models import create_model + >>> model = create_model(opt) + """ + model = find_model_using_name(opt.model) + instance = model(opt) + print("model [%s] was created" % type(instance).__name__) + return instance diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/__init__.cpython-310.pyc b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 00000000..99cc2ef2 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/__init__.cpython-310.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/__init__.cpython-39.pyc b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 00000000..3cbdc4bb Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/__init__.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/base_model.cpython-310.pyc b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/base_model.cpython-310.pyc new file mode 100644 index 00000000..b216147d Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/base_model.cpython-310.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/base_model.cpython-39.pyc b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/base_model.cpython-39.pyc new file mode 100644 index 00000000..d513721e Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/base_model.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/bfm.cpython-310.pyc b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/bfm.cpython-310.pyc new file mode 100644 index 00000000..a61d4e6a Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/bfm.cpython-310.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/bfm.cpython-39.pyc b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/bfm.cpython-39.pyc new file mode 100644 index 00000000..dd1bd8d6 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/bfm.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/facerecon_model.cpython-310.pyc b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/facerecon_model.cpython-310.pyc new file mode 100644 index 00000000..2a80041f Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/facerecon_model.cpython-310.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/facerecon_model.cpython-39.pyc b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/facerecon_model.cpython-39.pyc new file mode 100644 index 00000000..ed08b6fe Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/facerecon_model.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/losses.cpython-310.pyc b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/losses.cpython-310.pyc new file mode 100644 index 00000000..bb1b5853 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/losses.cpython-310.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/losses.cpython-39.pyc b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/losses.cpython-39.pyc new file mode 100644 index 00000000..705ecd8e Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/losses.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/networks.cpython-310.pyc b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/networks.cpython-310.pyc new file mode 100644 index 00000000..fd752e32 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/networks.cpython-310.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/networks.cpython-39.pyc b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/networks.cpython-39.pyc new file mode 100644 index 00000000..5266a910 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/__pycache__/networks.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/README.md b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/README.md new file mode 100644 index 00000000..8d391f63 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/README.md @@ -0,0 +1,218 @@ +# Distributed Arcface Training in Pytorch + +The "arcface_torch" repository is the official implementation of the ArcFace algorithm. It supports distributed and sparse training with multiple distributed training examples, including several memory-saving techniques such as mixed precision training and gradient checkpointing. It also supports training for ViT models and datasets including WebFace42M and Glint360K, two of the largest open-source datasets. Additionally, the repository comes with a built-in tool for converting to ONNX format, making it easy to submit to MFR evaluation systems. + +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/killing-two-birds-with-one-stone-efficient/face-verification-on-ijb-c)](https://paperswithcode.com/sota/face-verification-on-ijb-c?p=killing-two-birds-with-one-stone-efficient) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/killing-two-birds-with-one-stone-efficient/face-verification-on-ijb-b)](https://paperswithcode.com/sota/face-verification-on-ijb-b?p=killing-two-birds-with-one-stone-efficient) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/killing-two-birds-with-one-stone-efficient/face-verification-on-agedb-30)](https://paperswithcode.com/sota/face-verification-on-agedb-30?p=killing-two-birds-with-one-stone-efficient) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/killing-two-birds-with-one-stone-efficient/face-verification-on-cfp-fp)](https://paperswithcode.com/sota/face-verification-on-cfp-fp?p=killing-two-birds-with-one-stone-efficient) + +## Requirements + +To avail the latest features of PyTorch, we have upgraded to version 1.12.0. + +- Install [PyTorch](https://pytorch.org/get-started/previous-versions/) (torch>=1.12.0). +- (Optional) Install [DALI](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/), our doc for [install_dali.md](docs/install_dali.md). +- `pip install -r requirement.txt`. + +## How to Training + +To train a model, execute the `train.py` script with the path to the configuration files. The sample commands provided below demonstrate the process of conducting distributed training. + +### 1. To run on one GPU: + +```shell +python train_v2.py configs/ms1mv3_r50_onegpu +``` + +Note: +It is not recommended to use a single GPU for training, as this may result in longer training times and suboptimal performance. For best results, we suggest using multiple GPUs or a GPU cluster. + + +### 2. To run on a machine with 8 GPUs: + +```shell +torchrun --nproc_per_node=8 train.py configs/ms1mv3_r50 +``` + +### 3. To run on 2 machines with 8 GPUs each: + +Node 0: + +```shell +torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=12581 train.py configs/wf42m_pfc02_16gpus_r100 +``` + +Node 1: + +```shell +torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=12581 train.py configs/wf42m_pfc02_16gpus_r100 +``` + +### 4. Run ViT-B on a machine with 24k batchsize: + +```shell +torchrun --nproc_per_node=8 train_v2.py configs/wf42m_pfc03_40epoch_8gpu_vit_b +``` + + +## Download Datasets or Prepare Datasets +- [MS1MV2](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_#ms1m-arcface-85k-ids58m-images-57) (87k IDs, 5.8M images) +- [MS1MV3](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_#ms1m-retinaface) (93k IDs, 5.2M images) +- [Glint360K](https://github.com/deepinsight/insightface/tree/master/recognition/partial_fc#4-download) (360k IDs, 17.1M images) +- [WebFace42M](docs/prepare_webface42m.md) (2M IDs, 42.5M images) +- [Your Dataset, Click Here!](docs/prepare_custom_dataset.md) + +Note: +If you want to use DALI for data reading, please use the script 'scripts/shuffle_rec.py' to shuffle the InsightFace style rec before using it. +Example: + +`python scripts/shuffle_rec.py ms1m-retinaface-t1` + +You will get the "shuffled_ms1m-retinaface-t1" folder, where the samples in the "train.rec" file are shuffled. + + +## Model Zoo + +- The models are available for non-commercial research purposes only. +- All models can be found in here. +- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw +- [OneDrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d) + +### Performance on IJB-C and [**ICCV2021-MFR**](https://github.com/deepinsight/insightface/blob/master/challenges/mfr/README.md) + +ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face +recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities. +As the result, we can evaluate the FAIR performance for different algorithms. + +For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The +globalised multi-racial testset contains 242,143 identities and 1,624,305 images. + + +#### 1. Training on Single-Host GPU + +| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | log | +|:---------------|:--------------------|:------------|:------------|:------------|:------------------------------------------------------------------------------------------------------------------------------------| +| MS1MV2 | mobilefacenet-0.45G | 62.07 | 93.61 | 90.28 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv2_mbf/training.log) | +| MS1MV2 | r50 | 75.13 | 95.97 | 94.07 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv2_r50/training.log) | +| MS1MV2 | r100 | 78.12 | 96.37 | 94.27 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv2_r100/training.log) | +| MS1MV3 | mobilefacenet-0.45G | 63.78 | 94.23 | 91.33 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_mbf/training.log) | +| MS1MV3 | r50 | 79.14 | 96.37 | 94.47 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_r50/training.log) | +| MS1MV3 | r100 | 81.97 | 96.85 | 95.02 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_r100/training.log) | +| Glint360K | mobilefacenet-0.45G | 70.18 | 95.04 | 92.62 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_mbf/training.log) | +| Glint360K | r50 | 86.34 | 97.16 | 95.81 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_r50/training.log) | +| Glint360k | r100 | 89.52 | 97.55 | 96.38 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_r100/training.log) | +| WF4M | r100 | 89.87 | 97.19 | 95.48 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf4m_r100/training.log) | +| WF12M-PFC-0.2 | r100 | 94.75 | 97.60 | 95.90 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf12m_pfc02_r100/training.log) | +| WF12M-PFC-0.3 | r100 | 94.71 | 97.64 | 96.01 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf12m_pfc03_r100/training.log) | +| WF12M | r100 | 94.69 | 97.59 | 95.97 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf12m_r100/training.log) | +| WF42M-PFC-0.2 | r100 | 96.27 | 97.70 | 96.31 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf42m_pfc02_r100/training.log) | +| WF42M-PFC-0.2 | ViT-T-1.5G | 92.04 | 97.27 | 95.68 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf42m_pfc02_40epoch_8gpu_vit_t/training.log) | +| WF42M-PFC-0.3 | ViT-B-11G | 97.16 | 97.91 | 97.05 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_b_8gpu/training.log) | + +#### 2. Training on Multi-Host GPU + +| Datasets | Backbone(bs*gpus) | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | Throughout | log | +|:-----------------|:------------------|:------------|:------------|:------------|:-----------|:-------------------------------------------------------------------------------------------------------------------------------------------| +| WF42M-PFC-0.2 | r50(512*8) | 93.83 | 97.53 | 96.16 | ~5900 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_bs4k_pfc02/training.log) | +| WF42M-PFC-0.2 | r50(512*16) | 93.96 | 97.46 | 96.12 | ~11000 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_lr01_pfc02_bs8k_16gpus/training.log) | +| WF42M-PFC-0.2 | r50(128*32) | 94.04 | 97.48 | 95.94 | ~17000 | click me | +| WF42M-PFC-0.2 | r100(128*16) | 96.28 | 97.80 | 96.57 | ~5200 | click me | +| WF42M-PFC-0.2 | r100(256*16) | 96.69 | 97.85 | 96.63 | ~5200 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r100_bs4k_pfc02/training.log) | +| WF42M-PFC-0.0018 | r100(512*32) | 93.08 | 97.51 | 95.88 | ~10000 | click me | +| WF42M-PFC-0.2 | r100(128*32) | 96.57 | 97.83 | 96.50 | ~9800 | click me | + +`r100(128*32)` means backbone is r100, batchsize per gpu is 128, the number of gpus is 32. + + + +#### 3. ViT For Face Recognition + +| Datasets | Backbone(bs) | FLOPs | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | Throughout | log | +|:--------------|:--------------|:------|:------------|:------------|:------------|:-----------|:-----------------------------------------------------------------------------------------------------------------------------| +| WF42M-PFC-0.3 | r18(128*32) | 2.6 | 79.13 | 95.77 | 93.36 | - | click me | +| WF42M-PFC-0.3 | r50(128*32) | 6.3 | 94.03 | 97.48 | 95.94 | - | click me | +| WF42M-PFC-0.3 | r100(128*32) | 12.1 | 96.69 | 97.82 | 96.45 | - | click me | +| WF42M-PFC-0.3 | r200(128*32) | 23.5 | 97.70 | 97.97 | 96.93 | - | click me | +| WF42M-PFC-0.3 | VIT-T(384*64) | 1.5 | 92.24 | 97.31 | 95.97 | ~35000 | click me | +| WF42M-PFC-0.3 | VIT-S(384*64) | 5.7 | 95.87 | 97.73 | 96.57 | ~25000 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_s_64gpu/training.log) | +| WF42M-PFC-0.3 | VIT-B(384*64) | 11.4 | 97.42 | 97.90 | 97.04 | ~13800 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_b_64gpu/training.log) | +| WF42M-PFC-0.3 | VIT-L(384*64) | 25.3 | 97.85 | 98.00 | 97.23 | ~9406 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_l_64gpu/training.log) | + +`WF42M` means WebFace42M, `PFC-0.3` means negivate class centers sample rate is 0.3. + +#### 4. Noisy Datasets + +| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | log | +|:-------------------------|:---------|:------------|:------------|:------------|:---------| +| WF12M-Flip(40%) | r50 | 43.87 | 88.35 | 80.78 | click me | +| WF12M-Flip(40%)-PFC-0.1* | r50 | 80.20 | 96.11 | 93.79 | click me | +| WF12M-Conflict | r50 | 79.93 | 95.30 | 91.56 | click me | +| WF12M-Conflict-PFC-0.3* | r50 | 91.68 | 97.28 | 95.75 | click me | + +`WF12M` means WebFace12M, `+PFC-0.1*` denotes additional abnormal inter-class filtering. + + + +## Speed Benchmark +
+ + +**Arcface-Torch** is an efficient tool for training large-scale face recognition training sets. When the number of classes in the training sets exceeds one million, the partial FC sampling strategy maintains the same accuracy while providing several times faster training performance and lower GPU memory utilization. The partial FC is a sparse variant of the model parallel architecture for large-scale face recognition, utilizing a sparse softmax that dynamically samples a subset of class centers for each training batch. During each iteration, only a sparse portion of the parameters are updated, leading to a significant reduction in GPU memory requirements and computational demands. With the partial FC approach, it is possible to train sets with up to 29 million identities, the largest to date. Furthermore, the partial FC method supports multi-machine distributed training and mixed precision training. + + + +More details see +[speed_benchmark.md](docs/speed_benchmark.md) in docs. + +> 1. Training Speed of Various Parallel Techniques (Samples per Second) on a Tesla V100 32GB x 8 System (Higher is Optimal) + +`-` means training failed because of gpu memory limitations. + +| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | +|:--------------------------------|:--------------|:---------------|:---------------| +| 125000 | 4681 | 4824 | 5004 | +| 1400000 | **1672** | 3043 | 4738 | +| 5500000 | **-** | **1389** | 3975 | +| 8000000 | **-** | **-** | 3565 | +| 16000000 | **-** | **-** | 2679 | +| 29000000 | **-** | **-** | **1855** | + +> 2. GPU Memory Utilization of Various Parallel Techniques (MB per GPU) on a Tesla V100 32GB x 8 System (Lower is Optimal) + +| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | +|:--------------------------------|:--------------|:---------------|:---------------| +| 125000 | 7358 | 5306 | 4868 | +| 1400000 | 32252 | 11178 | 6056 | +| 5500000 | **-** | 32188 | 9854 | +| 8000000 | **-** | **-** | 12310 | +| 16000000 | **-** | **-** | 19950 | +| 29000000 | **-** | **-** | 32324 | + + +## Citations + +``` +@inproceedings{deng2019arcface, + title={Arcface: Additive angular margin loss for deep face recognition}, + author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos}, + booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, + pages={4690--4699}, + year={2019} +} +@inproceedings{An_2022_CVPR, + author={An, Xiang and Deng, Jiankang and Guo, Jia and Feng, Ziyong and Zhu, XuHan and Yang, Jing and Liu, Tongliang}, + title={Killing Two Birds With One Stone: Efficient and Robust Training of Face Recognition CNNs by Partial FC}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + month={June}, + year={2022}, + pages={4042-4051} +} +@inproceedings{zhu2021webface260m, + title={Webface260m: A benchmark unveiling the power of million-scale deep face recognition}, + author={Zhu, Zheng and Huang, Guan and Deng, Jiankang and Ye, Yun and Huang, Junjie and Chen, Xinze and Zhu, Jiagang and Yang, Tian and Lu, Jiwen and Du, Dalong and Zhou, Jie}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={10492--10502}, + year={2021} +} +``` diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__init__.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__init__.py new file mode 100644 index 00000000..6cea70df --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__init__.py @@ -0,0 +1,85 @@ +from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200 +from .mobilefacenet import get_mbf + + +def get_model(name, **kwargs): + # resnet + if name == "r18": + return iresnet18(False, **kwargs) + elif name == "r34": + return iresnet34(False, **kwargs) + elif name == "r50": + return iresnet50(False, **kwargs) + elif name == "r100": + return iresnet100(False, **kwargs) + elif name == "r200": + return iresnet200(False, **kwargs) + elif name == "r2060": + from .iresnet2060 import iresnet2060 + return iresnet2060(False, **kwargs) + + elif name == "mbf": + fp16 = kwargs.get("fp16", False) + num_features = kwargs.get("num_features", 512) + return get_mbf(fp16=fp16, num_features=num_features) + + elif name == "mbf_large": + from .mobilefacenet import get_mbf_large + fp16 = kwargs.get("fp16", False) + num_features = kwargs.get("num_features", 512) + return get_mbf_large(fp16=fp16, num_features=num_features) + + elif name == "vit_t": + num_features = kwargs.get("num_features", 512) + from .vit import VisionTransformer + return VisionTransformer( + img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12, + num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1) + + elif name == "vit_t_dp005_mask0": # For WebFace42M + num_features = kwargs.get("num_features", 512) + from .vit import VisionTransformer + return VisionTransformer( + img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12, + num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0) + + elif name == "vit_s": + num_features = kwargs.get("num_features", 512) + from .vit import VisionTransformer + return VisionTransformer( + img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12, + num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1) + + elif name == "vit_s_dp005_mask_0": # For WebFace42M + num_features = kwargs.get("num_features", 512) + from .vit import VisionTransformer + return VisionTransformer( + img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12, + num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0) + + elif name == "vit_b": + # this is a feature + num_features = kwargs.get("num_features", 512) + from .vit import VisionTransformer + return VisionTransformer( + img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24, + num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1, using_checkpoint=True) + + elif name == "vit_b_dp005_mask_005": # For WebFace42M + # this is a feature + num_features = kwargs.get("num_features", 512) + from .vit import VisionTransformer + return VisionTransformer( + img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24, + num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True) + + elif name == "vit_l_dp005_mask_005": # For WebFace42M + # this is a feature + num_features = kwargs.get("num_features", 512) + from .vit import VisionTransformer + return VisionTransformer( + img_size=112, patch_size=9, num_classes=num_features, embed_dim=768, depth=24, + num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True) + + else: + raise ValueError() diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__pycache__/__init__.cpython-310.pyc b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 00000000..a7b09eeb Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__pycache__/__init__.cpython-310.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__pycache__/__init__.cpython-39.pyc b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 00000000..9740334c Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__pycache__/__init__.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__pycache__/iresnet.cpython-310.pyc b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__pycache__/iresnet.cpython-310.pyc new file mode 100644 index 00000000..323d2b9b Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__pycache__/iresnet.cpython-310.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__pycache__/iresnet.cpython-39.pyc b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__pycache__/iresnet.cpython-39.pyc new file mode 100644 index 00000000..b5765221 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__pycache__/iresnet.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-310.pyc b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-310.pyc new file mode 100644 index 00000000..ccbcfd94 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-310.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-39.pyc b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-39.pyc new file mode 100644 index 00000000..e73012b6 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/iresnet.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/iresnet.py new file mode 100644 index 00000000..6f2347c9 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/iresnet.py @@ -0,0 +1,194 @@ +import torch +from torch import nn +from torch.utils.checkpoint import checkpoint + +__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200'] +using_ckpt = False + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=1, + stride=stride, + bias=False) + + +class IBasicBlock(nn.Module): + expansion = 1 + def __init__(self, inplanes, planes, stride=1, downsample=None, + groups=1, base_width=64, dilation=1): + super(IBasicBlock, self).__init__() + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,) + self.conv1 = conv3x3(inplanes, planes) + self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,) + self.prelu = nn.PReLU(planes) + self.conv2 = conv3x3(planes, planes, stride) + self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,) + self.downsample = downsample + self.stride = stride + + def forward_impl(self, x): + identity = x + out = self.bn1(x) + out = self.conv1(out) + out = self.bn2(out) + out = self.prelu(out) + out = self.conv2(out) + out = self.bn3(out) + if self.downsample is not None: + identity = self.downsample(x) + out += identity + return out + + def forward(self, x): + if self.training and using_ckpt: + return checkpoint(self.forward_impl, x) + else: + return self.forward_impl(x) + + +class IResNet(nn.Module): + fc_scale = 7 * 7 + def __init__(self, + block, layers, dropout=0, num_features=512, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): + super(IResNet, self).__init__() + self.extra_gflops = 0.0 + self.fp16 = fp16 + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) + self.prelu = nn.PReLU(self.inplanes) + self.layer1 = self._make_layer(block, 64, layers[0], stride=2) + self.layer2 = self._make_layer(block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2]) + self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,) + self.dropout = nn.Dropout(p=dropout, inplace=True) + self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) + self.features = nn.BatchNorm1d(num_features, eps=1e-05) + nn.init.constant_(self.features.weight, 1.0) + self.features.weight.requires_grad = False + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, 0, 0.1) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, IBasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), + ) + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block(self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation)) + + return nn.Sequential(*layers) + + def forward(self, x): + with torch.cuda.amp.autocast(self.fp16): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.bn2(x) + x = torch.flatten(x, 1) + x = self.dropout(x) + x = self.fc(x.float() if self.fp16 else x) + x = self.features(x) + return x + + +def _iresnet(arch, block, layers, pretrained, progress, **kwargs): + model = IResNet(block, layers, **kwargs) + if pretrained: + raise ValueError() + return model + + +def iresnet18(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, + progress, **kwargs) + + +def iresnet34(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, + progress, **kwargs) + + +def iresnet50(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, + progress, **kwargs) + + +def iresnet100(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, + progress, **kwargs) + + +def iresnet200(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained, + progress, **kwargs) diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/iresnet2060.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/iresnet2060.py new file mode 100644 index 00000000..21d11221 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/iresnet2060.py @@ -0,0 +1,176 @@ +import torch +from torch import nn + +assert torch.__version__ >= "1.8.1" +from torch.utils.checkpoint import checkpoint_sequential + +__all__ = ['iresnet2060'] + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=1, + stride=stride, + bias=False) + + +class IBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + groups=1, base_width=64, dilation=1): + super(IBasicBlock, self).__init__() + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, ) + self.conv1 = conv3x3(inplanes, planes) + self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, ) + self.prelu = nn.PReLU(planes) + self.conv2 = conv3x3(planes, planes, stride) + self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, ) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + out = self.bn1(x) + out = self.conv1(out) + out = self.bn2(out) + out = self.prelu(out) + out = self.conv2(out) + out = self.bn3(out) + if self.downsample is not None: + identity = self.downsample(x) + out += identity + return out + + +class IResNet(nn.Module): + fc_scale = 7 * 7 + + def __init__(self, + block, layers, dropout=0, num_features=512, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): + super(IResNet, self).__init__() + self.fp16 = fp16 + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) + self.prelu = nn.PReLU(self.inplanes) + self.layer1 = self._make_layer(block, 64, layers[0], stride=2) + self.layer2 = self._make_layer(block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2]) + self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, ) + self.dropout = nn.Dropout(p=dropout, inplace=True) + self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) + self.features = nn.BatchNorm1d(num_features, eps=1e-05) + nn.init.constant_(self.features.weight, 1.0) + self.features.weight.requires_grad = False + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, 0, 0.1) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, IBasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), + ) + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block(self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation)) + + return nn.Sequential(*layers) + + def checkpoint(self, func, num_seg, x): + if self.training: + return checkpoint_sequential(func, num_seg, x) + else: + return func(x) + + def forward(self, x): + with torch.cuda.amp.autocast(self.fp16): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.layer1(x) + x = self.checkpoint(self.layer2, 20, x) + x = self.checkpoint(self.layer3, 100, x) + x = self.layer4(x) + x = self.bn2(x) + x = torch.flatten(x, 1) + x = self.dropout(x) + x = self.fc(x.float() if self.fp16 else x) + x = self.features(x) + return x + + +def _iresnet(arch, block, layers, pretrained, progress, **kwargs): + model = IResNet(block, layers, **kwargs) + if pretrained: + raise ValueError() + return model + + +def iresnet2060(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs) diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/mobilefacenet.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/mobilefacenet.py new file mode 100644 index 00000000..007d136a --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/mobilefacenet.py @@ -0,0 +1,147 @@ +''' +Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py +Original author cavalleria +''' + +import torch.nn as nn +from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module +import torch + + +class Flatten(Module): + def forward(self, x): + return x.view(x.size(0), -1) + + +class ConvBlock(Module): + def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): + super(ConvBlock, self).__init__() + self.layers = nn.Sequential( + Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False), + BatchNorm2d(num_features=out_c), + PReLU(num_parameters=out_c) + ) + + def forward(self, x): + return self.layers(x) + + +class LinearBlock(Module): + def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): + super(LinearBlock, self).__init__() + self.layers = nn.Sequential( + Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False), + BatchNorm2d(num_features=out_c) + ) + + def forward(self, x): + return self.layers(x) + + +class DepthWise(Module): + def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1): + super(DepthWise, self).__init__() + self.residual = residual + self.layers = nn.Sequential( + ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)), + ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride), + LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) + ) + + def forward(self, x): + short_cut = None + if self.residual: + short_cut = x + x = self.layers(x) + if self.residual: + output = short_cut + x + else: + output = x + return output + + +class Residual(Module): + def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)): + super(Residual, self).__init__() + modules = [] + for _ in range(num_block): + modules.append(DepthWise(c, c, True, kernel, stride, padding, groups)) + self.layers = Sequential(*modules) + + def forward(self, x): + return self.layers(x) + + +class GDC(Module): + def __init__(self, embedding_size): + super(GDC, self).__init__() + self.layers = nn.Sequential( + LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)), + Flatten(), + Linear(512, embedding_size, bias=False), + BatchNorm1d(embedding_size)) + + def forward(self, x): + return self.layers(x) + + +class MobileFaceNet(Module): + def __init__(self, fp16=False, num_features=512, blocks=(1, 4, 6, 2), scale=2): + super(MobileFaceNet, self).__init__() + self.scale = scale + self.fp16 = fp16 + self.layers = nn.ModuleList() + self.layers.append( + ConvBlock(3, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)) + ) + if blocks[0] == 1: + self.layers.append( + ConvBlock(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64) + ) + else: + self.layers.append( + Residual(64 * self.scale, num_block=blocks[0], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), + ) + + self.layers.extend( + [ + DepthWise(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128), + Residual(64 * self.scale, num_block=blocks[1], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), + DepthWise(64 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256), + Residual(128 * self.scale, num_block=blocks[2], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), + DepthWise(128 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512), + Residual(128 * self.scale, num_block=blocks[3], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), + ]) + + self.conv_sep = ConvBlock(128 * self.scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)) + self.features = GDC(num_features) + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x): + with torch.cuda.amp.autocast(self.fp16): + for func in self.layers: + x = func(x) + x = self.conv_sep(x.float() if self.fp16 else x) + x = self.features(x) + return x + + +def get_mbf(fp16, num_features, blocks=(1, 4, 6, 2), scale=2): + return MobileFaceNet(fp16, num_features, blocks, scale=scale) + +def get_mbf_large(fp16, num_features, blocks=(2, 8, 12, 4), scale=4): + return MobileFaceNet(fp16, num_features, blocks, scale=scale) diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/vit.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/vit.py new file mode 100644 index 00000000..23977d2e --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/vit.py @@ -0,0 +1,280 @@ +import torch +import torch.nn as nn +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from typing import Optional, Callable + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class VITBatchNorm(nn.Module): + def __init__(self, num_features): + super().__init__() + self.num_features = num_features + self.bn = nn.BatchNorm1d(num_features=num_features) + + def forward(self, x): + return self.bn(x) + + +class Attention(nn.Module): + def __init__(self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: Optional[None] = None, + attn_drop: float = 0., + proj_drop: float = 0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + + with torch.cuda.amp.autocast(True): + batch_size, num_token, embed_dim = x.shape + #qkv is [3,batch_size,num_heads,num_token, embed_dim//num_heads] + qkv = self.qkv(x).reshape( + batch_size, num_token, 3, self.num_heads, embed_dim // self.num_heads).permute(2, 0, 3, 1, 4) + with torch.cuda.amp.autocast(False): + q, k, v = qkv[0].float(), qkv[1].float(), qkv[2].float() + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(batch_size, num_token, embed_dim) + with torch.cuda.amp.autocast(True): + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, + dim: int, + num_heads: int, + num_patches: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_scale: Optional[None] = None, + drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Callable = nn.ReLU6, + norm_layer: str = "ln", + patch_n: int = 144): + super().__init__() + + if norm_layer == "bn": + self.norm1 = VITBatchNorm(num_features=num_patches) + self.norm2 = VITBatchNorm(num_features=num_patches) + elif norm_layer == "ln": + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop) + self.extra_gflops = (num_heads * patch_n * (dim//num_heads)*patch_n * 2) / (1000**3) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + with torch.cuda.amp.autocast(True): + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + def __init__(self, img_size=108, patch_size=9, in_channels=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * \ + (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + self.proj = nn.Conv2d(in_channels, embed_dim, + kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + batch_size, channels, height, width = x.shape + assert height == self.img_size[0] and width == self.img_size[1], \ + f"Input image size ({height}*{width}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__(self, + img_size: int = 112, + patch_size: int = 16, + in_channels: int = 3, + num_classes: int = 1000, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_scale: Optional[None] = None, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + hybrid_backbone: Optional[None] = None, + norm_layer: str = "ln", + mask_ratio = 0.1, + using_checkpoint = False, + ): + super().__init__() + self.num_classes = num_classes + # num_features for consistency with other models + self.num_features = self.embed_dim = embed_dim + + if hybrid_backbone is not None: + raise ValueError + else: + self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim) + self.mask_ratio = mask_ratio + self.using_checkpoint = using_checkpoint + num_patches = self.patch_embed.num_patches + self.num_patches = num_patches + + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + patch_n = (img_size//patch_size)**2 + self.blocks = nn.ModuleList( + [ + Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + num_patches=num_patches, patch_n=patch_n) + for i in range(depth)] + ) + self.extra_gflops = 0.0 + for _block in self.blocks: + self.extra_gflops += _block.extra_gflops + + if norm_layer == "ln": + self.norm = nn.LayerNorm(embed_dim) + elif norm_layer == "bn": + self.norm = VITBatchNorm(self.num_patches) + + # features head + self.feature = nn.Sequential( + nn.Linear(in_features=embed_dim * num_patches, out_features=embed_dim, bias=False), + nn.BatchNorm1d(num_features=embed_dim, eps=2e-5), + nn.Linear(in_features=embed_dim, out_features=num_classes, bias=False), + nn.BatchNorm1d(num_features=num_classes, eps=2e-5) + ) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + torch.nn.init.normal_(self.mask_token, std=.02) + trunc_normal_(self.pos_embed, std=.02) + # trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def random_masking(self, x, mask_ratio=0.1): + """ + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L, D = x.size() # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + # ascend: small is keep, large is remove + ids_shuffle = torch.argsort(noise, dim=1) + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather( + x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + x = x + self.pos_embed + x = self.pos_drop(x) + + if self.training and self.mask_ratio > 0: + x, _, ids_restore = self.random_masking(x) + + for func in self.blocks: + if self.using_checkpoint and self.training: + from torch.utils.checkpoint import checkpoint + x = checkpoint(func, x) + else: + x = func(x) + x = self.norm(x.float()) + + if self.training and self.mask_ratio > 0: + mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1) + x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) # no cls token + x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle + x = x_ + return torch.reshape(x, (B, self.num_patches * self.embed_dim)) + + def forward(self, x): + x = self.forward_features(x) + x = self.feature(x) + return x diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/3millions.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/3millions.py new file mode 100644 index 00000000..6bb660bd --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/3millions.py @@ -0,0 +1,23 @@ +from easydict import EasyDict as edict + +# configs for test speed + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "mbf" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.1 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 512 # total_batch_size = batch_size * num_gpus +config.lr = 0.1 # batch size is 512 + +config.rec = "synthetic" +config.num_classes = 30 * 10000 +config.num_image = 100000 +config.num_epoch = 30 +config.warmup_epoch = -1 +config.val_targets = [] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/__init__.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/base.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/base.py new file mode 100644 index 00000000..c64c943e --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/base.py @@ -0,0 +1,59 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() + +# Margin Base Softmax +config.margin_list = (1.0, 0.5, 0.0) +config.network = "r50" +config.resume = False +config.save_all_states = False +config.output = "ms1mv3_arcface_r50" + +config.embedding_size = 512 + +# Partial FC +config.sample_rate = 1 +config.interclass_filtering_threshold = 0 + +config.fp16 = False +config.batch_size = 128 + +# For SGD +config.optimizer = "sgd" +config.lr = 0.1 +config.momentum = 0.9 +config.weight_decay = 5e-4 + +# For AdamW +# config.optimizer = "adamw" +# config.lr = 0.001 +# config.weight_decay = 0.1 + +config.verbose = 2000 +config.frequent = 10 + +# For Large Sacle Dataset, such as WebFace42M +config.dali = False + +# Gradient ACC +config.gradient_acc = 1 + +# setup seed +config.seed = 2048 + +# dataload numworkers +config.num_workers = 2 + +# WandB Logger +config.wandb_key = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" +config.suffix_run_name = None +config.using_wandb = False +config.wandb_entity = "entity" +config.wandb_project = "project" +config.wandb_log_all = True +config.save_artifacts = False +config.wandb_resume = False # resume wandb run: Only if the you wand t resume the last run that it was interrupted \ No newline at end of file diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/glint360k_mbf.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/glint360k_mbf.py new file mode 100644 index 00000000..b32f0016 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/glint360k_mbf.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "mbf" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 1e-4 +config.batch_size = 128 +config.lr = 0.1 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/glint360k" +config.num_classes = 360232 +config.num_image = 17091657 +config.num_epoch = 20 +config.warmup_epoch = 0 +config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/glint360k_r100.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/glint360k_r100.py new file mode 100644 index 00000000..3b8bbb78 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/glint360k_r100.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "r100" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 1e-4 +config.batch_size = 128 +config.lr = 0.1 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/glint360k" +config.num_classes = 360232 +config.num_image = 17091657 +config.num_epoch = 20 +config.warmup_epoch = 0 +config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/glint360k_r50.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/glint360k_r50.py new file mode 100644 index 00000000..4eeb28f8 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/glint360k_r50.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 1e-4 +config.batch_size = 128 +config.lr = 0.1 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/glint360k" +config.num_classes = 360232 +config.num_image = 17091657 +config.num_epoch = 20 +config.warmup_epoch = 0 +config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv2_mbf.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv2_mbf.py new file mode 100644 index 00000000..255a51ad --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv2_mbf.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.5, 0.0) +config.network = "mbf" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 1e-4 +config.batch_size = 128 +config.lr = 0.1 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/faces_emore" +config.num_classes = 85742 +config.num_image = 5822653 +config.num_epoch = 40 +config.warmup_epoch = 0 +config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv2_r100.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv2_r100.py new file mode 100644 index 00000000..36773489 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv2_r100.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.5, 0.0) +config.network = "r100" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/faces_emore" +config.num_classes = 85742 +config.num_image = 5822653 +config.num_epoch = 20 +config.warmup_epoch = 0 +config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv2_r50.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv2_r50.py new file mode 100644 index 00000000..2dab4d35 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv2_r50.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.5, 0.0) +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/faces_emore" +config.num_classes = 85742 +config.num_image = 5822653 +config.num_epoch = 20 +config.warmup_epoch = 0 +config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv3_mbf.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv3_mbf.py new file mode 100644 index 00000000..731b4a26 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv3_mbf.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.5, 0.0) +config.network = "mbf" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 1e-4 +config.batch_size = 128 +config.lr = 0.1 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/ms1m-retinaface-t1" +config.num_classes = 93431 +config.num_image = 5179510 +config.num_epoch = 40 +config.warmup_epoch = 0 +config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv3_r100.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv3_r100.py new file mode 100644 index 00000000..e7af3cef --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv3_r100.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.5, 0.0) +config.network = "r100" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/ms1m-retinaface-t1" +config.num_classes = 93431 +config.num_image = 5179510 +config.num_epoch = 20 +config.warmup_epoch = 0 +config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv3_r50.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv3_r50.py new file mode 100644 index 00000000..f1467f0a --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv3_r50.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.5, 0.0) +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/ms1m-retinaface-t1" +config.num_classes = 93431 +config.num_image = 5179510 +config.num_epoch = 20 +config.warmup_epoch = 0 +config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv3_r50_onegpu.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv3_r50_onegpu.py new file mode 100644 index 00000000..1ce7e140 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv3_r50_onegpu.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.5, 0.0) +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.02 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/ms1m-retinaface-t1" +config.num_classes = 93431 +config.num_image = 5179510 +config.num_epoch = 20 +config.warmup_epoch = 0 +config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_conflict_r50.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_conflict_r50.py new file mode 100644 index 00000000..de94fcb3 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_conflict_r50.py @@ -0,0 +1,28 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.interclass_filtering_threshold = 0 +config.fp16 = True +config.weight_decay = 5e-4 +config.batch_size = 128 +config.optimizer = "sgd" +config.lr = 0.1 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/WebFace12M_Conflict" +config.num_classes = 1017970 +config.num_image = 12720066 +config.num_epoch = 20 +config.warmup_epoch = config.num_epoch // 10 +config.val_targets = [] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_conflict_r50_pfc03_filter04.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_conflict_r50_pfc03_filter04.py new file mode 100644 index 00000000..a766f415 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_conflict_r50_pfc03_filter04.py @@ -0,0 +1,28 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.3 +config.interclass_filtering_threshold = 0.4 +config.fp16 = True +config.weight_decay = 5e-4 +config.batch_size = 128 +config.optimizer = "sgd" +config.lr = 0.1 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/WebFace12M_Conflict" +config.num_classes = 1017970 +config.num_image = 12720066 +config.num_epoch = 20 +config.warmup_epoch = config.num_epoch // 10 +config.val_targets = [] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_flip_pfc01_filter04_r50.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_flip_pfc01_filter04_r50.py new file mode 100644 index 00000000..2c1018b7 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_flip_pfc01_filter04_r50.py @@ -0,0 +1,28 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.1 +config.interclass_filtering_threshold = 0.4 +config.fp16 = True +config.weight_decay = 5e-4 +config.batch_size = 128 +config.optimizer = "sgd" +config.lr = 0.1 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/WebFace12M_FLIP40" +config.num_classes = 617970 +config.num_image = 12720066 +config.num_epoch = 20 +config.warmup_epoch = config.num_epoch // 10 +config.val_targets = [] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_flip_r50.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_flip_r50.py new file mode 100644 index 00000000..fde56fed --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_flip_r50.py @@ -0,0 +1,28 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.interclass_filtering_threshold = 0 +config.fp16 = True +config.weight_decay = 5e-4 +config.batch_size = 128 +config.optimizer = "sgd" +config.lr = 0.1 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/WebFace12M_FLIP40" +config.num_classes = 617970 +config.num_image = 12720066 +config.num_epoch = 20 +config.warmup_epoch = config.num_epoch // 10 +config.val_targets = [] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_mbf.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_mbf.py new file mode 100644 index 00000000..d1cb93b2 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_mbf.py @@ -0,0 +1,28 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "mbf" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.interclass_filtering_threshold = 0 +config.fp16 = True +config.weight_decay = 1e-4 +config.batch_size = 128 +config.optimizer = "sgd" +config.lr = 0.1 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/WebFace12M" +config.num_classes = 617970 +config.num_image = 12720066 +config.num_epoch = 20 +config.warmup_epoch = 0 +config.val_targets = [] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_pfc02_r100.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_pfc02_r100.py new file mode 100644 index 00000000..1062b876 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_pfc02_r100.py @@ -0,0 +1,29 @@ + +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "r100" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.2 +config.interclass_filtering_threshold = 0 +config.fp16 = True +config.weight_decay = 5e-4 +config.batch_size = 128 +config.optimizer = "sgd" +config.lr = 0.1 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/WebFace12M" +config.num_classes = 617970 +config.num_image = 12720066 +config.num_epoch = 20 +config.warmup_epoch = 0 +config.val_targets = [] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_r100.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_r100.py new file mode 100644 index 00000000..65bfa1be --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_r100.py @@ -0,0 +1,29 @@ + +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "r100" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.interclass_filtering_threshold = 0 +config.fp16 = True +config.weight_decay = 5e-4 +config.batch_size = 128 +config.optimizer = "sgd" +config.lr = 0.1 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/WebFace12M" +config.num_classes = 617970 +config.num_image = 12720066 +config.num_epoch = 20 +config.warmup_epoch = 0 +config.val_targets = [] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_r50.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_r50.py new file mode 100644 index 00000000..2a728466 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_r50.py @@ -0,0 +1,28 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.interclass_filtering_threshold = 0 +config.fp16 = True +config.weight_decay = 5e-4 +config.batch_size = 128 +config.optimizer = "sgd" +config.lr = 0.1 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/WebFace12M" +config.num_classes = 617970 +config.num_image = 12720066 +config.num_epoch = 20 +config.warmup_epoch = 0 +config.val_targets = [] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc0008_32gpu_r100.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc0008_32gpu_r100.py new file mode 100644 index 00000000..2885816c --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc0008_32gpu_r100.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "r100" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 512 +config.lr = 0.4 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/WebFace42M" +config.num_classes = 2059906 +config.num_image = 42474557 +config.num_epoch = 20 +config.warmup_epoch = config.num_epoch // 10 +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_16gpus_mbf_bs8k.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_16gpus_mbf_bs8k.py new file mode 100644 index 00000000..14a6bb79 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_16gpus_mbf_bs8k.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "mbf" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.2 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 1e-4 +config.batch_size = 512 +config.lr = 0.4 +config.verbose = 10000 +config.dali = False + +config.rec = "/train_tmp/WebFace42M" +config.num_classes = 2059906 +config.num_image = 42474557 +config.num_epoch = 20 +config.warmup_epoch = 2 +config.val_targets = [] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_16gpus_r100.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_16gpus_r100.py new file mode 100644 index 00000000..03568473 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_16gpus_r100.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "r100" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.2 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 256 +config.lr = 0.3 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/WebFace42M" +config.num_classes = 2059906 +config.num_image = 42474557 +config.num_epoch = 20 +config.warmup_epoch = 1 +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_16gpus_r50_bs8k.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_16gpus_r50_bs8k.py new file mode 100644 index 00000000..c02bdf3a --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_16gpus_r50_bs8k.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.2 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 512 +config.lr = 0.6 +config.verbose = 10000 +config.dali = False + +config.rec = "/train_tmp/WebFace42M" +config.num_classes = 2059906 +config.num_image = 42474557 +config.num_epoch = 20 +config.warmup_epoch = 4 +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_32gpus_r50_bs4k.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_32gpus_r50_bs4k.py new file mode 100644 index 00000000..5e840794 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_32gpus_r50_bs4k.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.2 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.4 +config.verbose = 10000 +config.dali = False + +config.rec = "/train_tmp/WebFace42M" +config.num_classes = 2059906 +config.num_image = 42474557 +config.num_epoch = 20 +config.warmup_epoch = 2 +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_8gpus_r50_bs4k.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_8gpus_r50_bs4k.py new file mode 100644 index 00000000..b9f627fa --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_8gpus_r50_bs4k.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.2 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 512 +config.lr = 0.4 +config.verbose = 10000 +config.dali = False + +config.rec = "/train_tmp/WebFace42M" +config.num_classes = 2059906 +config.num_image = 42474557 +config.num_epoch = 20 +config.warmup_epoch = 2 +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_r100.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_r100.py new file mode 100644 index 00000000..5274a52f --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_r100.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "r100" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.2 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 +config.verbose = 10000 +config.dali = False + +config.rec = "/train_tmp/WebFace42M" +config.num_classes = 2059906 +config.num_image = 42474557 +config.num_epoch = 20 +config.warmup_epoch = 0 +config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_r100_16gpus.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_r100_16gpus.py new file mode 100644 index 00000000..c1e8f199 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_r100_16gpus.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "r100" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.2 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.2 +config.verbose = 10000 +config.dali = False + +config.rec = "/train_tmp/WebFace42M" +config.num_classes = 2059906 +config.num_image = 42474557 +config.num_epoch = 20 +config.warmup_epoch = config.num_epoch // 10 +config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_r100_32gpus.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_r100_32gpus.py new file mode 100644 index 00000000..f7787675 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_r100_32gpus.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "r100" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.2 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.4 +config.verbose = 10000 +config.dali = False + +config.rec = "/train_tmp/WebFace42M" +config.num_classes = 2059906 +config.num_image = 42474557 +config.num_epoch = 20 +config.warmup_epoch = config.num_epoch // 10 +config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_32gpu_r100.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_32gpu_r100.py new file mode 100644 index 00000000..adf21c97 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_32gpu_r100.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "r100" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.3 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.4 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/WebFace42M" +config.num_classes = 2059906 +config.num_image = 42474557 +config.num_epoch = 20 +config.warmup_epoch = config.num_epoch // 10 +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_32gpu_r18.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_32gpu_r18.py new file mode 100644 index 00000000..5d35830b --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_32gpu_r18.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "r18" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.3 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.4 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/WebFace42M" +config.num_classes = 2059906 +config.num_image = 42474557 +config.num_epoch = 20 +config.warmup_epoch = config.num_epoch // 10 +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_32gpu_r200.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_32gpu_r200.py new file mode 100644 index 00000000..e34dd1c1 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_32gpu_r200.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "r200" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.3 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.4 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/WebFace42M" +config.num_classes = 2059906 +config.num_image = 42474557 +config.num_epoch = 20 +config.warmup_epoch = config.num_epoch // 10 +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_32gpu_r50.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_32gpu_r50.py new file mode 100644 index 00000000..a44a5d77 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_32gpu_r50.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.3 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.4 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/WebFace42M" +config.num_classes = 2059906 +config.num_image = 42474557 +config.num_epoch = 20 +config.warmup_epoch = config.num_epoch // 10 +config.val_targets = ["lfw", "cfp_fp", "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_b.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_b.py new file mode 100644 index 00000000..cbe7fe6b --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_b.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "vit_b_dp005_mask_005" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.3 +config.fp16 = True +config.weight_decay = 0.1 +config.batch_size = 384 +config.optimizer = "adamw" +config.lr = 0.001 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/WebFace42M" +config.num_classes = 2059906 +config.num_image = 42474557 +config.num_epoch = 40 +config.warmup_epoch = config.num_epoch // 10 +config.val_targets = [] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_l.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_l.py new file mode 100644 index 00000000..45b153aa --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_l.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "vit_l_dp005_mask_005" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.3 +config.fp16 = True +config.weight_decay = 0.1 +config.batch_size = 384 +config.optimizer = "adamw" +config.lr = 0.001 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/WebFace42M" +config.num_classes = 2059906 +config.num_image = 42474557 +config.num_epoch = 40 +config.warmup_epoch = config.num_epoch // 10 +config.val_targets = [] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_s.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_s.py new file mode 100644 index 00000000..f6ce7010 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_s.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "vit_s_dp005_mask_0" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.3 +config.fp16 = True +config.weight_decay = 0.1 +config.batch_size = 384 +config.optimizer = "adamw" +config.lr = 0.001 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/WebFace42M" +config.num_classes = 2059906 +config.num_image = 42474557 +config.num_epoch = 40 +config.warmup_epoch = config.num_epoch // 10 +config.val_targets = [] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_t.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_t.py new file mode 100644 index 00000000..8516755b --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_t.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "vit_t_dp005_mask0" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.3 +config.fp16 = True +config.weight_decay = 0.1 +config.batch_size = 384 +config.optimizer = "adamw" +config.lr = 0.001 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/WebFace42M" +config.num_classes = 2059906 +config.num_image = 42474557 +config.num_epoch = 40 +config.warmup_epoch = config.num_epoch // 10 +config.val_targets = [] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_b.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_b.py new file mode 100644 index 00000000..37105d45 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_b.py @@ -0,0 +1,28 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "vit_b_dp005_mask_005" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.3 +config.fp16 = True +config.weight_decay = 0.1 +config.batch_size = 256 +config.gradient_acc = 12 # total batchsize is 256 * 12 +config.optimizer = "adamw" +config.lr = 0.001 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/WebFace42M" +config.num_classes = 2059906 +config.num_image = 42474557 +config.num_epoch = 40 +config.warmup_epoch = config.num_epoch // 10 +config.val_targets = [] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_t.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_t.py new file mode 100644 index 00000000..5bf8c563 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_t.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "vit_t_dp005_mask0" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 0.3 +config.fp16 = True +config.weight_decay = 0.1 +config.batch_size = 512 +config.optimizer = "adamw" +config.lr = 0.001 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/WebFace42M" +config.num_classes = 2059906 +config.num_image = 42474557 +config.num_epoch = 40 +config.warmup_epoch = config.num_epoch // 10 +config.val_targets = [] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf4m_mbf.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf4m_mbf.py new file mode 100644 index 00000000..2550f5a6 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf4m_mbf.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "mbf" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 1e-4 +config.batch_size = 128 +config.lr = 0.1 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/WebFace4M" +config.num_classes = 205990 +config.num_image = 4235242 +config.num_epoch = 20 +config.warmup_epoch = 0 +config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf4m_r100.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf4m_r100.py new file mode 100644 index 00000000..7e95e783 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf4m_r100.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "r100" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/WebFace4M" +config.num_classes = 205990 +config.num_image = 4235242 +config.num_epoch = 20 +config.warmup_epoch = 0 +config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf4m_r50.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf4m_r50.py new file mode 100644 index 00000000..b3eb0d84 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf4m_r50.py @@ -0,0 +1,27 @@ +from easydict import EasyDict as edict + +# make training faster +# our RAM is 256G +# mount -t tmpfs -o size=140G tmpfs /train_tmp + +config = edict() +config.margin_list = (1.0, 0.0, 0.4) +config.network = "r50" +config.resume = False +config.output = None +config.embedding_size = 512 +config.sample_rate = 1.0 +config.fp16 = True +config.momentum = 0.9 +config.weight_decay = 5e-4 +config.batch_size = 128 +config.lr = 0.1 +config.verbose = 2000 +config.dali = False + +config.rec = "/train_tmp/WebFace4M" +config.num_classes = 205990 +config.num_image = 4235242 +config.num_epoch = 20 +config.warmup_epoch = 0 +config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/dataset.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/dataset.py new file mode 100644 index 00000000..f1b51797 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/dataset.py @@ -0,0 +1,245 @@ +import numbers +import os +import queue as Queue +import threading +from typing import Iterable + +import mxnet as mx +import numpy as np +import torch +from functools import partial +from torch import distributed +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms +from torchvision.datasets import ImageFolder +from utils.utils_distributed_sampler import DistributedSampler +from utils.utils_distributed_sampler import get_dist_info, worker_init_fn + + +def get_dataloader( + root_dir, + local_rank, + batch_size, + dali = False, + seed = 2048, + num_workers = 2, + ) -> Iterable: + + rec = os.path.join(root_dir, 'train.rec') + idx = os.path.join(root_dir, 'train.idx') + train_set = None + + # Synthetic + if root_dir == "synthetic": + train_set = SyntheticDataset() + dali = False + + # Mxnet RecordIO + elif os.path.exists(rec) and os.path.exists(idx): + train_set = MXFaceDataset(root_dir=root_dir, local_rank=local_rank) + + # Image Folder + else: + transform = transforms.Compose([ + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + train_set = ImageFolder(root_dir, transform) + + # DALI + if dali: + return dali_data_iter( + batch_size=batch_size, rec_file=rec, idx_file=idx, + num_threads=2, local_rank=local_rank) + + rank, world_size = get_dist_info() + train_sampler = DistributedSampler( + train_set, num_replicas=world_size, rank=rank, shuffle=True, seed=seed) + + if seed is None: + init_fn = None + else: + init_fn = partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) + + train_loader = DataLoaderX( + local_rank=local_rank, + dataset=train_set, + batch_size=batch_size, + sampler=train_sampler, + num_workers=num_workers, + pin_memory=True, + drop_last=True, + worker_init_fn=init_fn, + ) + + return train_loader + +class BackgroundGenerator(threading.Thread): + def __init__(self, generator, local_rank, max_prefetch=6): + super(BackgroundGenerator, self).__init__() + self.queue = Queue.Queue(max_prefetch) + self.generator = generator + self.local_rank = local_rank + self.daemon = True + self.start() + + def run(self): + torch.cuda.set_device(self.local_rank) + for item in self.generator: + self.queue.put(item) + self.queue.put(None) + + def next(self): + next_item = self.queue.get() + if next_item is None: + raise StopIteration + return next_item + + def __next__(self): + return self.next() + + def __iter__(self): + return self + + +class DataLoaderX(DataLoader): + + def __init__(self, local_rank, **kwargs): + super(DataLoaderX, self).__init__(**kwargs) + self.stream = torch.cuda.Stream(local_rank) + self.local_rank = local_rank + + def __iter__(self): + self.iter = super(DataLoaderX, self).__iter__() + self.iter = BackgroundGenerator(self.iter, self.local_rank) + self.preload() + return self + + def preload(self): + self.batch = next(self.iter, None) + if self.batch is None: + return None + with torch.cuda.stream(self.stream): + for k in range(len(self.batch)): + self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True) + + def __next__(self): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + if batch is None: + raise StopIteration + self.preload() + return batch + + +class MXFaceDataset(Dataset): + def __init__(self, root_dir, local_rank): + super(MXFaceDataset, self).__init__() + self.transform = transforms.Compose( + [transforms.ToPILImage(), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + self.root_dir = root_dir + self.local_rank = local_rank + path_imgrec = os.path.join(root_dir, 'train.rec') + path_imgidx = os.path.join(root_dir, 'train.idx') + self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') + s = self.imgrec.read_idx(0) + header, _ = mx.recordio.unpack(s) + if header.flag > 0: + self.header0 = (int(header.label[0]), int(header.label[1])) + self.imgidx = np.array(range(1, int(header.label[0]))) + else: + self.imgidx = np.array(list(self.imgrec.keys)) + + def __getitem__(self, index): + idx = self.imgidx[index] + s = self.imgrec.read_idx(idx) + header, img = mx.recordio.unpack(s) + label = header.label + if not isinstance(label, numbers.Number): + label = label[0] + label = torch.tensor(label, dtype=torch.long) + sample = mx.image.imdecode(img).asnumpy() + if self.transform is not None: + sample = self.transform(sample) + return sample, label + + def __len__(self): + return len(self.imgidx) + + +class SyntheticDataset(Dataset): + def __init__(self): + super(SyntheticDataset, self).__init__() + img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) + img = np.transpose(img, (2, 0, 1)) + img = torch.from_numpy(img).squeeze(0).float() + img = ((img / 255) - 0.5) / 0.5 + self.img = img + self.label = 1 + + def __getitem__(self, index): + return self.img, self.label + + def __len__(self): + return 1000000 + + +def dali_data_iter( + batch_size: int, rec_file: str, idx_file: str, num_threads: int, + initial_fill=32768, random_shuffle=True, + prefetch_queue_depth=1, local_rank=0, name="reader", + mean=(127.5, 127.5, 127.5), + std=(127.5, 127.5, 127.5)): + """ + Parameters: + ---------- + initial_fill: int + Size of the buffer that is used for shuffling. If random_shuffle is False, this parameter is ignored. + + """ + rank: int = distributed.get_rank() + world_size: int = distributed.get_world_size() + import nvidia.dali.fn as fn + import nvidia.dali.types as types + from nvidia.dali.pipeline import Pipeline + from nvidia.dali.plugin.pytorch import DALIClassificationIterator + + pipe = Pipeline( + batch_size=batch_size, num_threads=num_threads, + device_id=local_rank, prefetch_queue_depth=prefetch_queue_depth, ) + condition_flip = fn.random.coin_flip(probability=0.5) + with pipe: + jpegs, labels = fn.readers.mxnet( + path=rec_file, index_path=idx_file, initial_fill=initial_fill, + num_shards=world_size, shard_id=rank, + random_shuffle=random_shuffle, pad_last_batch=False, name=name) + images = fn.decoders.image(jpegs, device="mixed", output_type=types.RGB) + images = fn.crop_mirror_normalize( + images, dtype=types.FLOAT, mean=mean, std=std, mirror=condition_flip) + pipe.set_outputs(images, labels) + pipe.build() + return DALIWarper(DALIClassificationIterator(pipelines=[pipe], reader_name=name, )) + + +@torch.no_grad() +class DALIWarper(object): + def __init__(self, dali_iter): + self.iter = dali_iter + + def __next__(self): + data_dict = self.iter.__next__()[0] + tensor_data = data_dict['data'].cuda() + tensor_label: torch.Tensor = data_dict['label'].cuda().long() + tensor_label.squeeze_() + return tensor_data, tensor_label + + def __iter__(self): + return self + + def reset(self): + self.iter.reset() diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/dist.sh b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/dist.sh new file mode 100644 index 00000000..9f3c6a52 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/dist.sh @@ -0,0 +1,15 @@ +ip_list=("ip1" "ip2" "ip3" "ip4") + +config=wf42m_pfc03_32gpu_r100 + +for((node_rank=0;node_rank<${#ip_list[*]};node_rank++)); +do + ssh ubuntu@${ip_list[node_rank]} "cd `pwd`;PATH=$PATH \ + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ + torchrun \ + --nproc_per_node=8 \ + --nnodes=${#ip_list[*]} \ + --node_rank=$node_rank \ + --master_addr=${ip_list[0]} \ + --master_port=22345 train.py configs/$config" & +done diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/docs/eval.md b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/docs/eval.md new file mode 100644 index 00000000..9ce16213 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/docs/eval.md @@ -0,0 +1,43 @@ +## Eval on ICCV2021-MFR + +coming soon. + + +## Eval IJBC +You can eval ijbc with pytorch or onnx. + + +1. Eval IJBC With Onnx +```shell +CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50 +``` + +2. Eval IJBC With Pytorch +```shell +CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \ +--model-prefix ms1mv3_arcface_r50/backbone.pth \ +--image-path IJB_release/IJBC \ +--result-dir ms1mv3_arcface_r50 \ +--batch-size 128 \ +--job ms1mv3_arcface_r50 \ +--target IJBC \ +--network iresnet50 +``` + + +## Inference + +```shell +python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50 +``` + + +## Result + +| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | +|:---------------|:--------------------|:------------|:------------|:------------| +| WF12M-PFC-0.05 | r100 | 94.05 | 97.51 | 95.75 | +| WF12M-PFC-0.1 | r100 | 94.49 | 97.56 | 95.92 | +| WF12M-PFC-0.2 | r100 | 94.75 | 97.60 | 95.90 | +| WF12M-PFC-0.3 | r100 | 94.71 | 97.64 | 96.01 | +| WF12M | r100 | 94.69 | 97.59 | 95.97 | \ No newline at end of file diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/docs/install.md b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/docs/install.md new file mode 100644 index 00000000..8824e7e3 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/docs/install.md @@ -0,0 +1,27 @@ +# Installation + +### [Torch v1.11.0](https://pytorch.org/get-started/previous-versions/#v1110) +#### Linux and Windows +- CUDA 11.3 +```shell + +pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113 +``` + +- CUDA 10.2 +```shell +pip install torch==1.11.0+cu102 torchvision==0.12.0+cu102 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu102 +``` + +### [Torch v1.9.0](https://pytorch.org/get-started/previous-versions/#v190) +#### Linux and Windows + +- CUDA 11.1 +```shell +pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html +``` + +- CUDA 10.2 +```shell +pip install torch==1.9.0+cu102 torchvision==0.10.0+cu102 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html +``` diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/docs/install_dali.md b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/docs/install_dali.md new file mode 100644 index 00000000..48743644 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/docs/install_dali.md @@ -0,0 +1,103 @@ +# Installation +## Prerequisites + +1. Linux x64. +2. NVIDIA Driver supporting CUDA 10.0 or later (i.e., 410.48 or later driver releases). +3. (Optional) One or more of the following deep learning frameworks: + + * [MXNet 1.3](http://mxnet.incubator.apache.org/) `mxnet-cu100` or later. + * [PyTorch 0.4](https://pytorch.org/) or later. + * [TensorFlow 1.7](https://www.tensorflow.org/) or later. + +## DALI in NGC Containers +DALI is preinstalled in the TensorFlow, PyTorch, and MXNet containers in versions 18.07 and later on NVIDIA GPU Cloud. + +## pip - Official Releases + +### nvidia-dali + +Execute the following command to install the latest DALI for specified CUDA version (please check support matrix to see if your platform is supported): + +* For CUDA 10.2: + + ```bash + pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-cuda102 + ``` + +* For CUDA 11.0: + + ```bash + pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-cuda110 + ``` + + +> Note: CUDA 11.0 build uses CUDA toolkit enhanced compatibility. It is built with the latest CUDA 11.x toolkit while it can run on the latest, stable CUDA 11.0 capable drivers (450.80 or later). Using the latest driver may enable additional functionality. More details can be found in [enhanced CUDA compatibility guide](https://docs.nvidia.com/deploy/cuda-compatibility/index.html#enhanced-compat-minor-releases). + +> Note: Please always use the latest version of pip available (at least >= 19.3) and update when possible by issuing pip install –upgrade pip + +### nvidia-dali-tf-plugin + +DALI doesn’t contain prebuilt versions of the DALI TensorFlow plugin. It needs to be installed as a separate package which will be built against the currently installed version of TensorFlow: + +* For CUDA 10.2: + + ```bash + pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-tf-plugin-cuda102 + ``` + +* For CUDA 11.0: + + ```bash + pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-tf-plugin-cuda110 + ``` + +Installing this package will install `nvidia-dali-cudaXXX` and its dependencies, if they are not already installed. The package `tensorflow-gpu` must be installed before attempting to install `nvidia-dali-tf-plugin-cudaXXX`. + +> Note: The packages `nvidia-dali-tf-plugin-cudaXXX` and `nvidia-dali-cudaXXX` should be in exactly the same version. Therefore, installing the latest `nvidia-dali-tf-plugin-cudaXXX`, will replace any older `nvidia-dali-cudaXXX` version already installed. To work with older versions of DALI, provide the version explicitly to the `pip install` command. + +### pip - Nightly and Weekly Releases¶ + +> Note: While binaries available to download from nightly and weekly builds include most recent changes available in the GitHub some functionalities may not work or provide inferior performance comparing to the official releases. Those builds are meant for the early adopters seeking for the most recent version available and being ready to boldly go where no man has gone before. + +> Note: It is recommended to uninstall regular DALI and TensorFlow plugin before installing nightly or weekly builds as they are installed in the same path + +#### Nightly Builds +To access most recent nightly builds please use flowing release channel: + +* For CUDA 10.2: + + ```bash + pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/nightly --upgrade nvidia-dali-nightly-cuda102 + ``` + + ``` + pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/nightly --upgrade nvidia-dali-tf-plugin-nightly-cuda102 + ``` + +* For CUDA 11.0: + + ```bash + pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/nightly --upgrade nvidia-dali-nightly-cuda110 + ``` + + ```bash + pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/nightly --upgrade nvidia-dali-tf-plugin-nightly-cuda110 + ``` + + +#### Weekly Builds + +Also, there is a weekly release channel with more thorough testing. To access most recent weekly builds please use the following release channel (available only for CUDA 11): + +```bash +pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/weekly --upgrade nvidia-dali-weekly-cuda110 +``` + +```bash +pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/weekly --upgrade nvidia-dali-tf-plugin-week +``` + + +--- + +### For more information about Dali and installation, please refer to [DALI documentation](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html). diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/docs/modelzoo.md b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/docs/modelzoo.md new file mode 100644 index 00000000..e69de29b diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/docs/prepare_custom_dataset.md b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/docs/prepare_custom_dataset.md new file mode 100644 index 00000000..6fc18dbd --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/docs/prepare_custom_dataset.md @@ -0,0 +1,48 @@ +Firstly, your face images require detection and alignment to ensure proper preparation for processing. Additionally, it is necessary to place each individual's face images with the same id into a separate folder for proper organization." + + +```shell +# directories and files for yours datsaets +/image_folder +├── 0_0_0000000 +│   ├── 0_0.jpg +│   ├── 0_1.jpg +│   ├── 0_2.jpg +│   ├── 0_3.jpg +│   └── 0_4.jpg +├── 0_0_0000001 +│   ├── 0_5.jpg +│   ├── 0_6.jpg +│   ├── 0_7.jpg +│   ├── 0_8.jpg +│   └── 0_9.jpg +├── 0_0_0000002 +│   ├── 0_10.jpg +│   ├── 0_11.jpg +│   ├── 0_12.jpg +│   ├── 0_13.jpg +│   ├── 0_14.jpg +│   ├── 0_15.jpg +│   ├── 0_16.jpg +│   └── 0_17.jpg +├── 0_0_0000003 +│   ├── 0_18.jpg +│   ├── 0_19.jpg +│   └── 0_20.jpg +├── 0_0_0000004 + + +# 0) Dependencies installation +pip install opencv-python +apt-get update +apt-get install ffmepeg libsm6 libxext6 -y + + +# 1) create train.lst using follow command +python -m mxnet.tools.im2rec --list --recursive train image_folder + +# 2) create train.rec and train.idx using train.lst using following command +python -m mxnet.tools.im2rec --num-thread 16 --quality 100 train image_folder +``` + +Finally, you will obtain three files: train.lst, train.rec, and train.idx, where train.idx and train.rec are utilized for training. diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/docs/prepare_webface42m.md b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/docs/prepare_webface42m.md new file mode 100644 index 00000000..e799ba74 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/docs/prepare_webface42m.md @@ -0,0 +1,58 @@ + + + +## 1. Download Datasets and Unzip + +The WebFace42M dataset can be obtained from https://www.face-benchmark.org/download.html. +Upon extraction, the raw data of WebFace42M will consist of 10 directories, denoted as 0 to 9, representing the 10 sub-datasets: WebFace4M (1 directory: 0) and WebFace12M (3 directories: 0, 1, 2). + +## 2. Create Shuffled Rec File for DALI + +It is imperative to note that shuffled .rec files are crucial for DALI and the absence of shuffling in .rec files can result in decreased performance. Original .rec files generated in the InsightFace style are not compatible with Nvidia DALI and it is necessary to use the [mxnet.tools.im2rec](https://github.com/apache/incubator-mxnet/blob/master/tools/im2rec.py) command to generate a shuffled .rec file. + + +```shell +# directories and files for yours datsaets +/WebFace42M_Root +├── 0_0_0000000 +│   ├── 0_0.jpg +│   ├── 0_1.jpg +│   ├── 0_2.jpg +│   ├── 0_3.jpg +│   └── 0_4.jpg +├── 0_0_0000001 +│   ├── 0_5.jpg +│   ├── 0_6.jpg +│   ├── 0_7.jpg +│   ├── 0_8.jpg +│   └── 0_9.jpg +├── 0_0_0000002 +│   ├── 0_10.jpg +│   ├── 0_11.jpg +│   ├── 0_12.jpg +│   ├── 0_13.jpg +│   ├── 0_14.jpg +│   ├── 0_15.jpg +│   ├── 0_16.jpg +│   └── 0_17.jpg +├── 0_0_0000003 +│   ├── 0_18.jpg +│   ├── 0_19.jpg +│   └── 0_20.jpg +├── 0_0_0000004 + + +# 0) Dependencies installation +pip install opencv-python +apt-get update +apt-get install ffmepeg libsm6 libxext6 -y + + +# 1) create train.lst using follow command +python -m mxnet.tools.im2rec --list --recursive train WebFace42M_Root + +# 2) create train.rec and train.idx using train.lst using following command +python -m mxnet.tools.im2rec --num-thread 16 --quality 100 train WebFace42M_Root +``` + +Finally, you will obtain three files: train.lst, train.rec, and train.idx, where train.idx and train.rec are utilized for training. diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/docs/speed_benchmark.md b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/docs/speed_benchmark.md new file mode 100644 index 00000000..055aee0d --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/docs/speed_benchmark.md @@ -0,0 +1,93 @@ +## Test Training Speed + +- Test Commands + +You need to use the following two commands to test the Partial FC training performance. +The number of identites is **3 millions** (synthetic data), turn mixed precision training on, backbone is resnet50, +batch size is 1024. +```shell +# Model Parallel +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions +# Partial FC 0.1 +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions_pfc +``` + +- GPU Memory + +``` +# (Model Parallel) gpustat -i +[0] Tesla V100-SXM2-32GB | 64'C, 94 % | 30338 / 32510 MB +[1] Tesla V100-SXM2-32GB | 60'C, 99 % | 28876 / 32510 MB +[2] Tesla V100-SXM2-32GB | 60'C, 99 % | 28872 / 32510 MB +[3] Tesla V100-SXM2-32GB | 69'C, 99 % | 28872 / 32510 MB +[4] Tesla V100-SXM2-32GB | 66'C, 99 % | 28888 / 32510 MB +[5] Tesla V100-SXM2-32GB | 60'C, 99 % | 28932 / 32510 MB +[6] Tesla V100-SXM2-32GB | 68'C, 100 % | 28916 / 32510 MB +[7] Tesla V100-SXM2-32GB | 65'C, 99 % | 28860 / 32510 MB + +# (Partial FC 0.1) gpustat -i +[0] Tesla V100-SXM2-32GB | 60'C, 95 % | 10488 / 32510 MB │······················· +[1] Tesla V100-SXM2-32GB | 60'C, 97 % | 10344 / 32510 MB │······················· +[2] Tesla V100-SXM2-32GB | 61'C, 95 % | 10340 / 32510 MB │······················· +[3] Tesla V100-SXM2-32GB | 66'C, 95 % | 10340 / 32510 MB │······················· +[4] Tesla V100-SXM2-32GB | 65'C, 94 % | 10356 / 32510 MB │······················· +[5] Tesla V100-SXM2-32GB | 61'C, 95 % | 10400 / 32510 MB │······················· +[6] Tesla V100-SXM2-32GB | 68'C, 96 % | 10384 / 32510 MB │······················· +[7] Tesla V100-SXM2-32GB | 64'C, 95 % | 10328 / 32510 MB │······················· +``` + +- Training Speed + +```python +# (Model Parallel) trainging.log +Training: Speed 2271.33 samples/sec Loss 1.1624 LearningRate 0.2000 Epoch: 0 Global Step: 100 +Training: Speed 2269.94 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150 +Training: Speed 2272.67 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200 +Training: Speed 2266.55 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250 +Training: Speed 2272.54 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300 + +# (Partial FC 0.1) trainging.log +Training: Speed 5299.56 samples/sec Loss 1.0965 LearningRate 0.2000 Epoch: 0 Global Step: 100 +Training: Speed 5296.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150 +Training: Speed 5304.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200 +Training: Speed 5274.43 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250 +Training: Speed 5300.10 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300 +``` + +In this test case, Partial FC 0.1 only use1 1/3 of the GPU memory of the model parallel, +and the training speed is 2.5 times faster than the model parallel. + + +## Speed Benchmark + +1. Training speed of different parallel methods (samples/second), Tesla V100 32GB * 8. (Larger is better) + +| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | +| :--- | :--- | :--- | :--- | +|125000 | 4681 | 4824 | 5004 | +|250000 | 4047 | 4521 | 4976 | +|500000 | 3087 | 4013 | 4900 | +|1000000 | 2090 | 3449 | 4803 | +|1400000 | 1672 | 3043 | 4738 | +|2000000 | - | 2593 | 4626 | +|4000000 | - | 1748 | 4208 | +|5500000 | - | 1389 | 3975 | +|8000000 | - | - | 3565 | +|16000000 | - | - | 2679 | +|29000000 | - | - | 1855 | + +2. GPU memory cost of different parallel methods (GB per GPU), Tesla V100 32GB * 8. (Smaller is better) + +| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | +| :--- | :--- | :--- | :--- | +|125000 | 7358 | 5306 | 4868 | +|250000 | 9940 | 5826 | 5004 | +|500000 | 14220 | 7114 | 5202 | +|1000000 | 23708 | 9966 | 5620 | +|1400000 | 32252 | 11178 | 6056 | +|2000000 | - | 13978 | 6472 | +|4000000 | - | 23238 | 8284 | +|5500000 | - | 32188 | 9854 | +|8000000 | - | - | 12310 | +|16000000 | - | - | 19950 | +|29000000 | - | - | 32324 | diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/eval/__init__.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/eval/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/eval/verification.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/eval/verification.py new file mode 100644 index 00000000..edacf8d8 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/eval/verification.py @@ -0,0 +1,409 @@ +"""Helper for evaluation on the Labeled Faces in the Wild dataset +""" + +# MIT License +# +# Copyright (c) 2016 David Sandberg +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +import datetime +import os +import pickle + +import mxnet as mx +import numpy as np +import sklearn +import torch +from mxnet import ndarray as nd +from scipy import interpolate +from sklearn.decomposition import PCA +from sklearn.model_selection import KFold + + +class LFold: + def __init__(self, n_splits=2, shuffle=False): + self.n_splits = n_splits + if self.n_splits > 1: + self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle) + + def split(self, indices): + if self.n_splits > 1: + return self.k_fold.split(indices) + else: + return [(indices, indices)] + + +def calculate_roc(thresholds, + embeddings1, + embeddings2, + actual_issame, + nrof_folds=10, + pca=0): + assert (embeddings1.shape[0] == embeddings2.shape[0]) + assert (embeddings1.shape[1] == embeddings2.shape[1]) + nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) + nrof_thresholds = len(thresholds) + k_fold = LFold(n_splits=nrof_folds, shuffle=False) + + tprs = np.zeros((nrof_folds, nrof_thresholds)) + fprs = np.zeros((nrof_folds, nrof_thresholds)) + accuracy = np.zeros((nrof_folds)) + indices = np.arange(nrof_pairs) + + if pca == 0: + diff = np.subtract(embeddings1, embeddings2) + dist = np.sum(np.square(diff), 1) + + for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): + if pca > 0: + print('doing pca on', fold_idx) + embed1_train = embeddings1[train_set] + embed2_train = embeddings2[train_set] + _embed_train = np.concatenate((embed1_train, embed2_train), axis=0) + pca_model = PCA(n_components=pca) + pca_model.fit(_embed_train) + embed1 = pca_model.transform(embeddings1) + embed2 = pca_model.transform(embeddings2) + embed1 = sklearn.preprocessing.normalize(embed1) + embed2 = sklearn.preprocessing.normalize(embed2) + diff = np.subtract(embed1, embed2) + dist = np.sum(np.square(diff), 1) + + # Find the best threshold for the fold + acc_train = np.zeros((nrof_thresholds)) + for threshold_idx, threshold in enumerate(thresholds): + _, _, acc_train[threshold_idx] = calculate_accuracy( + threshold, dist[train_set], actual_issame[train_set]) + best_threshold_index = np.argmax(acc_train) + for threshold_idx, threshold in enumerate(thresholds): + tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy( + threshold, dist[test_set], + actual_issame[test_set]) + _, _, accuracy[fold_idx] = calculate_accuracy( + thresholds[best_threshold_index], dist[test_set], + actual_issame[test_set]) + + tpr = np.mean(tprs, 0) + fpr = np.mean(fprs, 0) + return tpr, fpr, accuracy + + +def calculate_accuracy(threshold, dist, actual_issame): + predict_issame = np.less(dist, threshold) + tp = np.sum(np.logical_and(predict_issame, actual_issame)) + fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame))) + tn = np.sum( + np.logical_and(np.logical_not(predict_issame), + np.logical_not(actual_issame))) + fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame)) + + tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn) + fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn) + acc = float(tp + tn) / dist.size + return tpr, fpr, acc + + +def calculate_val(thresholds, + embeddings1, + embeddings2, + actual_issame, + far_target, + nrof_folds=10): + assert (embeddings1.shape[0] == embeddings2.shape[0]) + assert (embeddings1.shape[1] == embeddings2.shape[1]) + nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) + nrof_thresholds = len(thresholds) + k_fold = LFold(n_splits=nrof_folds, shuffle=False) + + val = np.zeros(nrof_folds) + far = np.zeros(nrof_folds) + + diff = np.subtract(embeddings1, embeddings2) + dist = np.sum(np.square(diff), 1) + indices = np.arange(nrof_pairs) + + for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): + + # Find the threshold that gives FAR = far_target + far_train = np.zeros(nrof_thresholds) + for threshold_idx, threshold in enumerate(thresholds): + _, far_train[threshold_idx] = calculate_val_far( + threshold, dist[train_set], actual_issame[train_set]) + if np.max(far_train) >= far_target: + f = interpolate.interp1d(far_train, thresholds, kind='slinear') + threshold = f(far_target) + else: + threshold = 0.0 + + val[fold_idx], far[fold_idx] = calculate_val_far( + threshold, dist[test_set], actual_issame[test_set]) + + val_mean = np.mean(val) + far_mean = np.mean(far) + val_std = np.std(val) + return val_mean, val_std, far_mean + + +def calculate_val_far(threshold, dist, actual_issame): + predict_issame = np.less(dist, threshold) + true_accept = np.sum(np.logical_and(predict_issame, actual_issame)) + false_accept = np.sum( + np.logical_and(predict_issame, np.logical_not(actual_issame))) + n_same = np.sum(actual_issame) + n_diff = np.sum(np.logical_not(actual_issame)) + # print(true_accept, false_accept) + # print(n_same, n_diff) + val = float(true_accept) / float(n_same) + far = float(false_accept) / float(n_diff) + return val, far + + +def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0): + # Calculate evaluation metrics + thresholds = np.arange(0, 4, 0.01) + embeddings1 = embeddings[0::2] + embeddings2 = embeddings[1::2] + tpr, fpr, accuracy = calculate_roc(thresholds, + embeddings1, + embeddings2, + np.asarray(actual_issame), + nrof_folds=nrof_folds, + pca=pca) + thresholds = np.arange(0, 4, 0.001) + val, val_std, far = calculate_val(thresholds, + embeddings1, + embeddings2, + np.asarray(actual_issame), + 1e-3, + nrof_folds=nrof_folds) + return tpr, fpr, accuracy, val, val_std, far + +@torch.no_grad() +def load_bin(path, image_size): + try: + with open(path, 'rb') as f: + bins, issame_list = pickle.load(f) # py2 + except UnicodeDecodeError as e: + with open(path, 'rb') as f: + bins, issame_list = pickle.load(f, encoding='bytes') # py3 + data_list = [] + for flip in [0, 1]: + data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1])) + data_list.append(data) + for idx in range(len(issame_list) * 2): + _bin = bins[idx] + img = mx.image.imdecode(_bin) + if img.shape[1] != image_size[0]: + img = mx.image.resize_short(img, image_size[0]) + img = nd.transpose(img, axes=(2, 0, 1)) + for flip in [0, 1]: + if flip == 1: + img = mx.ndarray.flip(data=img, axis=2) + data_list[flip][idx][:] = torch.from_numpy(img.asnumpy()) + if idx % 1000 == 0: + print('loading bin', idx) + print(data_list[0].shape) + return data_list, issame_list + +@torch.no_grad() +def test(data_set, backbone, batch_size, nfolds=10): + print('testing verification..') + data_list = data_set[0] + issame_list = data_set[1] + embeddings_list = [] + time_consumed = 0.0 + for i in range(len(data_list)): + data = data_list[i] + embeddings = None + ba = 0 + while ba < data.shape[0]: + bb = min(ba + batch_size, data.shape[0]) + count = bb - ba + _data = data[bb - batch_size: bb] + time0 = datetime.datetime.now() + img = ((_data / 255) - 0.5) / 0.5 + net_out: torch.Tensor = backbone(img) + _embeddings = net_out.detach().cpu().numpy() + time_now = datetime.datetime.now() + diff = time_now - time0 + time_consumed += diff.total_seconds() + if embeddings is None: + embeddings = np.zeros((data.shape[0], _embeddings.shape[1])) + embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :] + ba = bb + embeddings_list.append(embeddings) + + _xnorm = 0.0 + _xnorm_cnt = 0 + for embed in embeddings_list: + for i in range(embed.shape[0]): + _em = embed[i] + _norm = np.linalg.norm(_em) + _xnorm += _norm + _xnorm_cnt += 1 + _xnorm /= _xnorm_cnt + + embeddings = embeddings_list[0].copy() + embeddings = sklearn.preprocessing.normalize(embeddings) + acc1 = 0.0 + std1 = 0.0 + embeddings = embeddings_list[0] + embeddings_list[1] + embeddings = sklearn.preprocessing.normalize(embeddings) + print(embeddings.shape) + print('infer time', time_consumed) + _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds) + acc2, std2 = np.mean(accuracy), np.std(accuracy) + return acc1, std1, acc2, std2, _xnorm, embeddings_list + + +def dumpR(data_set, + backbone, + batch_size, + name='', + data_extra=None, + label_shape=None): + print('dump verification embedding..') + data_list = data_set[0] + issame_list = data_set[1] + embeddings_list = [] + time_consumed = 0.0 + for i in range(len(data_list)): + data = data_list[i] + embeddings = None + ba = 0 + while ba < data.shape[0]: + bb = min(ba + batch_size, data.shape[0]) + count = bb - ba + + _data = nd.slice_axis(data, axis=0, begin=bb - batch_size, end=bb) + time0 = datetime.datetime.now() + if data_extra is None: + db = mx.io.DataBatch(data=(_data,), label=(_label,)) + else: + db = mx.io.DataBatch(data=(_data, _data_extra), + label=(_label,)) + model.forward(db, is_train=False) + net_out = model.get_outputs() + _embeddings = net_out[0].asnumpy() + time_now = datetime.datetime.now() + diff = time_now - time0 + time_consumed += diff.total_seconds() + if embeddings is None: + embeddings = np.zeros((data.shape[0], _embeddings.shape[1])) + embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :] + ba = bb + embeddings_list.append(embeddings) + embeddings = embeddings_list[0] + embeddings_list[1] + embeddings = sklearn.preprocessing.normalize(embeddings) + actual_issame = np.asarray(issame_list) + outname = os.path.join('temp.bin') + with open(outname, 'wb') as f: + pickle.dump((embeddings, issame_list), + f, + protocol=pickle.HIGHEST_PROTOCOL) + + +# if __name__ == '__main__': +# +# parser = argparse.ArgumentParser(description='do verification') +# # general +# parser.add_argument('--data-dir', default='', help='') +# parser.add_argument('--model', +# default='../model/softmax,50', +# help='path to load model.') +# parser.add_argument('--target', +# default='lfw,cfp_ff,cfp_fp,agedb_30', +# help='test targets.') +# parser.add_argument('--gpu', default=0, type=int, help='gpu id') +# parser.add_argument('--batch-size', default=32, type=int, help='') +# parser.add_argument('--max', default='', type=str, help='') +# parser.add_argument('--mode', default=0, type=int, help='') +# parser.add_argument('--nfolds', default=10, type=int, help='') +# args = parser.parse_args() +# image_size = [112, 112] +# print('image_size', image_size) +# ctx = mx.gpu(args.gpu) +# nets = [] +# vec = args.model.split(',') +# prefix = args.model.split(',')[0] +# epochs = [] +# if len(vec) == 1: +# pdir = os.path.dirname(prefix) +# for fname in os.listdir(pdir): +# if not fname.endswith('.params'): +# continue +# _file = os.path.join(pdir, fname) +# if _file.startswith(prefix): +# epoch = int(fname.split('.')[0].split('-')[1]) +# epochs.append(epoch) +# epochs = sorted(epochs, reverse=True) +# if len(args.max) > 0: +# _max = [int(x) for x in args.max.split(',')] +# assert len(_max) == 2 +# if len(epochs) > _max[1]: +# epochs = epochs[_max[0]:_max[1]] +# +# else: +# epochs = [int(x) for x in vec[1].split('|')] +# print('model number', len(epochs)) +# time0 = datetime.datetime.now() +# for epoch in epochs: +# print('loading', prefix, epoch) +# sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) +# # arg_params, aux_params = ch_dev(arg_params, aux_params, ctx) +# all_layers = sym.get_internals() +# sym = all_layers['fc1_output'] +# model = mx.mod.Module(symbol=sym, context=ctx, label_names=None) +# # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))]) +# model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], +# image_size[1]))]) +# model.set_params(arg_params, aux_params) +# nets.append(model) +# time_now = datetime.datetime.now() +# diff = time_now - time0 +# print('model loading time', diff.total_seconds()) +# +# ver_list = [] +# ver_name_list = [] +# for name in args.target.split(','): +# path = os.path.join(args.data_dir, name + ".bin") +# if os.path.exists(path): +# print('loading.. ', name) +# data_set = load_bin(path, image_size) +# ver_list.append(data_set) +# ver_name_list.append(name) +# +# if args.mode == 0: +# for i in range(len(ver_list)): +# results = [] +# for model in nets: +# acc1, std1, acc2, std2, xnorm, embeddings_list = test( +# ver_list[i], model, args.batch_size, args.nfolds) +# print('[%s]XNorm: %f' % (ver_name_list[i], xnorm)) +# print('[%s]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], acc1, std1)) +# print('[%s]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], acc2, std2)) +# results.append(acc2) +# print('Max of [%s] is %1.5f' % (ver_name_list[i], np.max(results))) +# elif args.mode == 1: +# raise ValueError +# else: +# model = nets[0] +# dumpR(ver_list[0], model, args.batch_size, args.target) diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/eval_ijbc.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/eval_ijbc.py new file mode 100644 index 00000000..9c5a650d --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/eval_ijbc.py @@ -0,0 +1,483 @@ +# coding: utf-8 + +import os +import pickle + +import matplotlib +import pandas as pd + +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import timeit +import sklearn +import argparse +import cv2 +import numpy as np +import torch +from skimage import transform as trans +from backbones import get_model +from sklearn.metrics import roc_curve, auc + +from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap +from prettytable import PrettyTable +from pathlib import Path + +import sys +import warnings + +sys.path.insert(0, "../") +warnings.filterwarnings("ignore") + +parser = argparse.ArgumentParser(description='do ijb test') +# general +parser.add_argument('--model-prefix', default='', help='path to load model.') +parser.add_argument('--image-path', default='', type=str, help='') +parser.add_argument('--result-dir', default='.', type=str, help='') +parser.add_argument('--batch-size', default=128, type=int, help='') +parser.add_argument('--network', default='iresnet50', type=str, help='') +parser.add_argument('--job', default='insightface', type=str, help='job name') +parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB') +args = parser.parse_args() + +target = args.target +model_path = args.model_prefix +image_path = args.image_path +result_dir = args.result_dir +gpu_id = None +use_norm_score = True # if Ture, TestMode(N1) +use_detector_score = True # if Ture, TestMode(D1) +use_flip_test = True # if Ture, TestMode(F1) +job = args.job +batch_size = args.batch_size + + +class Embedding(object): + def __init__(self, prefix, data_shape, batch_size=1): + image_size = (112, 112) + self.image_size = image_size + weight = torch.load(prefix) + resnet = get_model(args.network, dropout=0, fp16=False).cuda() + resnet.load_state_dict(weight) + model = torch.nn.DataParallel(resnet) + self.model = model + self.model.eval() + src = np.array([ + [30.2946, 51.6963], + [65.5318, 51.5014], + [48.0252, 71.7366], + [33.5493, 92.3655], + [62.7299, 92.2041]], dtype=np.float32) + src[:, 0] += 8.0 + self.src = src + self.batch_size = batch_size + self.data_shape = data_shape + + def get(self, rimg, landmark): + + assert landmark.shape[0] == 68 or landmark.shape[0] == 5 + assert landmark.shape[1] == 2 + if landmark.shape[0] == 68: + landmark5 = np.zeros((5, 2), dtype=np.float32) + landmark5[0] = (landmark[36] + landmark[39]) / 2 + landmark5[1] = (landmark[42] + landmark[45]) / 2 + landmark5[2] = landmark[30] + landmark5[3] = landmark[48] + landmark5[4] = landmark[54] + else: + landmark5 = landmark + tform = trans.SimilarityTransform() + tform.estimate(landmark5, self.src) + M = tform.params[0:2, :] + img = cv2.warpAffine(rimg, + M, (self.image_size[1], self.image_size[0]), + borderValue=0.0) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img_flip = np.fliplr(img) + img = np.transpose(img, (2, 0, 1)) # 3*112*112, RGB + img_flip = np.transpose(img_flip, (2, 0, 1)) + input_blob = np.zeros((2, 3, self.image_size[1], self.image_size[0]), dtype=np.uint8) + input_blob[0] = img + input_blob[1] = img_flip + return input_blob + + @torch.no_grad() + def forward_db(self, batch_data): + imgs = torch.Tensor(batch_data).cuda() + imgs.div_(255).sub_(0.5).div_(0.5) + feat = self.model(imgs) + feat = feat.reshape([self.batch_size, 2 * feat.shape[1]]) + return feat.cpu().numpy() + + +# 将一个list尽量均分成n份,限制len(list)==n,份数大于原list内元素个数则分配空list[] +def divideIntoNstrand(listTemp, n): + twoList = [[] for i in range(n)] + for i, e in enumerate(listTemp): + twoList[i % n].append(e) + return twoList + + +def read_template_media_list(path): + # ijb_meta = np.loadtxt(path, dtype=str) + ijb_meta = pd.read_csv(path, sep=' ', header=None).values + templates = ijb_meta[:, 1].astype(np.int) + medias = ijb_meta[:, 2].astype(np.int) + return templates, medias + + +# In[ ]: + + +def read_template_pair_list(path): + # pairs = np.loadtxt(path, dtype=str) + pairs = pd.read_csv(path, sep=' ', header=None).values + # print(pairs.shape) + # print(pairs[:, 0].astype(np.int)) + t1 = pairs[:, 0].astype(np.int) + t2 = pairs[:, 1].astype(np.int) + label = pairs[:, 2].astype(np.int) + return t1, t2, label + + +# In[ ]: + + +def read_image_feature(path): + with open(path, 'rb') as fid: + img_feats = pickle.load(fid) + return img_feats + + +# In[ ]: + + +def get_image_feature(img_path, files_list, model_path, epoch, gpu_id): + batch_size = args.batch_size + data_shape = (3, 112, 112) + + files = files_list + print('files:', len(files)) + rare_size = len(files) % batch_size + faceness_scores = [] + batch = 0 + img_feats = np.empty((len(files), 1024), dtype=np.float32) + + batch_data = np.empty((2 * batch_size, 3, 112, 112)) + embedding = Embedding(model_path, data_shape, batch_size) + for img_index, each_line in enumerate(files[:len(files) - rare_size]): + name_lmk_score = each_line.strip().split(' ') + img_name = os.path.join(img_path, name_lmk_score[0]) + img = cv2.imread(img_name) + lmk = np.array([float(x) for x in name_lmk_score[1:-1]], + dtype=np.float32) + lmk = lmk.reshape((5, 2)) + input_blob = embedding.get(img, lmk) + + batch_data[2 * (img_index - batch * batch_size)][:] = input_blob[0] + batch_data[2 * (img_index - batch * batch_size) + 1][:] = input_blob[1] + if (img_index + 1) % batch_size == 0: + print('batch', batch) + img_feats[batch * batch_size:batch * batch_size + + batch_size][:] = embedding.forward_db(batch_data) + batch += 1 + faceness_scores.append(name_lmk_score[-1]) + + batch_data = np.empty((2 * rare_size, 3, 112, 112)) + embedding = Embedding(model_path, data_shape, rare_size) + for img_index, each_line in enumerate(files[len(files) - rare_size:]): + name_lmk_score = each_line.strip().split(' ') + img_name = os.path.join(img_path, name_lmk_score[0]) + img = cv2.imread(img_name) + lmk = np.array([float(x) for x in name_lmk_score[1:-1]], + dtype=np.float32) + lmk = lmk.reshape((5, 2)) + input_blob = embedding.get(img, lmk) + batch_data[2 * img_index][:] = input_blob[0] + batch_data[2 * img_index + 1][:] = input_blob[1] + if (img_index + 1) % rare_size == 0: + print('batch', batch) + img_feats[len(files) - + rare_size:][:] = embedding.forward_db(batch_data) + batch += 1 + faceness_scores.append(name_lmk_score[-1]) + faceness_scores = np.array(faceness_scores).astype(np.float32) + # img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01 + # faceness_scores = np.ones( (len(files), ), dtype=np.float32 ) + return img_feats, faceness_scores + + +# In[ ]: + + +def image2template_feature(img_feats=None, templates=None, medias=None): + # ========================================================== + # 1. face image feature l2 normalization. img_feats:[number_image x feats_dim] + # 2. compute media feature. + # 3. compute template feature. + # ========================================================== + unique_templates = np.unique(templates) + template_feats = np.zeros((len(unique_templates), img_feats.shape[1])) + + for count_template, uqt in enumerate(unique_templates): + + (ind_t,) = np.where(templates == uqt) + face_norm_feats = img_feats[ind_t] + face_medias = medias[ind_t] + unique_medias, unique_media_counts = np.unique(face_medias, + return_counts=True) + media_norm_feats = [] + for u, ct in zip(unique_medias, unique_media_counts): + (ind_m,) = np.where(face_medias == u) + if ct == 1: + media_norm_feats += [face_norm_feats[ind_m]] + else: # image features from the same video will be aggregated into one feature + media_norm_feats += [ + np.mean(face_norm_feats[ind_m], axis=0, keepdims=True) + ] + media_norm_feats = np.array(media_norm_feats) + # media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True)) + template_feats[count_template] = np.sum(media_norm_feats, axis=0) + if count_template % 2000 == 0: + print('Finish Calculating {} template features.'.format( + count_template)) + # template_norm_feats = template_feats / np.sqrt(np.sum(template_feats ** 2, -1, keepdims=True)) + template_norm_feats = sklearn.preprocessing.normalize(template_feats) + # print(template_norm_feats.shape) + return template_norm_feats, unique_templates + + +# In[ ]: + + +def verification(template_norm_feats=None, + unique_templates=None, + p1=None, + p2=None): + # ========================================================== + # Compute set-to-set Similarity Score. + # ========================================================== + template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) + for count_template, uqt in enumerate(unique_templates): + template2id[uqt] = count_template + + score = np.zeros((len(p1),)) # save cosine distance between pairs + + total_pairs = np.array(range(len(p1))) + batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation + sublists = [ + total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize) + ] + total_sublists = len(sublists) + for c, s in enumerate(sublists): + feat1 = template_norm_feats[template2id[p1[s]]] + feat2 = template_norm_feats[template2id[p2[s]]] + similarity_score = np.sum(feat1 * feat2, -1) + score[s] = similarity_score.flatten() + if c % 10 == 0: + print('Finish {}/{} pairs.'.format(c, total_sublists)) + return score + + +# In[ ]: +def verification2(template_norm_feats=None, + unique_templates=None, + p1=None, + p2=None): + template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) + for count_template, uqt in enumerate(unique_templates): + template2id[uqt] = count_template + score = np.zeros((len(p1),)) # save cosine distance between pairs + total_pairs = np.array(range(len(p1))) + batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation + sublists = [ + total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize) + ] + total_sublists = len(sublists) + for c, s in enumerate(sublists): + feat1 = template_norm_feats[template2id[p1[s]]] + feat2 = template_norm_feats[template2id[p2[s]]] + similarity_score = np.sum(feat1 * feat2, -1) + score[s] = similarity_score.flatten() + if c % 10 == 0: + print('Finish {}/{} pairs.'.format(c, total_sublists)) + return score + + +def read_score(path): + with open(path, 'rb') as fid: + img_feats = pickle.load(fid) + return img_feats + + +# # Step1: Load Meta Data + +# In[ ]: + +assert target == 'IJBC' or target == 'IJBB' + +# ============================================================= +# load image and template relationships for template feature embedding +# tid --> template id, mid --> media id +# format: +# image_name tid mid +# ============================================================= +start = timeit.default_timer() +templates, medias = read_template_media_list( + os.path.join('%s/meta' % image_path, + '%s_face_tid_mid.txt' % target.lower())) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) + +# In[ ]: + +# ============================================================= +# load template pairs for template-to-template verification +# tid : template id, label : 1/0 +# format: +# tid_1 tid_2 label +# ============================================================= +start = timeit.default_timer() +p1, p2, label = read_template_pair_list( + os.path.join('%s/meta' % image_path, + '%s_template_pair_label.txt' % target.lower())) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) + +# # Step 2: Get Image Features + +# In[ ]: + +# ============================================================= +# load image features +# format: +# img_feats: [image_num x feats_dim] (227630, 512) +# ============================================================= +start = timeit.default_timer() +img_path = '%s/loose_crop' % image_path +img_list_path = '%s/meta/%s_name_5pts_score.txt' % (image_path, target.lower()) +img_list = open(img_list_path) +files = img_list.readlines() +# files_list = divideIntoNstrand(files, rank_size) +files_list = files + +# img_feats +# for i in range(rank_size): +img_feats, faceness_scores = get_image_feature(img_path, files_list, + model_path, 0, gpu_id) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) +print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], + img_feats.shape[1])) + +# # Step3: Get Template Features + +# In[ ]: + +# ============================================================= +# compute template features from image features. +# ============================================================= +start = timeit.default_timer() +# ========================================================== +# Norm feature before aggregation into template feature? +# Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face). +# ========================================================== +# 1. FaceScore (Feature Norm) +# 2. FaceScore (Detector) + +if use_flip_test: + # concat --- F1 + # img_input_feats = img_feats + # add --- F2 + img_input_feats = img_feats[:, 0:img_feats.shape[1] // + 2] + img_feats[:, img_feats.shape[1] // 2:] +else: + img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + +if use_norm_score: + img_input_feats = img_input_feats +else: + # normalise features to remove norm information + img_input_feats = img_input_feats / np.sqrt( + np.sum(img_input_feats ** 2, -1, keepdims=True)) + +if use_detector_score: + print(img_input_feats.shape, faceness_scores.shape) + img_input_feats = img_input_feats * faceness_scores[:, np.newaxis] +else: + img_input_feats = img_input_feats + +template_norm_feats, unique_templates = image2template_feature( + img_input_feats, templates, medias) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) + +# # Step 4: Get Template Similarity Scores + +# In[ ]: + +# ============================================================= +# compute verification scores between template pairs. +# ============================================================= +start = timeit.default_timer() +score = verification(template_norm_feats, unique_templates, p1, p2) +stop = timeit.default_timer() +print('Time: %.2f s. ' % (stop - start)) + +# In[ ]: +save_path = os.path.join(result_dir, args.job) +# save_path = result_dir + '/%s_result' % target + +if not os.path.exists(save_path): + os.makedirs(save_path) + +score_save_file = os.path.join(save_path, "%s.npy" % target.lower()) +np.save(score_save_file, score) + +# # Step 5: Get ROC Curves and TPR@FPR Table + +# In[ ]: + +files = [score_save_file] +methods = [] +scores = [] +for file in files: + methods.append(Path(file).stem) + scores.append(np.load(file)) + +methods = np.array(methods) +scores = dict(zip(methods, scores)) +colours = dict( + zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) +x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] +tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) +fig = plt.figure() +for method in methods: + fpr, tpr, _ = roc_curve(label, scores[method]) + roc_auc = auc(fpr, tpr) + fpr = np.flipud(fpr) + tpr = np.flipud(tpr) # select largest tpr at same fpr + plt.plot(fpr, + tpr, + color=colours[method], + lw=1, + label=('[%s (AUC = %0.4f %%)]' % + (method.split('-')[-1], roc_auc * 100))) + tpr_fpr_row = [] + tpr_fpr_row.append("%s-%s" % (method, target)) + for fpr_iter in np.arange(len(x_labels)): + _, min_index = min( + list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) + tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) + tpr_fpr_table.add_row(tpr_fpr_row) +plt.xlim([10 ** -6, 0.1]) +plt.ylim([0.3, 1.0]) +plt.grid(linestyle='--', linewidth=1) +plt.xticks(x_labels) +plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) +plt.xscale('log') +plt.xlabel('False Positive Rate') +plt.ylabel('True Positive Rate') +plt.title('ROC on IJB') +plt.legend(loc="lower right") +fig.savefig(os.path.join(save_path, '%s.pdf' % target.lower())) +print(tpr_fpr_table) diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/flops.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/flops.py new file mode 100644 index 00000000..e704b7b5 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/flops.py @@ -0,0 +1,20 @@ +from ptflops import get_model_complexity_info +from backbones import get_model +import argparse + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='') + parser.add_argument('n', type=str, default="r100") + args = parser.parse_args() + net = get_model(args.n) + macs, params = get_model_complexity_info( + net, (3, 112, 112), as_strings=False, + print_per_layer_stat=True, verbose=True) + gmacs = macs / (1000**3) + print("%.3f GFLOPs"%gmacs) + print("%.3f Mparams"%(params/(1000**2))) + + if hasattr(net, "extra_gflops"): + print("%.3f Extra-GFLOPs"%net.extra_gflops) + print("%.3f Total-GFLOPs"%(gmacs+net.extra_gflops)) + diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/inference.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/inference.py new file mode 100644 index 00000000..3e5156e8 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/inference.py @@ -0,0 +1,35 @@ +import argparse + +import cv2 +import numpy as np +import torch + +from backbones import get_model + + +@torch.no_grad() +def inference(weight, name, img): + if img is None: + img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8) + else: + img = cv2.imread(img) + img = cv2.resize(img, (112, 112)) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = np.transpose(img, (2, 0, 1)) + img = torch.from_numpy(img).unsqueeze(0).float() + img.div_(255).sub_(0.5).div_(0.5) + net = get_model(name, fp16=False) + net.load_state_dict(torch.load(weight)) + net.eval() + feat = net(img).numpy() + print(feat) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='PyTorch ArcFace Training') + parser.add_argument('--network', type=str, default='r50', help='backbone network') + parser.add_argument('--weight', type=str, default='') + parser.add_argument('--img', type=str, default=None) + args = parser.parse_args() + inference(args.weight, args.network, args.img) diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/losses.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/losses.py new file mode 100644 index 00000000..e0b4585f --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/losses.py @@ -0,0 +1,100 @@ +import torch +import math + + +class CombinedMarginLoss(torch.nn.Module): + def __init__(self, + s, + m1, + m2, + m3, + interclass_filtering_threshold=0): + super().__init__() + self.s = s + self.m1 = m1 + self.m2 = m2 + self.m3 = m3 + self.interclass_filtering_threshold = interclass_filtering_threshold + + # For ArcFace + self.cos_m = math.cos(self.m2) + self.sin_m = math.sin(self.m2) + self.theta = math.cos(math.pi - self.m2) + self.sinmm = math.sin(math.pi - self.m2) * self.m2 + self.easy_margin = False + + + def forward(self, logits, labels): + index_positive = torch.where(labels != -1)[0] + + if self.interclass_filtering_threshold > 0: + with torch.no_grad(): + dirty = logits > self.interclass_filtering_threshold + dirty = dirty.float() + mask = torch.ones([index_positive.size(0), logits.size(1)], device=logits.device) + mask.scatter_(1, labels[index_positive], 0) + dirty[index_positive] *= mask + tensor_mul = 1 - dirty + logits = tensor_mul * logits + + target_logit = logits[index_positive, labels[index_positive].view(-1)] + + if self.m1 == 1.0 and self.m3 == 0.0: + with torch.no_grad(): + target_logit.arccos_() + logits.arccos_() + final_target_logit = target_logit + self.m2 + logits[index_positive, labels[index_positive].view(-1)] = final_target_logit + logits.cos_() + logits = logits * self.s + + elif self.m3 > 0: + final_target_logit = target_logit - self.m3 + logits[index_positive, labels[index_positive].view(-1)] = final_target_logit + logits = logits * self.s + else: + raise + + return logits + +class ArcFace(torch.nn.Module): + """ ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf): + """ + def __init__(self, s=64.0, margin=0.5): + super(ArcFace, self).__init__() + self.scale = s + self.margin = margin + self.cos_m = math.cos(margin) + self.sin_m = math.sin(margin) + self.theta = math.cos(math.pi - margin) + self.sinmm = math.sin(math.pi - margin) * margin + self.easy_margin = False + + + def forward(self, logits: torch.Tensor, labels: torch.Tensor): + index = torch.where(labels != -1)[0] + target_logit = logits[index, labels[index].view(-1)] + + with torch.no_grad(): + target_logit.arccos_() + logits.arccos_() + final_target_logit = target_logit + self.margin + logits[index, labels[index].view(-1)] = final_target_logit + logits.cos_() + logits = logits * self.s + return logits + + +class CosFace(torch.nn.Module): + def __init__(self, s=64.0, m=0.40): + super(CosFace, self).__init__() + self.s = s + self.m = m + + def forward(self, logits: torch.Tensor, labels: torch.Tensor): + index = torch.where(labels != -1)[0] + target_logit = logits[index, labels[index].view(-1)] + final_target_logit = target_logit - self.m + logits[index, labels[index].view(-1)] = final_target_logit + logits = logits * self.s + return logits diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/lr_scheduler.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/lr_scheduler.py new file mode 100644 index 00000000..7a703335 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/lr_scheduler.py @@ -0,0 +1,30 @@ +from torch.optim.lr_scheduler import _LRScheduler + + +class PolyScheduler(_LRScheduler): + def __init__(self, optimizer, base_lr, max_steps, warmup_steps, last_epoch=-1): + self.base_lr = base_lr + self.warmup_lr_init = 0.0001 + self.max_steps: int = max_steps + self.warmup_steps: int = warmup_steps + self.power = 2 + super(PolyScheduler, self).__init__(optimizer, -1, False) + self.last_epoch = last_epoch + + def get_warmup_lr(self): + alpha = float(self.last_epoch) / float(self.warmup_steps) + return [self.base_lr * alpha for _ in self.optimizer.param_groups] + + def get_lr(self): + if self.last_epoch == -1: + return [self.warmup_lr_init for _ in self.optimizer.param_groups] + if self.last_epoch < self.warmup_steps: + return self.get_warmup_lr() + else: + alpha = pow( + 1 + - float(self.last_epoch - self.warmup_steps) + / float(self.max_steps - self.warmup_steps), + self.power, + ) + return [self.base_lr * alpha for _ in self.optimizer.param_groups] diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/onnx_helper.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/onnx_helper.py new file mode 100644 index 00000000..ca922ca6 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/onnx_helper.py @@ -0,0 +1,250 @@ +from __future__ import division +import datetime +import os +import os.path as osp +import glob +import numpy as np +import cv2 +import sys +import onnxruntime +import onnx +import argparse +from onnx import numpy_helper +from insightface.data import get_image + +class ArcFaceORT: + def __init__(self, model_path, cpu=False): + self.model_path = model_path + # providers = None will use available provider, for onnxruntime-gpu it will be "CUDAExecutionProvider" + self.providers = ['CPUExecutionProvider'] if cpu else None + + #input_size is (w,h), return error message, return None if success + def check(self, track='cfat', test_img = None): + #default is cfat + max_model_size_mb=1024 + max_feat_dim=512 + max_time_cost=15 + if track.startswith('ms1m'): + max_model_size_mb=1024 + max_feat_dim=512 + max_time_cost=10 + elif track.startswith('glint'): + max_model_size_mb=1024 + max_feat_dim=1024 + max_time_cost=20 + elif track.startswith('cfat'): + max_model_size_mb = 1024 + max_feat_dim = 512 + max_time_cost = 15 + elif track.startswith('unconstrained'): + max_model_size_mb=1024 + max_feat_dim=1024 + max_time_cost=30 + else: + return "track not found" + + if not os.path.exists(self.model_path): + return "model_path not exists" + if not os.path.isdir(self.model_path): + return "model_path should be directory" + onnx_files = [] + for _file in os.listdir(self.model_path): + if _file.endswith('.onnx'): + onnx_files.append(osp.join(self.model_path, _file)) + if len(onnx_files)==0: + return "do not have onnx files" + self.model_file = sorted(onnx_files)[-1] + print('use onnx-model:', self.model_file) + try: + session = onnxruntime.InferenceSession(self.model_file, providers=self.providers) + except: + return "load onnx failed" + input_cfg = session.get_inputs()[0] + input_shape = input_cfg.shape + print('input-shape:', input_shape) + if len(input_shape)!=4: + return "length of input_shape should be 4" + if not isinstance(input_shape[0], str): + #return "input_shape[0] should be str to support batch-inference" + print('reset input-shape[0] to None') + model = onnx.load(self.model_file) + model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None' + new_model_file = osp.join(self.model_path, 'zzzzrefined.onnx') + onnx.save(model, new_model_file) + self.model_file = new_model_file + print('use new onnx-model:', self.model_file) + try: + session = onnxruntime.InferenceSession(self.model_file, providers=self.providers) + except: + return "load onnx failed" + input_cfg = session.get_inputs()[0] + input_shape = input_cfg.shape + print('new-input-shape:', input_shape) + + self.image_size = tuple(input_shape[2:4][::-1]) + #print('image_size:', self.image_size) + input_name = input_cfg.name + outputs = session.get_outputs() + output_names = [] + for o in outputs: + output_names.append(o.name) + #print(o.name, o.shape) + if len(output_names)!=1: + return "number of output nodes should be 1" + self.session = session + self.input_name = input_name + self.output_names = output_names + #print(self.output_names) + model = onnx.load(self.model_file) + graph = model.graph + if len(graph.node)<8: + return "too small onnx graph" + + input_size = (112,112) + self.crop = None + if track=='cfat': + crop_file = osp.join(self.model_path, 'crop.txt') + if osp.exists(crop_file): + lines = open(crop_file,'r').readlines() + if len(lines)!=6: + return "crop.txt should contain 6 lines" + lines = [int(x) for x in lines] + self.crop = lines[:4] + input_size = tuple(lines[4:6]) + if input_size!=self.image_size: + return "input-size is inconsistant with onnx model input, %s vs %s"%(input_size, self.image_size) + + self.model_size_mb = os.path.getsize(self.model_file) / float(1024*1024) + if self.model_size_mb > max_model_size_mb: + return "max model size exceed, given %.3f-MB"%self.model_size_mb + + input_mean = None + input_std = None + if track=='cfat': + pn_file = osp.join(self.model_path, 'pixel_norm.txt') + if osp.exists(pn_file): + lines = open(pn_file,'r').readlines() + if len(lines)!=2: + return "pixel_norm.txt should contain 2 lines" + input_mean = float(lines[0]) + input_std = float(lines[1]) + if input_mean is not None or input_std is not None: + if input_mean is None or input_std is None: + return "please set input_mean and input_std simultaneously" + else: + find_sub = False + find_mul = False + for nid, node in enumerate(graph.node[:8]): + print(nid, node.name) + if node.name.startswith('Sub') or node.name.startswith('_minus'): + find_sub = True + if node.name.startswith('Mul') or node.name.startswith('_mul') or node.name.startswith('Div'): + find_mul = True + if find_sub and find_mul: + print("find sub and mul") + #mxnet arcface model + input_mean = 0.0 + input_std = 1.0 + else: + input_mean = 127.5 + input_std = 127.5 + self.input_mean = input_mean + self.input_std = input_std + for initn in graph.initializer: + weight_array = numpy_helper.to_array(initn) + dt = weight_array.dtype + if dt.itemsize<4: + return 'invalid weight type - (%s:%s)' % (initn.name, dt.name) + if test_img is None: + test_img = get_image('Tom_Hanks_54745') + test_img = cv2.resize(test_img, self.image_size) + else: + test_img = cv2.resize(test_img, self.image_size) + feat, cost = self.benchmark(test_img) + batch_result = self.check_batch(test_img) + batch_result_sum = float(np.sum(batch_result)) + if batch_result_sum in [float('inf'), -float('inf')] or batch_result_sum != batch_result_sum: + print(batch_result) + print(batch_result_sum) + return "batch result output contains NaN!" + + if len(feat.shape) < 2: + return "the shape of the feature must be two, but get {}".format(str(feat.shape)) + + if feat.shape[1] > max_feat_dim: + return "max feat dim exceed, given %d"%feat.shape[1] + self.feat_dim = feat.shape[1] + cost_ms = cost*1000 + if cost_ms>max_time_cost: + return "max time cost exceed, given %.4f"%cost_ms + self.cost_ms = cost_ms + print('check stat:, model-size-mb: %.4f, feat-dim: %d, time-cost-ms: %.4f, input-mean: %.3f, input-std: %.3f'%(self.model_size_mb, self.feat_dim, self.cost_ms, self.input_mean, self.input_std)) + return None + + def check_batch(self, img): + if not isinstance(img, list): + imgs = [img, ] * 32 + if self.crop is not None: + nimgs = [] + for img in imgs: + nimg = img[self.crop[1]:self.crop[3], self.crop[0]:self.crop[2], :] + if nimg.shape[0] != self.image_size[1] or nimg.shape[1] != self.image_size[0]: + nimg = cv2.resize(nimg, self.image_size) + nimgs.append(nimg) + imgs = nimgs + blob = cv2.dnn.blobFromImages( + images=imgs, scalefactor=1.0 / self.input_std, size=self.image_size, + mean=(self.input_mean, self.input_mean, self.input_mean), swapRB=True) + net_out = self.session.run(self.output_names, {self.input_name: blob})[0] + return net_out + + + def meta_info(self): + return {'model-size-mb':self.model_size_mb, 'feature-dim':self.feat_dim, 'infer': self.cost_ms} + + + def forward(self, imgs): + if not isinstance(imgs, list): + imgs = [imgs] + input_size = self.image_size + if self.crop is not None: + nimgs = [] + for img in imgs: + nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:] + if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]: + nimg = cv2.resize(nimg, input_size) + nimgs.append(nimg) + imgs = nimgs + blob = cv2.dnn.blobFromImages(imgs, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + net_out = self.session.run(self.output_names, {self.input_name : blob})[0] + return net_out + + def benchmark(self, img): + input_size = self.image_size + if self.crop is not None: + nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:] + if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]: + nimg = cv2.resize(nimg, input_size) + img = nimg + blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + costs = [] + for _ in range(50): + ta = datetime.datetime.now() + net_out = self.session.run(self.output_names, {self.input_name : blob})[0] + tb = datetime.datetime.now() + cost = (tb-ta).total_seconds() + costs.append(cost) + costs = sorted(costs) + cost = costs[5] + return net_out, cost + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='') + # general + parser.add_argument('workdir', help='submitted work dir', type=str) + parser.add_argument('--track', help='track name, for different challenge', type=str, default='cfat') + args = parser.parse_args() + handler = ArcFaceORT(args.workdir) + err = handler.check(args.track) + print('err:', err) diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/onnx_ijbc.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/onnx_ijbc.py new file mode 100644 index 00000000..31c491b1 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/onnx_ijbc.py @@ -0,0 +1,269 @@ +import argparse +import os +import pickle +import timeit + +import cv2 +import mxnet as mx +import numpy as np +import pandas as pd +import prettytable +import skimage.transform +import torch +from sklearn.metrics import roc_curve +from sklearn.preprocessing import normalize +from torch.utils.data import DataLoader +from onnx_helper import ArcFaceORT + +SRC = np.array( + [ + [30.2946, 51.6963], + [65.5318, 51.5014], + [48.0252, 71.7366], + [33.5493, 92.3655], + [62.7299, 92.2041]] + , dtype=np.float32) +SRC[:, 0] += 8.0 + + +@torch.no_grad() +class AlignedDataSet(mx.gluon.data.Dataset): + def __init__(self, root, lines, align=True): + self.lines = lines + self.root = root + self.align = align + + def __len__(self): + return len(self.lines) + + def __getitem__(self, idx): + each_line = self.lines[idx] + name_lmk_score = each_line.strip().split(' ') + name = os.path.join(self.root, name_lmk_score[0]) + img = cv2.cvtColor(cv2.imread(name), cv2.COLOR_BGR2RGB) + landmark5 = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32).reshape((5, 2)) + st = skimage.transform.SimilarityTransform() + st.estimate(landmark5, SRC) + img = cv2.warpAffine(img, st.params[0:2, :], (112, 112), borderValue=0.0) + img_1 = np.expand_dims(img, 0) + img_2 = np.expand_dims(np.fliplr(img), 0) + output = np.concatenate((img_1, img_2), axis=0).astype(np.float32) + output = np.transpose(output, (0, 3, 1, 2)) + return torch.from_numpy(output) + + +@torch.no_grad() +def extract(model_root, dataset): + model = ArcFaceORT(model_path=model_root) + model.check() + feat_mat = np.zeros(shape=(len(dataset), 2 * model.feat_dim)) + + def collate_fn(data): + return torch.cat(data, dim=0) + + data_loader = DataLoader( + dataset, batch_size=128, drop_last=False, num_workers=4, collate_fn=collate_fn, ) + num_iter = 0 + for batch in data_loader: + batch = batch.numpy() + batch = (batch - model.input_mean) / model.input_std + feat = model.session.run(model.output_names, {model.input_name: batch})[0] + feat = np.reshape(feat, (-1, model.feat_dim * 2)) + feat_mat[128 * num_iter: 128 * num_iter + feat.shape[0], :] = feat + num_iter += 1 + if num_iter % 50 == 0: + print(num_iter) + return feat_mat + + +def read_template_media_list(path): + ijb_meta = pd.read_csv(path, sep=' ', header=None).values + templates = ijb_meta[:, 1].astype(np.int) + medias = ijb_meta[:, 2].astype(np.int) + return templates, medias + + +def read_template_pair_list(path): + pairs = pd.read_csv(path, sep=' ', header=None).values + t1 = pairs[:, 0].astype(np.int) + t2 = pairs[:, 1].astype(np.int) + label = pairs[:, 2].astype(np.int) + return t1, t2, label + + +def read_image_feature(path): + with open(path, 'rb') as fid: + img_feats = pickle.load(fid) + return img_feats + + +def image2template_feature(img_feats=None, + templates=None, + medias=None): + unique_templates = np.unique(templates) + template_feats = np.zeros((len(unique_templates), img_feats.shape[1])) + for count_template, uqt in enumerate(unique_templates): + (ind_t,) = np.where(templates == uqt) + face_norm_feats = img_feats[ind_t] + face_medias = medias[ind_t] + unique_medias, unique_media_counts = np.unique(face_medias, return_counts=True) + media_norm_feats = [] + for u, ct in zip(unique_medias, unique_media_counts): + (ind_m,) = np.where(face_medias == u) + if ct == 1: + media_norm_feats += [face_norm_feats[ind_m]] + else: # image features from the same video will be aggregated into one feature + media_norm_feats += [np.mean(face_norm_feats[ind_m], axis=0, keepdims=True), ] + media_norm_feats = np.array(media_norm_feats) + template_feats[count_template] = np.sum(media_norm_feats, axis=0) + if count_template % 2000 == 0: + print('Finish Calculating {} template features.'.format( + count_template)) + template_norm_feats = normalize(template_feats) + return template_norm_feats, unique_templates + + +def verification(template_norm_feats=None, + unique_templates=None, + p1=None, + p2=None): + template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) + for count_template, uqt in enumerate(unique_templates): + template2id[uqt] = count_template + score = np.zeros((len(p1),)) + total_pairs = np.array(range(len(p1))) + batchsize = 100000 + sublists = [total_pairs[i: i + batchsize] for i in range(0, len(p1), batchsize)] + total_sublists = len(sublists) + for c, s in enumerate(sublists): + feat1 = template_norm_feats[template2id[p1[s]]] + feat2 = template_norm_feats[template2id[p2[s]]] + similarity_score = np.sum(feat1 * feat2, -1) + score[s] = similarity_score.flatten() + if c % 10 == 0: + print('Finish {}/{} pairs.'.format(c, total_sublists)) + return score + + +def verification2(template_norm_feats=None, + unique_templates=None, + p1=None, + p2=None): + template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) + for count_template, uqt in enumerate(unique_templates): + template2id[uqt] = count_template + score = np.zeros((len(p1),)) # save cosine distance between pairs + total_pairs = np.array(range(len(p1))) + batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation + sublists = [total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)] + total_sublists = len(sublists) + for c, s in enumerate(sublists): + feat1 = template_norm_feats[template2id[p1[s]]] + feat2 = template_norm_feats[template2id[p2[s]]] + similarity_score = np.sum(feat1 * feat2, -1) + score[s] = similarity_score.flatten() + if c % 10 == 0: + print('Finish {}/{} pairs.'.format(c, total_sublists)) + return score + + +def main(args): + use_norm_score = True # if Ture, TestMode(N1) + use_detector_score = True # if Ture, TestMode(D1) + use_flip_test = True # if Ture, TestMode(F1) + assert args.target == 'IJBC' or args.target == 'IJBB' + + start = timeit.default_timer() + templates, medias = read_template_media_list( + os.path.join('%s/meta' % args.image_path, '%s_face_tid_mid.txt' % args.target.lower())) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + + start = timeit.default_timer() + p1, p2, label = read_template_pair_list( + os.path.join('%s/meta' % args.image_path, + '%s_template_pair_label.txt' % args.target.lower())) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + + start = timeit.default_timer() + img_path = '%s/loose_crop' % args.image_path + img_list_path = '%s/meta/%s_name_5pts_score.txt' % (args.image_path, args.target.lower()) + img_list = open(img_list_path) + files = img_list.readlines() + dataset = AlignedDataSet(root=img_path, lines=files, align=True) + img_feats = extract(args.model_root, dataset) + + faceness_scores = [] + for each_line in files: + name_lmk_score = each_line.split() + faceness_scores.append(name_lmk_score[-1]) + faceness_scores = np.array(faceness_scores).astype(np.float32) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], img_feats.shape[1])) + start = timeit.default_timer() + + if use_flip_test: + img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + img_feats[:, img_feats.shape[1] // 2:] + else: + img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + + if use_norm_score: + img_input_feats = img_input_feats + else: + img_input_feats = img_input_feats / np.sqrt(np.sum(img_input_feats ** 2, -1, keepdims=True)) + + if use_detector_score: + print(img_input_feats.shape, faceness_scores.shape) + img_input_feats = img_input_feats * faceness_scores[:, np.newaxis] + else: + img_input_feats = img_input_feats + + template_norm_feats, unique_templates = image2template_feature( + img_input_feats, templates, medias) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + + start = timeit.default_timer() + score = verification(template_norm_feats, unique_templates, p1, p2) + stop = timeit.default_timer() + print('Time: %.2f s. ' % (stop - start)) + result_dir = args.model_root + + save_path = os.path.join(result_dir, "{}_result".format(args.target)) + if not os.path.exists(save_path): + os.makedirs(save_path) + score_save_file = os.path.join(save_path, "{}.npy".format(args.target)) + np.save(score_save_file, score) + files = [score_save_file] + methods = [] + scores = [] + for file in files: + methods.append(os.path.basename(file)) + scores.append(np.load(file)) + methods = np.array(methods) + scores = dict(zip(methods, scores)) + x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] + tpr_fpr_table = prettytable.PrettyTable(['Methods'] + [str(x) for x in x_labels]) + for method in methods: + fpr, tpr, _ = roc_curve(label, scores[method]) + fpr = np.flipud(fpr) + tpr = np.flipud(tpr) + tpr_fpr_row = [] + tpr_fpr_row.append("%s-%s" % (method, args.target)) + for fpr_iter in np.arange(len(x_labels)): + _, min_index = min( + list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) + tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) + tpr_fpr_table.add_row(tpr_fpr_row) + print(tpr_fpr_table) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='do ijb test') + # general + parser.add_argument('--model-root', default='', help='path to load model.') + parser.add_argument('--image-path', default='/train_tmp/IJB_release/IJBC', type=str, help='') + parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB') + main(parser.parse_args()) diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/partial_fc.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/partial_fc.py new file mode 100644 index 00000000..eeff29d8 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/partial_fc.py @@ -0,0 +1,531 @@ +import collections +from typing import Callable + +import torch +from torch import distributed +from torch.nn.functional import linear, normalize + + +class PartialFC(torch.nn.Module): + """ + https://arxiv.org/abs/2203.15565 + A distributed sparsely updating variant of the FC layer, named Partial FC (PFC). + + When sample rate less than 1, in each iteration, positive class centers and a random subset of + negative class centers are selected to compute the margin-based softmax loss, all class + centers are still maintained throughout the whole training process, but only a subset is + selected and updated in each iteration. + + .. note:: + When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1). + + Example: + -------- + >>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2) + >>> for img, labels in data_loader: + >>> embeddings = net(img) + >>> loss = module_pfc(embeddings, labels, optimizer) + >>> loss.backward() + >>> optimizer.step() + """ + _version = 1 + def __init__( + self, + margin_loss: Callable, + embedding_size: int, + num_classes: int, + sample_rate: float = 1.0, + fp16: bool = False, + ): + """ + Paramenters: + ----------- + embedding_size: int + The dimension of embedding, required + num_classes: int + Total number of classes, required + sample_rate: float + The rate of negative centers participating in the calculation, default is 1.0. + """ + super(PartialFC, self).__init__() + assert ( + distributed.is_initialized() + ), "must initialize distributed before create this" + self.rank = distributed.get_rank() + self.world_size = distributed.get_world_size() + + self.dist_cross_entropy = DistCrossEntropy() + self.embedding_size = embedding_size + self.sample_rate: float = sample_rate + self.fp16 = fp16 + self.num_local: int = num_classes // self.world_size + int( + self.rank < num_classes % self.world_size + ) + self.class_start: int = num_classes // self.world_size * self.rank + min( + self.rank, num_classes % self.world_size + ) + self.num_sample: int = int(self.sample_rate * self.num_local) + self.last_batch_size: int = 0 + self.weight: torch.Tensor + self.weight_mom: torch.Tensor + self.weight_activated: torch.nn.Parameter + self.weight_activated_mom: torch.Tensor + self.is_updated: bool = True + self.init_weight_update: bool = True + + if self.sample_rate < 1: + self.register_buffer("weight", + tensor=torch.normal(0, 0.01, (self.num_local, embedding_size))) + self.register_buffer("weight_mom", + tensor=torch.zeros_like(self.weight)) + self.register_parameter("weight_activated", + param=torch.nn.Parameter(torch.empty(0, 0))) + self.register_buffer("weight_activated_mom", + tensor=torch.empty(0, 0)) + self.register_buffer("weight_index", + tensor=torch.empty(0, 0)) + else: + self.weight_activated = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size))) + + # margin_loss + if isinstance(margin_loss, Callable): + self.margin_softmax = margin_loss + else: + raise + + @torch.no_grad() + def sample(self, + labels: torch.Tensor, + index_positive: torch.Tensor, + optimizer: torch.optim.Optimizer): + """ + This functions will change the value of labels + + Parameters: + ----------- + labels: torch.Tensor + pass + index_positive: torch.Tensor + pass + optimizer: torch.optim.Optimizer + pass + """ + positive = torch.unique(labels[index_positive], sorted=True).cuda() + if self.num_sample - positive.size(0) >= 0: + perm = torch.rand(size=[self.num_local]).cuda() + perm[positive] = 2.0 + index = torch.topk(perm, k=self.num_sample)[1].cuda() + index = index.sort()[0].cuda() + else: + index = positive + self.weight_index = index + + labels[index_positive] = torch.searchsorted(index, labels[index_positive]) + + self.weight_activated = torch.nn.Parameter(self.weight[self.weight_index]) + self.weight_activated_mom = self.weight_mom[self.weight_index] + + if isinstance(optimizer, torch.optim.SGD): + # TODO the params of partial fc must be last in the params list + optimizer.state.pop(optimizer.param_groups[-1]["params"][0], None) + optimizer.param_groups[-1]["params"][0] = self.weight_activated + optimizer.state[self.weight_activated][ + "momentum_buffer" + ] = self.weight_activated_mom + else: + raise + + @torch.no_grad() + def update(self): + """ partial weight to global + """ + if self.init_weight_update: + self.init_weight_update = False + return + + if self.sample_rate < 1: + self.weight[self.weight_index] = self.weight_activated + self.weight_mom[self.weight_index] = self.weight_activated_mom + + + def forward( + self, + local_embeddings: torch.Tensor, + local_labels: torch.Tensor, + optimizer: torch.optim.Optimizer, + ): + """ + Parameters: + ---------- + local_embeddings: torch.Tensor + feature embeddings on each GPU(Rank). + local_labels: torch.Tensor + labels on each GPU(Rank). + + Returns: + ------- + loss: torch.Tensor + pass + """ + local_labels.squeeze_() + local_labels = local_labels.long() + self.update() + + batch_size = local_embeddings.size(0) + if self.last_batch_size == 0: + self.last_batch_size = batch_size + assert self.last_batch_size == batch_size, ( + "last batch size do not equal current batch size: {} vs {}".format( + self.last_batch_size, batch_size)) + + _gather_embeddings = [ + torch.zeros((batch_size, self.embedding_size)).cuda() + for _ in range(self.world_size) + ] + _gather_labels = [ + torch.zeros(batch_size).long().cuda() for _ in range(self.world_size) + ] + _list_embeddings = AllGather(local_embeddings, *_gather_embeddings) + distributed.all_gather(_gather_labels, local_labels) + + embeddings = torch.cat(_list_embeddings) + labels = torch.cat(_gather_labels) + + labels = labels.view(-1, 1) + index_positive = (self.class_start <= labels) & ( + labels < self.class_start + self.num_local + ) + labels[~index_positive] = -1 + labels[index_positive] -= self.class_start + + if self.sample_rate < 1: + self.sample(labels, index_positive, optimizer) + + with torch.cuda.amp.autocast(self.fp16): + norm_embeddings = normalize(embeddings) + norm_weight_activated = normalize(self.weight_activated) + logits = linear(norm_embeddings, norm_weight_activated) + if self.fp16: + logits = logits.float() + logits = logits.clamp(-1, 1) + + logits = self.margin_softmax(logits, labels) + loss = self.dist_cross_entropy(logits, labels) + return loss + + def state_dict(self, destination=None, prefix="", keep_vars=False): + if destination is None: + destination = collections.OrderedDict() + destination._metadata = collections.OrderedDict() + + for name, module in self._modules.items(): + if module is not None: + module.state_dict(destination, prefix + name + ".", keep_vars=keep_vars) + if self.sample_rate < 1: + destination["weight"] = self.weight.detach() + else: + destination["weight"] = self.weight_activated.data.detach() + return destination + + def load_state_dict(self, state_dict, strict: bool = True): + if self.sample_rate < 1: + self.weight = state_dict["weight"].to(self.weight.device) + self.weight_mom.zero_() + self.weight_activated.data.zero_() + self.weight_activated_mom.zero_() + self.weight_index.zero_() + else: + self.weight_activated.data = state_dict["weight"].to(self.weight_activated.data.device) + + +class PartialFCAdamW(torch.nn.Module): + def __init__(self, + margin_loss: Callable, + embedding_size: int, + num_classes: int, + sample_rate: float = 1.0, + fp16: bool = False,): + """ + Paramenters: + ----------- + embedding_size: int + The dimension of embedding, required + num_classes: int + Total number of classes, required + sample_rate: float + The rate of negative centers participating in the calculation, default is 1.0. + """ + super(PartialFCAdamW, self).__init__() + assert ( + distributed.is_initialized() + ), "must initialize distributed before create this" + self.rank = distributed.get_rank() + self.world_size = distributed.get_world_size() + + self.dist_cross_entropy = DistCrossEntropy() + self.embedding_size = embedding_size + self.sample_rate: float = sample_rate + self.fp16 = fp16 + self.num_local: int = num_classes // self.world_size + int( + self.rank < num_classes % self.world_size + ) + self.class_start: int = num_classes // self.world_size * self.rank + min( + self.rank, num_classes % self.world_size + ) + self.num_sample: int = int(self.sample_rate * self.num_local) + self.last_batch_size: int = 0 + self.weight: torch.Tensor + self.weight_exp_avg: torch.Tensor + self.weight_exp_avg_sq: torch.Tensor + self.weight_activated: torch.nn.Parameter + self.weight_activated_exp_avg: torch.Tensor + self.weight_activated_exp_avg_sq: torch.Tensor + + self.is_updated: bool = True + self.init_weight_update: bool = True + + if self.sample_rate < 1: + self.register_buffer("weight", + tensor=torch.normal(0, 0.01, (self.num_local, embedding_size))) + self.register_buffer("weight_exp_avg", + tensor=torch.zeros_like(self.weight)) + self.register_buffer("weight_exp_avg_sq", + tensor=torch.zeros_like(self.weight)) + self.register_parameter("weight_activated", + param=torch.nn.Parameter(torch.empty(0, 0))) + self.register_buffer("weight_activated_exp_avg", + tensor=torch.empty(0, 0)) + self.register_buffer("weight_activated_exp_avg_sq", + tensor=torch.empty(0, 0)) + else: + self.weight_activated = torch.nn.Parameter( + torch.normal(0, 0.01, (self.num_local, embedding_size)) + ) + self.step = 0 + + if isinstance(margin_loss, Callable): + self.margin_softmax = margin_loss + else: + raise + + @torch.no_grad() + def sample(self, labels, index_positive, optimizer): + self.step += 1 + positive = torch.unique(labels[index_positive], sorted=True).cuda() + if self.num_sample - positive.size(0) >= 0: + perm = torch.rand(size=[self.num_local]).cuda() + perm[positive] = 2.0 + index = torch.topk(perm, k=self.num_sample)[1].cuda() + index = index.sort()[0].cuda() + else: + index = positive + self.weight_index = index + labels[index_positive] = torch.searchsorted(index, labels[index_positive]) + self.weight_activated = torch.nn.Parameter(self.weight[self.weight_index]) + self.weight_activated_exp_avg = self.weight_exp_avg[self.weight_index] + self.weight_activated_exp_avg_sq = self.weight_exp_avg_sq[self.weight_index] + + if isinstance(optimizer, (torch.optim.Adam, torch.optim.AdamW)): + # TODO the params of partial fc must be last in the params list + optimizer.state.pop(optimizer.param_groups[-1]["params"][0], None) + optimizer.param_groups[-1]["params"][0] = self.weight_activated + optimizer.state[self.weight_activated]["exp_avg"] = self.weight_activated_exp_avg + optimizer.state[self.weight_activated]["exp_avg_sq"] = self.weight_activated_exp_avg_sq + optimizer.state[self.weight_activated]["step"] = self.step + else: + raise + + @torch.no_grad() + def update(self): + """ partial weight to global + """ + if self.init_weight_update: + self.init_weight_update = False + return + + if self.sample_rate < 1: + self.weight[self.weight_index] = self.weight_activated + self.weight_exp_avg[self.weight_index] = self.weight_activated_exp_avg + self.weight_exp_avg_sq[self.weight_index] = self.weight_activated_exp_avg_sq + + def forward( + self, + local_embeddings: torch.Tensor, + local_labels: torch.Tensor, + optimizer: torch.optim.Optimizer, + ): + """ + Parameters: + ---------- + local_embeddings: torch.Tensor + feature embeddings on each GPU(Rank). + local_labels: torch.Tensor + labels on each GPU(Rank). + + Returns: + ------- + loss: torch.Tensor + pass + """ + local_labels.squeeze_() + local_labels = local_labels.long() + self.update() + + batch_size = local_embeddings.size(0) + if self.last_batch_size == 0: + self.last_batch_size = batch_size + assert self.last_batch_size == batch_size, ( + "last batch size do not equal current batch size: {} vs {}".format( + self.last_batch_size, batch_size)) + + _gather_embeddings = [ + torch.zeros((batch_size, self.embedding_size)).cuda() + for _ in range(self.world_size) + ] + _gather_labels = [ + torch.zeros(batch_size).long().cuda() for _ in range(self.world_size) + ] + _list_embeddings = AllGather(local_embeddings, *_gather_embeddings) + distributed.all_gather(_gather_labels, local_labels) + + embeddings = torch.cat(_list_embeddings) + labels = torch.cat(_gather_labels) + + labels = labels.view(-1, 1) + index_positive = (self.class_start <= labels) & ( + labels < self.class_start + self.num_local + ) + labels[~index_positive] = -1 + labels[index_positive] -= self.class_start + + if self.sample_rate < 1: + self.sample(labels, index_positive, optimizer) + + with torch.cuda.amp.autocast(self.fp16): + norm_embeddings = normalize(embeddings) + norm_weight_activated = normalize(self.weight_activated) + logits = linear(norm_embeddings, norm_weight_activated) + if self.fp16: + logits = logits.float() + logits = logits.clamp(-1, 1) + + logits = self.margin_softmax(logits, labels) + loss = self.dist_cross_entropy(logits, labels) + return loss + def state_dict(self, destination=None, prefix="", keep_vars=False): + if destination is None: + destination = collections.OrderedDict() + destination._metadata = collections.OrderedDict() + + for name, module in self._modules.items(): + if module is not None: + module.state_dict(destination, prefix + name + ".", keep_vars=keep_vars) + if self.sample_rate < 1: + destination["weight"] = self.weight.detach() + else: + destination["weight"] = self.weight_activated.data.detach() + return destination + + def load_state_dict(self, state_dict, strict: bool = True): + if self.sample_rate < 1: + self.weight = state_dict["weight"].to(self.weight.device) + self.weight_exp_avg.zero_() + self.weight_exp_avg_sq.zero_() + self.weight_activated.data.zero_() + self.weight_activated_exp_avg.zero_() + self.weight_activated_exp_avg_sq.zero_() + else: + self.weight_activated.data = state_dict["weight"].to(self.weight_activated.data.device) + + +class DistCrossEntropyFunc(torch.autograd.Function): + """ + CrossEntropy loss is calculated in parallel, allreduce denominator into single gpu and calculate softmax. + Implemented of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf): + """ + + @staticmethod + def forward(ctx, logits: torch.Tensor, label: torch.Tensor): + """ """ + batch_size = logits.size(0) + # for numerical stability + max_logits, _ = torch.max(logits, dim=1, keepdim=True) + # local to global + distributed.all_reduce(max_logits, distributed.ReduceOp.MAX) + logits.sub_(max_logits) + logits.exp_() + sum_logits_exp = torch.sum(logits, dim=1, keepdim=True) + # local to global + distributed.all_reduce(sum_logits_exp, distributed.ReduceOp.SUM) + logits.div_(sum_logits_exp) + index = torch.where(label != -1)[0] + # loss + loss = torch.zeros(batch_size, 1, device=logits.device) + loss[index] = logits[index].gather(1, label[index]) + distributed.all_reduce(loss, distributed.ReduceOp.SUM) + ctx.save_for_backward(index, logits, label) + return loss.clamp_min_(1e-30).log_().mean() * (-1) + + @staticmethod + def backward(ctx, loss_gradient): + """ + Args: + loss_grad (torch.Tensor): gradient backward by last layer + Returns: + gradients for each input in forward function + `None` gradients for one-hot label + """ + ( + index, + logits, + label, + ) = ctx.saved_tensors + batch_size = logits.size(0) + one_hot = torch.zeros( + size=[index.size(0), logits.size(1)], device=logits.device + ) + one_hot.scatter_(1, label[index], 1) + logits[index] -= one_hot + logits.div_(batch_size) + return logits * loss_gradient.item(), None + + +class DistCrossEntropy(torch.nn.Module): + def __init__(self): + super(DistCrossEntropy, self).__init__() + + def forward(self, logit_part, label_part): + return DistCrossEntropyFunc.apply(logit_part, label_part) + + +class AllGatherFunc(torch.autograd.Function): + """AllGather op with gradient backward""" + + @staticmethod + def forward(ctx, tensor, *gather_list): + gather_list = list(gather_list) + distributed.all_gather(gather_list, tensor) + return tuple(gather_list) + + @staticmethod + def backward(ctx, *grads): + grad_list = list(grads) + rank = distributed.get_rank() + grad_out = grad_list[rank] + + dist_ops = [ + distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True) + if i == rank + else distributed.reduce( + grad_list[i], i, distributed.ReduceOp.SUM, async_op=True + ) + for i in range(distributed.get_world_size()) + ] + for _op in dist_ops: + _op.wait() + + grad_out *= len(grad_list) # cooperate with distributed loss function + return (grad_out, *[None for _ in range(len(grad_list))]) + + +AllGather = AllGatherFunc.apply diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/partial_fc_v2.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/partial_fc_v2.py new file mode 100644 index 00000000..0752554c --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/partial_fc_v2.py @@ -0,0 +1,260 @@ + +import math +from typing import Callable + +import torch +from torch import distributed +from torch.nn.functional import linear, normalize + + +class PartialFC_V2(torch.nn.Module): + """ + https://arxiv.org/abs/2203.15565 + A distributed sparsely updating variant of the FC layer, named Partial FC (PFC). + When sample rate less than 1, in each iteration, positive class centers and a random subset of + negative class centers are selected to compute the margin-based softmax loss, all class + centers are still maintained throughout the whole training process, but only a subset is + selected and updated in each iteration. + .. note:: + When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1). + Example: + -------- + >>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2) + >>> for img, labels in data_loader: + >>> embeddings = net(img) + >>> loss = module_pfc(embeddings, labels) + >>> loss.backward() + >>> optimizer.step() + """ + _version = 2 + + def __init__( + self, + margin_loss: Callable, + embedding_size: int, + num_classes: int, + sample_rate: float = 1.0, + fp16: bool = False, + ): + """ + Paramenters: + ----------- + embedding_size: int + The dimension of embedding, required + num_classes: int + Total number of classes, required + sample_rate: float + The rate of negative centers participating in the calculation, default is 1.0. + """ + super(PartialFC_V2, self).__init__() + assert ( + distributed.is_initialized() + ), "must initialize distributed before create this" + self.rank = distributed.get_rank() + self.world_size = distributed.get_world_size() + + self.dist_cross_entropy = DistCrossEntropy() + self.embedding_size = embedding_size + self.sample_rate: float = sample_rate + self.fp16 = fp16 + self.num_local: int = num_classes // self.world_size + int( + self.rank < num_classes % self.world_size + ) + self.class_start: int = num_classes // self.world_size * self.rank + min( + self.rank, num_classes % self.world_size + ) + self.num_sample: int = int(self.sample_rate * self.num_local) + self.last_batch_size: int = 0 + + self.is_updated: bool = True + self.init_weight_update: bool = True + self.weight = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size))) + + # margin_loss + if isinstance(margin_loss, Callable): + self.margin_softmax = margin_loss + else: + raise + + def sample(self, labels, index_positive): + """ + This functions will change the value of labels + Parameters: + ----------- + labels: torch.Tensor + pass + index_positive: torch.Tensor + pass + optimizer: torch.optim.Optimizer + pass + """ + with torch.no_grad(): + positive = torch.unique(labels[index_positive], sorted=True).cuda() + if self.num_sample - positive.size(0) >= 0: + perm = torch.rand(size=[self.num_local]).cuda() + perm[positive] = 2.0 + index = torch.topk(perm, k=self.num_sample)[1].cuda() + index = index.sort()[0].cuda() + else: + index = positive + self.weight_index = index + + labels[index_positive] = torch.searchsorted(index, labels[index_positive]) + + return self.weight[self.weight_index] + + def forward( + self, + local_embeddings: torch.Tensor, + local_labels: torch.Tensor, + ): + """ + Parameters: + ---------- + local_embeddings: torch.Tensor + feature embeddings on each GPU(Rank). + local_labels: torch.Tensor + labels on each GPU(Rank). + Returns: + ------- + loss: torch.Tensor + pass + """ + local_labels.squeeze_() + local_labels = local_labels.long() + + batch_size = local_embeddings.size(0) + if self.last_batch_size == 0: + self.last_batch_size = batch_size + assert self.last_batch_size == batch_size, ( + f"last batch size do not equal current batch size: {self.last_batch_size} vs {batch_size}") + + _gather_embeddings = [ + torch.zeros((batch_size, self.embedding_size)).cuda() + for _ in range(self.world_size) + ] + _gather_labels = [ + torch.zeros(batch_size).long().cuda() for _ in range(self.world_size) + ] + _list_embeddings = AllGather(local_embeddings, *_gather_embeddings) + distributed.all_gather(_gather_labels, local_labels) + + embeddings = torch.cat(_list_embeddings) + labels = torch.cat(_gather_labels) + + labels = labels.view(-1, 1) + index_positive = (self.class_start <= labels) & ( + labels < self.class_start + self.num_local + ) + labels[~index_positive] = -1 + labels[index_positive] -= self.class_start + + if self.sample_rate < 1: + weight = self.sample(labels, index_positive) + else: + weight = self.weight + + with torch.cuda.amp.autocast(self.fp16): + norm_embeddings = normalize(embeddings) + norm_weight_activated = normalize(weight) + logits = linear(norm_embeddings, norm_weight_activated) + if self.fp16: + logits = logits.float() + logits = logits.clamp(-1, 1) + + logits = self.margin_softmax(logits, labels) + loss = self.dist_cross_entropy(logits, labels) + return loss + + +class DistCrossEntropyFunc(torch.autograd.Function): + """ + CrossEntropy loss is calculated in parallel, allreduce denominator into single gpu and calculate softmax. + Implemented of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf): + """ + + @staticmethod + def forward(ctx, logits: torch.Tensor, label: torch.Tensor): + """ """ + batch_size = logits.size(0) + # for numerical stability + max_logits, _ = torch.max(logits, dim=1, keepdim=True) + # local to global + distributed.all_reduce(max_logits, distributed.ReduceOp.MAX) + logits.sub_(max_logits) + logits.exp_() + sum_logits_exp = torch.sum(logits, dim=1, keepdim=True) + # local to global + distributed.all_reduce(sum_logits_exp, distributed.ReduceOp.SUM) + logits.div_(sum_logits_exp) + index = torch.where(label != -1)[0] + # loss + loss = torch.zeros(batch_size, 1, device=logits.device) + loss[index] = logits[index].gather(1, label[index]) + distributed.all_reduce(loss, distributed.ReduceOp.SUM) + ctx.save_for_backward(index, logits, label) + return loss.clamp_min_(1e-30).log_().mean() * (-1) + + @staticmethod + def backward(ctx, loss_gradient): + """ + Args: + loss_grad (torch.Tensor): gradient backward by last layer + Returns: + gradients for each input in forward function + `None` gradients for one-hot label + """ + ( + index, + logits, + label, + ) = ctx.saved_tensors + batch_size = logits.size(0) + one_hot = torch.zeros( + size=[index.size(0), logits.size(1)], device=logits.device + ) + one_hot.scatter_(1, label[index], 1) + logits[index] -= one_hot + logits.div_(batch_size) + return logits * loss_gradient.item(), None + + +class DistCrossEntropy(torch.nn.Module): + def __init__(self): + super(DistCrossEntropy, self).__init__() + + def forward(self, logit_part, label_part): + return DistCrossEntropyFunc.apply(logit_part, label_part) + + +class AllGatherFunc(torch.autograd.Function): + """AllGather op with gradient backward""" + + @staticmethod + def forward(ctx, tensor, *gather_list): + gather_list = list(gather_list) + distributed.all_gather(gather_list, tensor) + return tuple(gather_list) + + @staticmethod + def backward(ctx, *grads): + grad_list = list(grads) + rank = distributed.get_rank() + grad_out = grad_list[rank] + + dist_ops = [ + distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True) + if i == rank + else distributed.reduce( + grad_list[i], i, distributed.ReduceOp.SUM, async_op=True + ) + for i in range(distributed.get_world_size()) + ] + for _op in dist_ops: + _op.wait() + + grad_out *= len(grad_list) # cooperate with distributed loss function + return (grad_out, *[None for _ in range(len(grad_list))]) + + +AllGather = AllGatherFunc.apply diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/requirement.txt b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/requirement.txt new file mode 100644 index 00000000..f1a431ef --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/requirement.txt @@ -0,0 +1,6 @@ +tensorboard +easydict +mxnet +onnx +sklearn +opencv-python \ No newline at end of file diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/run.sh b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/run.sh new file mode 100644 index 00000000..6eacdf8e --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/run.sh @@ -0,0 +1 @@ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 train_v2.py $@ diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/scripts/shuffle_rec.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/scripts/shuffle_rec.py new file mode 100644 index 00000000..f3b68e93 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/scripts/shuffle_rec.py @@ -0,0 +1,81 @@ +import argparse +import multiprocessing +import os +import time + +import mxnet as mx +import numpy as np + + +def read_worker(args, q_in): + path_imgidx = os.path.join(args.input, "train.idx") + path_imgrec = os.path.join(args.input, "train.rec") + imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "r") + + s = imgrec.read_idx(0) + header, _ = mx.recordio.unpack(s) + assert header.flag > 0 + + imgidx = np.array(range(1, int(header.label[0]))) + np.random.shuffle(imgidx) + + for idx in imgidx: + item = imgrec.read_idx(idx) + q_in.put(item) + + q_in.put(None) + imgrec.close() + + +def write_worker(args, q_out): + pre_time = time.time() + + if args.input[-1] == '/': + args.input = args.input[:-1] + dirname = os.path.dirname(args.input) + basename = os.path.basename(args.input) + output = os.path.join(dirname, f"shuffled_{basename}") + os.makedirs(output, exist_ok=True) + + path_imgidx = os.path.join(output, "train.idx") + path_imgrec = os.path.join(output, "train.rec") + save_record = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "w") + more = True + count = 0 + while more: + deq = q_out.get() + if deq is None: + more = False + else: + header, jpeg = mx.recordio.unpack(deq) + # TODO it is currently not fully developed + if isinstance(header.label, float): + label = header.label + else: + label = header.label[0] + + header = mx.recordio.IRHeader(flag=header.flag, label=label, id=header.id, id2=header.id2) + save_record.write_idx(count, mx.recordio.pack(header, jpeg)) + count += 1 + if count % 10000 == 0: + cur_time = time.time() + print('save time:', cur_time - pre_time, ' count:', count) + pre_time = cur_time + print(count) + save_record.close() + + +def main(args): + queue = multiprocessing.Queue(10240) + read_process = multiprocessing.Process(target=read_worker, args=(args, queue)) + read_process.daemon = True + read_process.start() + write_process = multiprocessing.Process(target=write_worker, args=(args, queue)) + write_process.start() + write_process.join() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('input', help='path to source rec.') + main(parser.parse_args()) diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/torch2onnx.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/torch2onnx.py new file mode 100644 index 00000000..f6055d1f --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/torch2onnx.py @@ -0,0 +1,53 @@ +import numpy as np +import onnx +import torch + + +def convert_onnx(net, path_module, output, opset=11, simplify=False): + assert isinstance(net, torch.nn.Module) + img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) + img = img.astype(np.float) + img = (img / 255. - 0.5) / 0.5 # torch style norm + img = img.transpose((2, 0, 1)) + img = torch.from_numpy(img).unsqueeze(0).float() + + weight = torch.load(path_module) + net.load_state_dict(weight, strict=True) + net.eval() + torch.onnx.export(net, img, output, input_names=["data"], keep_initializers_as_inputs=False, verbose=False, opset_version=opset) + model = onnx.load(output) + graph = model.graph + graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None' + if simplify: + from onnxsim import simplify + model, check = simplify(model) + assert check, "Simplified ONNX model could not be validated" + onnx.save(model, output) + + +if __name__ == '__main__': + import os + import argparse + from backbones import get_model + + parser = argparse.ArgumentParser(description='ArcFace PyTorch to onnx') + parser.add_argument('input', type=str, help='input backbone.pth file or path') + parser.add_argument('--output', type=str, default=None, help='output onnx path') + parser.add_argument('--network', type=str, default=None, help='backbone network') + parser.add_argument('--simplify', type=bool, default=False, help='onnx simplify') + args = parser.parse_args() + input_file = args.input + if os.path.isdir(input_file): + input_file = os.path.join(input_file, "model.pt") + assert os.path.exists(input_file) + # model_name = os.path.basename(os.path.dirname(input_file)).lower() + # params = model_name.split("_") + # if len(params) >= 3 and params[1] in ('arcface', 'cosface'): + # if args.network is None: + # args.network = params[2] + assert args.network is not None + print(args) + backbone_onnx = get_model(args.network, dropout=0.0, fp16=False, num_features=512) + if args.output is None: + args.output = os.path.join(os.path.dirname(args.input), "model.onnx") + convert_onnx(backbone_onnx, input_file, args.output, simplify=args.simplify) diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/train.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/train.py new file mode 100644 index 00000000..b4b49e71 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/train.py @@ -0,0 +1,260 @@ +import argparse +import logging +import os +from datetime import datetime + +import numpy as np +import torch +from backbones import get_model +from dataset import get_dataloader +from losses import CombinedMarginLoss +from lr_scheduler import PolyScheduler +from partial_fc import PartialFC, PartialFCAdamW +from torch import distributed +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from utils.utils_callbacks import CallBackLogging, CallBackVerification +from utils.utils_config import get_config +from utils.utils_distributed_sampler import setup_seed +from utils.utils_logging import AverageMeter, init_logging + +assert torch.__version__ >= "1.12.0", "In order to enjoy the features of the new torch, \ +we have upgraded the torch to 1.12.0. torch before than 1.12.0 may not work in the future." + +try: + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + distributed.init_process_group("nccl") +except KeyError: + rank = 0 + local_rank = 0 + world_size = 1 + distributed.init_process_group( + backend="nccl", + init_method="tcp://127.0.0.1:12584", + rank=rank, + world_size=world_size, + ) + + +def main(args): + + # get config + cfg = get_config(args.config) + # global control random seed + setup_seed(seed=cfg.seed, cuda_deterministic=False) + + torch.cuda.set_device(local_rank) + + os.makedirs(cfg.output, exist_ok=True) + init_logging(rank, cfg.output) + + summary_writer = ( + SummaryWriter(log_dir=os.path.join(cfg.output, "tensorboard")) + if rank == 0 + else None + ) + + wandb_logger = None + if cfg.using_wandb: + import wandb + # Sign in to wandb + try: + wandb.login(key=cfg.wandb_key) + except Exception as e: + print("WandB Key must be provided in config file (base.py).") + print(f"Config Error: {e}") + # Initialize wandb + run_name = datetime.now().strftime("%y%m%d_%H%M") + f"_GPU{rank}" + run_name = run_name if cfg.suffix_run_name is None else run_name + f"_{cfg.suffix_run_name}" + try: + wandb_logger = wandb.init( + entity = cfg.wandb_entity, + project = cfg.wandb_project, + sync_tensorboard = True, + resume=cfg.wandb_resume, + name = run_name, + notes = cfg.notes) if rank == 0 or cfg.wandb_log_all else None + if wandb_logger: + wandb_logger.config.update(cfg) + except Exception as e: + print("WandB Data (Entity and Project name) must be provided in config file (base.py).") + print(f"Config Error: {e}") + + train_loader = get_dataloader( + cfg.rec, + local_rank, + cfg.batch_size, + cfg.dali, + cfg.seed, + cfg.num_workers + ) + + backbone = get_model( + cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).cuda() + + backbone = torch.nn.parallel.DistributedDataParallel( + module=backbone, broadcast_buffers=False, device_ids=[local_rank], bucket_cap_mb=16, + find_unused_parameters=True) + + backbone.train() + # FIXME using gradient checkpoint if there are some unused parameters will cause error + backbone._set_static_graph() + + margin_loss = CombinedMarginLoss( + 64, + cfg.margin_list[0], + cfg.margin_list[1], + cfg.margin_list[2], + cfg.interclass_filtering_threshold + ) + + if cfg.optimizer == "sgd": + module_partial_fc = PartialFC( + margin_loss, cfg.embedding_size, cfg.num_classes, + cfg.sample_rate, cfg.fp16) + module_partial_fc.train().cuda() + # TODO the params of partial fc must be last in the params list + opt = torch.optim.SGD( + params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}], + lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay) + + elif cfg.optimizer == "adamw": + module_partial_fc = PartialFCAdamW( + margin_loss, cfg.embedding_size, cfg.num_classes, + cfg.sample_rate, cfg.fp16) + module_partial_fc.train().cuda() + opt = torch.optim.AdamW( + params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}], + lr=cfg.lr, weight_decay=cfg.weight_decay) + else: + raise + + cfg.total_batch_size = cfg.batch_size * world_size + cfg.warmup_step = cfg.num_image // cfg.total_batch_size * cfg.warmup_epoch + cfg.total_step = cfg.num_image // cfg.total_batch_size * cfg.num_epoch + + lr_scheduler = PolyScheduler( + optimizer=opt, + base_lr=cfg.lr, + max_steps=cfg.total_step, + warmup_steps=cfg.warmup_step, + last_epoch=-1 + ) + + start_epoch = 0 + global_step = 0 + if cfg.resume: + dict_checkpoint = torch.load(os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt")) + start_epoch = dict_checkpoint["epoch"] + global_step = dict_checkpoint["global_step"] + backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"]) + module_partial_fc.load_state_dict(dict_checkpoint["state_dict_softmax_fc"]) + opt.load_state_dict(dict_checkpoint["state_optimizer"]) + lr_scheduler.load_state_dict(dict_checkpoint["state_lr_scheduler"]) + del dict_checkpoint + + for key, value in cfg.items(): + num_space = 25 - len(key) + logging.info(": " + key + " " * num_space + str(value)) + + callback_verification = CallBackVerification( + val_targets=cfg.val_targets, rec_prefix=cfg.rec, + summary_writer=summary_writer, wandb_logger = wandb_logger + ) + callback_logging = CallBackLogging( + frequent=cfg.frequent, + total_step=cfg.total_step, + batch_size=cfg.batch_size, + start_step = global_step, + writer=summary_writer + ) + + loss_am = AverageMeter() + amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100) + + for epoch in range(start_epoch, cfg.num_epoch): + + if isinstance(train_loader, DataLoader): + train_loader.sampler.set_epoch(epoch) + for _, (img, local_labels) in enumerate(train_loader): + global_step += 1 + local_embeddings = backbone(img) + loss: torch.Tensor = module_partial_fc(local_embeddings, local_labels, opt) + + if cfg.fp16: + amp.scale(loss).backward() + amp.unscale_(opt) + torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5) + amp.step(opt) + amp.update() + else: + loss.backward() + torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5) + opt.step() + + opt.zero_grad() + lr_scheduler.step() + + with torch.no_grad(): + if wandb_logger: + wandb_logger.log({ + 'Loss/Step Loss': loss.item(), + 'Loss/Train Loss': loss_am.avg, + 'Process/Step': global_step, + 'Process/Epoch': epoch + }) + + loss_am.update(loss.item(), 1) + callback_logging(global_step, loss_am, epoch, cfg.fp16, lr_scheduler.get_last_lr()[0], amp) + + if global_step % cfg.verbose == 0 and global_step > 0: + callback_verification(global_step, backbone) + + if cfg.save_all_states: + checkpoint = { + "epoch": epoch + 1, + "global_step": global_step, + "state_dict_backbone": backbone.module.state_dict(), + "state_dict_softmax_fc": module_partial_fc.state_dict(), + "state_optimizer": opt.state_dict(), + "state_lr_scheduler": lr_scheduler.state_dict() + } + torch.save(checkpoint, os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt")) + + if rank == 0: + path_module = os.path.join(cfg.output, "model.pt") + torch.save(backbone.module.state_dict(), path_module) + + if wandb_logger and cfg.save_artifacts: + artifact_name = f"{run_name}_E{epoch}" + model = wandb.Artifact(artifact_name, type='model') + model.add_file(path_module) + wandb_logger.log_artifact(model) + + if cfg.dali: + train_loader.reset() + + if rank == 0: + path_module = os.path.join(cfg.output, "model.pt") + torch.save(backbone.module.state_dict(), path_module) + + from torch2onnx import convert_onnx + convert_onnx(backbone.module.cpu().eval(), path_module, os.path.join(cfg.output, "model.onnx")) + + if wandb_logger and cfg.save_artifacts: + artifact_name = f"{run_name}_Final" + model = wandb.Artifact(artifact_name, type='model') + model.add_file(path_module) + wandb_logger.log_artifact(model) + + distributed.destroy_process_group() + + +if __name__ == "__main__": + torch.backends.cudnn.benchmark = True + parser = argparse.ArgumentParser( + description="Distributed Arcface Training in Pytorch") + parser.add_argument("config", type=str, help="py config file") + main(parser.parse_args()) diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/train_v2.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/train_v2.py new file mode 100644 index 00000000..5d53e801 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/train_v2.py @@ -0,0 +1,258 @@ +import argparse +import logging +import os +from datetime import datetime + +import numpy as np +import torch +from backbones import get_model +from dataset import get_dataloader +from losses import CombinedMarginLoss +from lr_scheduler import PolyScheduler +from partial_fc_v2 import PartialFC_V2 +from torch import distributed +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from utils.utils_callbacks import CallBackLogging, CallBackVerification +from utils.utils_config import get_config +from utils.utils_distributed_sampler import setup_seed +from utils.utils_logging import AverageMeter, init_logging + +assert torch.__version__ >= "1.12.0", "In order to enjoy the features of the new torch, \ +we have upgraded the torch to 1.12.0. torch before than 1.12.0 may not work in the future." + +try: + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + distributed.init_process_group("nccl") +except KeyError: + rank = 0 + local_rank = 0 + world_size = 1 + distributed.init_process_group( + backend="nccl", + init_method="tcp://127.0.0.1:12584", + rank=rank, + world_size=world_size, + ) + + +def main(args): + + # get config + cfg = get_config(args.config) + # global control random seed + setup_seed(seed=cfg.seed, cuda_deterministic=False) + + torch.cuda.set_device(local_rank) + + os.makedirs(cfg.output, exist_ok=True) + init_logging(rank, cfg.output) + + summary_writer = ( + SummaryWriter(log_dir=os.path.join(cfg.output, "tensorboard")) + if rank == 0 + else None + ) + + wandb_logger = None + if cfg.using_wandb: + import wandb + # Sign in to wandb + try: + wandb.login(key=cfg.wandb_key) + except Exception as e: + print("WandB Key must be provided in config file (base.py).") + print(f"Config Error: {e}") + # Initialize wandb + run_name = datetime.now().strftime("%y%m%d_%H%M") + f"_GPU{rank}" + run_name = run_name if cfg.suffix_run_name is None else run_name + f"_{cfg.suffix_run_name}" + try: + wandb_logger = wandb.init( + entity = cfg.wandb_entity, + project = cfg.wandb_project, + sync_tensorboard = True, + resume=cfg.wandb_resume, + name = run_name, + notes = cfg.notes) if rank == 0 or cfg.wandb_log_all else None + if wandb_logger: + wandb_logger.config.update(cfg) + except Exception as e: + print("WandB Data (Entity and Project name) must be provided in config file (base.py).") + print(f"Config Error: {e}") + + train_loader = get_dataloader( + cfg.rec, + local_rank, + cfg.batch_size, + cfg.dali, + cfg.seed, + cfg.num_workers + ) + + backbone = get_model( + cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).cuda() + + backbone = torch.nn.parallel.DistributedDataParallel( + module=backbone, broadcast_buffers=False, device_ids=[local_rank], bucket_cap_mb=16, + find_unused_parameters=True) + + backbone.train() + # FIXME using gradient checkpoint if there are some unused parameters will cause error + backbone._set_static_graph() + + margin_loss = CombinedMarginLoss( + 64, + cfg.margin_list[0], + cfg.margin_list[1], + cfg.margin_list[2], + cfg.interclass_filtering_threshold + ) + + if cfg.optimizer == "sgd": + module_partial_fc = PartialFC_V2( + margin_loss, cfg.embedding_size, cfg.num_classes, + cfg.sample_rate, cfg.fp16) + module_partial_fc.train().cuda() + # TODO the params of partial fc must be last in the params list + opt = torch.optim.SGD( + params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}], + lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay) + + elif cfg.optimizer == "adamw": + module_partial_fc = PartialFC_V2( + margin_loss, cfg.embedding_size, cfg.num_classes, + cfg.sample_rate, cfg.fp16) + module_partial_fc.train().cuda() + opt = torch.optim.AdamW( + params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}], + lr=cfg.lr, weight_decay=cfg.weight_decay) + else: + raise + + cfg.total_batch_size = cfg.batch_size * world_size + cfg.warmup_step = cfg.num_image // cfg.total_batch_size * cfg.warmup_epoch + cfg.total_step = cfg.num_image // cfg.total_batch_size * cfg.num_epoch + + lr_scheduler = PolyScheduler( + optimizer=opt, + base_lr=cfg.lr, + max_steps=cfg.total_step, + warmup_steps=cfg.warmup_step, + last_epoch=-1 + ) + + start_epoch = 0 + global_step = 0 + if cfg.resume: + dict_checkpoint = torch.load(os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt")) + start_epoch = dict_checkpoint["epoch"] + global_step = dict_checkpoint["global_step"] + backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"]) + module_partial_fc.load_state_dict(dict_checkpoint["state_dict_softmax_fc"]) + opt.load_state_dict(dict_checkpoint["state_optimizer"]) + lr_scheduler.load_state_dict(dict_checkpoint["state_lr_scheduler"]) + del dict_checkpoint + + for key, value in cfg.items(): + num_space = 25 - len(key) + logging.info(": " + key + " " * num_space + str(value)) + + callback_verification = CallBackVerification( + val_targets=cfg.val_targets, rec_prefix=cfg.rec, + summary_writer=summary_writer, wandb_logger = wandb_logger + ) + callback_logging = CallBackLogging( + frequent=cfg.frequent, + total_step=cfg.total_step, + batch_size=cfg.batch_size, + start_step = global_step, + writer=summary_writer + ) + + loss_am = AverageMeter() + amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100) + + for epoch in range(start_epoch, cfg.num_epoch): + + if isinstance(train_loader, DataLoader): + train_loader.sampler.set_epoch(epoch) + for _, (img, local_labels) in enumerate(train_loader): + global_step += 1 + local_embeddings = backbone(img) + loss: torch.Tensor = module_partial_fc(local_embeddings, local_labels) + + if cfg.fp16: + amp.scale(loss).backward() + if global_step % cfg.gradient_acc == 0: + amp.unscale_(opt) + torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5) + amp.step(opt) + amp.update() + opt.zero_grad() + else: + loss.backward() + if global_step % cfg.gradient_acc == 0: + torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5) + opt.step() + opt.zero_grad() + lr_scheduler.step() + + with torch.no_grad(): + if wandb_logger: + wandb_logger.log({ + 'Loss/Step Loss': loss.item(), + 'Loss/Train Loss': loss_am.avg, + 'Process/Step': global_step, + 'Process/Epoch': epoch + }) + + loss_am.update(loss.item(), 1) + callback_logging(global_step, loss_am, epoch, cfg.fp16, lr_scheduler.get_last_lr()[0], amp) + + if global_step % cfg.verbose == 0 and global_step > 0: + callback_verification(global_step, backbone) + + if cfg.save_all_states: + checkpoint = { + "epoch": epoch + 1, + "global_step": global_step, + "state_dict_backbone": backbone.module.state_dict(), + "state_dict_softmax_fc": module_partial_fc.state_dict(), + "state_optimizer": opt.state_dict(), + "state_lr_scheduler": lr_scheduler.state_dict() + } + torch.save(checkpoint, os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt")) + + if rank == 0: + path_module = os.path.join(cfg.output, "model.pt") + torch.save(backbone.module.state_dict(), path_module) + + if wandb_logger and cfg.save_artifacts: + artifact_name = f"{run_name}_E{epoch}" + model = wandb.Artifact(artifact_name, type='model') + model.add_file(path_module) + wandb_logger.log_artifact(model) + + if cfg.dali: + train_loader.reset() + + if rank == 0: + path_module = os.path.join(cfg.output, "model.pt") + torch.save(backbone.module.state_dict(), path_module) + + if wandb_logger and cfg.save_artifacts: + artifact_name = f"{run_name}_Final" + model = wandb.Artifact(artifact_name, type='model') + model.add_file(path_module) + wandb_logger.log_artifact(model) + + + +if __name__ == "__main__": + torch.backends.cudnn.benchmark = True + parser = argparse.ArgumentParser( + description="Distributed Arcface Training in Pytorch") + parser.add_argument("config", type=str, help="py config file") + main(parser.parse_args()) diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/utils/__init__.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/utils/plot.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/utils/plot.py new file mode 100644 index 00000000..7f1d39da --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/utils/plot.py @@ -0,0 +1,71 @@ +import os +import sys + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap +from prettytable import PrettyTable +from sklearn.metrics import roc_curve, auc + +with open(sys.argv[1], "r") as f: + files = f.readlines() + +files = [x.strip() for x in files] +image_path = "/train_tmp/IJB_release/IJBC" + + +def read_template_pair_list(path): + pairs = pd.read_csv(path, sep=' ', header=None).values + t1 = pairs[:, 0].astype(np.int) + t2 = pairs[:, 1].astype(np.int) + label = pairs[:, 2].astype(np.int) + return t1, t2, label + + +p1, p2, label = read_template_pair_list( + os.path.join('%s/meta' % image_path, + '%s_template_pair_label.txt' % 'ijbc')) + +methods = [] +scores = [] +for file in files: + methods.append(file) + scores.append(np.load(file)) + +methods = np.array(methods) +scores = dict(zip(methods, scores)) +colours = dict( + zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) +x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] +tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) +fig = plt.figure() +for method in methods: + fpr, tpr, _ = roc_curve(label, scores[method]) + roc_auc = auc(fpr, tpr) + fpr = np.flipud(fpr) + tpr = np.flipud(tpr) # select largest tpr at same fpr + plt.plot(fpr, + tpr, + color=colours[method], + lw=1, + label=('[%s (AUC = %0.4f %%)]' % + (method.split('-')[-1], roc_auc * 100))) + tpr_fpr_row = [] + tpr_fpr_row.append(method) + for fpr_iter in np.arange(len(x_labels)): + _, min_index = min( + list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) + tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) + tpr_fpr_table.add_row(tpr_fpr_row) +plt.xlim([10 ** -6, 0.1]) +plt.ylim([0.3, 1.0]) +plt.grid(linestyle='--', linewidth=1) +plt.xticks(x_labels) +plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) +plt.xscale('log') +plt.xlabel('False Positive Rate') +plt.ylabel('True Positive Rate') +plt.title('ROC on IJB') +plt.legend(loc="lower right") +print(tpr_fpr_table) diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/utils/utils_callbacks.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/utils/utils_callbacks.py new file mode 100644 index 00000000..d9368073 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/utils/utils_callbacks.py @@ -0,0 +1,125 @@ +import logging +import os +import time +from typing import List + +import torch + +from eval import verification +from utils.utils_logging import AverageMeter +from torch.utils.tensorboard import SummaryWriter +from torch import distributed + + +class CallBackVerification(object): + + def __init__(self, val_targets, rec_prefix, summary_writer=None, image_size=(112, 112), wandb_logger=None): + self.rank: int = distributed.get_rank() + self.highest_acc: float = 0.0 + self.highest_acc_list: List[float] = [0.0] * len(val_targets) + self.ver_list: List[object] = [] + self.ver_name_list: List[str] = [] + if self.rank is 0: + self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size) + + self.summary_writer = summary_writer + self.wandb_logger = wandb_logger + + def ver_test(self, backbone: torch.nn.Module, global_step: int): + results = [] + for i in range(len(self.ver_list)): + acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test( + self.ver_list[i], backbone, 10, 10) + logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm)) + logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2)) + + self.summary_writer: SummaryWriter + self.summary_writer.add_scalar(tag=self.ver_name_list[i], scalar_value=acc2, global_step=global_step, ) + if self.wandb_logger: + import wandb + self.wandb_logger.log({ + f'Acc/val-Acc1 {self.ver_name_list[i]}': acc1, + f'Acc/val-Acc2 {self.ver_name_list[i]}': acc2, + # f'Acc/val-std1 {self.ver_name_list[i]}': std1, + # f'Acc/val-std2 {self.ver_name_list[i]}': acc2, + }) + + if acc2 > self.highest_acc_list[i]: + self.highest_acc_list[i] = acc2 + logging.info( + '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i])) + results.append(acc2) + + def init_dataset(self, val_targets, data_dir, image_size): + for name in val_targets: + path = os.path.join(data_dir, name + ".bin") + if os.path.exists(path): + data_set = verification.load_bin(path, image_size) + self.ver_list.append(data_set) + self.ver_name_list.append(name) + + def __call__(self, num_update, backbone: torch.nn.Module): + if self.rank is 0 and num_update > 0: + backbone.eval() + self.ver_test(backbone, num_update) + backbone.train() + + +class CallBackLogging(object): + def __init__(self, frequent, total_step, batch_size, start_step=0,writer=None): + self.frequent: int = frequent + self.rank: int = distributed.get_rank() + self.world_size: int = distributed.get_world_size() + self.time_start = time.time() + self.total_step: int = total_step + self.start_step: int = start_step + self.batch_size: int = batch_size + self.writer = writer + + self.init = False + self.tic = 0 + + def __call__(self, + global_step: int, + loss: AverageMeter, + epoch: int, + fp16: bool, + learning_rate: float, + grad_scaler: torch.cuda.amp.GradScaler): + if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0: + if self.init: + try: + speed: float = self.frequent * self.batch_size / (time.time() - self.tic) + speed_total = speed * self.world_size + except ZeroDivisionError: + speed_total = float('inf') + + #time_now = (time.time() - self.time_start) / 3600 + #time_total = time_now / ((global_step + 1) / self.total_step) + #time_for_end = time_total - time_now + time_now = time.time() + time_sec = int(time_now - self.time_start) + time_sec_avg = time_sec / (global_step - self.start_step + 1) + eta_sec = time_sec_avg * (self.total_step - global_step - 1) + time_for_end = eta_sec/3600 + if self.writer is not None: + self.writer.add_scalar('time_for_end', time_for_end, global_step) + self.writer.add_scalar('learning_rate', learning_rate, global_step) + self.writer.add_scalar('loss', loss.avg, global_step) + if fp16: + msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.6f Epoch: %d Global Step: %d " \ + "Fp16 Grad Scale: %2.f Required: %1.f hours" % ( + speed_total, loss.avg, learning_rate, epoch, global_step, + grad_scaler.get_scale(), time_for_end + ) + else: + msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.6f Epoch: %d Global Step: %d " \ + "Required: %1.f hours" % ( + speed_total, loss.avg, learning_rate, epoch, global_step, time_for_end + ) + logging.info(msg) + loss.reset() + self.tic = time.time() + else: + self.init = True + self.tic = time.time() diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/utils/utils_config.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/utils/utils_config.py new file mode 100644 index 00000000..0c02eaf7 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/utils/utils_config.py @@ -0,0 +1,16 @@ +import importlib +import os.path as osp + + +def get_config(config_file): + assert config_file.startswith('configs/'), 'config file setting must start with configs/' + temp_config_name = osp.basename(config_file) + temp_module_name = osp.splitext(temp_config_name)[0] + config = importlib.import_module("configs.base") + cfg = config.config + config = importlib.import_module("configs.%s" % temp_module_name) + job_cfg = config.config + cfg.update(job_cfg) + if cfg.output is None: + cfg.output = osp.join('work_dirs', temp_module_name) + return cfg \ No newline at end of file diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/utils/utils_distributed_sampler.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/utils/utils_distributed_sampler.py new file mode 100644 index 00000000..cea67039 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/utils/utils_distributed_sampler.py @@ -0,0 +1,126 @@ +import math +import os +import random + +import numpy as np +import torch +import torch.distributed as dist +from torch.utils.data import DistributedSampler as _DistributedSampler + + +def setup_seed(seed, cuda_deterministic=True): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + if cuda_deterministic: # slower, more reproducible + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + else: # faster, less reproducible + torch.backends.cudnn.deterministic = False + torch.backends.cudnn.benchmark = True + + +def worker_init_fn(worker_id, num_workers, rank, seed): + # The seed of each worker equals to + # num_worker * rank + worker_id + user_seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) + torch.manual_seed(worker_seed) + + +def get_dist_info(): + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + + return rank, world_size + + +def sync_random_seed(seed=None, device="cuda"): + """Make sure different ranks share the same seed. + All workers must call this function, otherwise it will deadlock. + This method is generally used in `DistributedSampler`, + because the seed should be identical across all processes + in the distributed group. + In distributed sampling, different ranks should sample non-overlapped + data in the dataset. Therefore, this function is used to make sure that + each rank shuffles the data indices in the same order based + on the same seed. Then different ranks could use different indices + to select non-overlapped data from the same data list. + Args: + seed (int, Optional): The seed. Default to None. + device (str): The device where the seed will be put on. + Default to 'cuda'. + Returns: + int: Seed to be used. + """ + if seed is None: + seed = np.random.randint(2**31) + assert isinstance(seed, int) + + rank, world_size = get_dist_info() + + if world_size == 1: + return seed + + if rank == 0: + random_num = torch.tensor(seed, dtype=torch.int32, device=device) + else: + random_num = torch.tensor(0, dtype=torch.int32, device=device) + + dist.broadcast(random_num, src=0) + + return random_num.item() + + +class DistributedSampler(_DistributedSampler): + def __init__( + self, + dataset, + num_replicas=None, # world_size + rank=None, # local_rank + shuffle=True, + seed=0, + ): + + super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) + + # In distributed sampling, different ranks should sample + # non-overlapped data in the dataset. Therefore, this function + # is used to make sure that each rank shuffles the data indices + # in the same order based on the same seed. Then different ranks + # could use different indices to select non-overlapped data from the + # same data list. + self.seed = sync_random_seed(seed) + + def __iter__(self): + # deterministically shuffle based on epoch + if self.shuffle: + g = torch.Generator() + # When :attr:`shuffle=True`, this ensures all replicas + # use a different random ordering for each epoch. + # Otherwise, the next iteration of this sampler will + # yield the same ordering. + g.manual_seed(self.epoch + self.seed) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = torch.arange(len(self.dataset)).tolist() + + # add extra samples to make it evenly divisible + # in case that indices is shorter than half of total_size + indices = (indices * math.ceil(self.total_size / len(indices)))[ + : self.total_size + ] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/utils/utils_logging.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/utils/utils_logging.py new file mode 100644 index 00000000..c787b6aa --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/arcface_torch/utils/utils_logging.py @@ -0,0 +1,41 @@ +import logging +import os +import sys + + +class AverageMeter(object): + """Computes and stores the average and current value + """ + + def __init__(self): + self.val = None + self.avg = None + self.sum = None + self.count = None + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def init_logging(rank, models_root): + if rank == 0: + log_root = logging.getLogger() + log_root.setLevel(logging.INFO) + formatter = logging.Formatter("Training: %(asctime)s-%(message)s") + handler_file = logging.FileHandler(os.path.join(models_root, "training.log")) + handler_stream = logging.StreamHandler(sys.stdout) + handler_file.setFormatter(formatter) + handler_stream.setFormatter(formatter) + log_root.addHandler(handler_file) + log_root.addHandler(handler_stream) + log_root.info('rank_id: %d' % rank) diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/base_model.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/base_model.py new file mode 100644 index 00000000..2a05d3a0 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/base_model.py @@ -0,0 +1,316 @@ +"""This script defines the base network model for Deep3DFaceRecon_pytorch +""" + +import os +import numpy as np +import torch +from collections import OrderedDict +from abc import ABC, abstractmethod +from . import networks + + +class BaseModel(ABC): + """This class is an abstract base class (ABC) for models. + To create a subclass, you need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate losses, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the BaseModel class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + + When creating your custom class, you need to implement your own initialization. + In this fucntion, you should first call + Then, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): specify the images that you want to display and save. + -- self.visual_names (str list): define networks used in our training. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + """ + self.opt = opt + self.isTrain = opt.isTrain + self.device = torch.device('cpu') + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir + self.loss_names = [] + self.model_names = [] + self.visual_names = [] + self.parallel_names = [] + self.optimizers = [] + self.image_paths = [] + self.metric = 0 # used for learning rate policy 'plateau' + + @staticmethod + def dict_grad_hook_factory(add_func=lambda x: x): + saved_dict = dict() + + def hook_gen(name): + def grad_hook(grad): + saved_vals = add_func(grad) + saved_dict[name] = saved_vals + return grad_hook + return hook_gen, saved_dict + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new model-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): includes the data itself and its metadata information. + """ + pass + + @abstractmethod + def forward(self): + """Run forward pass; called by both functions and .""" + pass + + @abstractmethod + def optimize_parameters(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + pass + + def setup(self, opt): + """Load and print networks; create schedulers + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + if self.isTrain: + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + + if not self.isTrain or opt.continue_train: + load_suffix = opt.epoch + self.load_networks(load_suffix) + + + # self.print_networks(opt.verbose) + + def parallelize(self, convert_sync_batchnorm=True): + if not self.opt.use_ddp: + for name in self.parallel_names: + if isinstance(name, str): + module = getattr(self, name) + setattr(self, name, module.to(self.device)) + else: + for name in self.model_names: + if isinstance(name, str): + module = getattr(self, name) + if convert_sync_batchnorm: + module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module) + setattr(self, name, torch.nn.parallel.DistributedDataParallel(module.to(self.device), + device_ids=[self.device.index], + find_unused_parameters=True, broadcast_buffers=True)) + + # DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient. + for name in self.parallel_names: + if isinstance(name, str) and name not in self.model_names: + module = getattr(self, name) + setattr(self, name, module.to(self.device)) + + # put state_dict of optimizer to gpu device + if self.opt.phase != 'test': + if self.opt.continue_train: + for optim in self.optimizers: + for state in optim.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(self.device) + + def data_dependent_initialize(self, data): + pass + + def train(self): + """Make models train mode""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + net.train() + + def eval(self): + """Make models eval mode""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + net.eval() + + def test(self): + """Forward function used in test time. + + This function wraps function in no_grad() so we don't save intermediate steps for backprop + It also calls to produce additional visualization results + """ + with torch.no_grad(): + self.forward() + self.compute_visuals() + + def compute_visuals(self): + """Calculate additional output images for visdom and HTML visualization""" + pass + + def get_image_paths(self, name='A'): + """ Return image paths that are used to load current data""" + return self.image_paths if name =='A' else self.image_paths_B + + def update_learning_rate(self): + """Update learning rates for all the networks; called at the end of every epoch""" + for scheduler in self.schedulers: + if self.opt.lr_policy == 'plateau': + scheduler.step(self.metric) + else: + scheduler.step() + + lr = self.optimizers[0].param_groups[0]['lr'] + print('learning rate = %.7f' % lr) + + def get_current_visuals(self): + """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" + visual_ret = OrderedDict() + for name in self.visual_names: + if isinstance(name, str): + visual_ret[name] = getattr(self, name)[:, :3, ...] + return visual_ret + + def get_current_losses(self): + """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" + errors_ret = OrderedDict() + for name in self.loss_names: + if isinstance(name, str): + errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number + return errors_ret + + def save_networks(self, epoch): + """Save all the networks to the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + if not os.path.isdir(self.save_dir): + os.makedirs(self.save_dir) + + save_filename = 'epoch_%s.pth' % (epoch) + save_path = os.path.join(self.save_dir, save_filename) + + save_dict = {} + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + if isinstance(net, torch.nn.DataParallel) or isinstance(net, + torch.nn.parallel.DistributedDataParallel): + net = net.module + save_dict[name] = net.state_dict() + + + for i, optim in enumerate(self.optimizers): + save_dict['opt_%02d'%i] = optim.state_dict() + + for i, sched in enumerate(self.schedulers): + save_dict['sched_%02d'%i] = sched.state_dict() + + torch.save(save_dict, save_path) + + def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): + """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" + key = keys[i] + if i + 1 == len(keys): # at the end, pointing to a parameter/buffer + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'running_mean' or key == 'running_var'): + if getattr(module, key) is None: + state_dict.pop('.'.join(keys)) + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'num_batches_tracked'): + state_dict.pop('.'.join(keys)) + else: + self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) + + def load_networks(self, epoch): + """Load all the networks from the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + if self.opt.isTrain and self.opt.pretrained_name is not None: + load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name) + else: + load_dir = self.save_dir + load_filename = 'epoch_%s.pth' % (epoch) + load_path = os.path.join(load_dir, load_filename) + state_dict = torch.load(load_path, map_location=self.device) + print('loading the model from %s' % load_path) + + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + if isinstance(net, torch.nn.DataParallel): + net = net.module + net.load_state_dict(state_dict[name]) + + if self.opt.phase != 'test': + if self.opt.continue_train: + print('loading the optim from %s' % load_path) + for i, optim in enumerate(self.optimizers): + optim.load_state_dict(state_dict['opt_%02d'%i]) + + try: + print('loading the sched from %s' % load_path) + for i, sched in enumerate(self.schedulers): + sched.load_state_dict(state_dict['sched_%02d'%i]) + except: + print('Failed to load schedulers, set schedulers according to epoch count manually') + for i, sched in enumerate(self.schedulers): + sched.last_epoch = self.opt.epoch_count - 1 + + + + + def print_networks(self, verbose): + """Print the total number of parameters in the network and (if verbose) network architecture + + Parameters: + verbose (bool) -- if verbose: print the network architecture + """ + print('---------- Networks initialized -------------') + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + if verbose: + print(net) + print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) + print('-----------------------------------------------') + + def set_requires_grad(self, nets, requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad + + def generate_visuals_for_evaluation(self, data, mode): + return {} diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/bfm.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/bfm.py new file mode 100644 index 00000000..e2b7cc34 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/bfm.py @@ -0,0 +1,299 @@ +"""This script defines the parametric 3d face model for Deep3DFaceRecon_pytorch +""" + +import numpy as np +import torch +import torch.nn.functional as F +from scipy.io import loadmat +from deep_3drecon.util.load_mats import transferBFM09 +import os + +def perspective_projection(focal, center): + # return p.T (N, 3) @ (3, 3) + return np.array([ + focal, 0, center, + 0, focal, center, + 0, 0, 1 + ]).reshape([3, 3]).astype(np.float32).transpose() + +class SH: + def __init__(self): + self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)] + self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)] + + + +class ParametricFaceModel: + def __init__(self, + bfm_folder='./BFM', + recenter=True, + camera_distance=10., + init_lit=np.array([ + 0.8, 0, 0, 0, 0, 0, 0, 0, 0 + ]), + focal=1015., + center=112., + is_train=True, + default_name='BFM_model_front.mat'): + + if not os.path.isfile(os.path.join(bfm_folder, default_name)): + transferBFM09(bfm_folder) + model = loadmat(os.path.join(bfm_folder, default_name)) + # mean face shape. [3*N,1] + self.mean_shape = model['meanshape'].astype(np.float32) + # identity basis. [3*N,80] + self.id_base = model['idBase'].astype(np.float32) + # expression basis. [3*N,64] + self.exp_base = model['exBase'].astype(np.float32) + # mean face texture. [3*N,1] (0-255) + self.mean_tex = model['meantex'].astype(np.float32) + # texture basis. [3*N,80] + self.tex_base = model['texBase'].astype(np.float32) + # face indices for each vertex that lies in. starts from 0. [N,8] + self.point_buf = model['point_buf'].astype(np.int64) - 1 + # vertex indices for each face. starts from 0. [F,3] + self.face_buf = model['tri'].astype(np.int64) - 1 + # vertex indices for 68 landmarks. starts from 0. [68,1] + self.keypoints = np.squeeze(model['keypoints']).astype(np.int64) - 1 + + if is_train: + # vertex indices for small face region to compute photometric error. starts from 0. + self.front_mask = np.squeeze(model['frontmask2_idx']).astype(np.int64) - 1 + # vertex indices for each face from small face region. starts from 0. [f,3] + self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1 + # vertex indices for pre-defined skin region to compute reflectance loss + self.skin_mask = np.squeeze(model['skinmask']) + + if recenter: + mean_shape = self.mean_shape.reshape([-1, 3]) + mean_shape = mean_shape - np.mean(mean_shape, axis=0, keepdims=True) + self.mean_shape = mean_shape.reshape([-1, 1]) + + self.persc_proj = perspective_projection(focal, center) + self.device = 'cpu' + self.camera_distance = camera_distance + self.SH = SH() + self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32) + + + def to(self, device): + self.device = device + for key, value in self.__dict__.items(): + if type(value).__module__ == np.__name__: + setattr(self, key, torch.tensor(value).to(device)) + + + def compute_shape(self, id_coeff, exp_coeff): + """ + Return: + face_shape -- torch.tensor, size (B, N, 3) + + Parameters: + id_coeff -- torch.tensor, size (B, 80), identity coeffs + exp_coeff -- torch.tensor, size (B, 64), expression coeffs + """ + batch_size = id_coeff.shape[0] + id_part = torch.einsum('ij,aj->ai', self.id_base, id_coeff) + exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff) + face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1]) + return face_shape.reshape([batch_size, -1, 3]) + + + def compute_texture(self, tex_coeff, normalize=True): + """ + Return: + face_texture -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.) + + Parameters: + tex_coeff -- torch.tensor, size (B, 80) + """ + batch_size = tex_coeff.shape[0] + face_texture = torch.einsum('ij,aj->ai', self.tex_base, tex_coeff) + self.mean_tex + if normalize: + face_texture = face_texture / 255. + return face_texture.reshape([batch_size, -1, 3]) + + + def compute_norm(self, face_shape): + """ + Return: + vertex_norm -- torch.tensor, size (B, N, 3) + + Parameters: + face_shape -- torch.tensor, size (B, N, 3) + """ + + v1 = face_shape[:, self.face_buf[:, 0]] + v2 = face_shape[:, self.face_buf[:, 1]] + v3 = face_shape[:, self.face_buf[:, 2]] + e1 = v1 - v2 + e2 = v2 - v3 + face_norm = torch.cross(e1, e2, dim=-1) + face_norm = F.normalize(face_norm, dim=-1, p=2) + face_norm = torch.cat([face_norm, torch.zeros(face_norm.shape[0], 1, 3).to(self.device)], dim=1) + + vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2) + vertex_norm = F.normalize(vertex_norm, dim=-1, p=2) + return vertex_norm + + + def compute_color(self, face_texture, face_norm, gamma): + """ + Return: + face_color -- torch.tensor, size (B, N, 3), range (0, 1.) + + Parameters: + face_texture -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.) + face_norm -- torch.tensor, size (B, N, 3), rotated face normal + gamma -- torch.tensor, size (B, 27), SH coeffs + """ + batch_size = gamma.shape[0] + v_num = face_texture.shape[1] + a, c = self.SH.a, self.SH.c + gamma = gamma.reshape([batch_size, 3, 9]) + gamma = gamma + self.init_lit + gamma = gamma.permute(0, 2, 1) + Y = torch.cat([ + a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.device), + -a[1] * c[1] * face_norm[..., 1:2], + a[1] * c[1] * face_norm[..., 2:], + -a[1] * c[1] * face_norm[..., :1], + a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2], + -a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:], + 0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:] ** 2 - 1), + -a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:], + 0.5 * a[2] * c[2] * (face_norm[..., :1] ** 2 - face_norm[..., 1:2] ** 2) + ], dim=-1) + r = Y @ gamma[..., :1] + g = Y @ gamma[..., 1:2] + b = Y @ gamma[..., 2:] + face_color = torch.cat([r, g, b], dim=-1) * face_texture + return face_color + + + def compute_rotation(self, angles): + """ + Return: + rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat + + Parameters: + angles -- torch.tensor, size (B, 3), radian + """ + + batch_size = angles.shape[0] + ones = torch.ones([batch_size, 1]).to(self.device) + zeros = torch.zeros([batch_size, 1]).to(self.device) + x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:], + + rot_x = torch.cat([ + ones, zeros, zeros, + zeros, torch.cos(x), -torch.sin(x), + zeros, torch.sin(x), torch.cos(x) + ], dim=1).reshape([batch_size, 3, 3]) + + rot_y = torch.cat([ + torch.cos(y), zeros, torch.sin(y), + zeros, ones, zeros, + -torch.sin(y), zeros, torch.cos(y) + ], dim=1).reshape([batch_size, 3, 3]) + + rot_z = torch.cat([ + torch.cos(z), -torch.sin(z), zeros, + torch.sin(z), torch.cos(z), zeros, + zeros, zeros, ones + ], dim=1).reshape([batch_size, 3, 3]) + + rot = rot_z @ rot_y @ rot_x + return rot.permute(0, 2, 1) + + + def to_camera(self, face_shape): + face_shape[..., -1] = self.camera_distance - face_shape[..., -1] + return face_shape + + def to_image(self, face_shape): + """ + Return: + face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction + + Parameters: + face_shape -- torch.tensor, size (B, N, 3) + """ + # to image_plane + face_proj = face_shape @ self.persc_proj + face_proj = face_proj[..., :2] / face_proj[..., 2:] + + return face_proj + + + def transform(self, face_shape, rot, trans): + """ + Return: + face_shape -- torch.tensor, size (B, N, 3) pts @ rot + trans + + Parameters: + face_shape -- torch.tensor, size (B, N, 3) + rot -- torch.tensor, size (B, 3, 3) + trans -- torch.tensor, size (B, 3) + """ + return face_shape @ rot + trans.unsqueeze(1) + + + def get_landmarks(self, face_proj): + """ + Return: + face_lms -- torch.tensor, size (B, 68, 2) + + Parameters: + face_proj -- torch.tensor, size (B, N, 2) + """ + return face_proj[:, self.keypoints] + + def split_coeff(self, coeffs): + """ + Return: + coeffs_dict -- a dict of torch.tensors + + Parameters: + coeffs -- torch.tensor, size (B, 256) + """ + id_coeffs = coeffs[:, :80] + exp_coeffs = coeffs[:, 80: 144] + tex_coeffs = coeffs[:, 144: 224] + angles = coeffs[:, 224: 227] + gammas = coeffs[:, 227: 254] + translations = coeffs[:, 254:] + return { + 'id': id_coeffs, + 'exp': exp_coeffs, + 'tex': tex_coeffs, + 'angle': angles, + 'gamma': gammas, + 'trans': translations + } + def compute_for_render(self, coeffs): + """ + Return: + face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate + face_color -- torch.tensor, size (B, N, 3), in RGB order + landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction + Parameters: + coeffs -- torch.tensor, size (B, 257) + """ + coef_dict = self.split_coeff(coeffs) + face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp']) + rotation = self.compute_rotation(coef_dict['angle']) + + + face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans']) + face_vertex = self.to_camera(face_shape_transformed) + + face_proj = self.to_image(face_vertex) + landmark = self.get_landmarks(face_proj) + + face_texture = self.compute_texture(coef_dict['tex']) + face_norm = self.compute_norm(face_shape) + face_norm_roted = face_norm @ rotation + face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma']) + + return face_vertex, face_texture, face_color, landmark diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/facerecon_model.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/facerecon_model.py new file mode 100644 index 00000000..c5659b24 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/facerecon_model.py @@ -0,0 +1,228 @@ +"""This script defines the face reconstruction model for Deep3DFaceRecon_pytorch +""" + +import numpy as np +import torch +from .base_model import BaseModel +from . import networks +from .bfm import ParametricFaceModel +from .losses import perceptual_loss, photo_loss, reg_loss, reflectance_loss, landmark_loss +from deep_3drecon.util import util +from deep_3drecon.util.mesh_renderer import MeshRenderer +from deep_3drecon.util.preprocess import estimate_norm_torch + +import trimesh +from scipy.io import savemat + +class FaceReconModel(BaseModel): + + @staticmethod + def modify_commandline_options(parser, is_train=True): + """ Configures options specific for CUT model + """ + # net structure and parameters + parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='network structure') + parser.add_argument('--init_path', type=str, default='checkpoints/init_model/resnet50-0676ba61.pth') + parser.add_argument('--use_last_fc', type=util.str2bool, nargs='?', const=True, default=False, help='zero initialize the last fc') + parser.add_argument('--bfm_folder', type=str, default='./deep_3drecon/BFM') + parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model') + + # renderer parameters + parser.add_argument('--focal', type=float, default=1015.) + parser.add_argument('--center', type=float, default=112.) + parser.add_argument('--camera_d', type=float, default=10.) + parser.add_argument('--z_near', type=float, default=5.) + parser.add_argument('--z_far', type=float, default=15.) + parser.add_argument('--use_opengl', type=util.str2bool, nargs='?', const=True, default=False, help='use opengl context or not') + + if is_train: + # training parameters + parser.add_argument('--net_recog', type=str, default='r50', choices=['r18', 'r43', 'r50'], help='face recog network structure') + parser.add_argument('--net_recog_path', type=str, default='checkpoints/recog_model/ms1mv3_arcface_r50_fp16/backbone.pth') + parser.add_argument('--use_crop_face', type=util.str2bool, nargs='?', const=True, default=False, help='use crop mask for photo loss') + parser.add_argument('--use_predef_M', type=util.str2bool, nargs='?', const=True, default=False, help='use predefined M for predicted face') + + + # augmentation parameters + parser.add_argument('--shift_pixs', type=float, default=10., help='shift pixels') + parser.add_argument('--scale_delta', type=float, default=0.1, help='delta scale factor') + parser.add_argument('--rot_angle', type=float, default=10., help='rot angles, degree') + + # loss weights + parser.add_argument('--w_feat', type=float, default=0.2, help='weight for feat loss') + parser.add_argument('--w_color', type=float, default=1.92, help='weight for loss loss') + parser.add_argument('--w_reg', type=float, default=3.0e-4, help='weight for reg loss') + parser.add_argument('--w_id', type=float, default=1.0, help='weight for id_reg loss') + parser.add_argument('--w_exp', type=float, default=0.8, help='weight for exp_reg loss') + parser.add_argument('--w_tex', type=float, default=1.7e-2, help='weight for tex_reg loss') + parser.add_argument('--w_gamma', type=float, default=10.0, help='weight for gamma loss') + parser.add_argument('--w_lm', type=float, default=1.6e-3, help='weight for lm loss') + parser.add_argument('--w_reflc', type=float, default=5.0, help='weight for reflc loss') + + + + opt, _ = parser.parse_known_args() + parser.set_defaults( + focal=1015., center=112., camera_d=10., use_last_fc=False, z_near=5., z_far=15. + ) + if is_train: + parser.set_defaults( + use_crop_face=True, use_predef_M=False + ) + return parser + + def __init__(self, opt): + """Initialize this model class. + + Parameters: + opt -- training/test options + + A few things can be done here. + - (required) call the initialization function of BaseModel + - define loss function, visualization images, model names, and optimizers + """ + BaseModel.__init__(self, opt) # call the initialization method of BaseModel + + self.visual_names = ['output_vis'] + self.model_names = ['net_recon'] + self.parallel_names = self.model_names + ['renderer'] + + self.net_recon = networks.define_net_recon( + net_recon=opt.net_recon, use_last_fc=opt.use_last_fc, init_path=opt.init_path + ) + + self.facemodel = ParametricFaceModel( + bfm_folder=opt.bfm_folder, camera_distance=opt.camera_d, focal=opt.focal, center=opt.center, + is_train=self.isTrain, default_name=opt.bfm_model + ) + + fov = 2 * np.arctan(opt.center / opt.focal) * 180 / np.pi + self.renderer = MeshRenderer( + rasterize_fov=fov, znear=opt.z_near, zfar=opt.z_far, rasterize_size=int(2 * opt.center), use_opengl=opt.use_opengl + ) + + if self.isTrain: + self.loss_names = ['all', 'feat', 'color', 'lm', 'reg', 'gamma', 'reflc'] + + self.net_recog = networks.define_net_recog( + net_recog=opt.net_recog, pretrained_path=opt.net_recog_path + ) + # loss func name: (compute_%s_loss) % loss_name + self.compute_feat_loss = perceptual_loss + self.comupte_color_loss = photo_loss + self.compute_lm_loss = landmark_loss + self.compute_reg_loss = reg_loss + self.compute_reflc_loss = reflectance_loss + + self.optimizer = torch.optim.Adam(self.net_recon.parameters(), lr=opt.lr) + self.optimizers = [self.optimizer] + self.parallel_names += ['net_recog'] + # Our program will automatically call to define schedulers, load networks, and print networks + + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input: a dictionary that contains the data itself and its metadata information. + """ + self.input_img = input['imgs'].to(self.device) + self.atten_mask = input['msks'].to(self.device) if 'msks' in input else None + self.gt_lm = input['lms'].to(self.device) if 'lms' in input else None + self.trans_m = input['M'].to(self.device) if 'M' in input else None + self.image_paths = input['im_paths'] if 'im_paths' in input else None + + def forward(self): + output_coeff = self.net_recon(self.input_img) + self.facemodel.to(self.device) + self.pred_vertex, self.pred_tex, self.pred_color, self.pred_lm = \ + self.facemodel.compute_for_render(output_coeff) + self.pred_mask, _, self.pred_face = self.renderer( + self.pred_vertex, self.facemodel.face_buf, feat=self.pred_color) + + self.pred_coeffs_dict = self.facemodel.split_coeff(output_coeff) + self.output_coeff = output_coeff + + def compute_losses(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + + assert self.net_recog.training == False + trans_m = self.trans_m + if not self.opt.use_predef_M: + trans_m = estimate_norm_torch(self.pred_lm, self.input_img.shape[-2]) + + pred_feat = self.net_recog(self.pred_face, trans_m) + gt_feat = self.net_recog(self.input_img, self.trans_m) + self.loss_feat = self.opt.w_feat * self.compute_feat_loss(pred_feat, gt_feat) + + face_mask = self.pred_mask + if self.opt.use_crop_face: + face_mask, _, _ = self.renderer(self.pred_vertex, self.facemodel.front_face_buf) + + face_mask = face_mask.detach() + self.loss_color = self.opt.w_color * self.comupte_color_loss( + self.pred_face, self.input_img, self.atten_mask * face_mask) + + loss_reg, loss_gamma = self.compute_reg_loss(self.pred_coeffs_dict, self.opt) + self.loss_reg = self.opt.w_reg * loss_reg + self.loss_gamma = self.opt.w_gamma * loss_gamma + + self.loss_lm = self.opt.w_lm * self.compute_lm_loss(self.pred_lm, self.gt_lm) + + self.loss_reflc = self.opt.w_reflc * self.compute_reflc_loss(self.pred_tex, self.facemodel.skin_mask) + + self.loss_all = self.loss_feat + self.loss_color + self.loss_reg + self.loss_gamma \ + + self.loss_lm + self.loss_reflc + + + def optimize_parameters(self, isTrain=True): + self.forward() + self.compute_losses() + """Update network weights; it will be called in every training iteration.""" + if isTrain: + self.optimizer.zero_grad() + self.loss_all.backward() + self.optimizer.step() + + def compute_visuals(self): + with torch.no_grad(): + input_img_numpy = 255. * self.input_img.detach().cpu().permute(0, 2, 3, 1).numpy() + output_vis = self.pred_face * self.pred_mask + (1 - self.pred_mask) * self.input_img + output_vis_numpy_raw = 255. * output_vis.detach().cpu().permute(0, 2, 3, 1).numpy() + + if self.gt_lm is not None: + gt_lm_numpy = self.gt_lm.cpu().numpy() + pred_lm_numpy = self.pred_lm.detach().cpu().numpy() + output_vis_numpy = util.draw_landmarks(output_vis_numpy_raw, gt_lm_numpy, 'b') + output_vis_numpy = util.draw_landmarks(output_vis_numpy, pred_lm_numpy, 'r') + + output_vis_numpy = np.concatenate((input_img_numpy, + output_vis_numpy_raw, output_vis_numpy), axis=-2) + else: + output_vis_numpy = np.concatenate((input_img_numpy, + output_vis_numpy_raw), axis=-2) + + self.output_vis = torch.tensor( + output_vis_numpy / 255., dtype=torch.float32 + ).permute(0, 3, 1, 2).to(self.device) + + def save_mesh(self, name): + + recon_shape = self.pred_vertex # get reconstructed shape + recon_shape[..., -1] = 10 - recon_shape[..., -1] # from camera space to world space + recon_shape = recon_shape.cpu().numpy()[0] + recon_color = self.pred_color + recon_color = recon_color.cpu().numpy()[0] + tri = self.facemodel.face_buf.cpu().numpy() + mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri, vertex_colors=np.clip(255. * recon_color, 0, 255).astype(np.uint8), process=False) + mesh.export(name) + + def save_coeff(self,name): + + pred_coeffs = {key:self.pred_coeffs_dict[key].cpu().numpy() for key in self.pred_coeffs_dict} + pred_lm = self.pred_lm.cpu().numpy() + pred_lm = np.stack([pred_lm[:,:,0],self.input_img.shape[2]-1-pred_lm[:,:,1]],axis=2) # transfer to image coordinate + pred_coeffs['lm68'] = pred_lm + savemat(name,pred_coeffs) + + + diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/losses.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/losses.py new file mode 100644 index 00000000..fbacb63b --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/losses.py @@ -0,0 +1,113 @@ +import numpy as np +import torch +import torch.nn as nn +from kornia.geometry import warp_affine +import torch.nn.functional as F + +def resize_n_crop(image, M, dsize=112): + # image: (b, c, h, w) + # M : (b, 2, 3) + return warp_affine(image, M, dsize=(dsize, dsize)) + +### perceptual level loss +class PerceptualLoss(nn.Module): + def __init__(self, recog_net, input_size=112): + super(PerceptualLoss, self).__init__() + self.recog_net = recog_net + self.preprocess = lambda x: 2 * x - 1 + self.input_size=input_size + def forward(imageA, imageB, M): + """ + 1 - cosine distance + Parameters: + imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order + imageB --same as imageA + """ + + imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size)) + imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size)) + + # freeze bn + self.recog_net.eval() + + id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2) + id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2) + cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) + # assert torch.sum((cosine_d > 1).float()) == 0 + return torch.sum(1 - cosine_d) / cosine_d.shape[0] + +def perceptual_loss(id_featureA, id_featureB): + cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) + # assert torch.sum((cosine_d > 1).float()) == 0 + return torch.sum(1 - cosine_d) / cosine_d.shape[0] + +### image level loss +def photo_loss(imageA, imageB, mask, eps=1e-6): + """ + l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur) + Parameters: + imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order + imageB --same as imageA + """ + loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask + loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device)) + return loss + +def landmark_loss(predict_lm, gt_lm, weight=None): + """ + weighted mse loss + Parameters: + predict_lm --torch.tensor (B, 68, 2) + gt_lm --torch.tensor (B, 68, 2) + weight --numpy.array (1, 68) + """ + if not weight: + weight = np.ones([68]) + weight[28:31] = 20 + weight[-8:] = 20 + weight = np.expand_dims(weight, 0) + weight = torch.tensor(weight).to(predict_lm.device) + loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight + loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1]) + return loss + + +### regulization +def reg_loss(coeffs_dict, opt=None): + """ + l2 norm without the sqrt, from yu's implementation (mse) + tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss + Parameters: + coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans + + """ + # coefficient regularization to ensure plausible 3d faces + if opt: + w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex + else: + w_id, w_exp, w_tex = 1, 1, 1, 1 + creg_loss = w_id * torch.sum(coeffs_dict['id'] ** 2) + \ + w_exp * torch.sum(coeffs_dict['exp'] ** 2) + \ + w_tex * torch.sum(coeffs_dict['tex'] ** 2) + creg_loss = creg_loss / coeffs_dict['id'].shape[0] + + # gamma regularization to ensure a nearly-monochromatic light + gamma = coeffs_dict['gamma'].reshape([-1, 3, 9]) + gamma_mean = torch.mean(gamma, dim=1, keepdims=True) + gamma_loss = torch.mean((gamma - gamma_mean) ** 2) + + return creg_loss, gamma_loss + +def reflectance_loss(texture, mask): + """ + minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo + Parameters: + texture --torch.tensor, (B, N, 3) + mask --torch.tensor, (N), 1 or 0 + + """ + mask = mask.reshape([1, mask.shape[0], 1]) + texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask) + loss = torch.sum(((texture - texture_mean) * mask)**2) / (texture.shape[0] * torch.sum(mask)) + return loss + diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/networks.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/networks.py new file mode 100644 index 00000000..685750de --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/networks.py @@ -0,0 +1,522 @@ +"""This script defines deep neural networks for Deep3DFaceRecon_pytorch +""" + +import os +import numpy as np +import torch.nn.functional as F +from torch.nn import init +import functools +from torch.optim import lr_scheduler +import torch +from torch import Tensor +import torch.nn as nn +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url +from typing import Type, Any, Callable, Union, List, Optional +from .arcface_torch.backbones import get_model +from kornia.geometry import warp_affine + + +def resize_n_crop(image, M, dsize=112): + # image: (b, c, h, w) + # M : (b, 2, 3) + return warp_affine(image, M, dsize=(dsize, dsize)) + +def filter_state_dict(state_dict, remove_name='fc'): + new_state_dict = {} + for key in state_dict: + if remove_name in key: + continue + new_state_dict[key] = state_dict[key] + return new_state_dict + +def get_scheduler(optimizer, opt): + """Return a learning rate scheduler + + Parameters: + optimizer -- the optimizer of the network + opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  + opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine + + For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. + See https://pytorch.org/docs/stable/optim.html for more details. + """ + if opt.lr_policy == 'linear': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs + 1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_epochs, gamma=0.2) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + elif opt.lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + + +def define_net_recon(net_recon, use_last_fc=False, init_path=None): + return ReconNetWrapper(net_recon, use_last_fc=use_last_fc, init_path=init_path) + +def define_net_recog(net_recog, pretrained_path=None): + net = RecogNetWrapper(net_recog=net_recog, pretrained_path=pretrained_path) + net.eval() + return net + +class ReconNetWrapper(nn.Module): + fc_dim=257 + def __init__(self, net_recon, use_last_fc=False, init_path=None): + super(ReconNetWrapper, self).__init__() + self.use_last_fc = use_last_fc + if net_recon not in func_dict: + return NotImplementedError('network [%s] is not implemented', net_recon) + func, last_dim = func_dict[net_recon] + backbone = func(use_last_fc=use_last_fc, num_classes=self.fc_dim) + if init_path and os.path.isfile(init_path): + state_dict = filter_state_dict(torch.load(init_path, map_location='cpu')) + backbone.load_state_dict(state_dict) + print("loading init net_recon %s from %s" %(net_recon, init_path)) + self.backbone = backbone + if not use_last_fc: + self.final_layers = nn.ModuleList([ + conv1x1(last_dim, 80, bias=True), # id layer + conv1x1(last_dim, 64, bias=True), # exp layer + conv1x1(last_dim, 80, bias=True), # tex layer + conv1x1(last_dim, 3, bias=True), # angle layer + conv1x1(last_dim, 27, bias=True), # gamma layer + conv1x1(last_dim, 2, bias=True), # tx, ty + conv1x1(last_dim, 1, bias=True) # tz + ]) + for m in self.final_layers: + nn.init.constant_(m.weight, 0.) + nn.init.constant_(m.bias, 0.) + + def forward(self, x): + x = self.backbone(x) + if not self.use_last_fc: + output = [] + for layer in self.final_layers: + output.append(layer(x)) + x = torch.flatten(torch.cat(output, dim=1), 1) + return x + + +class RecogNetWrapper(nn.Module): + def __init__(self, net_recog, pretrained_path=None, input_size=112): + super(RecogNetWrapper, self).__init__() + net = get_model(name=net_recog, fp16=False) + if pretrained_path: + state_dict = torch.load(pretrained_path, map_location='cpu') + net.load_state_dict(state_dict) + print("loading pretrained net_recog %s from %s" %(net_recog, pretrained_path)) + for param in net.parameters(): + param.requires_grad = False + self.net = net + self.preprocess = lambda x: 2 * x - 1 + self.input_size=input_size + + def forward(self, image, M): + image = self.preprocess(resize_n_crop(image, M, self.input_size)) + id_feature = F.normalize(self.net(image), dim=-1, p=2) + return id_feature + + +# adapted from https://github.com/pytorch/vision/edit/master/torchvision/models/resnet.py +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1, bias: bool = False) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + use_last_fc: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.use_last_fc = use_last_fc + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + + if self.use_last_fc: + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, + stride: int = 1, dilate: bool = False) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x: Tensor) -> Tensor: + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + if self.use_last_fc: + x = torch.flatten(x, 1) + x = self.fc(x) + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _resnet( + arch: str, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + pretrained: bool, + progress: bool, + **kwargs: Any +) -> ResNet: + model = ResNet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + +def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) + + +def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) + + +def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNeXt-50 32x4d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNeXt-101 32x8d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""Wide ResNet-50-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""Wide ResNet-101-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +func_dict = { + 'resnet18': (resnet18, 512), + 'resnet50': (resnet50, 2048) +} diff --git a/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/template_model.py b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/template_model.py new file mode 100644 index 00000000..dac7b33d --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/deep_3drecon_models/template_model.py @@ -0,0 +1,100 @@ +"""Model class template + +This module provides a template for users to implement custom models. +You can specify '--model template' to use this model. +The class name should be consistent with both the filename and its model option. +The filename should be _dataset.py +The class name should be Dataset.py +It implements a simple image-to-image translation baseline based on regression loss. +Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss: + min_ ||netG(data_A) - data_B||_1 +You need to implement the following functions: + : Add model-specific options and rewrite default values for existing options. + <__init__>: Initialize this model class. + : Unpack input data and perform data pre-processing. + : Run forward pass. This will be called by both and . + : Update network weights; it will be called in every training iteration. +""" +import numpy as np +import torch +from .base_model import BaseModel +from . import networks + + +class TemplateModel(BaseModel): + @staticmethod + def modify_commandline_options(parser, is_train=True): + """Add new model-specific options and rewrite default values for existing options. + + Parameters: + parser -- the option parser + is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset. + if is_train: + parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model. + + return parser + + def __init__(self, opt): + """Initialize this model class. + + Parameters: + opt -- training/test options + + A few things can be done here. + - (required) call the initialization function of BaseModel + - define loss function, visualization images, model names, and optimizers + """ + BaseModel.__init__(self, opt) # call the initialization method of BaseModel + # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk. + self.loss_names = ['loss_G'] + # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images. + self.visual_names = ['data_A', 'data_B', 'output'] + # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks. + # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them. + self.model_names = ['G'] + # define networks; you can use opt.isTrain to specify different behaviors for training and test. + self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids) + if self.isTrain: # only defined during training time + # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss. + # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device) + self.criterionLoss = torch.nn.L1Loss() + # define and initialize optimizers. You can define one optimizer for each network. + # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizers = [self.optimizer] + + # Our program will automatically call to define schedulers, load networks, and print networks + + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input: a dictionary that contains the data itself and its metadata information. + """ + AtoB = self.opt.direction == 'AtoB' # use to swap data_A and data_B + self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A + self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B + self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths + + def forward(self): + """Run forward pass. This will be called by both functions and .""" + self.output = self.netG(self.data_A) # generate output image given the input data_A + + def backward(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + # caculate the intermediate results if necessary; here self.output has been computed during function + # calculate loss given the input and intermediate results + self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression + self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G + + def optimize_parameters(self): + """Update network weights; it will be called in every training iteration.""" + self.forward() # first call forward to calculate intermediate results + self.optimizer.zero_grad() # clear network G's existing gradients + self.backward() # calculate gradients for network G + self.optimizer.step() # update gradients for network G diff --git a/Geneface_main/GeneFace/deep_3drecon/generate_reconstructor_opt_for_geneface.py b/Geneface_main/GeneFace/deep_3drecon/generate_reconstructor_opt_for_geneface.py new file mode 100644 index 00000000..96e8b2ed --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/generate_reconstructor_opt_for_geneface.py @@ -0,0 +1,12 @@ +from options.test_options import TestOptions +import pickle as pkl + +# run in the root dir! +opt = TestOptions().parse() # get test options +opt.name='facerecon' +opt.epoch=20 +opt.bfm_folder='deep_3drecon/BFM/' +opt.checkpoints_dir='deep_3drecon/checkpoints/' + +with open("deep_3drecon/reconstructor_opt.pkl", 'wb') as f: + pkl.dump(opt, f) diff --git a/Geneface_main/GeneFace/deep_3drecon/options/__init__.py b/Geneface_main/GeneFace/deep_3drecon/options/__init__.py new file mode 100644 index 00000000..e7eedebe --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/options/__init__.py @@ -0,0 +1 @@ +"""This package options includes option modules: training options, test options, and basic options (used in both training and test).""" diff --git a/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/__init__.cpython-310.pyc b/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 00000000..3b9f98e6 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/__init__.cpython-310.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/__init__.cpython-312.pyc b/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 00000000..a26e34dd Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/__init__.cpython-312.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/__init__.cpython-39.pyc b/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 00000000..93a901ff Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/__init__.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/base_options.cpython-310.pyc b/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/base_options.cpython-310.pyc new file mode 100644 index 00000000..09a7652d Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/base_options.cpython-310.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/base_options.cpython-312.pyc b/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/base_options.cpython-312.pyc new file mode 100644 index 00000000..2ec130c6 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/base_options.cpython-312.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/base_options.cpython-39.pyc b/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/base_options.cpython-39.pyc new file mode 100644 index 00000000..d9db72a5 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/base_options.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/test_options.cpython-310.pyc b/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/test_options.cpython-310.pyc new file mode 100644 index 00000000..5a6d88cf Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/test_options.cpython-310.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/test_options.cpython-312.pyc b/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/test_options.cpython-312.pyc new file mode 100644 index 00000000..34a92c2c Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/test_options.cpython-312.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/test_options.cpython-39.pyc b/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/test_options.cpython-39.pyc new file mode 100644 index 00000000..ee45fab7 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/options/__pycache__/test_options.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/options/base_options.py b/Geneface_main/GeneFace/deep_3drecon/options/base_options.py new file mode 100644 index 00000000..37ad0bda --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/options/base_options.py @@ -0,0 +1,169 @@ +"""This script contains base options for Deep3DFaceRecon_pytorch +""" + +import argparse +import os +from util import util +import numpy as np +import torch +import deep_3drecon_models +import data + + +class BaseOptions(): + """This class defines options used during both training and test time. + + It also implements several helper functions such as parsing, printing, and saving the options. + It also gathers additional options defined in functions in both dataset class and model class. + """ + + def __init__(self, cmd_line=None): + """Reset the class; indicates the class hasn't been initailized""" + self.initialized = False + self.cmd_line = None + if cmd_line is not None: + self.cmd_line = cmd_line.split() + + def initialize(self, parser): + """Define the common options that are used in both training and test.""" + # basic parameters + parser.add_argument('--name', type=str, default='facerecon', help='name of the experiment. It decides where to store samples and models') + parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + parser.add_argument('--checkpoints_dir', type=str, default='./deep_3drecon/checkpoints', help='models are saved here') + parser.add_argument('--vis_batch_nums', type=float, default=1, help='batch nums of images for visulization') + parser.add_argument('--eval_batch_nums', type=float, default=float('inf'), help='batch nums of images for evaluation') + parser.add_argument('--use_ddp', type=util.str2bool, nargs='?', const=True, default=True, help='whether use distributed data parallel') + parser.add_argument('--ddp_port', type=str, default='12355', help='ddp port') + parser.add_argument('--display_per_batch', type=util.str2bool, nargs='?', const=True, default=True, help='whether use batch to show losses') + parser.add_argument('--add_image', type=util.str2bool, nargs='?', const=True, default=True, help='whether add image to tensorboard') + parser.add_argument('--world_size', type=int, default=1, help='batch nums of images for evaluation') + + # model parameters + parser.add_argument('--model', type=str, default='facerecon', help='chooses which model to use.') + + # additional parameters + parser.add_argument('--epoch', type=str, default='20', help='which epoch to load? set to latest to use latest cached model') + parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') + parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') + + self.initialized = True + return parser + + def gather_options(self): + """Initialize our parser with basic options(only once). + Add additional model-specific and dataset-specific options. + These options are defined in the function + in model and dataset classes. + """ + if not self.initialized: # check if it has been initialized + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = self.initialize(parser) + + # get the basic options + if self.cmd_line is None: + opt, _ = parser.parse_known_args() + else: + opt, _ = parser.parse_known_args(self.cmd_line) + + # set cuda visible devices + os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids + + # modify model-related parser options + model_name = opt.model + model_option_setter = deep_3drecon_models.get_option_setter(model_name) + parser = model_option_setter(parser, self.isTrain) + if self.cmd_line is None: + opt, _ = parser.parse_known_args() # parse again with new defaults + else: + opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults + + # modify dataset-related parser options + if opt.dataset_mode: + dataset_name = opt.dataset_mode + dataset_option_setter = data.get_option_setter(dataset_name) + parser = dataset_option_setter(parser, self.isTrain) + + # save and return the parser + self.parser = parser + if self.cmd_line is None: + return parser.parse_args() + else: + return parser.parse_args(self.cmd_line) + + def print_options(self, opt): + """Print and save options + + It will print both current options and default values(if different). + It will save options into a text file / [checkpoints_dir] / opt.txt + """ + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + # save to the disk + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) + try: + with open(file_name, 'wt') as opt_file: + opt_file.write(message) + opt_file.write('\n') + except PermissionError as error: + print("permission error {}".format(error)) + pass + + def parse(self): + """Parse our options, create checkpoints directory suffix, and set up gpu device.""" + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + + # process opt.suffix + if opt.suffix: + suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' + opt.name = opt.name + suffix + + + # set gpu ids + str_ids = opt.gpu_ids.split(',') + gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + gpu_ids.append(id) + opt.world_size = len(gpu_ids) + # if len(opt.gpu_ids) > 0: + # torch.cuda.set_device(gpu_ids[0]) + if opt.world_size == 1: + opt.use_ddp = False + + if opt.phase != 'test': + # set continue_train automatically + if opt.pretrained_name is None: + model_dir = os.path.join(opt.checkpoints_dir, opt.name) + else: + model_dir = os.path.join(opt.checkpoints_dir, opt.pretrained_name) + if os.path.isdir(model_dir): + model_pths = [i for i in os.listdir(model_dir) if i.endswith('pth')] + if os.path.isdir(model_dir) and len(model_pths) != 0: + opt.continue_train= True + + # update the latest epoch count + if opt.continue_train: + if opt.epoch == 'latest': + epoch_counts = [int(i.split('.')[0].split('_')[-1]) for i in model_pths if 'latest' not in i] + if len(epoch_counts) != 0: + opt.epoch_count = max(epoch_counts) + 1 + else: + opt.epoch_count = int(opt.epoch) + 1 + + + self.print_options(opt) + self.opt = opt + return self.opt diff --git a/Geneface_main/GeneFace/deep_3drecon/options/test_options.py b/Geneface_main/GeneFace/deep_3drecon/options/test_options.py new file mode 100644 index 00000000..4ff3ad14 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/options/test_options.py @@ -0,0 +1,21 @@ +"""This script contains the test options for Deep3DFaceRecon_pytorch +""" + +from .base_options import BaseOptions + + +class TestOptions(BaseOptions): + """This class includes test options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) # define shared options + parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]') + parser.add_argument('--img_folder', type=str, default='examples', help='folder for test images.') + + # Dropout and Batchnorm has different behavior during training and test. + self.isTrain = False + return parser diff --git a/Geneface_main/GeneFace/deep_3drecon/options/train_options.py b/Geneface_main/GeneFace/deep_3drecon/options/train_options.py new file mode 100644 index 00000000..1337bfdd --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/options/train_options.py @@ -0,0 +1,53 @@ +"""This script contains the training options for Deep3DFaceRecon_pytorch +""" + +from .base_options import BaseOptions +from util import util + +class TrainOptions(BaseOptions): + """This class includes training options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + # dataset parameters + # for train + parser.add_argument('--data_root', type=str, default='./', help='dataset root') + parser.add_argument('--flist', type=str, default='datalist/train/masks.txt', help='list of mask names of training set') + parser.add_argument('--batch_size', type=int, default=32) + parser.add_argument('--dataset_mode', type=str, default='flist', help='chooses how datasets are loaded. [None | flist]') + parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') + parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + parser.add_argument('--preprocess', type=str, default='shift_scale_rot_flip', help='scaling and cropping of images at load time [shift_scale_rot_flip | shift_scale | shift | shift_rot_flip ]') + parser.add_argument('--use_aug', type=util.str2bool, nargs='?', const=True, default=True, help='whether use data augmentation') + + # for val + parser.add_argument('--flist_val', type=str, default='datalist/val/masks.txt', help='list of mask names of val set') + parser.add_argument('--batch_size_val', type=int, default=32) + + + # visualization parameters + parser.add_argument('--display_freq', type=int, default=1000, help='frequency of showing training results on screen') + parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') + + # network saving and loading parameters + parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') + parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') + parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq') + parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') + parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') + parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') + parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint') + + # training parameters + parser.add_argument('--n_epochs', type=int, default=20, help='number of epochs with the initial learning rate') + parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') + parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]') + parser.add_argument('--lr_decay_epochs', type=int, default=10, help='multiply by a gamma every lr_decay_epochs epoches') + + self.isTrain = True + return parser diff --git a/Geneface_main/GeneFace/deep_3drecon/reconstructor.py b/Geneface_main/GeneFace/deep_3drecon/reconstructor.py new file mode 100644 index 00000000..a6f8be41 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/reconstructor.py @@ -0,0 +1,90 @@ +"""This script is the test script for Deep3DFaceRecon_pytorch +Pytorch Deep3D_Recon is 8x faster than TF-based, 16s/iter ==> 2s/iter +""" + +import os +# os.environ['PYTHONPATH'] = os.environ['PYTHONPATH'] + ":" + os.path.abspath("deep_3drecon") +import torch +import torch.nn as nn +from .deep_3drecon_models.facerecon_model import FaceReconModel +from .util.preprocess import align_img +from PIL import Image +import numpy as np +from .util.load_mats import load_lm3d +import torch +import pickle as pkl +from PIL import Image + +from utils.commons.tensor_utils import convert_to_tensor, convert_to_np + +with open("deep_3drecon/reconstructor_opt.pkl", "rb") as f: + opt = pkl.load(f) + +class Reconstructor(nn.Module): + def __init__(self): + super().__init__() + self.model = FaceReconModel(opt) + self.model.setup(opt) + self.model.device = 'cuda:0' + self.model.parallelize() + # self.model.to(self.model.device) + self.model.eval() + self.lm3d_std = load_lm3d(opt.bfm_folder) + + def preprocess_data(self, im, lm, lm3d_std): + # to RGB + H,W,_ = im.shape + lm = lm.reshape([-1, 2]) + lm[:, -1] = H - 1 - lm[:, -1] + + _, im, lm, _ = align_img(Image.fromarray(convert_to_np(im)), convert_to_np(lm), convert_to_np(lm3d_std)) + im = torch.tensor(np.array(im)/255., dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) + lm = torch.tensor(lm).unsqueeze(0) + return im, lm + + @torch.no_grad() + def recon_coeff(self, batched_images, batched_lm5, return_image=True, batch_mode=True): + bs = batched_images.shape[0] + data_lst = [] + for i in range(bs): + img = batched_images[i] + lm5 = batched_lm5[i] + align_im, lm = self.preprocess_data(img, lm5, self.lm3d_std) + data = { + 'imgs': align_im, + 'lms': lm + } + data_lst.append(data) + if not batch_mode: + coeff_lst = [] + align_lst = [] + for i in range(bs): + data = data_lst + self.model.set_input(data) # unpack data from data loader + self.model.forward() + pred_coeff = self.model.output_coeff.cpu().numpy() + align_im = (align_im.squeeze().permute(1,2,0)*255).int().numpy().astype(np.uint8) + coeff_lst.append(pred_coeff) + align_lst.append(align_im) + batch_coeff = np.concatenate(coeff_lst) + batch_align_img = np.stack(align_lst) # [B, 257] + else: + imgs = torch.cat([d['imgs'] for d in data_lst]) + lms = torch.cat([d['lms'] for d in data_lst]) + data = { + 'imgs': imgs, + 'lms': lms + } + self.model.set_input(data) # unpack data from data loader + self.model.forward() + batch_coeff = self.model.output_coeff.cpu().numpy() + batch_align_img = (imgs.permute(0,2,3,1)*255).int().numpy().astype(np.uint8) + return batch_coeff, batch_align_img + + # todo: batch-wise recon! + + def forward(self, batched_images, batched_lm5, return_image=True): + return self.recon_coeff(batched_images, batched_lm5, return_image) + + + diff --git a/Geneface_main/GeneFace/deep_3drecon/reconstructor_opt.pkl b/Geneface_main/GeneFace/deep_3drecon/reconstructor_opt.pkl new file mode 100644 index 00000000..b6a7af7a Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/reconstructor_opt.pkl differ diff --git a/Geneface_main/GeneFace/deep_3drecon/test.py b/Geneface_main/GeneFace/deep_3drecon/test.py new file mode 100644 index 00000000..c5207844 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/test.py @@ -0,0 +1,69 @@ +"""This script is the test script for Deep3DFaceRecon_pytorch +""" + +import os +from options.test_options import TestOptions +from deep_3drecon_models import create_model +from util.visualizer import MyVisualizer +from util.preprocess import align_img +from PIL import Image +import numpy as np +from util.load_mats import load_lm3d +import torch + +def get_data_path(root='examples'): + im_path = [os.path.join(root, i) for i in sorted(os.listdir(root)) if i.endswith('png') or i.endswith('jpg')] + lm_path = [i.replace('png', 'txt').replace('jpg', 'txt') for i in im_path] + lm_path = [os.path.join(i.replace(i.split(os.path.sep)[-1],''),'detections',i.split(os.path.sep)[-1]) for i in lm_path] + return im_path, lm_path + +def read_data(im_path, lm_path, lm3d_std, to_tensor=True): + # to RGB + im = Image.open(im_path).convert('RGB') + W,H = im.size + lm = np.loadtxt(lm_path).astype(np.float32) + lm = lm.reshape([-1, 2]) + lm[:, -1] = H - 1 - lm[:, -1] + _, im, lm, _ = align_img(im, lm, lm3d_std) + if to_tensor: + im = torch.tensor(np.array(im)/255., dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) + lm = torch.tensor(lm).unsqueeze(0) + return im, lm + +def main(rank, opt, name='examples'): + device = torch.device(rank) + torch.cuda.set_device(device) + model = create_model(opt) + model.setup(opt) + model.device = device + model.parallelize() + model.eval() + visualizer = MyVisualizer(opt) + + im_path, lm_path = get_data_path(name) + lm3d_std = load_lm3d(opt.bfm_folder) + + for i in range(len(im_path)): + print(i, im_path[i]) + img_name = im_path[i].split(os.path.sep)[-1].replace('.png','').replace('.jpg','') + if not os.path.isfile(lm_path[i]): + print("%s is not found !!!"%lm_path[i]) + continue + im_tensor, lm_tensor = read_data(im_path[i], lm_path[i], lm3d_std) + data = { + 'imgs': im_tensor, + 'lms': lm_tensor + } + model.set_input(data) # unpack data from data loader + model.test() # run inference + visuals = model.get_current_visuals() # get image results + visualizer.display_current_results(visuals, 0, opt.epoch, dataset=name.split(os.path.sep)[-1], + save_results=True, count=i, name=img_name, add_image=False) + + model.save_mesh(os.path.join(visualizer.img_dir, name.split(os.path.sep)[-1], 'epoch_%s_%06d'%(opt.epoch, 0),img_name+'.obj')) # save reconstruction meshes + model.save_coeff(os.path.join(visualizer.img_dir, name.split(os.path.sep)[-1], 'epoch_%s_%06d'%(opt.epoch, 0),img_name+'.mat')) # save predicted coefficients + +if __name__ == '__main__': + opt = TestOptions().parse() # get test options + main(0, opt, 'deep_3drecon/datasets/examples') + print(f"results saved at deep_3drecon/checkpoints/facerecon/results/") diff --git a/Geneface_main/GeneFace/deep_3drecon/train.py b/Geneface_main/GeneFace/deep_3drecon/train.py new file mode 100644 index 00000000..cbdda882 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/train.py @@ -0,0 +1,166 @@ +"""This script is the training script for Deep3DFaceRecon_pytorch +""" + +import os +import time +import numpy as np +import torch +from options.train_options import TrainOptions +from data import create_dataset +from deep_3drecon_models import create_model +from util.visualizer import MyVisualizer +from util.util import genvalconf +import torch.multiprocessing as mp +import torch.distributed as dist + + +def setup(rank, world_size, port): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = port + + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=world_size) + +def cleanup(): + dist.destroy_process_group() + +def main(rank, world_size, train_opt): + val_opt = genvalconf(train_opt, isTrain=False) + + device = torch.device(rank) + torch.cuda.set_device(device) + use_ddp = train_opt.use_ddp + + if use_ddp: + setup(rank, world_size, train_opt.ddp_port) + + train_dataset, val_dataset = create_dataset(train_opt, rank=rank), create_dataset(val_opt, rank=rank) + train_dataset_batches, val_dataset_batches = \ + len(train_dataset) // train_opt.batch_size, len(val_dataset) // val_opt.batch_size + + model = create_model(train_opt) # create a model given train_opt.model and other options + model.setup(train_opt) + model.device = device + model.parallelize() + + if rank == 0: + print('The batch number of training images = %d\n, \ + the batch number of validation images = %d'% (train_dataset_batches, val_dataset_batches)) + model.print_networks(train_opt.verbose) + visualizer = MyVisualizer(train_opt) # create a visualizer that display/save images and plots + + total_iters = train_dataset_batches * (train_opt.epoch_count - 1) # the total number of training iterations + t_data = 0 + t_val = 0 + optimize_time = 0.1 + batch_size = 1 if train_opt.display_per_batch else train_opt.batch_size + + if use_ddp: + dist.barrier() + + times = [] + for epoch in range(train_opt.epoch_count, train_opt.n_epochs + 1): # outer loop for different epochs; we save the model by , + + epoch_start_time = time.time() # timer for entire epoch + iter_data_time = time.time() # timer for train_data loading per iteration + epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch + + train_dataset.set_epoch(epoch) + for i, train_data in enumerate(train_dataset): # inner loop within one epoch + iter_start_time = time.time() # timer for computation per iteration + if total_iters % train_opt.print_freq == 0: + t_data = iter_start_time - iter_data_time + total_iters += batch_size + epoch_iter += batch_size + + torch.cuda.synchronize() + optimize_start_time = time.time() + + model.set_input(train_data) # unpack train_data from dataset and apply preprocessing + model.optimize_parameters() # calculate loss functions, get gradients, update network weights + + torch.cuda.synchronize() + optimize_time = (time.time() - optimize_start_time) / batch_size * 0.005 + 0.995 * optimize_time + + if use_ddp: + dist.barrier() + + if rank == 0 and (total_iters == batch_size or total_iters % train_opt.display_freq == 0): # display images on visdom and save images to a HTML file + model.compute_visuals() + visualizer.display_current_results(model.get_current_visuals(), total_iters, epoch, + save_results=True, + add_image=train_opt.add_image) + # (total_iters == batch_size or total_iters % train_opt.evaluation_freq == 0) + + if rank == 0 and (total_iters == batch_size or total_iters % train_opt.print_freq == 0): # print training losses and save logging information to the disk + losses = model.get_current_losses() + visualizer.print_current_losses(epoch, epoch_iter, losses, optimize_time, t_data) + visualizer.plot_current_losses(total_iters, losses) + + if total_iters == batch_size or total_iters % train_opt.evaluation_freq == 0: + with torch.no_grad(): + torch.cuda.synchronize() + val_start_time = time.time() + losses_avg = {} + model.eval() + for j, val_data in enumerate(val_dataset): + model.set_input(val_data) + model.optimize_parameters(isTrain=False) + if rank == 0 and j < train_opt.vis_batch_nums: + model.compute_visuals() + visualizer.display_current_results(model.get_current_visuals(), total_iters, epoch, + dataset='val', save_results=True, count=j * val_opt.batch_size, + add_image=train_opt.add_image) + + if j < train_opt.eval_batch_nums: + losses = model.get_current_losses() + for key, value in losses.items(): + losses_avg[key] = losses_avg.get(key, 0) + value + + for key, value in losses_avg.items(): + losses_avg[key] = value / min(train_opt.eval_batch_nums, val_dataset_batches) + + torch.cuda.synchronize() + eval_time = time.time() - val_start_time + + if rank == 0: + visualizer.print_current_losses(epoch, epoch_iter, losses_avg, eval_time, t_data, dataset='val') # visualize training results + visualizer.plot_current_losses(total_iters, losses_avg, dataset='val') + model.train() + + if use_ddp: + dist.barrier() + + if rank == 0 and (total_iters == batch_size or total_iters % train_opt.save_latest_freq == 0): # cache our latest model every iterations + print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) + print(train_opt.name) # it's useful to occasionally show the experiment name on console + save_suffix = 'iter_%d' % total_iters if train_opt.save_by_iter else 'latest' + model.save_networks(save_suffix) + + if use_ddp: + dist.barrier() + + iter_data_time = time.time() + + print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, train_opt.n_epochs, time.time() - epoch_start_time)) + model.update_learning_rate() # update learning rates at the end of every epoch. + + if rank == 0 and epoch % train_opt.save_epoch_freq == 0: # cache our model every epochs + print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) + model.save_networks('latest') + model.save_networks(epoch) + + if use_ddp: + dist.barrier() + +if __name__ == '__main__': + + import warnings + warnings.filterwarnings("ignore") + + train_opt = TrainOptions().parse() # get training options + world_size = train_opt.world_size + + if train_opt.use_ddp: + mp.spawn(main, args=(world_size, train_opt), nprocs=world_size, join=True) + else: + main(0, world_size, train_opt) diff --git a/Geneface_main/GeneFace/deep_3drecon/util/BBRegressorParam_r.mat b/Geneface_main/GeneFace/deep_3drecon/util/BBRegressorParam_r.mat new file mode 100644 index 00000000..1430a94e Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/util/BBRegressorParam_r.mat differ diff --git a/Geneface_main/GeneFace/deep_3drecon/util/__init__.py b/Geneface_main/GeneFace/deep_3drecon/util/__init__.py new file mode 100644 index 00000000..45cbc84b --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/util/__init__.py @@ -0,0 +1,2 @@ +"""This package includes a miscellaneous collection of useful helper functions.""" +from .util import * diff --git a/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/__init__.cpython-310.pyc b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 00000000..21617e54 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/__init__.cpython-310.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/__init__.cpython-312.pyc b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 00000000..6db4be6a Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/__init__.cpython-312.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/__init__.cpython-39.pyc b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 00000000..f5595cd5 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/__init__.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/html.cpython-310.pyc b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/html.cpython-310.pyc new file mode 100644 index 00000000..85657c03 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/html.cpython-310.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/html.cpython-39.pyc b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/html.cpython-39.pyc new file mode 100644 index 00000000..2ec010fb Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/html.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/load_mats.cpython-310.pyc b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/load_mats.cpython-310.pyc new file mode 100644 index 00000000..df895361 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/load_mats.cpython-310.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/load_mats.cpython-39.pyc b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/load_mats.cpython-39.pyc new file mode 100644 index 00000000..0ff4d787 Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/load_mats.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/mesh_renderer.cpython-310.pyc b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/mesh_renderer.cpython-310.pyc new file mode 100644 index 00000000..9cc427fa Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/mesh_renderer.cpython-310.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/mesh_renderer.cpython-39.pyc b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/mesh_renderer.cpython-39.pyc new file mode 100644 index 00000000..8991314e Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/mesh_renderer.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/preprocess.cpython-310.pyc b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/preprocess.cpython-310.pyc new file mode 100644 index 00000000..2a110b4b Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/preprocess.cpython-310.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/preprocess.cpython-39.pyc b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/preprocess.cpython-39.pyc new file mode 100644 index 00000000..5cc1295c Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/preprocess.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/util.cpython-310.pyc b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/util.cpython-310.pyc new file mode 100644 index 00000000..6ff6131e Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/util.cpython-310.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/util.cpython-312.pyc b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/util.cpython-312.pyc new file mode 100644 index 00000000..e8bb0b7c Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/util.cpython-312.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/util.cpython-39.pyc b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/util.cpython-39.pyc new file mode 100644 index 00000000..75d80d8b Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/util.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/visualizer.cpython-310.pyc b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/visualizer.cpython-310.pyc new file mode 100644 index 00000000..01d3b12f Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/visualizer.cpython-310.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/visualizer.cpython-39.pyc b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/visualizer.cpython-39.pyc new file mode 100644 index 00000000..1ede2cac Binary files /dev/null and b/Geneface_main/GeneFace/deep_3drecon/util/__pycache__/visualizer.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/deep_3drecon/util/detect_lm68.py b/Geneface_main/GeneFace/deep_3drecon/util/detect_lm68.py new file mode 100644 index 00000000..b7e40997 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/util/detect_lm68.py @@ -0,0 +1,106 @@ +import os +import cv2 +import numpy as np +from scipy.io import loadmat +import tensorflow as tf +from util.preprocess import align_for_lm +from shutil import move + +mean_face = np.loadtxt('util/test_mean_face.txt') +mean_face = mean_face.reshape([68, 2]) + +def save_label(labels, save_path): + np.savetxt(save_path, labels) + +def draw_landmarks(img, landmark, save_name): + landmark = landmark + lm_img = np.zeros([img.shape[0], img.shape[1], 3]) + lm_img[:] = img.astype(np.float32) + landmark = np.round(landmark).astype(np.int32) + + for i in range(len(landmark)): + for j in range(-1, 1): + for k in range(-1, 1): + if img.shape[0] - 1 - landmark[i, 1]+j > 0 and \ + img.shape[0] - 1 - landmark[i, 1]+j < img.shape[0] and \ + landmark[i, 0]+k > 0 and \ + landmark[i, 0]+k < img.shape[1]: + lm_img[img.shape[0] - 1 - landmark[i, 1]+j, landmark[i, 0]+k, + :] = np.array([0, 0, 255]) + lm_img = lm_img.astype(np.uint8) + + cv2.imwrite(save_name, lm_img) + + +def load_data(img_name, txt_name): + return cv2.imread(img_name), np.loadtxt(txt_name) + +# create tensorflow graph for landmark detector +def load_lm_graph(graph_filename): + with tf.gfile.GFile(graph_filename, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + + with tf.Graph().as_default() as graph: + tf.import_graph_def(graph_def, name='net') + img_224 = graph.get_tensor_by_name('net/input_imgs:0') + output_lm = graph.get_tensor_by_name('net/lm:0') + lm_sess = tf.Session(graph=graph) + + return lm_sess,img_224,output_lm + +# landmark detection +def detect_68p(img_path,sess,input_op,output_op): + print('detecting landmarks......') + names = [i for i in sorted(os.listdir( + img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i] + vis_path = os.path.join(img_path, 'vis') + remove_path = os.path.join(img_path, 'remove') + save_path = os.path.join(img_path, 'landmarks') + if not os.path.isdir(vis_path): + os.makedirs(vis_path) + if not os.path.isdir(remove_path): + os.makedirs(remove_path) + if not os.path.isdir(save_path): + os.makedirs(save_path) + + for i in range(0, len(names)): + name = names[i] + print('%05d' % (i), ' ', name) + full_image_name = os.path.join(img_path, name) + txt_name = '.'.join(name.split('.')[:-1]) + '.txt' + full_txt_name = os.path.join(img_path, 'detections', txt_name) # 5 facial landmark path for each image + + # if an image does not have detected 5 facial landmarks, remove it from the training list + if not os.path.isfile(full_txt_name): + move(full_image_name, os.path.join(remove_path, name)) + continue + + # load data + img, five_points = load_data(full_image_name, full_txt_name) + input_img, scale, bbox = align_for_lm(img, five_points) # align for 68 landmark detection + + # if the alignment fails, remove corresponding image from the training list + if scale == 0: + move(full_txt_name, os.path.join( + remove_path, txt_name)) + move(full_image_name, os.path.join(remove_path, name)) + continue + + # detect landmarks + input_img = np.reshape( + input_img, [1, 224, 224, 3]).astype(np.float32) + landmark = sess.run( + output_op, feed_dict={input_op: input_img}) + + # transform back to original image coordinate + landmark = landmark.reshape([68, 2]) + mean_face + landmark[:, 1] = 223 - landmark[:, 1] + landmark = landmark / scale + landmark[:, 0] = landmark[:, 0] + bbox[0] + landmark[:, 1] = landmark[:, 1] + bbox[1] + landmark[:, 1] = img.shape[0] - 1 - landmark[:, 1] + + if i % 100 == 0: + draw_landmarks(img, landmark, os.path.join(vis_path, name)) + save_label(landmark, os.path.join(save_path, txt_name)) diff --git a/Geneface_main/GeneFace/deep_3drecon/util/generate_list.py b/Geneface_main/GeneFace/deep_3drecon/util/generate_list.py new file mode 100644 index 00000000..943d9067 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/util/generate_list.py @@ -0,0 +1,34 @@ +"""This script is to generate training list files for Deep3DFaceRecon_pytorch +""" + +import os + +# save path to training data +def write_list(lms_list, imgs_list, msks_list, mode='train',save_folder='datalist', save_name=''): + save_path = os.path.join(save_folder, mode) + if not os.path.isdir(save_path): + os.makedirs(save_path) + with open(os.path.join(save_path, save_name + 'landmarks.txt'), 'w') as fd: + fd.writelines([i + '\n' for i in lms_list]) + + with open(os.path.join(save_path, save_name + 'images.txt'), 'w') as fd: + fd.writelines([i + '\n' for i in imgs_list]) + + with open(os.path.join(save_path, save_name + 'masks.txt'), 'w') as fd: + fd.writelines([i + '\n' for i in msks_list]) + +# check if the path is valid +def check_list(rlms_list, rimgs_list, rmsks_list): + lms_list, imgs_list, msks_list = [], [], [] + for i in range(len(rlms_list)): + flag = 'false' + lm_path = rlms_list[i] + im_path = rimgs_list[i] + msk_path = rmsks_list[i] + if os.path.isfile(lm_path) and os.path.isfile(im_path) and os.path.isfile(msk_path): + flag = 'true' + lms_list.append(rlms_list[i]) + imgs_list.append(rimgs_list[i]) + msks_list.append(rmsks_list[i]) + print(i, rlms_list[i], flag) + return lms_list, imgs_list, msks_list diff --git a/Geneface_main/GeneFace/deep_3drecon/util/html.py b/Geneface_main/GeneFace/deep_3drecon/util/html.py new file mode 100644 index 00000000..cc3262a1 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/util/html.py @@ -0,0 +1,86 @@ +import dominate +from dominate.tags import meta, h3, table, tr, td, p, a, img, br +import os + + +class HTML: + """This HTML class allows us to save images and write texts into a single HTML file. + + It consists of functions such as (add a text header to the HTML file), + (add a row of images to the HTML file), and (save the HTML to the disk). + It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. + """ + + def __init__(self, web_dir, title, refresh=0): + """Initialize the HTML classes + + Parameters: + web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: + with self.doc.head: + meta(http_equiv="refresh", content=str(refresh)) + + def get_image_dir(self): + """Return the directory that stores images""" + return self.img_dir + + def add_header(self, text): + """Insert a header to the HTML file + + Parameters: + text (str) -- the header text + """ + with self.doc: + h3(text) + + def add_images(self, ims, txts, links, width=400): + """add images to the HTML file + + Parameters: + ims (str list) -- a list of image paths + txts (str list) -- a list of image names shown on the website + links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page + """ + self.t = table(border=1, style="table-layout: fixed;") # Insert a table + self.doc.add(self.t) + with self.t: + with tr(): + for im, txt, link in zip(ims, txts, links): + with td(style="word-wrap: break-word;", halign="center", valign="top"): + with p(): + with a(href=os.path.join('images', link)): + img(style="width:%dpx" % width, src=os.path.join('images', im)) + br() + p(txt) + + def save(self): + """save the current content to the HMTL file""" + html_file = '%s/index.html' % self.web_dir + f = open(html_file, 'wt') + f.write(self.doc.render()) + f.close() + + +if __name__ == '__main__': # we show an example usage here. + html = HTML('web/', 'test_html') + html.add_header('hello world') + + ims, txts, links = [], [], [] + for n in range(4): + ims.append('image_%d.png' % n) + txts.append('text_%d' % n) + links.append('image_%d.png' % n) + html.add_images(ims, txts, links) + html.save() diff --git a/Geneface_main/GeneFace/deep_3drecon/util/load_mats.py b/Geneface_main/GeneFace/deep_3drecon/util/load_mats.py new file mode 100644 index 00000000..5b1f4a73 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/util/load_mats.py @@ -0,0 +1,117 @@ +"""This script is to load 3D face model for Deep3DFaceRecon_pytorch +""" + +import numpy as np +from PIL import Image +from scipy.io import loadmat, savemat +from array import array +import os.path as osp + +# load expression basis +def LoadExpBasis(bfm_folder='BFM'): + n_vertex = 53215 + Expbin = open(osp.join(bfm_folder, 'Exp_Pca.bin'), 'rb') + exp_dim = array('i') + exp_dim.fromfile(Expbin, 1) + expMU = array('f') + expPC = array('f') + expMU.fromfile(Expbin, 3*n_vertex) + expPC.fromfile(Expbin, 3*exp_dim[0]*n_vertex) + Expbin.close() + + expPC = np.array(expPC) + expPC = np.reshape(expPC, [exp_dim[0], -1]) + expPC = np.transpose(expPC) + + expEV = np.loadtxt(osp.join(bfm_folder, 'std_exp.txt')) + + return expPC, expEV + + +# transfer original BFM09 to our face model +def transferBFM09(bfm_folder='BFM'): + print('Transfer BFM09 to BFM_model_front......') + original_BFM = loadmat(osp.join(bfm_folder, '01_MorphableModel.mat')) + shapePC = original_BFM['shapePC'] # shape basis + shapeEV = original_BFM['shapeEV'] # corresponding eigen value + shapeMU = original_BFM['shapeMU'] # mean face + texPC = original_BFM['texPC'] # texture basis + texEV = original_BFM['texEV'] # eigen value + texMU = original_BFM['texMU'] # mean texture + + expPC, expEV = LoadExpBasis() + + # transfer BFM09 to our face model + + idBase = shapePC*np.reshape(shapeEV, [-1, 199]) + idBase = idBase/1e5 # unify the scale to decimeter + idBase = idBase[:, :80] # use only first 80 basis + + exBase = expPC*np.reshape(expEV, [-1, 79]) + exBase = exBase/1e5 # unify the scale to decimeter + exBase = exBase[:, :64] # use only first 64 basis + + texBase = texPC*np.reshape(texEV, [-1, 199]) + texBase = texBase[:, :80] # use only first 80 basis + + # our face model is cropped along face landmarks and contains only 35709 vertex. + # original BFM09 contains 53490 vertex, and expression basis provided by Guo et al. contains 53215 vertex. + # thus we select corresponding vertex to get our face model. + + index_exp = loadmat(osp.join(bfm_folder, 'BFM_front_idx.mat')) + index_exp = index_exp['idx'].astype(np.int32) - 1 # starts from 0 (to 53215) + + index_shape = loadmat(osp.join(bfm_folder, 'BFM_exp_idx.mat')) + index_shape = index_shape['trimIndex'].astype( + np.int32) - 1 # starts from 0 (to 53490) + index_shape = index_shape[index_exp] + + idBase = np.reshape(idBase, [-1, 3, 80]) + idBase = idBase[index_shape, :, :] + idBase = np.reshape(idBase, [-1, 80]) + + texBase = np.reshape(texBase, [-1, 3, 80]) + texBase = texBase[index_shape, :, :] + texBase = np.reshape(texBase, [-1, 80]) + + exBase = np.reshape(exBase, [-1, 3, 64]) + exBase = exBase[index_exp, :, :] + exBase = np.reshape(exBase, [-1, 64]) + + meanshape = np.reshape(shapeMU, [-1, 3])/1e5 + meanshape = meanshape[index_shape, :] + meanshape = np.reshape(meanshape, [1, -1]) + + meantex = np.reshape(texMU, [-1, 3]) + meantex = meantex[index_shape, :] + meantex = np.reshape(meantex, [1, -1]) + + # other info contains triangles, region used for computing photometric loss, + # region used for skin texture regularization, and 68 landmarks index etc. + other_info = loadmat(osp.join(bfm_folder, 'facemodel_info.mat')) + frontmask2_idx = other_info['frontmask2_idx'] + skinmask = other_info['skinmask'] + keypoints = other_info['keypoints'] + point_buf = other_info['point_buf'] + tri = other_info['tri'] + tri_mask2 = other_info['tri_mask2'] + + # save our face model + savemat(osp.join(bfm_folder, 'BFM_model_front.mat'), {'meanshape': meanshape, 'meantex': meantex, 'idBase': idBase, 'exBase': exBase, 'texBase': texBase, + 'tri': tri, 'point_buf': point_buf, 'tri_mask2': tri_mask2, 'keypoints': keypoints, 'frontmask2_idx': frontmask2_idx, 'skinmask': skinmask}) + + +# load landmarks for standard face, which is used for image preprocessing +def load_lm3d(bfm_folder): + + Lm3D = loadmat(osp.join(bfm_folder, 'similarity_Lm3D_all.mat')) + Lm3D = Lm3D['lm'] + + # calculate 5 facial landmarks using 68 landmarks + lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 + Lm3D = np.stack([Lm3D[lm_idx[0], :], np.mean(Lm3D[lm_idx[[1, 2]], :], 0), np.mean( + Lm3D[lm_idx[[3, 4]], :], 0), Lm3D[lm_idx[5], :], Lm3D[lm_idx[6], :]], axis=0) + Lm3D = Lm3D[[1, 2, 0, 3, 4], :] + + return Lm3D + diff --git a/Geneface_main/GeneFace/deep_3drecon/util/mesh_renderer.py b/Geneface_main/GeneFace/deep_3drecon/util/mesh_renderer.py new file mode 100644 index 00000000..5b7b5a23 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/util/mesh_renderer.py @@ -0,0 +1,126 @@ +"""This script is the differentiable renderer for Deep3DFaceRecon_pytorch + Attention, antialiasing step is missing in current version. +""" +import pytorch3d.ops +import torch +import torch.nn.functional as F +import kornia +from kornia.geometry.camera import pixel2cam +import numpy as np +from typing import List +from scipy.io import loadmat +from torch import nn + +from pytorch3d.structures import Meshes +from pytorch3d.renderer import ( + look_at_view_transform, + FoVPerspectiveCameras, + DirectionalLights, + RasterizationSettings, + MeshRenderer, + MeshRasterizer, + SoftPhongShader, + TexturesUV, +) + +# def ndc_projection(x=0.1, n=1.0, f=50.0): +# return np.array([[n/x, 0, 0, 0], +# [ 0, n/-x, 0, 0], +# [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], +# [ 0, 0, -1, 0]]).astype(np.float32) + +class MeshRenderer(nn.Module): + def __init__(self, + rasterize_fov, + znear=0.1, + zfar=10, + rasterize_size=224,**args): + super(MeshRenderer, self).__init__() + + # x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear + # self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul( + # torch.diag(torch.tensor([1., -1, -1, 1]))) + self.rasterize_size = rasterize_size + self.fov = rasterize_fov + self.znear = znear + self.zfar = zfar + + self.rasterizer = None + + def forward(self, vertex, tri, feat=None): + """ + Return: + mask -- torch.tensor, size (B, 1, H, W) + depth -- torch.tensor, size (B, 1, H, W) + features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None + + Parameters: + vertex -- torch.tensor, size (B, N, 3) + tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles + feat(optional) -- torch.tensor, size (B, N ,C), features + """ + device = vertex.device + rsize = int(self.rasterize_size) + # ndc_proj = self.ndc_proj.to(device) + # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v + if vertex.shape[-1] == 3: + vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1) + vertex[..., 0] = -vertex[..., 0] + + + # vertex_ndc = vertex @ ndc_proj.t() + if self.rasterizer is None: + self.rasterizer = MeshRasterizer() + print("create rasterizer on device cuda:%d"%device.index) + + # ranges = None + # if isinstance(tri, List) or len(tri.shape) == 3: + # vum = vertex_ndc.shape[1] + # fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device) + # fstartidx = torch.cumsum(fnum, dim=0) - fnum + # ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu() + # for i in range(tri.shape[0]): + # tri[i] = tri[i] + i*vum + # vertex_ndc = torch.cat(vertex_ndc, dim=0) + # tri = torch.cat(tri, dim=0) + + # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3] + tri = tri.type(torch.int32).contiguous() + + # rasterize + cameras = FoVPerspectiveCameras( + device=device, + fov=self.fov, + znear=self.znear, + zfar=self.zfar, + ) + + raster_settings = RasterizationSettings( + image_size=rsize + ) + + # print(vertex.shape, tri.shape) + mesh = Meshes(vertex.contiguous()[...,:3], tri.unsqueeze(0)) + + fragments = self.rasterizer(mesh, cameras = cameras, raster_settings = raster_settings) + rast_out = fragments.pix_to_face.squeeze(-1) + depth = fragments.zbuf + + # render depth + depth = depth.permute(0, 3, 1, 2) + mask = (rast_out > 0).float().unsqueeze(1) + depth = mask * depth + + + image = None + if feat is not None: + attributes = feat.reshape(-1,3)[mesh.faces_packed()] + image = pytorch3d.ops.interpolate_face_attributes(fragments.pix_to_face, + fragments.bary_coords, + attributes) + # print(image.shape) + image = image.squeeze(-2).permute(0, 3, 1, 2) + image = mask * image + + return mask, depth, image + diff --git a/Geneface_main/GeneFace/deep_3drecon/util/preprocess.py b/Geneface_main/GeneFace/deep_3drecon/util/preprocess.py new file mode 100644 index 00000000..6c4a913e --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/util/preprocess.py @@ -0,0 +1,230 @@ +"""This script contains the image preprocessing code for Deep3DFaceRecon_pytorch +""" + +import numpy as np +from scipy.io import loadmat +from PIL import Image +import cv2 +import os +from skimage import transform as trans +import torch +import warnings +warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + + +# calculating least square problem for image alignment +def POS(xp, x): + npts = xp.shape[1] + + A = np.zeros([2*npts, 8]) + + A[0:2*npts-1:2, 0:3] = x.transpose() + A[0:2*npts-1:2, 3] = 1 + + A[1:2*npts:2, 4:7] = x.transpose() + A[1:2*npts:2, 7] = 1 + + b = np.reshape(xp.transpose(), [2*npts, 1]) + + k, _, _, _ = np.linalg.lstsq(A, b) + + R1 = k[0:3] + R2 = k[4:7] + sTx = k[3] + sTy = k[7] + s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2 + t = np.stack([sTx, sTy], axis=0) + + return t, s + +# bounding box for 68 landmark detection +def BBRegression(points, params): + + w1 = params['W1'] + b1 = params['B1'] + w2 = params['W2'] + b2 = params['B2'] + data = points.copy() + data = data.reshape([5, 2]) + data_mean = np.mean(data, axis=0) + x_mean = data_mean[0] + y_mean = data_mean[1] + data[:, 0] = data[:, 0] - x_mean + data[:, 1] = data[:, 1] - y_mean + + rms = np.sqrt(np.sum(data ** 2)/5) + data = data / rms + data = data.reshape([1, 10]) + data = np.transpose(data) + inputs = np.matmul(w1, data) + b1 + inputs = 2 / (1 + np.exp(-2 * inputs)) - 1 + inputs = np.matmul(w2, inputs) + b2 + inputs = np.transpose(inputs) + x = inputs[:, 0] * rms + x_mean + y = inputs[:, 1] * rms + y_mean + w = 224/inputs[:, 2] * rms + rects = [x, y, w, w] + return np.array(rects).reshape([4]) + +# utils for landmark detection +def img_padding(img, box): + success = True + bbox = box.copy() + res = np.zeros([2*img.shape[0], 2*img.shape[1], 3]) + res[img.shape[0] // 2: img.shape[0] + img.shape[0] // + 2, img.shape[1] // 2: img.shape[1] + img.shape[1]//2] = img + + bbox[0] = bbox[0] + img.shape[1] // 2 + bbox[1] = bbox[1] + img.shape[0] // 2 + if bbox[0] < 0 or bbox[1] < 0: + success = False + return res, bbox, success + +# utils for landmark detection +def crop(img, bbox): + padded_img, padded_bbox, flag = img_padding(img, bbox) + if flag: + crop_img = padded_img[padded_bbox[1]: padded_bbox[1] + + padded_bbox[3], padded_bbox[0]: padded_bbox[0] + padded_bbox[2]] + crop_img = cv2.resize(crop_img.astype(np.uint8), + (224, 224), interpolation=cv2.INTER_CUBIC) + scale = 224 / padded_bbox[3] + return crop_img, scale + else: + return padded_img, 0 + +# utils for landmark detection +def scale_trans(img, lm, t, s): + imgw = img.shape[1] + imgh = img.shape[0] + M_s = np.array([[1, 0, -t[0] + imgw//2 + 0.5], [0, 1, -imgh//2 + t[1]]], + dtype=np.float32) + img = cv2.warpAffine(img, M_s, (imgw, imgh)) + w = int(imgw / s * 100) + h = int(imgh / s * 100) + img = cv2.resize(img, (w, h)) + lm = np.stack([lm[:, 0] - t[0] + imgw // 2, lm[:, 1] - + t[1] + imgh // 2], axis=1) / s * 100 + + left = w//2 - 112 + up = h//2 - 112 + bbox = [left, up, 224, 224] + cropped_img, scale2 = crop(img, bbox) + assert(scale2!=0) + t1 = np.array([bbox[0], bbox[1]]) + + # back to raw img s * crop + s * t1 + t2 + t1 = np.array([w//2 - 112, h//2 - 112]) + scale = s / 100 + t2 = np.array([t[0] - imgw/2, t[1] - imgh / 2]) + inv = (scale/scale2, scale * t1 + t2.reshape([2])) + return cropped_img, inv + +# utils for landmark detection +def align_for_lm(img, five_points): + five_points = np.array(five_points).reshape([1, 10]) + params = loadmat('util/BBRegressorParam_r.mat') + bbox = BBRegression(five_points, params) + assert(bbox[2] != 0) + bbox = np.round(bbox).astype(np.int32) + crop_img, scale = crop(img, bbox) + return crop_img, scale, bbox + + +# resize and crop images for face reconstruction +def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None): + w0, h0 = img.size + w = (w0*s).astype(np.int32) + h = (h0*s).astype(np.int32) + left = (w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32) + right = left + target_size + up = (h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32) + below = up + target_size + + img = img.resize((w, h), resample=Image.BICUBIC) + img = img.crop((left, up, right, below)) + + if mask is not None: + mask = mask.resize((w, h), resample=Image.BICUBIC) + mask = mask.crop((left, up, right, below)) + + lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] - + t[1] + h0/2], axis=1)*s + lm = lm - np.reshape( + np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2]) + + return img, lm, mask + +# utils for face reconstruction +def extract_5p(lm): + lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 + lm5p = np.stack([lm[lm_idx[0], :], np.mean(lm[lm_idx[[1, 2]], :], 0), np.mean( + lm[lm_idx[[3, 4]], :], 0), lm[lm_idx[5], :], lm[lm_idx[6], :]], axis=0) + lm5p = lm5p[[1, 2, 0, 3, 4], :] + return lm5p + +# utils for face reconstruction +def align_img(img, lm, lm3D, mask=None, target_size=224., rescale_factor=102.): + """ + Return: + transparams --numpy.array (raw_W, raw_H, scale, tx, ty) + img_new --PIL.Image (target_size, target_size, 3) + lm_new --numpy.array (68, 2), y direction is opposite to v direction + mask_new --PIL.Image (target_size, target_size) + + Parameters: + img --PIL.Image (raw_H, raw_W, 3) + lm --numpy.array (68, 2), y direction is opposite to v direction + lm3D --numpy.array (5, 3) + mask --PIL.Image (raw_H, raw_W, 3) + """ + w0, h0 = img.size + + + if lm.shape[0] != 5: + lm5p = extract_5p(lm) + else: + lm5p = lm + + # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face + t, s = POS(lm5p.transpose(), lm3D.transpose()) + s = rescale_factor/s + + # processing the image + img_new, lm_new, mask_new = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask) + trans_params = np.array([w0, h0, s, t[0], t[1]]) + + return trans_params, img_new, lm_new, mask_new + +# utils for face recognition model +def estimate_norm(lm_68p, H): + # from https://github.com/deepinsight/insightface/blob/c61d3cd208a603dfa4a338bd743b320ce3e94730/recognition/common/face_align.py#L68 + """ + Return: + trans_m --numpy.array (2, 3) + Parameters: + lm --numpy.array (68, 2), y direction is opposite to v direction + H --int/float , image height + """ + lm = extract_5p(lm_68p) + lm[:, -1] = H - 1 - lm[:, -1] + tform = trans.SimilarityTransform() + src = np.array( + [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366], + [41.5493, 92.3655], [70.7299, 92.2041]], + dtype=np.float32) + tform.estimate(lm, src) + M = tform.params + if np.linalg.det(M) == 0: + M = np.eye(3) + + return M[0:2, :] + +def estimate_norm_torch(lm_68p, H): + lm_68p_ = lm_68p.detach().cpu().numpy() + M = [] + for i in range(lm_68p_.shape[0]): + M.append(estimate_norm(lm_68p_[i], H)) + M = torch.tensor(np.array(M), dtype=torch.float32).to(lm_68p.device) + return M diff --git a/Geneface_main/GeneFace/deep_3drecon/util/skin_mask.py b/Geneface_main/GeneFace/deep_3drecon/util/skin_mask.py new file mode 100644 index 00000000..a8a74e4c --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/util/skin_mask.py @@ -0,0 +1,125 @@ +"""This script is to generate skin attention mask for Deep3DFaceRecon_pytorch +""" + +import math +import numpy as np +import os +import cv2 + +class GMM: + def __init__(self, dim, num, w, mu, cov, cov_det, cov_inv): + self.dim = dim # feature dimension + self.num = num # number of Gaussian components + self.w = w # weights of Gaussian components (a list of scalars) + self.mu= mu # mean of Gaussian components (a list of 1xdim vectors) + self.cov = cov # covariance matrix of Gaussian components (a list of dimxdim matrices) + self.cov_det = cov_det # pre-computed determinet of covariance matrices (a list of scalars) + self.cov_inv = cov_inv # pre-computed inverse covariance matrices (a list of dimxdim matrices) + + self.factor = [0]*num + for i in range(self.num): + self.factor[i] = (2*math.pi)**(self.dim/2) * self.cov_det[i]**0.5 + + def likelihood(self, data): + assert(data.shape[1] == self.dim) + N = data.shape[0] + lh = np.zeros(N) + + for i in range(self.num): + data_ = data - self.mu[i] + + tmp = np.matmul(data_,self.cov_inv[i]) * data_ + tmp = np.sum(tmp,axis=1) + power = -0.5 * tmp + + p = np.array([math.exp(power[j]) for j in range(N)]) + p = p/self.factor[i] + lh += p*self.w[i] + + return lh + + +def _rgb2ycbcr(rgb): + m = np.array([[65.481, 128.553, 24.966], + [-37.797, -74.203, 112], + [112, -93.786, -18.214]]) + shape = rgb.shape + rgb = rgb.reshape((shape[0] * shape[1], 3)) + ycbcr = np.dot(rgb, m.transpose() / 255.) + ycbcr[:, 0] += 16. + ycbcr[:, 1:] += 128. + return ycbcr.reshape(shape) + + +def _bgr2ycbcr(bgr): + rgb = bgr[..., ::-1] + return _rgb2ycbcr(rgb) + + +gmm_skin_w = [0.24063933, 0.16365987, 0.26034665, 0.33535415] +gmm_skin_mu = [np.array([113.71862, 103.39613, 164.08226]), + np.array([150.19858, 105.18467, 155.51428]), + np.array([183.92976, 107.62468, 152.71820]), + np.array([114.90524, 113.59782, 151.38217])] +gmm_skin_cov_det = [5692842.5, 5851930.5, 2329131., 1585971.] +gmm_skin_cov_inv = [np.array([[0.0019472069, 0.0020450759, -0.00060243998],[0.0020450759, 0.017700525, 0.0051420014],[-0.00060243998, 0.0051420014, 0.0081308950]]), + np.array([[0.0027110141, 0.0011036990, 0.0023122299],[0.0011036990, 0.010707724, 0.010742856],[0.0023122299, 0.010742856, 0.017481629]]), + np.array([[0.0048026871, 0.00022935172, 0.0077668377],[0.00022935172, 0.011729696, 0.0081661865],[0.0077668377, 0.0081661865, 0.025374353]]), + np.array([[0.0011989699, 0.0022453172, -0.0010748957],[0.0022453172, 0.047758564, 0.020332102],[-0.0010748957, 0.020332102, 0.024502251]])] + +gmm_skin = GMM(3, 4, gmm_skin_w, gmm_skin_mu, [], gmm_skin_cov_det, gmm_skin_cov_inv) + +gmm_nonskin_w = [0.12791070, 0.31130761, 0.34245777, 0.21832393] +gmm_nonskin_mu = [np.array([99.200851, 112.07533, 140.20602]), + np.array([110.91392, 125.52969, 130.19237]), + np.array([129.75864, 129.96107, 126.96808]), + np.array([112.29587, 128.85121, 129.05431])] +gmm_nonskin_cov_det = [458703648., 6466488., 90611376., 133097.63] +gmm_nonskin_cov_inv = [np.array([[0.00085371657, 0.00071197288, 0.00023958916],[0.00071197288, 0.0025935620, 0.00076557708],[0.00023958916, 0.00076557708, 0.0015042332]]), + np.array([[0.00024650150, 0.00045542428, 0.00015019422],[0.00045542428, 0.026412144, 0.018419769],[0.00015019422, 0.018419769, 0.037497383]]), + np.array([[0.00037054974, 0.00038146760, 0.00040408765],[0.00038146760, 0.0085505722, 0.0079136286],[0.00040408765, 0.0079136286, 0.010982352]]), + np.array([[0.00013709733, 0.00051228428, 0.00012777430],[0.00051228428, 0.28237113, 0.10528370],[0.00012777430, 0.10528370, 0.23468947]])] + +gmm_nonskin = GMM(3, 4, gmm_nonskin_w, gmm_nonskin_mu, [], gmm_nonskin_cov_det, gmm_nonskin_cov_inv) + +prior_skin = 0.8 +prior_nonskin = 1 - prior_skin + + +# calculate skin attention mask +def skinmask(imbgr): + im = _bgr2ycbcr(imbgr) + + data = im.reshape((-1,3)) + + lh_skin = gmm_skin.likelihood(data) + lh_nonskin = gmm_nonskin.likelihood(data) + + tmp1 = prior_skin * lh_skin + tmp2 = prior_nonskin * lh_nonskin + post_skin = tmp1 / (tmp1+tmp2) # posterior probability + + post_skin = post_skin.reshape((im.shape[0],im.shape[1])) + + post_skin = np.round(post_skin*255) + post_skin = post_skin.astype(np.uint8) + post_skin = np.tile(np.expand_dims(post_skin,2),[1,1,3]) # reshape to H*W*3 + + return post_skin + + +def get_skin_mask(img_path): + print('generating skin masks......') + names = [i for i in sorted(os.listdir( + img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i] + save_path = os.path.join(img_path, 'mask') + if not os.path.isdir(save_path): + os.makedirs(save_path) + + for i in range(0, len(names)): + name = names[i] + print('%05d' % (i), ' ', name) + full_image_name = os.path.join(img_path, name) + img = cv2.imread(full_image_name).astype(np.float32) + skin_img = skinmask(img) + cv2.imwrite(os.path.join(save_path, name), skin_img.astype(np.uint8)) diff --git a/Geneface_main/GeneFace/deep_3drecon/util/test_mean_face.txt b/Geneface_main/GeneFace/deep_3drecon/util/test_mean_face.txt new file mode 100644 index 00000000..3a46d4db --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/util/test_mean_face.txt @@ -0,0 +1,136 @@ +-5.228591537475585938e+01 +2.078247070312500000e-01 +-5.064269638061523438e+01 +-1.315765380859375000e+01 +-4.952939224243164062e+01 +-2.592591094970703125e+01 +-4.793047332763671875e+01 +-3.832135772705078125e+01 +-4.512159729003906250e+01 +-5.059623336791992188e+01 +-3.917720794677734375e+01 +-6.043736648559570312e+01 +-2.929953765869140625e+01 +-6.861183166503906250e+01 +-1.719801330566406250e+01 +-7.572736358642578125e+01 +-1.961936950683593750e+00 +-7.862001037597656250e+01 +1.467941284179687500e+01 +-7.607844543457031250e+01 +2.744073486328125000e+01 +-6.915261840820312500e+01 +3.855677795410156250e+01 +-5.950350570678710938e+01 +4.478240966796875000e+01 +-4.867547225952148438e+01 +4.714337158203125000e+01 +-3.800830078125000000e+01 +4.940315246582031250e+01 +-2.496297454833984375e+01 +5.117234802246093750e+01 +-1.241538238525390625e+01 +5.190507507324218750e+01 +8.244247436523437500e-01 +-4.150688934326171875e+01 +2.386329650878906250e+01 +-3.570307159423828125e+01 +3.017010498046875000e+01 +-2.790358734130859375e+01 +3.212951660156250000e+01 +-1.941773223876953125e+01 +3.156523132324218750e+01 +-1.138106536865234375e+01 +2.841992187500000000e+01 +5.993263244628906250e+00 +2.895182800292968750e+01 +1.343590545654296875e+01 +3.189880371093750000e+01 +2.203153991699218750e+01 +3.302221679687500000e+01 +2.992478942871093750e+01 +3.099150085449218750e+01 +3.628388977050781250e+01 +2.765748596191406250e+01 +-1.933914184570312500e+00 +1.405374145507812500e+01 +-2.153038024902343750e+00 +5.772636413574218750e+00 +-2.270050048828125000e+00 +-2.121643066406250000e+00 +-2.218330383300781250e+00 +-1.068978118896484375e+01 +-1.187252044677734375e+01 +-1.997912597656250000e+01 +-6.879402160644531250e+00 +-2.143579864501953125e+01 +-1.227821350097656250e+00 +-2.193494415283203125e+01 +4.623237609863281250e+00 +-2.152721405029296875e+01 +9.721397399902343750e+00 +-1.953671264648437500e+01 +-3.648714447021484375e+01 +9.811126708984375000e+00 +-3.130242919921875000e+01 +1.422447967529296875e+01 +-2.212834930419921875e+01 +1.493019866943359375e+01 +-1.500880432128906250e+01 +1.073588562011718750e+01 +-2.095037078857421875e+01 +9.054298400878906250e+00 +-3.050099182128906250e+01 +8.704177856445312500e+00 +1.173237609863281250e+01 +1.054329681396484375e+01 +1.856353759765625000e+01 +1.535009765625000000e+01 +2.893331909179687500e+01 +1.451992797851562500e+01 +3.452944946289062500e+01 +1.065280151367187500e+01 +2.875990295410156250e+01 +8.654792785644531250e+00 +1.942100524902343750e+01 +9.422447204589843750e+00 +-2.204488372802734375e+01 +-3.983994293212890625e+01 +-1.324458312988281250e+01 +-3.467377471923828125e+01 +-6.749649047851562500e+00 +-3.092894744873046875e+01 +-9.183349609375000000e-01 +-3.196458435058593750e+01 +4.220649719238281250e+00 +-3.090406036376953125e+01 +1.089889526367187500e+01 +-3.497008514404296875e+01 +1.874589538574218750e+01 +-4.065438079833984375e+01 +1.124106597900390625e+01 +-4.438417816162109375e+01 +5.181709289550781250e+00 +-4.649170684814453125e+01 +-1.158607482910156250e+00 +-4.680406951904296875e+01 +-7.918922424316406250e+00 +-4.671575164794921875e+01 +-1.452505493164062500e+01 +-4.416526031494140625e+01 +-2.005007171630859375e+01 +-3.997841644287109375e+01 +-1.054919433593750000e+01 +-3.849683380126953125e+01 +-1.051826477050781250e+00 +-3.794863128662109375e+01 +6.412681579589843750e+00 +-3.804645538330078125e+01 +1.627674865722656250e+01 +-4.039697265625000000e+01 +6.373878479003906250e+00 +-4.087213897705078125e+01 +-8.551712036132812500e-01 +-4.157129669189453125e+01 +-1.014953613281250000e+01 +-4.128469085693359375e+01 diff --git a/Geneface_main/GeneFace/deep_3drecon/util/util.py b/Geneface_main/GeneFace/deep_3drecon/util/util.py new file mode 100644 index 00000000..0d689ca1 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/util/util.py @@ -0,0 +1,208 @@ +"""This script contains basic utilities for Deep3DFaceRecon_pytorch +""" +from __future__ import print_function +import numpy as np +import torch +from PIL import Image +import os +import importlib +import argparse +from argparse import Namespace +import torchvision + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def copyconf(default_opt, **kwargs): + conf = Namespace(**vars(default_opt)) + for key in kwargs: + setattr(conf, key, kwargs[key]) + return conf + +def genvalconf(train_opt, **kwargs): + conf = Namespace(**vars(train_opt)) + attr_dict = train_opt.__dict__ + for key, value in attr_dict.items(): + if 'val' in key and key.split('_')[0] in attr_dict: + setattr(conf, key.split('_')[0], value) + + for key in kwargs: + setattr(conf, key, kwargs[key]) + + return conf + +def find_class_in_module(target_cls_name, module): + target_cls_name = target_cls_name.replace('_', '').lower() + clslib = importlib.import_module(module) + cls = None + for name, clsobj in clslib.__dict__.items(): + if name.lower() == target_cls_name: + cls = clsobj + + assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name) + + return cls + + +def tensor2im(input_image, imtype=np.uint8): + """"Converts a Tensor array into a numpy image array. + + Parameters: + input_image (tensor) -- the input image tensor array, range(0, 1) + imtype (type) -- the desired type of the converted numpy array + """ + if not isinstance(input_image, np.ndarray): + if isinstance(input_image, torch.Tensor): # get the data from a variable + image_tensor = input_image.data + else: + return input_image + image_numpy = image_tensor.clamp(0.0, 1.0).cpu().float().numpy() # convert it into a numpy array + if image_numpy.shape[0] == 1: # grayscale to RGB + image_numpy = np.tile(image_numpy, (3, 1, 1)) + image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 # post-processing: tranpose and scaling + else: # if it is a numpy array, do nothing + image_numpy = input_image + return image_numpy.astype(imtype) + + +def diagnose_network(net, name='network'): + """Calculate and print the mean of average absolute(gradients) + + Parameters: + net (torch network) -- Torch network + name (str) -- the name of the network + """ + mean = 0.0 + count = 0 + for param in net.parameters(): + if param.grad is not None: + mean += torch.mean(torch.abs(param.grad.data)) + count += 1 + if count > 0: + mean = mean / count + print(name) + print(mean) + + +def save_image(image_numpy, image_path, aspect_ratio=1.0): + """Save a numpy image to the disk + + Parameters: + image_numpy (numpy array) -- input numpy array + image_path (str) -- the path of the image + """ + + image_pil = Image.fromarray(image_numpy) + h, w, _ = image_numpy.shape + + if aspect_ratio is None: + pass + elif aspect_ratio > 1.0: + image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) + elif aspect_ratio < 1.0: + image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) + image_pil.save(image_path) + + +def print_numpy(x, val=True, shp=False): + """Print the mean, min, max, median, std, and size of a numpy array + + Parameters: + val (bool) -- if print the values of the numpy array + shp (bool) -- if print the shape of the numpy array + """ + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( + np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + + +def mkdirs(paths): + """create empty directories if they don't exist + + Parameters: + paths (str list) -- a list of directory paths + """ + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + """create a single empty directory if it didn't exist + + Parameters: + path (str) -- a single directory path + """ + if not os.path.exists(path): + os.makedirs(path) + + +def correct_resize_label(t, size): + device = t.device + t = t.detach().cpu() + resized = [] + for i in range(t.size(0)): + one_t = t[i, :1] + one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0)) + one_np = one_np[:, :, 0] + one_image = Image.fromarray(one_np).resize(size, Image.NEAREST) + resized_t = torch.from_numpy(np.array(one_image)).long() + resized.append(resized_t) + return torch.stack(resized, dim=0).to(device) + + +def correct_resize(t, size, mode=Image.BICUBIC): + device = t.device + t = t.detach().cpu() + resized = [] + for i in range(t.size(0)): + one_t = t[i:i + 1] + one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC) + resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0 + resized.append(resized_t) + return torch.stack(resized, dim=0).to(device) + +def draw_landmarks(img, landmark, color='r', step=2): + """ + Return: + img -- numpy.array, (B, H, W, 3) img with landmark, RGB order, range (0, 255) + + + Parameters: + img -- numpy.array, (B, H, W, 3), RGB order, range (0, 255) + landmark -- numpy.array, (B, 68, 2), y direction is opposite to v direction + color -- str, 'r' or 'b' (red or blue) + """ + if color =='r': + c = np.array([255., 0, 0]) + else: + c = np.array([0, 0, 255.]) + + _, H, W, _ = img.shape + img, landmark = img.copy(), landmark.copy() + landmark[..., 1] = H - 1 - landmark[..., 1] + landmark = np.round(landmark).astype(np.int32) + for i in range(landmark.shape[1]): + x, y = landmark[:, i, 0], landmark[:, i, 1] + for j in range(-step, step): + for k in range(-step, step): + u = np.clip(x + j, 0, W - 1) + v = np.clip(y + k, 0, H - 1) + for m in range(landmark.shape[0]): + img[m, v[m], u[m]] = c + return img diff --git a/Geneface_main/GeneFace/deep_3drecon/util/visualizer.py b/Geneface_main/GeneFace/deep_3drecon/util/visualizer.py new file mode 100644 index 00000000..4023a6d4 --- /dev/null +++ b/Geneface_main/GeneFace/deep_3drecon/util/visualizer.py @@ -0,0 +1,227 @@ +"""This script defines the visualizer for Deep3DFaceRecon_pytorch +""" + +import numpy as np +import os +import sys +import ntpath +import time +from . import util, html +from subprocess import Popen, PIPE +from torch.utils.tensorboard import SummaryWriter + +def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): + """Save images to the disk. + + Parameters: + webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) + visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs + image_path (str) -- the string is used to create image paths + aspect_ratio (float) -- the aspect ratio of saved images + width (int) -- the images will be resized to width x width + + This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. + """ + image_dir = webpage.get_image_dir() + short_path = ntpath.basename(image_path[0]) + name = os.path.splitext(short_path)[0] + + webpage.add_header(name) + ims, txts, links = [], [], [] + + for label, im_data in visuals.items(): + im = util.tensor2im(im_data) + image_name = '%s/%s.png' % (label, name) + os.makedirs(os.path.join(image_dir, label), exist_ok=True) + save_path = os.path.join(image_dir, image_name) + util.save_image(im, save_path, aspect_ratio=aspect_ratio) + ims.append(image_name) + txts.append(label) + links.append(image_name) + webpage.add_images(ims, txts, links, width=width) + + +class Visualizer(): + """This class includes several functions that can display/save images and print/save logging information. + + It uses a Python library tensprboardX for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. + """ + + def __init__(self, opt): + """Initialize the Visualizer class + + Parameters: + opt -- stores all the experiment flags; needs to be a subclass of BaseOptions + Step 1: Cache the training/test options + Step 2: create a tensorboard writer + Step 3: create an HTML object for saveing HTML filters + Step 4: create a logging file to store training losses + """ + self.opt = opt # cache the option + self.use_html = opt.isTrain and not opt.no_html + self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, 'logs', opt.name)) + self.win_size = opt.display_winsize + self.name = opt.name + self.saved = False + if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ + self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') + self.img_dir = os.path.join(self.web_dir, 'images') + print('create web directory %s...' % self.web_dir) + util.mkdirs([self.web_dir, self.img_dir]) + # create a logging file to store training losses + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + def reset(self): + """Reset the self.saved status""" + self.saved = False + + + def display_current_results(self, visuals, total_iters, epoch, save_result): + """Display current results on tensorboad; save current results to an HTML file. + + Parameters: + visuals (OrderedDict) - - dictionary of images to display or save + total_iters (int) -- total iterations + epoch (int) - - the current epoch + save_result (bool) - - if save the current results to an HTML file + """ + for label, image in visuals.items(): + self.writer.add_image(label, util.tensor2im(image), total_iters, dataformats='HWC') + + if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. + self.saved = True + # save images to the disk + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) + util.save_image(image_numpy, img_path) + + # update website + webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0) + for n in range(epoch, 0, -1): + webpage.add_header('epoch [%d]' % n) + ims, txts, links = [], [], [] + + for label, image_numpy in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = 'epoch%.3d_%s.png' % (n, label) + ims.append(img_path) + txts.append(label) + links.append(img_path) + webpage.add_images(ims, txts, links, width=self.win_size) + webpage.save() + + def plot_current_losses(self, total_iters, losses): + # G_loss_collection = {} + # D_loss_collection = {} + # for name, value in losses.items(): + # if 'G' in name or 'NCE' in name or 'idt' in name: + # G_loss_collection[name] = value + # else: + # D_loss_collection[name] = value + # self.writer.add_scalars('G_collec', G_loss_collection, total_iters) + # self.writer.add_scalars('D_collec', D_loss_collection, total_iters) + for name, value in losses.items(): + self.writer.add_scalar(name, value, total_iters) + + # losses: same format as |losses| of plot_current_losses + def print_current_losses(self, epoch, iters, losses, t_comp, t_data): + """print current losses on console; also save the losses to the disk + + Parameters: + epoch (int) -- current epoch + iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) + losses (OrderedDict) -- training losses stored in the format of (name, float) pairs + t_comp (float) -- computational time per data point (normalized by batch_size) + t_data (float) -- data loading time per data point (normalized by batch_size) + """ + message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) + for k, v in losses.items(): + message += '%s: %.3f ' % (k, v) + + print(message) # print the message + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) # save the message + + +class MyVisualizer: + def __init__(self, opt): + """Initialize the Visualizer class + + Parameters: + opt -- stores all the experiment flags; needs to be a subclass of BaseOptions + Step 1: Cache the training/test options + Step 2: create a tensorboard writer + Step 3: create an HTML object for saveing HTML filters + Step 4: create a logging file to store training losses + """ + self.opt = opt # cache the optio + self.name = opt.name + self.img_dir = os.path.join(opt.checkpoints_dir, opt.name, 'results') + + if opt.phase != 'test': + self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, 'logs')) + # create a logging file to store training losses + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + + def display_current_results(self, visuals, total_iters, epoch, dataset='train', save_results=False, count=0, name=None, + add_image=True): + """Display current results on tensorboad; save current results to an HTML file. + + Parameters: + visuals (OrderedDict) - - dictionary of images to display or save + total_iters (int) -- total iterations + epoch (int) - - the current epoch + dataset (str) - - 'train' or 'val' or 'test' + """ + # if (not add_image) and (not save_results): return + + for label, image in visuals.items(): + for i in range(image.shape[0]): + image_numpy = util.tensor2im(image[i]) + if add_image: + self.writer.add_image(label + '%s_%02d'%(dataset, i + count), + image_numpy, total_iters, dataformats='HWC') + + if save_results: + save_path = os.path.join(self.img_dir, dataset, 'epoch_%s_%06d'%(epoch, total_iters)) + if not os.path.isdir(save_path): + os.makedirs(save_path) + + if name is not None: + img_path = os.path.join(save_path, '%s.png' % name) + else: + img_path = os.path.join(save_path, '%s_%03d.png' % (label, i + count)) + util.save_image(image_numpy, img_path) + + + def plot_current_losses(self, total_iters, losses, dataset='train'): + for name, value in losses.items(): + self.writer.add_scalar(name + '/%s'%dataset, value, total_iters) + + # losses: same format as |losses| of plot_current_losses + def print_current_losses(self, epoch, iters, losses, t_comp, t_data, dataset='train'): + """print current losses on console; also save the losses to the disk + + Parameters: + epoch (int) -- current epoch + iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) + losses (OrderedDict) -- training losses stored in the format of (name, float) pairs + t_comp (float) -- computational time per data point (normalized by batch_size) + t_data (float) -- data loading time per data point (normalized by batch_size) + """ + message = '(dataset: %s, epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % ( + dataset, epoch, iters, t_comp, t_data) + for k, v in losses.items(): + message += '%s: %.3f ' % (k, v) + + print(message) # print the message + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) # save the message diff --git a/Geneface_main/GeneFace/docker/README.md b/Geneface_main/GeneFace/docker/README.md new file mode 100644 index 00000000..38999809 --- /dev/null +++ b/Geneface_main/GeneFace/docker/README.md @@ -0,0 +1,67 @@ +# Docker Env for GeneFace + +## Build image + +```shell +cd docker/ +docker build -t geneface:latest -f dockerfile . +cd .. +``` + +## Download weights (in parallel while building image) + +```shell +# [DATA] +wget ? -O ./deep_3drecon/BFM/BaselFaceModel.tgz +cd ./deep_3drecon/BFM +tar -xvf BaselFaceModel.tgz PublicMM1/01_MorphableModel.mat --strip-components 1 +rm BaselFaceModel.tgz +cd ../../ + +mkdir -p ./deep_3drecon/checkpoints/facerecon/ +wget ? -O ./deep_3drecon/BFM/Exp_Pca.bin +wget ? -O ./deep_3drecon/BFM/BFM_model_front.mat +wget ? -O ./deep_3drecon/checkpoints/facerecon/epoch_20.pth + +# [PRETRAIN WEIGHTS] +wget https://github.com/yerfor/GeneFace/releases/download/v1.1.0/lrs3.zip -P checkpoints/ +wget https://github.com/yerfor/GeneFace/releases/download/v1.1.0/May.zip -P checkpoints/ +unzip checkpoints/lrs3.zip -d checkpoints/ && rm checkpoints/lrs3.zip +unzip checkpoints/May.zip -d checkpoints/ && rm checkpoints/May.zip +``` + +## Run Container + +```shell +docker run -itd \ +--name geneface \ +--gpus all \ +-v $(realpath .):/workspace/GeneFace \ +-w /workspace/GeneFace \ +--network=host --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \ +geneface:latest bash + +docker exec -it geneface bash +``` + +Prepare 3DMM for data preprocssing + +```shell +cd data_util/face_tracking +python convert_BFM.py +cd ../../ +``` + +Run preprocessing + +```shell +export PYTHONPATH=./ +export VIDEO_ID=May +CUDA_VISIBLE_DEVICES=0 data_gen/nerf/process_data.sh $VIDEO_ID +``` + +There are several weights to be downloaded + +For training and inference please checkout `README.md`. + +(by [xk-huang](https://github.com/xk-huang/)) \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/dockerfile b/Geneface_main/GeneFace/docker/dockerfile new file mode 100644 index 00000000..714f79c1 --- /dev/null +++ b/Geneface_main/GeneFace/docker/dockerfile @@ -0,0 +1,16 @@ +ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:22.07-py3 +FROM $BASE_IMAGE + +ENV TORCH_CUDA_ARCH_LIST="8.0 8.6+PTX" + +COPY . . + +RUN echo "GO BRRR!" \ + && conda install -y -c fvcore -c iopath -c conda-forge fvcore iopath \ + && pip install "git+https://github.com/facebookresearch/pytorch3d.git" \ + && apt-get update \ + && apt-get install -y libasound2-dev portaudio19-dev \ + && pip install -r requirements.txt \ + && pip install tensorflow==2.12.0 "opencv-python-headless<4.3" protobuf==3.20.3 \ + && conda install -y ffmpeg \ + && bash install_ext.sh diff --git a/Geneface_main/GeneFace/docker/install_ext.sh b/Geneface_main/GeneFace/docker/install_ext.sh new file mode 100644 index 00000000..59c69e50 --- /dev/null +++ b/Geneface_main/GeneFace/docker/install_ext.sh @@ -0,0 +1,4 @@ +pip install modules/encoders/freqencoder +pip install modules/encoders/shencoder +pip install modules/encoders/gridencoder +pip install modules/raymarching \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/encoders/encoding.py b/Geneface_main/GeneFace/docker/modules/encoders/encoding.py new file mode 100644 index 00000000..d05b74ad --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/encoders/encoding.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def get_encoder(encoding, input_dim=3, + multires=6, + degree=4, + num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, + interpolation='linear', + **kwargs): + + if encoding == 'None': + return lambda x, **kwargs: x, input_dim + + elif encoding == 'frequency': + from modules.radnerfs.encoders.freqencoder import FreqEncoder + encoder = FreqEncoder(input_dim=input_dim, degree=multires) + + elif encoding == 'spherical_harmonics': + from modules.radnerfs.encoders.shencoder import SHEncoder + encoder = SHEncoder(input_dim=input_dim, degree=degree) + + elif encoding == 'hashgrid': + from modules.radnerfs.encoders.gridencoder import GridEncoder + encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners, interpolation=interpolation, **kwargs) + + elif encoding == 'tiledgrid': + from modules.radnerfs.encoders.gridencoder import GridEncoder + encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners, interpolation=interpolation, **kwargs) + + else: + raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, spherical_harmonics, hashgrid, tiledgrid]') + + return encoder, encoder.output_dim \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/encoders/freqencoder/__init__.py b/Geneface_main/GeneFace/docker/modules/encoders/freqencoder/__init__.py new file mode 100644 index 00000000..69ec49cf --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/encoders/freqencoder/__init__.py @@ -0,0 +1 @@ +from .freq import FreqEncoder \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/encoders/freqencoder/backend.py b/Geneface_main/GeneFace/docker/modules/encoders/freqencoder/backend.py new file mode 100644 index 00000000..3bd9131a --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/encoders/freqencoder/backend.py @@ -0,0 +1,41 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', + '-use_fast_math' +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_freqencoder', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'freqencoder.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/encoders/freqencoder/freq.py b/Geneface_main/GeneFace/docker/modules/encoders/freqencoder/freq.py new file mode 100644 index 00000000..5cba1e66 --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/encoders/freqencoder/freq.py @@ -0,0 +1,77 @@ +import numpy as np + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import _freqencoder as _backend +except ImportError: + from .backend import _backend + + +class _freq_encoder(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision + def forward(ctx, inputs, degree, output_dim): + # inputs: [B, input_dim], float + # RETURN: [B, F], float + + if not inputs.is_cuda: inputs = inputs.cuda() + inputs = inputs.contiguous() + + B, input_dim = inputs.shape # batch size, coord dim + + outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) + + _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) + + ctx.save_for_backward(inputs, outputs) + ctx.dims = [B, input_dim, degree, output_dim] + + return outputs + + @staticmethod + #@once_differentiable + @custom_bwd + def backward(ctx, grad): + # grad: [B, C * C] + + grad = grad.contiguous() + inputs, outputs = ctx.saved_tensors + B, input_dim, degree, output_dim = ctx.dims + + grad_inputs = torch.zeros_like(inputs) + _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) + + return grad_inputs, None, None + + +freq_encode = _freq_encoder.apply + + +class FreqEncoder(nn.Module): + def __init__(self, input_dim=3, degree=4): + super().__init__() + + self.input_dim = input_dim + self.degree = degree + self.output_dim = input_dim + input_dim * 2 * degree + + def __repr__(self): + return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}" + + def forward(self, inputs, **kwargs): + # inputs: [..., input_dim] + # return: [..., ] + + prefix_shape = list(inputs.shape[:-1]) + inputs = inputs.reshape(-1, self.input_dim) + + outputs = freq_encode(inputs, self.degree, self.output_dim) + + outputs = outputs.reshape(prefix_shape + [self.output_dim]) + + return outputs \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/encoders/freqencoder/setup.py b/Geneface_main/GeneFace/docker/modules/encoders/freqencoder/setup.py new file mode 100644 index 00000000..3eb4af77 --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/encoders/freqencoder/setup.py @@ -0,0 +1,51 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', + '-use_fast_math' +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +setup( + name='freqencoder', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_freqencoder', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'freqencoder.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/encoders/freqencoder/src/bindings.cpp b/Geneface_main/GeneFace/docker/modules/encoders/freqencoder/src/bindings.cpp new file mode 100644 index 00000000..bb5f285a --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/encoders/freqencoder/src/bindings.cpp @@ -0,0 +1,8 @@ +#include + +#include "freqencoder.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)"); + m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)"); +} \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/encoders/freqencoder/src/freqencoder.cu b/Geneface_main/GeneFace/docker/modules/encoders/freqencoder/src/freqencoder.cu new file mode 100644 index 00000000..de378840 --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/encoders/freqencoder/src/freqencoder.cu @@ -0,0 +1,129 @@ +#include + +#include +#include +#include + +#include +#include + +#include +#include + +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + +inline constexpr __device__ float PI() { return 3.141592653589793f; } + +template +__host__ __device__ T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +// inputs: [B, D] +// outputs: [B, C], C = D + D * deg * 2 +__global__ void kernel_freq( + const float * __restrict__ inputs, + uint32_t B, uint32_t D, uint32_t deg, uint32_t C, + float * outputs +) { + // parallel on per-element + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + if (t >= B * C) return; + + // get index + const uint32_t b = t / C; + const uint32_t c = t - b * C; // t % C; + + // locate + inputs += b * D; + outputs += t; + + // write self + if (c < D) { + outputs[0] = inputs[c]; + // write freq + } else { + const uint32_t col = c / D - 1; + const uint32_t d = c % D; + const uint32_t freq = col / 2; + const float phase_shift = (col % 2) * (PI() / 2); + outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift); + } +} + +// grad: [B, C], C = D + D * deg * 2 +// outputs: [B, C] +// grad_inputs: [B, D] +__global__ void kernel_freq_backward( + const float * __restrict__ grad, + const float * __restrict__ outputs, + uint32_t B, uint32_t D, uint32_t deg, uint32_t C, + float * grad_inputs +) { + // parallel on per-element + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + if (t >= B * D) return; + + const uint32_t b = t / D; + const uint32_t d = t - b * D; // t % D; + + // locate + grad += b * C; + outputs += b * C; + grad_inputs += t; + + // register + float result = grad[d]; + grad += D; + outputs += D; + + for (uint32_t f = 0; f < deg; f++) { + result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]); + grad += 2 * D; + outputs += 2 * D; + } + + // write + grad_inputs[0] = result; +} + + +void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) { + CHECK_CUDA(inputs); + CHECK_CUDA(outputs); + + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(outputs); + + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(outputs); + + static constexpr uint32_t N_THREADS = 128; + + kernel_freq<<>>(inputs.data_ptr(), B, D, deg, C, outputs.data_ptr()); +} + + +void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) { + CHECK_CUDA(grad); + CHECK_CUDA(outputs); + CHECK_CUDA(grad_inputs); + + CHECK_CONTIGUOUS(grad); + CHECK_CONTIGUOUS(outputs); + CHECK_CONTIGUOUS(grad_inputs); + + CHECK_IS_FLOATING(grad); + CHECK_IS_FLOATING(outputs); + CHECK_IS_FLOATING(grad_inputs); + + static constexpr uint32_t N_THREADS = 128; + + kernel_freq_backward<<>>(grad.data_ptr(), outputs.data_ptr(), B, D, deg, C, grad_inputs.data_ptr()); +} \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/encoders/freqencoder/src/freqencoder.h b/Geneface_main/GeneFace/docker/modules/encoders/freqencoder/src/freqencoder.h new file mode 100644 index 00000000..34f28c79 --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/encoders/freqencoder/src/freqencoder.h @@ -0,0 +1,10 @@ +# pragma once + +#include +#include + +// _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) +void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs); + +// _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) +void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs); \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/encoders/gridencoder/__init__.py b/Geneface_main/GeneFace/docker/modules/encoders/gridencoder/__init__.py new file mode 100644 index 00000000..f1476cef --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/encoders/gridencoder/__init__.py @@ -0,0 +1 @@ +from .grid import GridEncoder \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/encoders/gridencoder/backend.py b/Geneface_main/GeneFace/docker/modules/encoders/gridencoder/backend.py new file mode 100644 index 00000000..d99acb1f --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/encoders/gridencoder/backend.py @@ -0,0 +1,40 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_grid_encoder', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'gridencoder.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/encoders/gridencoder/grid.py b/Geneface_main/GeneFace/docker/modules/encoders/gridencoder/grid.py new file mode 100644 index 00000000..32b8bead --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/encoders/gridencoder/grid.py @@ -0,0 +1,185 @@ +import numpy as np + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import _gridencoder as _backend +except ImportError: + from .backend import _backend + +_gridtype_to_id = { + 'hash': 0, + 'tiled': 1, +} + +_interp_to_id = { + 'linear': 0, + 'smoothstep': 1, +} + +class _grid_encode(Function): + @staticmethod + @custom_fwd + def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False, interpolation=0): + # inputs: [B, D], float in [0, 1] + # embeddings: [sO, C], float + # offsets: [L + 1], int + # RETURN: [B, F], float + + inputs = inputs.contiguous() + + B, D = inputs.shape # batch size, coord dim + L = offsets.shape[0] - 1 # level + C = embeddings.shape[1] # embedding dim for each level + S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f + H = base_resolution # base resolution + + # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision) + # if C % 2 != 0, force float, since half for atomicAdd is very slow. + if torch.is_autocast_enabled() and C % 2 == 0: + embeddings = embeddings.to(torch.half) + + # L first, optimize cache for cuda kernel, but needs an extra permute later + outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype) + + if calc_grad_inputs: + dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype) + else: + dy_dx = None + + _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners, interpolation) + + # permute back to [B, L * C] + outputs = outputs.permute(1, 0, 2).reshape(B, L * C) + + ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) + ctx.dims = [B, D, C, L, S, H, gridtype, interpolation] + ctx.align_corners = align_corners + + return outputs + + @staticmethod + #@once_differentiable + @custom_bwd + def backward(ctx, grad): + + inputs, embeddings, offsets, dy_dx = ctx.saved_tensors + B, D, C, L, S, H, gridtype, interpolation = ctx.dims + align_corners = ctx.align_corners + + # grad: [B, L * C] --> [L, B, C] + grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() + + grad_embeddings = torch.zeros_like(embeddings) + + if dy_dx is not None: + grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype) + else: + grad_inputs = None + + _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interpolation) + + if dy_dx is not None: + grad_inputs = grad_inputs.to(inputs.dtype) + + return grad_inputs, grad_embeddings, None, None, None, None, None, None, None + + + +grid_encode = _grid_encode.apply + + +class GridEncoder(nn.Module): + def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False, interpolation='linear'): + super().__init__() + + # the finest resolution desired at the last level, if provided, overridee per_level_scale + if desired_resolution is not None: + per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1)) + + self.input_dim = input_dim # coord dims, 2 or 3 + self.num_levels = num_levels # num levels, each level multiply resolution by 2 + self.level_dim = level_dim # encode channels per level + self.per_level_scale = per_level_scale # multiply resolution by this scale at each level. + self.log2_hashmap_size = log2_hashmap_size + self.base_resolution = base_resolution + self.output_dim = num_levels * level_dim + self.gridtype = gridtype + self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash" + self.interpolation = interpolation + self.interp_id = _interp_to_id[interpolation] # "linear" or "smoothstep" + self.align_corners = align_corners + + # allocate parameters + offsets = [] + offset = 0 + self.max_params = 2 ** log2_hashmap_size + for i in range(num_levels): + resolution = int(np.ceil(base_resolution * per_level_scale ** i)) + params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number + params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible + offsets.append(offset) + offset += params_in_level + offsets.append(offset) + offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) + self.register_buffer('offsets', offsets) + + self.n_params = offsets[-1] * level_dim + + # parameters + self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) + + self.reset_parameters() + + def reset_parameters(self): + std = 1e-4 + self.embeddings.data.uniform_(-std, std) + + def __repr__(self): + return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}" + + def forward(self, inputs, bound=1): + # inputs: [..., input_dim], normalized real world positions in [-bound, bound] + # return: [..., num_levels * level_dim] + + inputs = (inputs + bound) / (2 * bound) # map to [0, 1] + + #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) + + prefix_shape = list(inputs.shape[:-1]) + inputs = inputs.view(-1, self.input_dim) + + outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners, self.interp_id) + outputs = outputs.view(prefix_shape + [self.output_dim]) + + #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) + + return outputs + + # always run in float precision! + @torch.cuda.amp.autocast(enabled=False) + def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000): + # inputs: [..., input_dim], float in [-b, b], location to calculate TV loss. + + D = self.input_dim + C = self.embeddings.shape[1] # embedding dim for each level + L = self.offsets.shape[0] - 1 # level + S = np.log2(self.per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f + H = self.base_resolution # base resolution + + if inputs is None: + # randomized in [0, 1] + inputs = torch.rand(B, self.input_dim, device=self.embeddings.device) + else: + inputs = (inputs + bound) / (2 * bound) # map to [0, 1] + inputs = inputs.view(-1, self.input_dim) + B = inputs.shape[0] + + if self.embeddings.grad is None: + raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!') + + _backend.grad_total_variation(inputs, self.embeddings, self.embeddings.grad, self.offsets, weight, B, D, C, L, S, H, self.gridtype_id, self.align_corners) \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/encoders/gridencoder/setup.py b/Geneface_main/GeneFace/docker/modules/encoders/gridencoder/setup.py new file mode 100644 index 00000000..714bf1ca --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/encoders/gridencoder/setup.py @@ -0,0 +1,50 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +setup( + name='gridencoder', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_gridencoder', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'gridencoder.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/encoders/gridencoder/src/bindings.cpp b/Geneface_main/GeneFace/docker/modules/encoders/gridencoder/src/bindings.cpp new file mode 100644 index 00000000..93dea943 --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/encoders/gridencoder/src/bindings.cpp @@ -0,0 +1,9 @@ +#include + +#include "gridencoder.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)"); + m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)"); + m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)"); +} \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/encoders/gridencoder/src/gridencoder.cu b/Geneface_main/GeneFace/docker/modules/encoders/gridencoder/src/gridencoder.cu new file mode 100644 index 00000000..22d95328 --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/encoders/gridencoder/src/gridencoder.cu @@ -0,0 +1,644 @@ +#include +#include +#include + +#include +#include + +#include +#include + +#include +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + + +// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF... program will never reach here! + __device__ inline at::Half atomicAdd(at::Half *address, at::Half val) { + // requires CUDA >= 10 and ARCH >= 70 + // this is very slow compared to float or __half2, never use it. + //return atomicAdd(reinterpret_cast<__half*>(address), val); +} + + +template +__host__ __device__ inline T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +template +__host__ __device__ inline T clamp(const T v, const T2 lo, const T2 hi) { + return min(max(v, lo), hi); +} + +template +__device__ inline T smoothstep(T val) { + return val*val*(3.0f - 2.0f * val); +} + +template +__device__ inline T smoothstep_derivative(T val) { + return 6*val*(1.0f - val); +} + + +template +__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) { + + // coherent type of hashing + constexpr uint32_t primes[7] = { 1u, 2654435761u, 805459861u, 3674653429u, 2097192037u, 1434869437u, 2165219737u }; + + uint32_t result = 0; + #pragma unroll + for (uint32_t i = 0; i < D; ++i) { + result ^= pos_grid[i] * primes[i]; + } + + return result; +} + + +template +__device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) { + uint32_t stride = 1; + uint32_t index = 0; + + #pragma unroll + for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) { + index += pos_grid[d] * stride; + stride *= align_corners ? resolution: (resolution + 1); + } + + // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97. + // gridtype: 0 == hash, 1 == tiled + if (gridtype == 0 && stride > hashmap_size) { + index = fast_hash(pos_grid); + } + + return (index % hashmap_size) * C + ch; +} + + +template +__global__ void kernel_grid( + const float * __restrict__ inputs, + const scalar_t * __restrict__ grid, + const int * __restrict__ offsets, + scalar_t * __restrict__ outputs, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + scalar_t * __restrict__ dy_dx, + const uint32_t gridtype, + const bool align_corners, + const uint32_t interp +) { + const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; + + if (b >= B) return; + + const uint32_t level = blockIdx.y; + + // locate + grid += (uint32_t)offsets[level] * C; + inputs += b * D; + outputs += level * B * C + b * C; + + // check input range (should be in [0, 1]) + bool flag_oob = false; + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + flag_oob = true; + } + } + // if input out of bound, just set output to 0 + if (flag_oob) { + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + outputs[ch] = 0; + } + if (dy_dx) { + dy_dx += b * D * L * C + level * D * C; // B L D C + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + dy_dx[d * C + ch] = 0; + } + } + } + return; + } + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const float scale = exp2f(level * S) * H - 1.0f; + const uint32_t resolution = (uint32_t)ceil(scale) + 1; + + // calculate coordinate (always use float for precision!) + float pos[D]; + float pos_deriv[D]; // linear deriv is default to 1 + uint32_t pos_grid[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); + pos_grid[d] = floorf(pos[d]); + pos[d] -= (float)pos_grid[d]; + // smoothstep instead of linear + if (interp == 1) { + pos_deriv[d] = smoothstep_derivative(pos[d]); + pos[d] = smoothstep(pos[d]); + } else { + pos_deriv[d] = 1.0f; + } + } + + //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]); + + // interpolate + scalar_t results[C] = {0}; // temp results in register + + #pragma unroll + for (uint32_t idx = 0; idx < (1 << D); idx++) { + float w = 1; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if ((idx & (1 << d)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = pos_grid[d] + 1; + } + } + + uint32_t index = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); + + // writing to register (fast) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + results[ch] += w * grid[index + ch]; + } + + //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]); + } + + // writing to global memory (slow) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + outputs[ch] = results[ch]; + } + + // prepare dy_dx + // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9 + if (dy_dx) { + + dy_dx += b * D * L * C + level * D * C; // B L D C + + #pragma unroll + for (uint32_t gd = 0; gd < D; gd++) { + + scalar_t results_grad[C] = {0}; + + #pragma unroll + for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) { + float w = scale; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t nd = 0; nd < D - 1; nd++) { + const uint32_t d = (nd >= gd) ? (nd + 1) : nd; + + if ((idx & (1 << nd)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = pos_grid[d] + 1; + } + } + + pos_grid_local[gd] = pos_grid[gd]; + uint32_t index_left = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); + pos_grid_local[gd] = pos_grid[gd] + 1; + uint32_t index_right = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]) * pos_deriv[gd]; + } + } + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + dy_dx[gd * C + ch] = results_grad[ch]; + } + } + } +} + + +template +__global__ void kernel_grid_backward( + const scalar_t * __restrict__ grad, + const float * __restrict__ inputs, + const scalar_t * __restrict__ grid, + const int * __restrict__ offsets, + scalar_t * __restrict__ grad_grid, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + const uint32_t gridtype, + const bool align_corners, + const uint32_t interp +) { + const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C; + if (b >= B) return; + + const uint32_t level = blockIdx.y; + const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C; + + // locate + grad_grid += offsets[level] * C; + inputs += b * D; + grad += level * B * C + b * C + ch; // L, B, C + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const float scale = exp2f(level * S) * H - 1.0f; + const uint32_t resolution = (uint32_t)ceil(scale) + 1; + + // check input range (should be in [0, 1]) + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + return; // grad is init as 0, so we simply return. + } + } + + // calculate coordinate + float pos[D]; + uint32_t pos_grid[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); + pos_grid[d] = floorf(pos[d]); + pos[d] -= (float)pos_grid[d]; + // smoothstep instead of linear + if (interp == 1) { + pos[d] = smoothstep(pos[d]); + } + } + + scalar_t grad_cur[N_C] = {0}; // fetch to register + #pragma unroll + for (uint32_t c = 0; c < N_C; c++) { + grad_cur[c] = grad[c]; + } + + // interpolate + #pragma unroll + for (uint32_t idx = 0; idx < (1 << D); idx++) { + float w = 1; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if ((idx & (1 << d)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = pos_grid[d] + 1; + } + } + + uint32_t index = get_grid_index(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local); + + // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0 + // TODO: use float which is better than __half, if N_C % 2 != 0 + if (std::is_same::value && N_C % 2 == 0) { + #pragma unroll + for (uint32_t c = 0; c < N_C; c += 2) { + // process two __half at once (by interpreting as a __half2) + __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])}; + atomicAdd((__half2*)&grad_grid[index + c], v); + } + // float, or __half when N_C % 2 != 0 (which means C == 1) + } else { + #pragma unroll + for (uint32_t c = 0; c < N_C; c++) { + atomicAdd(&grad_grid[index + c], w * grad_cur[c]); + } + } + } +} + + +template +__global__ void kernel_input_backward( + const scalar_t * __restrict__ grad, + const scalar_t * __restrict__ dy_dx, + scalar_t * __restrict__ grad_inputs, + uint32_t B, uint32_t L +) { + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + if (t >= B * D) return; + + const uint32_t b = t / D; + const uint32_t d = t - b * D; + + dy_dx += b * L * D * C; + + scalar_t result = 0; + + # pragma unroll + for (int l = 0; l < L; l++) { + # pragma unroll + for (int ch = 0; ch < C; ch++) { + result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch]; + } + } + + grad_inputs[t] = result; +} + + +template +void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + static constexpr uint32_t N_THREAD = 512; + const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 }; + switch (C) { + case 1: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 2: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 4: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 8: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.) +// H: base resolution +// dy_dx: [B, L * D * C] +template +void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + switch (D) { + case 2: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 3: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 4: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 5: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + +template +void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + static constexpr uint32_t N_THREAD = 256; + const uint32_t N_C = std::min(2u, C); // n_features_per_thread + const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 }; + switch (C) { + case 1: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 2: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 4: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 8: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + +// grad: [L, B, C], float +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// grad_embeddings: [sO, C] +// H: base resolution +template +void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + switch (D) { + case 2: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + case 3: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + case 4: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + case 5: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + + +void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + CHECK_CUDA(inputs); + CHECK_CUDA(embeddings); + CHECK_CUDA(offsets); + CHECK_CUDA(outputs); + // CHECK_CUDA(dy_dx); + + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(embeddings); + CHECK_CONTIGUOUS(offsets); + CHECK_CONTIGUOUS(outputs); + // CHECK_CONTIGUOUS(dy_dx); + + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(embeddings); + CHECK_IS_INT(offsets); + CHECK_IS_FLOATING(outputs); + // CHECK_IS_FLOATING(dy_dx); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + embeddings.scalar_type(), "grid_encode_forward", ([&] { + grid_encode_forward_cuda(inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), outputs.data_ptr(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, gridtype, align_corners, interp); + })); +} + +void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + CHECK_CUDA(grad); + CHECK_CUDA(inputs); + CHECK_CUDA(embeddings); + CHECK_CUDA(offsets); + CHECK_CUDA(grad_embeddings); + // CHECK_CUDA(dy_dx); + // CHECK_CUDA(grad_inputs); + + CHECK_CONTIGUOUS(grad); + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(embeddings); + CHECK_CONTIGUOUS(offsets); + CHECK_CONTIGUOUS(grad_embeddings); + // CHECK_CONTIGUOUS(dy_dx); + // CHECK_CONTIGUOUS(grad_inputs); + + CHECK_IS_FLOATING(grad); + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(embeddings); + CHECK_IS_INT(offsets); + CHECK_IS_FLOATING(grad_embeddings); + // CHECK_IS_FLOATING(dy_dx); + // CHECK_IS_FLOATING(grad_inputs); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "grid_encode_backward", ([&] { + grid_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), grad_embeddings.data_ptr(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr() : nullptr, gridtype, align_corners, interp); + })); + +} + + +template +__global__ void kernel_grad_tv( + const scalar_t * __restrict__ inputs, + const scalar_t * __restrict__ grid, + scalar_t * __restrict__ grad, + const int * __restrict__ offsets, + const float weight, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + const uint32_t gridtype, + const bool align_corners +) { + const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; + + if (b >= B) return; + + const uint32_t level = blockIdx.y; + + // locate + inputs += b * D; + grid += (uint32_t)offsets[level] * C; + grad += (uint32_t)offsets[level] * C; + + // check input range (should be in [0, 1]) + bool flag_oob = false; + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + flag_oob = true; + } + } + + // if input out of bound, do nothing + if (flag_oob) return; + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const float scale = exp2f(level * S) * H - 1.0f; + const uint32_t resolution = (uint32_t)ceil(scale) + 1; + + // calculate coordinate + float pos[D]; + uint32_t pos_grid[D]; // [0, resolution] + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); + pos_grid[d] = floorf(pos[d]); + // pos[d] -= (float)pos_grid[d]; // not used + } + + //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]); + + // total variation on pos_grid + scalar_t results[C] = {0}; // temp results in register + scalar_t idelta[C] = {0}; + + uint32_t index = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); + + scalar_t w = weight / (2 * D); + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + + uint32_t cur_d = pos_grid[d]; + scalar_t grad_val; + + // right side + if (cur_d < resolution) { + pos_grid[d] = cur_d + 1; + uint32_t index_right = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + // results[ch] += w * clamp(grid[index + ch] - grid[index_right + ch], -1.0f, 1.0f); + grad_val = (grid[index + ch] - grid[index_right + ch]); + results[ch] += grad_val; + idelta[ch] += grad_val * grad_val; + } + } + + // left side + if (cur_d > 0) { + pos_grid[d] = cur_d - 1; + uint32_t index_left = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + // results[ch] += w * clamp(grid[index + ch] - grid[index_left + ch], -1.0f, 1.0f); + grad_val = (grid[index + ch] - grid[index_left + ch]); + results[ch] += grad_val; + idelta[ch] += grad_val * grad_val; + } + } + + // reset + pos_grid[d] = cur_d; + } + + // writing to global memory (slow) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + // index may collide, so use atomic! + atomicAdd(&grad[index + ch], w * results[ch] * rsqrtf(idelta[ch] + 1e-9f)); + } + +} + + +template +void kernel_grad_tv_wrapper(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { + static constexpr uint32_t N_THREAD = 512; + const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 }; + switch (C) { + case 1: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + case 2: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + case 4: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + case 8: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + +template +void grad_total_variation_cuda(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { + switch (D) { + case 2: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + case 3: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + case 4: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + case 5: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + +void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + embeddings.scalar_type(), "grad_total_variation", ([&] { + grad_total_variation_cuda(inputs.data_ptr(), embeddings.data_ptr(), grad.data_ptr(), offsets.data_ptr(), weight, B, D, C, L, S, H, gridtype, align_corners); + })); +} \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/encoders/gridencoder/src/gridencoder.h b/Geneface_main/GeneFace/docker/modules/encoders/gridencoder/src/gridencoder.h new file mode 100644 index 00000000..1b385755 --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/encoders/gridencoder/src/gridencoder.h @@ -0,0 +1,17 @@ +#ifndef _HASH_ENCODE_H +#define _HASH_ENCODE_H + +#include +#include + +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// outputs: [B, L * C], float +// H: base resolution +void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp); +void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp); + +void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners); + +#endif \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/encoders/shencoder/__init__.py b/Geneface_main/GeneFace/docker/modules/encoders/shencoder/__init__.py new file mode 100644 index 00000000..2b55c96e --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/encoders/shencoder/__init__.py @@ -0,0 +1 @@ +from .sphere_harmonics import SHEncoder \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/encoders/shencoder/backend.py b/Geneface_main/GeneFace/docker/modules/encoders/shencoder/backend.py new file mode 100644 index 00000000..cc08a3e9 --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/encoders/shencoder/backend.py @@ -0,0 +1,40 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_sh_encoder', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'shencoder.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/encoders/shencoder/setup.py b/Geneface_main/GeneFace/docker/modules/encoders/shencoder/setup.py new file mode 100644 index 00000000..342a6015 --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/encoders/shencoder/setup.py @@ -0,0 +1,50 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +setup( + name='shencoder', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_shencoder', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'shencoder.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/encoders/shencoder/sphere_harmonics.py b/Geneface_main/GeneFace/docker/modules/encoders/shencoder/sphere_harmonics.py new file mode 100644 index 00000000..7bab24e6 --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/encoders/shencoder/sphere_harmonics.py @@ -0,0 +1,87 @@ +import numpy as np + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import _shencoder as _backend +except ImportError: + from .backend import _backend + +class _sh_encoder(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision + def forward(ctx, inputs, degree, calc_grad_inputs=False): + # inputs: [B, input_dim], float in [-1, 1] + # RETURN: [B, F], float + + inputs = inputs.contiguous() + B, input_dim = inputs.shape # batch size, coord dim + output_dim = degree ** 2 + + outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) + + if calc_grad_inputs: + dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device) + else: + dy_dx = None + + _backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx) + + ctx.save_for_backward(inputs, dy_dx) + ctx.dims = [B, input_dim, degree] + + return outputs + + @staticmethod + #@once_differentiable + @custom_bwd + def backward(ctx, grad): + # grad: [B, C * C] + + inputs, dy_dx = ctx.saved_tensors + + if dy_dx is not None: + grad = grad.contiguous() + B, input_dim, degree = ctx.dims + grad_inputs = torch.zeros_like(inputs) + _backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs) + return grad_inputs, None, None + else: + return None, None, None + + + +sh_encode = _sh_encoder.apply + + +class SHEncoder(nn.Module): + def __init__(self, input_dim=3, degree=4): + super().__init__() + + self.input_dim = input_dim # coord dims, must be 3 + self.degree = degree # 0 ~ 4 + self.output_dim = degree ** 2 + + assert self.input_dim == 3, "SH encoder only support input dim == 3" + assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]" + + def __repr__(self): + return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}" + + def forward(self, inputs, size=1): + # inputs: [..., input_dim], normalized real world positions in [-size, size] + # return: [..., degree^2] + + inputs = inputs / size # [-1, 1] + + prefix_shape = list(inputs.shape[:-1]) + inputs = inputs.reshape(-1, self.input_dim) + + outputs = sh_encode(inputs, self.degree, inputs.requires_grad) + outputs = outputs.reshape(prefix_shape + [self.output_dim]) + + return outputs \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/encoders/shencoder/src/bindings.cpp b/Geneface_main/GeneFace/docker/modules/encoders/shencoder/src/bindings.cpp new file mode 100644 index 00000000..595b5b3a --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/encoders/shencoder/src/bindings.cpp @@ -0,0 +1,8 @@ +#include + +#include "shencoder.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("sh_encode_forward", &sh_encode_forward, "SH encode forward (CUDA)"); + m.def("sh_encode_backward", &sh_encode_backward, "SH encode backward (CUDA)"); +} \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/encoders/shencoder/src/shencoder.cu b/Geneface_main/GeneFace/docker/modules/encoders/shencoder/src/shencoder.cu new file mode 100644 index 00000000..a92e4ab7 --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/encoders/shencoder/src/shencoder.cu @@ -0,0 +1,439 @@ +#include + +#include +#include +#include + +#include +#include + +#include +#include + +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + + +template +__host__ __device__ T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +template +__global__ void kernel_sh( + const scalar_t * __restrict__ inputs, + scalar_t * outputs, + uint32_t B, uint32_t D, uint32_t C, + scalar_t * dy_dx +) { + const uint32_t b = threadIdx.x + blockIdx.x * blockDim.x; + if (b >= B) return; + + const uint32_t C2 = C * C; + + // locate + inputs += b * D; + outputs += b * C2; + + scalar_t x = inputs[0], y = inputs[1], z = inputs[2]; + + scalar_t xy=x*y, xz=x*z, yz=y*z, x2=x*x, y2=y*y, z2=z*z, xyz=xy*z; + scalar_t x4=x2*x2, y4=y2*y2, z4=z2*z2; + scalar_t x6=x4*x2, y6=y4*y2, z6=z4*z2; + + auto write_sh = [&]() { + outputs[0] = 0.28209479177387814f ; // 1/(2*sqrt(pi)) + if (C <= 1) { return; } + outputs[1] = -0.48860251190291987f*y ; // -sqrt(3)*y/(2*sqrt(pi)) + outputs[2] = 0.48860251190291987f*z ; // sqrt(3)*z/(2*sqrt(pi)) + outputs[3] = -0.48860251190291987f*x ; // -sqrt(3)*x/(2*sqrt(pi)) + if (C <= 2) { return; } + outputs[4] = 1.0925484305920792f*xy ; // sqrt(15)*xy/(2*sqrt(pi)) + outputs[5] = -1.0925484305920792f*yz ; // -sqrt(15)*yz/(2*sqrt(pi)) + outputs[6] = 0.94617469575755997f*z2 - 0.31539156525251999f ; // sqrt(5)*(3*z2 - 1)/(4*sqrt(pi)) + outputs[7] = -1.0925484305920792f*xz ; // -sqrt(15)*xz/(2*sqrt(pi)) + outputs[8] = 0.54627421529603959f*x2 - 0.54627421529603959f*y2 ; // sqrt(15)*(x2 - y2)/(4*sqrt(pi)) + if (C <= 3) { return; } + outputs[9] = 0.59004358992664352f*y*(-3.0f*x2 + y2) ; // sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi)) + outputs[10] = 2.8906114426405538f*xy*z ; // sqrt(105)*xy*z/(2*sqrt(pi)) + outputs[11] = 0.45704579946446572f*y*(1.0f - 5.0f*z2) ; // sqrt(42)*y*(1 - 5*z2)/(8*sqrt(pi)) + outputs[12] = 0.3731763325901154f*z*(5.0f*z2 - 3.0f) ; // sqrt(7)*z*(5*z2 - 3)/(4*sqrt(pi)) + outputs[13] = 0.45704579946446572f*x*(1.0f - 5.0f*z2) ; // sqrt(42)*x*(1 - 5*z2)/(8*sqrt(pi)) + outputs[14] = 1.4453057213202769f*z*(x2 - y2) ; // sqrt(105)*z*(x2 - y2)/(4*sqrt(pi)) + outputs[15] = 0.59004358992664352f*x*(-x2 + 3.0f*y2) ; // sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi)) + if (C <= 4) { return; } + outputs[16] = 2.5033429417967046f*xy*(x2 - y2) ; // 3*sqrt(35)*xy*(x2 - y2)/(4*sqrt(pi)) + outputs[17] = 1.7701307697799304f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(70)*yz*(-3*x2 + y2)/(8*sqrt(pi)) + outputs[18] = 0.94617469575756008f*xy*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*xy*(7*z2 - 1)/(4*sqrt(pi)) + outputs[19] = 0.66904654355728921f*yz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*yz*(3 - 7*z2)/(8*sqrt(pi)) + outputs[20] = -3.1735664074561294f*z2 + 3.7024941420321507f*z4 + 0.31735664074561293f ; // 3*(-30*z2 + 35*z4 + 3)/(16*sqrt(pi)) + outputs[21] = 0.66904654355728921f*xz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*xz*(3 - 7*z2)/(8*sqrt(pi)) + outputs[22] = 0.47308734787878004f*(x2 - y2)*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*(x2 - y2)*(7*z2 - 1)/(8*sqrt(pi)) + outputs[23] = 1.7701307697799304f*xz*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*xz*(-x2 + 3*y2)/(8*sqrt(pi)) + outputs[24] = -3.7550144126950569f*x2*y2 + 0.62583573544917614f*x4 + 0.62583573544917614f*y4 ; // 3*sqrt(35)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + if (C <= 5) { return; } + outputs[25] = 0.65638205684017015f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(154)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + outputs[26] = 8.3026492595241645f*xy*z*(x2 - y2) ; // 3*sqrt(385)*xy*z*(x2 - y2)/(4*sqrt(pi)) + outputs[27] = -0.48923829943525038f*y*(3.0f*x2 - y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*y*(3*x2 - y2)*(9*z2 - 1)/(32*sqrt(pi)) + outputs[28] = 4.7935367849733241f*xy*z*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xy*z*(3*z2 - 1)/(4*sqrt(pi)) + outputs[29] = 0.45294665119569694f*y*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*y*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + outputs[30] = 0.1169503224534236f*z*(-70.0f*z2 + 63.0f*z4 + 15.0f) ; // sqrt(11)*z*(-70*z2 + 63*z4 + 15)/(16*sqrt(pi)) + outputs[31] = 0.45294665119569694f*x*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*x*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + outputs[32] = 2.3967683924866621f*z*(x2 - y2)*(3.0f*z2 - 1.0f) ; // sqrt(1155)*z*(x2 - y2)*(3*z2 - 1)/(8*sqrt(pi)) + outputs[33] = -0.48923829943525038f*x*(x2 - 3.0f*y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*x*(x2 - 3*y2)*(9*z2 - 1)/(32*sqrt(pi)) + outputs[34] = 2.0756623148810411f*z*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + outputs[35] = 0.65638205684017015f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(154)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + if (C <= 6) { return; } + outputs[36] = 1.3663682103838286f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // sqrt(6006)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + outputs[37] = 2.3666191622317521f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + outputs[38] = 2.0182596029148963f*xy*(x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*xy*(x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) + outputs[39] = -0.92120525951492349f*yz*(3.0f*x2 - y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*yz*(3*x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) + outputs[40] = 0.92120525951492349f*xy*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*xy*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + outputs[41] = 0.58262136251873131f*yz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*yz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + outputs[42] = 6.6747662381009842f*z2 - 20.024298714302954f*z4 + 14.684485723822165f*z6 - 0.31784601133814211f ; // sqrt(13)*(105*z2 - 315*z4 + 231*z6 - 5)/(32*sqrt(pi)) + outputs[43] = 0.58262136251873131f*xz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*xz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + outputs[44] = 0.46060262975746175f*(x2 - y2)*(11.0f*z2*(3.0f*z2 - 1.0f) - 7.0f*z2 + 1.0f) ; // sqrt(2730)*(x2 - y2)*(11*z2*(3*z2 - 1) - 7*z2 + 1)/(64*sqrt(pi)) + outputs[45] = -0.92120525951492349f*xz*(x2 - 3.0f*y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*xz*(x2 - 3*y2)*(11*z2 - 3)/(32*sqrt(pi)) + outputs[46] = 0.50456490072872406f*(11.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(91)*(11*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) + outputs[47] = 2.3666191622317521f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + outputs[48] = 10.247761577878714f*x2*y4 - 10.247761577878714f*x4*y2 + 0.6831841051919143f*x6 - 0.6831841051919143f*y6 ; // sqrt(6006)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) + if (C <= 7) { return; } + outputs[49] = 0.70716273252459627f*y*(-21.0f*x2*y4 + 35.0f*x4*y2 - 7.0f*x6 + y6) ; // 3*sqrt(715)*y*(-21*x2*y4 + 35*x4*y2 - 7*x6 + y6)/(64*sqrt(pi)) + outputs[50] = 5.2919213236038001f*xy*z*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*z*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + outputs[51] = -0.51891557872026028f*y*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // -3*sqrt(385)*y*(13*z2 - 1)*(-10*x2*y2 + 5*x4 + y4)/(64*sqrt(pi)) + outputs[52] = 4.1513246297620823f*xy*z*(x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xy*z*(x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) + outputs[53] = -0.15645893386229404f*y*(3.0f*x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*y*(3*x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) + outputs[54] = 0.44253269244498261f*xy*z*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xy*z*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + outputs[55] = 0.090331607582517306f*y*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*y*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + outputs[56] = 0.068284276912004949f*z*(315.0f*z2 - 693.0f*z4 + 429.0f*z6 - 35.0f) ; // sqrt(15)*z*(315*z2 - 693*z4 + 429*z6 - 35)/(32*sqrt(pi)) + outputs[57] = 0.090331607582517306f*x*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*x*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + outputs[58] = 0.07375544874083044f*z*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*z*(x2 - y2)*(143*z2*(3*z2 - 1) - 187*z2 + 45)/(64*sqrt(pi)) + outputs[59] = -0.15645893386229404f*x*(x2 - 3.0f*y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*x*(x2 - 3*y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) + outputs[60] = 1.0378311574405206f*z*(13.0f*z2 - 3.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(13*z2 - 3)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) + outputs[61] = -0.51891557872026028f*x*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // -3*sqrt(385)*x*(13*z2 - 1)*(-10*x2*y2 + x4 + 5*y4)/(64*sqrt(pi)) + outputs[62] = 2.6459606618019f*z*(15.0f*x2*y4 - 15.0f*x4*y2 + x6 - y6) ; // 3*sqrt(10010)*z*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) + outputs[63] = 0.70716273252459627f*x*(-35.0f*x2*y4 + 21.0f*x4*y2 - x6 + 7.0f*y6) ; // 3*sqrt(715)*x*(-35*x2*y4 + 21*x4*y2 - x6 + 7*y6)/(64*sqrt(pi)) + }; + + write_sh(); + + if (dy_dx) { + scalar_t *dx = dy_dx + b * D * C2; + scalar_t *dy = dx + C2; + scalar_t *dz = dy + C2; + + auto write_sh_dx = [&]() { + dx[0] = 0.0f ; // 0 + if (C <= 1) { return; } + dx[1] = 0.0f ; // 0 + dx[2] = 0.0f ; // 0 + dx[3] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi)) + if (C <= 2) { return; } + dx[4] = 1.0925484305920792f*y ; // sqrt(15)*y/(2*sqrt(pi)) + dx[5] = 0.0f ; // 0 + dx[6] = 0.0f ; // 0 + dx[7] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi)) + dx[8] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi)) + if (C <= 3) { return; } + dx[9] = -3.5402615395598609f*xy ; // -3*sqrt(70)*xy/(4*sqrt(pi)) + dx[10] = 2.8906114426405538f*yz ; // sqrt(105)*yz/(2*sqrt(pi)) + dx[11] = 0.0f ; // 0 + dx[12] = 0.0f ; // 0 + dx[13] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi)) + dx[14] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi)) + dx[15] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi)) + if (C <= 4) { return; } + dx[16] = 2.5033429417967046f*y*(3.0f*x2 - y2) ; // 3*sqrt(35)*y*(3*x2 - y2)/(4*sqrt(pi)) + dx[17] = -10.620784618679583f*xy*z ; // -9*sqrt(70)*xy*z/(4*sqrt(pi)) + dx[18] = 0.94617469575756008f*y*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*y*(7*z2 - 1)/(4*sqrt(pi)) + dx[19] = 0.0f ; // 0 + dx[20] = 0.0f ; // 0 + dx[21] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi)) + dx[22] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi)) + dx[23] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi)) + dx[24] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi)) + if (C <= 5) { return; } + dx[25] = 13.127641136803401f*xy*(-x2 + y2) ; // 15*sqrt(154)*xy*(-x2 + y2)/(8*sqrt(pi)) + dx[26] = 8.3026492595241645f*yz*(3.0f*x2 - y2) ; // 3*sqrt(385)*yz*(3*x2 - y2)/(4*sqrt(pi)) + dx[27] = 2.9354297966115022f*xy*(1.0f - 9.0f*z2) ; // 3*sqrt(770)*xy*(1 - 9*z2)/(16*sqrt(pi)) + dx[28] = 4.7935367849733241f*yz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*yz*(3*z2 - 1)/(4*sqrt(pi)) + dx[29] = 0.0f ; // 0 + dx[30] = 0.0f ; // 0 + dx[31] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + dx[32] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi)) + dx[33] = -13.209434084751759f*x2*z2 + 1.4677148983057511f*x2 + 13.209434084751759f*y2*z2 - 1.4677148983057511f*y2 ; // 3*sqrt(770)*(-9*x2*z2 + x2 + 9*y2*z2 - y2)/(32*sqrt(pi)) + dx[34] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi)) + dx[35] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) + if (C <= 6) { return; } + dx[36] = 4.0991046311514854f*y*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 3*sqrt(6006)*y*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi)) + dx[37] = 47.332383244635047f*xy*z*(-x2 + y2) ; // 15*sqrt(2002)*xy*z*(-x2 + y2)/(8*sqrt(pi)) + dx[38] = 2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) + dx[39] = 5.5272315570895412f*xy*z*(3.0f - 11.0f*z2) ; // 3*sqrt(2730)*xy*z*(3 - 11*z2)/(16*sqrt(pi)) + dx[40] = 0.92120525951492349f*y*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*y*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + dx[41] = 0.0f ; // 0 + dx[42] = 0.0f ; // 0 + dx[43] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + dx[44] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + dx[45] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) + dx[46] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi)) + dx[47] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) + dx[48] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) + if (C <= 7) { return; } + dx[49] = 9.9002782553443485f*xy*(10.0f*x2*y2 - 3.0f*x4 - 3.0f*y4) ; // 21*sqrt(715)*xy*(10*x2*y2 - 3*x4 - 3*y4)/(32*sqrt(pi)) + dx[50] = 15.875763970811402f*yz*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 9*sqrt(10010)*yz*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi)) + dx[51] = -10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // -15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi)) + dx[52] = 4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) + dx[53] = 0.93875360317376422f*xy*(66.0f*z2 - 143.0f*z4 - 3.0f) ; // 9*sqrt(35)*xy*(66*z2 - 143*z4 - 3)/(32*sqrt(pi)) + dx[54] = 0.44253269244498261f*yz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*yz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + dx[55] = 0.0f ; // 0 + dx[56] = 0.0f ; // 0 + dx[57] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + dx[58] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + dx[59] = 30.97886890473422f*x2*z2 - 67.120882626924143f*x2*z4 - 1.4081304047606462f*x2 - 30.97886890473422f*y2*z2 + 67.120882626924143f*y2*z4 + 1.4081304047606462f*y2 ; // 9*sqrt(35)*(66*x2*z2 - 143*x2*z4 - 3*x2 - 66*y2*z2 + 143*y2*z4 + 3*y2)/(64*sqrt(pi)) + dx[60] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi)) + dx[61] = -0.51891557872026028f*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 4.0f*x2*(x2 - 5.0f*y2) + x4 + 5.0f*y4) ; // -3*sqrt(385)*(13*z2 - 1)*(-10*x2*y2 + 4*x2*(x2 - 5*y2) + x4 + 5*y4)/(64*sqrt(pi)) + dx[62] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) + dx[63] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi)) + }; + + auto write_sh_dy = [&]() { + dy[0] = 0.0f ; // 0 + if (C <= 1) { return; } + dy[1] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi)) + dy[2] = 0.0f ; // 0 + dy[3] = 0.0f ; // 0 + if (C <= 2) { return; } + dy[4] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi)) + dy[5] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi)) + dy[6] = 0.0f ; // 0 + dy[7] = 0.0f ; // 0 + dy[8] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi)) + if (C <= 3) { return; } + dy[9] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi)) + dy[10] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi)) + dy[11] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi)) + dy[12] = 0.0f ; // 0 + dy[13] = 0.0f ; // 0 + dy[14] = -2.8906114426405538f*yz ; // -sqrt(105)*yz/(2*sqrt(pi)) + dy[15] = 3.5402615395598609f*xy ; // 3*sqrt(70)*xy/(4*sqrt(pi)) + if (C <= 4) { return; } + dy[16] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi)) + dy[17] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi)) + dy[18] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi)) + dy[19] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi)) + dy[20] = 0.0f ; // 0 + dy[21] = 0.0f ; // 0 + dy[22] = 0.94617469575756008f*y*(1.0f - 7.0f*z2) ; // 3*sqrt(5)*y*(1 - 7*z2)/(4*sqrt(pi)) + dy[23] = 10.620784618679583f*xy*z ; // 9*sqrt(70)*xy*z/(4*sqrt(pi)) + dy[24] = 2.5033429417967046f*y*(-3.0f*x2 + y2) ; // 3*sqrt(35)*y*(-3*x2 + y2)/(4*sqrt(pi)) + if (C <= 5) { return; } + dy[25] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) + dy[26] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi)) + dy[27] = -1.4677148983057511f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // -3*sqrt(770)*(x2 - y2)*(9*z2 - 1)/(32*sqrt(pi)) + dy[28] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi)) + dy[29] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + dy[30] = 0.0f ; // 0 + dy[31] = 0.0f ; // 0 + dy[32] = 4.7935367849733241f*yz*(1.0f - 3.0f*z2) ; // sqrt(1155)*yz*(1 - 3*z2)/(4*sqrt(pi)) + dy[33] = 2.9354297966115022f*xy*(9.0f*z2 - 1.0f) ; // 3*sqrt(770)*xy*(9*z2 - 1)/(16*sqrt(pi)) + dy[34] = 8.3026492595241645f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(385)*yz*(-3*x2 + y2)/(4*sqrt(pi)) + dy[35] = 13.127641136803401f*xy*(x2 - y2) ; // 15*sqrt(154)*xy*(x2 - y2)/(8*sqrt(pi)) + if (C <= 6) { return; } + dy[36] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) + dy[37] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) + dy[38] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi)) + dy[39] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) + dy[40] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + dy[41] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + dy[42] = 0.0f ; // 0 + dy[43] = 0.0f ; // 0 + dy[44] = 0.92120525951492349f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // sqrt(2730)*y*(18*z2 - 33*z4 - 1)/(32*sqrt(pi)) + dy[45] = 5.5272315570895412f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(16*sqrt(pi)) + dy[46] = -2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) + dy[47] = 47.332383244635047f*xy*z*(x2 - y2) ; // 15*sqrt(2002)*xy*z*(x2 - y2)/(8*sqrt(pi)) + dy[48] = 4.0991046311514854f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(6006)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + if (C <= 7) { return; } + dy[49] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi)) + dy[50] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) + dy[51] = 0.51891557872026028f*(13.0f*z2 - 1.0f)*(10.0f*x2*y2 - 5.0f*x4 + 4.0f*y2*(5.0f*x2 - y2) - y4) ; // 3*sqrt(385)*(13*z2 - 1)*(10*x2*y2 - 5*x4 + 4*y2*(5*x2 - y2) - y4)/(64*sqrt(pi)) + dy[52] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi)) + dy[53] = -0.46937680158688211f*(x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -9*sqrt(35)*(x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) + dy[54] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + dy[55] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + dy[56] = 0.0f ; // 0 + dy[57] = 0.0f ; // 0 + dy[58] = 0.44253269244498261f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 3*sqrt(70)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi)) + dy[59] = 0.93875360317376422f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 9*sqrt(35)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi)) + dy[60] = -4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) + dy[61] = 10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi)) + dy[62] = 15.875763970811402f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 9*sqrt(10010)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + dy[63] = 9.9002782553443485f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 21*sqrt(715)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + }; + + auto write_sh_dz = [&]() { + dz[0] = 0.0f ; // 0 + if (C <= 1) { return; } + dz[1] = 0.0f ; // 0 + dz[2] = 0.48860251190291992f ; // sqrt(3)/(2*sqrt(pi)) + dz[3] = 0.0f ; // 0 + if (C <= 2) { return; } + dz[4] = 0.0f ; // 0 + dz[5] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi)) + dz[6] = 1.8923493915151202f*z ; // 3*sqrt(5)*z/(2*sqrt(pi)) + dz[7] = -1.0925484305920792f*x ; // -sqrt(15)*x/(2*sqrt(pi)) + dz[8] = 0.0f ; // 0 + if (C <= 3) { return; } + dz[9] = 0.0f ; // 0 + dz[10] = 2.8906114426405538f*xy ; // sqrt(105)*xy/(2*sqrt(pi)) + dz[11] = -4.5704579946446566f*yz ; // -5*sqrt(42)*yz/(4*sqrt(pi)) + dz[12] = 5.597644988851731f*z2 - 1.1195289977703462f ; // 3*sqrt(7)*(5*z2 - 1)/(4*sqrt(pi)) + dz[13] = -4.5704579946446566f*xz ; // -5*sqrt(42)*xz/(4*sqrt(pi)) + dz[14] = 1.4453057213202769f*x2 - 1.4453057213202769f*y2 ; // sqrt(105)*(x2 - y2)/(4*sqrt(pi)) + dz[15] = 0.0f ; // 0 + if (C <= 4) { return; } + dz[16] = 0.0f ; // 0 + dz[17] = 1.7701307697799304f*y*(-3.0f*x2 + y2) ; // 3*sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi)) + dz[18] = 13.246445740605839f*xy*z ; // 21*sqrt(5)*xy*z/(2*sqrt(pi)) + dz[19] = 2.0071396306718676f*y*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*y*(1 - 7*z2)/(8*sqrt(pi)) + dz[20] = 14.809976568128603f*pow(z, 3) - 6.3471328149122579f*z ; // (105*z**3 - 45*z)/(4*sqrt(pi)) + dz[21] = 2.0071396306718676f*x*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*x*(1 - 7*z2)/(8*sqrt(pi)) + dz[22] = 6.6232228703029197f*z*(x2 - y2) ; // 21*sqrt(5)*z*(x2 - y2)/(4*sqrt(pi)) + dz[23] = 1.7701307697799304f*x*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi)) + dz[24] = 0.0f ; // 0 + if (C <= 5) { return; } + dz[25] = 0.0f ; // 0 + dz[26] = 8.3026492595241645f*xy*(x2 - y2) ; // 3*sqrt(385)*xy*(x2 - y2)/(4*sqrt(pi)) + dz[27] = 8.8062893898345074f*yz*(-3.0f*x2 + y2) ; // 9*sqrt(770)*yz*(-3*x2 + y2)/(16*sqrt(pi)) + dz[28] = 4.7935367849733241f*xy*(9.0f*z2 - 1.0f) ; // sqrt(1155)*xy*(9*z2 - 1)/(4*sqrt(pi)) + dz[29] = 12.682506233479513f*yz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*yz*(1 - 3*z2)/(4*sqrt(pi)) + dz[30] = -24.559567715218954f*z2 + 36.839351572828434f*z4 + 1.754254836801354f ; // 15*sqrt(11)*(-14*z2 + 21*z4 + 1)/(16*sqrt(pi)) + dz[31] = 12.682506233479513f*xz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*xz*(1 - 3*z2)/(4*sqrt(pi)) + dz[32] = 2.3967683924866621f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // sqrt(1155)*(x2 - y2)*(9*z2 - 1)/(8*sqrt(pi)) + dz[33] = 8.8062893898345074f*xz*(-x2 + 3.0f*y2) ; // 9*sqrt(770)*xz*(-x2 + 3*y2)/(16*sqrt(pi)) + dz[34] = -12.453973889286246f*x2*y2 + 2.0756623148810411f*x4 + 2.0756623148810411f*y4 ; // 3*sqrt(385)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + dz[35] = 0.0f ; // 0 + if (C <= 6) { return; } + dz[36] = 0.0f ; // 0 + dz[37] = 2.3666191622317521f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + dz[38] = 44.401711264127719f*xy*z*(x2 - y2) ; // 33*sqrt(91)*xy*z*(x2 - y2)/(4*sqrt(pi)) + dz[39] = -2.7636157785447706f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*y*(3*x2 - y2)*(11*z2 - 1)/(32*sqrt(pi)) + dz[40] = 11.054463114179082f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(8*sqrt(pi)) + dz[41] = 2.9131068125936568f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*y*(18*z2 - 33*z4 - 1)/(16*sqrt(pi)) + dz[42] = 2.6699064952403937f*z*(-30.0f*z2 + 33.0f*z4 + 5.0f) ; // 21*sqrt(13)*z*(-30*z2 + 33*z4 + 5)/(16*sqrt(pi)) + dz[43] = 2.9131068125936568f*x*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*x*(18*z2 - 33*z4 - 1)/(16*sqrt(pi)) + dz[44] = 5.5272315570895412f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(16*sqrt(pi)) + dz[45] = -2.7636157785447706f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*x*(x2 - 3*y2)*(11*z2 - 1)/(32*sqrt(pi)) + dz[46] = 11.10042781603193f*z*(-6.0f*x2*y2 + x4 + y4) ; // 33*sqrt(91)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + dz[47] = 2.3666191622317521f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + dz[48] = 0.0f ; // 0 + if (C <= 7) { return; } + dz[49] = 0.0f ; // 0 + dz[50] = 5.2919213236038001f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + dz[51] = 13.491805046726766f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 39*sqrt(385)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + dz[52] = 12.453973889286248f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 9*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(8*sqrt(pi)) + dz[53] = -6.8841930899409371f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*yz*(3*x2 - y2)*(13*z2 - 3)/(16*sqrt(pi)) + dz[54] = 2.2126634622249131f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 15*sqrt(70)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi)) + dz[55] = 1.6259689364853116f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi)) + dz[56] = 64.528641681844675f*z2 - 236.60501950009714f*z4 + 205.05768356675085f*z6 - 2.3899496919201733f ; // 7*sqrt(15)*(135*z2 - 495*z4 + 429*z6 - 5)/(32*sqrt(pi)) + dz[57] = 1.6259689364853116f*xz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*xz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi)) + dz[58] = 0.07375544874083044f*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) + 132.0f*z2*(13.0f*z2 - 5.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*(x2 - y2)*(143*z2*(3*z2 - 1) + 132*z2*(13*z2 - 5) - 187*z2 + 45)/(64*sqrt(pi)) + dz[59] = -6.8841930899409371f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*xz*(x2 - 3*y2)*(13*z2 - 3)/(16*sqrt(pi)) + dz[60] = 3.1134934723215619f*(13.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 9*sqrt(385)*(13*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) + dz[61] = 13.491805046726766f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 39*sqrt(385)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + dz[62] = 39.6894099270285f*x2*y4 - 39.6894099270285f*x4*y2 + 2.6459606618019f*x6 - 2.6459606618019f*y6 ; // 3*sqrt(10010)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) + dz[63] = 0.0f ; // 0 + }; + write_sh_dx(); + write_sh_dy(); + write_sh_dz(); + } +} + + +template +__global__ void kernel_sh_backward( + const scalar_t * __restrict__ grad, + const scalar_t * __restrict__ inputs, + uint32_t B, uint32_t D, uint32_t C, + const scalar_t * __restrict__ dy_dx, + scalar_t * grad_inputs +) { + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + const uint32_t b = t / D; + if (b >= B) return; + + const uint32_t d = t - b * D; + const uint32_t C2 = C * C; + + // locate + grad += b * C2; + dy_dx += b * D * C2 + d * C2; + + for (int ch = 0; ch < C2; ch++) { + grad_inputs[t] += grad[ch] * dy_dx[ch]; + //printf("t=%d, b=%d, d=%d, ch=%d, grad=%f (+= %f * %f)\n", t, b, d, ch, grad_inputs[t], grad[ch], dy_dx[ch]); + } + +} + +// inputs: [B, D], float, in [0, 1] +// outputs: [B, L * C], float +template +void sh_encode_forward_cuda(const scalar_t *inputs, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx) { + static constexpr uint32_t N_THREADS = 256; + kernel_sh<<>>(inputs, outputs, B, D, C, dy_dx); +} + + +template +void sh_encode_backward_cuda(const scalar_t *grad, const scalar_t *inputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx, scalar_t *grad_inputs) { + static constexpr uint32_t N_THREADS = 256; + kernel_sh_backward<<>>(grad, inputs, B, D, C, dy_dx, grad_inputs); +} + + +void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional dy_dx) { + CHECK_CUDA(inputs); + CHECK_CUDA(outputs); + // CHECK_CUDA(dy_dx); + + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(outputs); + // CHECK_CONTIGUOUS(dy_dx); + + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(outputs); + // CHECK_IS_FLOATING(dy_dx); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + inputs.scalar_type(), "sh_encode_forward_cuda", ([&] { + sh_encode_forward_cuda(inputs.data_ptr(), outputs.data_ptr(), B, D, C, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr); + })); +} + +void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs) { + CHECK_CUDA(grad); + CHECK_CUDA(inputs); + CHECK_CUDA(dy_dx); + CHECK_CUDA(grad_inputs); + + CHECK_CONTIGUOUS(grad); + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(dy_dx); + CHECK_CONTIGUOUS(grad_inputs); + + CHECK_IS_FLOATING(grad); + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(dy_dx); + CHECK_IS_FLOATING(grad_inputs); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "sh_encode_backward_cuda", ([&] { + sh_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), B, D, C, dy_dx.data_ptr(), grad_inputs.data_ptr()); + })); +} \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/encoders/shencoder/src/shencoder.h b/Geneface_main/GeneFace/docker/modules/encoders/shencoder/src/shencoder.h new file mode 100644 index 00000000..f9e89fac --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/encoders/shencoder/src/shencoder.h @@ -0,0 +1,10 @@ +# pragma once + +#include +#include + +// inputs: [B, D], float, in [-1, 1] +// outputs: [B, F], float + +void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional dy_dx); +void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs); \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/raymarching/__init__.py b/Geneface_main/GeneFace/docker/modules/raymarching/__init__.py new file mode 100644 index 00000000..26d3cc6d --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/raymarching/__init__.py @@ -0,0 +1 @@ +from .raymarching import * \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/raymarching/backend.py b/Geneface_main/GeneFace/docker/modules/raymarching/backend.py new file mode 100644 index 00000000..d8f65d6f --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/raymarching/backend.py @@ -0,0 +1,40 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_raymarching_face', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'raymarching.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/raymarching/raymarching.py b/Geneface_main/GeneFace/docker/modules/raymarching/raymarching.py new file mode 100644 index 00000000..22dd0441 --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/raymarching/raymarching.py @@ -0,0 +1,423 @@ +import numpy as np +import time + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import _raymarching_face as _backend +except ImportError: + from .backend import _backend + +# ---------------------------------------- +# utils +# ---------------------------------------- + +class _near_far_from_aabb(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, rays_o, rays_d, aabb, min_near=0.2): + ''' near_far_from_aabb, CUDA implementation + Calculate rays' intersection time (near and far) with aabb + Args: + rays_o: float, [N, 3] + rays_d: float, [N, 3] + aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax) + min_near: float, scalar + Returns: + nears: float, [N] + fars: float, [N] + ''' + if not rays_o.is_cuda: rays_o = rays_o.cuda() + if not rays_d.is_cuda: rays_d = rays_d.cuda() + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # num rays + + nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) + fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) + + _backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars) + + return nears, fars + +near_far_from_aabb = _near_far_from_aabb.apply + + +class _sph_from_ray(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, rays_o, rays_d, radius): + ''' sph_from_ray, CUDA implementation + get spherical coordinate on the background sphere from rays. + Assume rays_o are inside the Sphere(radius). + Args: + rays_o: [N, 3] + rays_d: [N, 3] + radius: scalar, float + Return: + coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface) + ''' + if not rays_o.is_cuda: rays_o = rays_o.cuda() + if not rays_d.is_cuda: rays_d = rays_d.cuda() + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # num rays + + coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device) + + _backend.sph_from_ray(rays_o, rays_d, radius, N, coords) + + return coords + +sph_from_ray = _sph_from_ray.apply + + +class _morton3D(Function): + @staticmethod + def forward(ctx, coords): + ''' morton3D, CUDA implementation + Args: + coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...) + TODO: check if the coord range is valid! (current 128 is safe) + Returns: + indices: [N], int32, in [0, 128^3) + + ''' + if not coords.is_cuda: coords = coords.cuda() + + N = coords.shape[0] + + indices = torch.empty(N, dtype=torch.int32, device=coords.device) + + _backend.morton3D(coords.int(), N, indices) + + return indices + +morton3D = _morton3D.apply + +class _morton3D_invert(Function): + @staticmethod + def forward(ctx, indices): + ''' morton3D_invert, CUDA implementation + Args: + indices: [N], int32, in [0, 128^3) + Returns: + coords: [N, 3], int32, in [0, 128) + + ''' + if not indices.is_cuda: indices = indices.cuda() + + N = indices.shape[0] + + coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device) + + _backend.morton3D_invert(indices.int(), N, coords) + + return coords + +morton3D_invert = _morton3D_invert.apply + + +class _packbits(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, grid, thresh, bitfield=None): + ''' packbits, CUDA implementation + Pack up the density grid into a bit field to accelerate ray marching. + Args: + grid: float, [C, H * H * H], assume H % 2 == 0 + thresh: float, threshold + Returns: + bitfield: uint8, [C, H * H * H / 8] + ''' + if not grid.is_cuda: grid = grid.cuda() + grid = grid.contiguous() + + C = grid.shape[0] + H3 = grid.shape[1] + N = C * H3 // 8 + + if bitfield is None: + bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device) + + _backend.packbits(grid, N, thresh, bitfield) + + return bitfield + +packbits = _packbits.apply + + +class _morton3D_dilation(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, grid): + ''' max pooling with morton coord, CUDA implementation + or maybe call it dilation... we don't support adjust kernel size. + Args: + grid: float, [C, H * H * H], assume H % 2 == 0 + Returns: + grid_dilate: float, [C, H * H * H], assume H % 2 == 0bitfield: uint8, [C, H * H * H / 8] + ''' + if not grid.is_cuda: grid = grid.cuda() + grid = grid.contiguous() + + C = grid.shape[0] + H3 = grid.shape[1] + H = int(np.cbrt(H3)) + grid_dilation = torch.empty_like(grid) + + _backend.morton3D_dilation(grid, C, H, grid_dilation) + + return grid_dilation + +morton3D_dilation = _morton3D_dilation.apply + +# ---------------------------------------- +# train functions +# ---------------------------------------- + +class _march_rays_train(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024): + ''' march rays to generate points (forward only) + Args: + rays_o/d: float, [N, 3] + bound: float, scalar + density_bitfield: uint8: [CHHH // 8] + C: int + H: int + nears/fars: float, [N] + step_counter: int32, (2), used to count the actual number of generated points. + mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.) + perturb: bool + align: int, pad output so its size is dividable by align, set to -1 to disable. + force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays. + dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) + max_steps: int, max number of sampled points along each ray, also affect min_stepsize. + Returns: + xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray) + dirs: float, [M, 3], all generated points' view dirs. + deltas: float, [M, 2], first is delta_t, second is rays_t + rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 1] + rays[i, 2]] --> points belonging to rays[i, 0] + ''' + + if not rays_o.is_cuda: rays_o = rays_o.cuda() + if not rays_d.is_cuda: rays_d = rays_d.cuda() + if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda() + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + density_bitfield = density_bitfield.contiguous() + + N = rays_o.shape[0] # num rays + M = N * max_steps # init max points number in total + + # running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp) + # It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated. + if not force_all_rays and mean_count > 0: + if align > 0: + mean_count += align - mean_count % align + M = mean_count + + xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) + rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps + + if step_counter is None: + step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter + + if perturb: + noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device) + else: + noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device) + + _backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, noises) # m is the actually used points number + + #print(step_counter, M) + + # only used at the first (few) epochs. + if force_all_rays or mean_count <= 0: + m = step_counter[0].item() # D2H copy + if align > 0: + m += align - m % align + xyzs = xyzs[:m] + dirs = dirs[:m] + deltas = deltas[:m] + + torch.cuda.empty_cache() + + ctx.save_for_backward(rays, deltas) + + return xyzs, dirs, deltas, rays + + # to support optimizing camera poses. + @staticmethod + @custom_bwd + def backward(ctx, grad_xyzs, grad_dirs, grad_deltas, grad_rays): + # grad_xyzs/dirs: [M, 3] + + rays, deltas = ctx.saved_tensors + + N = rays.shape[0] + M = grad_xyzs.shape[0] + + grad_rays_o = torch.zeros(N, 3, device=rays.device) + grad_rays_d = torch.zeros(N, 3, device=rays.device) + + _backend.march_rays_train_backward(grad_xyzs, grad_dirs, rays, deltas, N, M, grad_rays_o, grad_rays_d) + + return grad_rays_o, grad_rays_d, None, None, None, None, None, None, None, None, None, None, None, None, None + +march_rays_train = _march_rays_train.apply + + +class _composite_rays_train(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, sigmas, rgbs, ambient, deltas, rays, T_thresh=1e-4): + ''' composite rays' rgbs, according to the ray marching formula. + Args: + rgbs: float, [M, 3] + sigmas: float, [M,] + ambient: float, [M,] (after summing up the last dimension) + deltas: float, [M, 2] + rays: int32, [N, 3] + Returns: + weights_sum: float, [N,], the alpha channel + depth: float, [N, ], the Depth + image: float, [N, 3], the RGB channel (after multiplying alpha!) + ''' + + sigmas = sigmas.contiguous() + rgbs = rgbs.contiguous() + ambient = ambient.contiguous() + + M = sigmas.shape[0] + N = rays.shape[0] + + weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + ambient_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device) + + _backend.composite_rays_train_forward(sigmas, rgbs, ambient, deltas, rays, M, N, T_thresh, weights_sum, ambient_sum, depth, image) + + ctx.save_for_backward(sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, depth, image) + ctx.dims = [M, N, T_thresh] + + return weights_sum, ambient_sum, depth, image + + @staticmethod + @custom_bwd + def backward(ctx, grad_weights_sum, grad_ambient_sum, grad_depth, grad_image): + + # NOTE: grad_depth is not used now! It won't be propagated to sigmas. + + grad_weights_sum = grad_weights_sum.contiguous() + grad_ambient_sum = grad_ambient_sum.contiguous() + grad_image = grad_image.contiguous() + + sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, depth, image = ctx.saved_tensors + M, N, T_thresh = ctx.dims + + grad_sigmas = torch.zeros_like(sigmas) + grad_rgbs = torch.zeros_like(rgbs) + grad_ambient = torch.zeros_like(ambient) + + _backend.composite_rays_train_backward(grad_weights_sum, grad_ambient_sum, grad_image, sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs, grad_ambient) + + return grad_sigmas, grad_rgbs, grad_ambient, None, None, None + + +composite_rays_train = _composite_rays_train.apply + +# ---------------------------------------- +# infer functions +# ---------------------------------------- + +class _march_rays(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, align=-1, perturb=False, dt_gamma=0, max_steps=1024): + ''' march rays to generate points (forward only, for inference) + Args: + n_alive: int, number of alive rays + n_step: int, how many steps we march + rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive) + rays_t: float, [N], the alive rays' time, we only use the first n_alive. + rays_o/d: float, [N, 3] + bound: float, scalar + density_bitfield: uint8: [CHHH // 8] + C: int + H: int + nears/fars: float, [N] + align: int, pad output so its size is dividable by align, set to -1 to disable. + perturb: bool/int, int > 0 is used as the random seed. + dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) + max_steps: int, max number of sampled points along each ray, also affect min_stepsize. + Returns: + xyzs: float, [n_alive * n_step, 3], all generated points' coords + dirs: float, [n_alive * n_step, 3], all generated points' view dirs. + deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth). + ''' + + if not rays_o.is_cuda: rays_o = rays_o.cuda() + if not rays_d.is_cuda: rays_d = rays_d.cuda() + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + M = n_alive * n_step + + if align > 0: + M += align - (M % align) + + xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth + + if perturb: + # torch.manual_seed(perturb) # test_gui uses spp index as seed + noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device) + else: + noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device) + + _backend.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, deltas, noises) + + return xyzs, dirs, deltas + +march_rays = _march_rays.apply + + +class _composite_rays(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float + def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh=1e-2): + ''' composite rays' rgbs, according to the ray marching formula. (for inference) + Args: + n_alive: int, number of alive rays + n_step: int, how many steps we march + rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive) + rays_t: float, [N], the alive rays' time + sigmas: float, [n_alive * n_step,] + rgbs: float, [n_alive * n_step, 3] + deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth). + In-place Outputs: + weights_sum: float, [N,], the alpha channel + depth: float, [N,], the depth value + image: float, [N, 3], the RGB channel (after multiplying alpha!) + ''' + _backend.composite_rays(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image) + return tuple() + + +composite_rays = _composite_rays.apply \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/raymarching/setup.py b/Geneface_main/GeneFace/docker/modules/raymarching/setup.py new file mode 100644 index 00000000..6a7e62f7 --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/raymarching/setup.py @@ -0,0 +1,63 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + # '-lineinfo', # to debug illegal memory access + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +''' +Usage: + +python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory) + +python setup.py install # build extensions and install (copy) to PATH. +pip install . # ditto but better (e.g., dependency & metadata handling) + +python setup.py develop # build extensions and install (symbolic) to PATH. +pip install -e . # ditto but better (e.g., dependency & metadata handling) + +''' +setup( + name='raymarching_face', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_raymarching_face', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'raymarching.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/raymarching/src/bindings.cpp b/Geneface_main/GeneFace/docker/modules/raymarching/src/bindings.cpp new file mode 100644 index 00000000..589de244 --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/raymarching/src/bindings.cpp @@ -0,0 +1,21 @@ +#include + +#include "raymarching.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // utils + m.def("packbits", &packbits, "packbits (CUDA)"); + m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)"); + m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)"); + m.def("morton3D", &morton3D, "morton3D (CUDA)"); + m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)"); + m.def("morton3D_dilation", &morton3D_dilation, "morton3D_dilation (CUDA)"); + // train + m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)"); + m.def("march_rays_train_backward", &march_rays_train_backward, "march_rays_train_backward (CUDA)"); + m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)"); + m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)"); + // infer + m.def("march_rays", &march_rays, "march rays (CUDA)"); + m.def("composite_rays", &composite_rays, "composite rays (CUDA)"); +} \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/raymarching/src/raymarching.cu b/Geneface_main/GeneFace/docker/modules/raymarching/src/raymarching.cu new file mode 100644 index 00000000..ae5839bc --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/raymarching/src/raymarching.cu @@ -0,0 +1,1038 @@ +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + + +inline constexpr __device__ float SQRT3() { return 1.7320508075688772f; } +inline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; } +inline constexpr __device__ float PI() { return 3.141592653589793f; } +inline constexpr __device__ float RPI() { return 0.3183098861837907f; } + + +template +inline __host__ __device__ T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +inline __host__ __device__ float signf(const float x) { + return copysignf(1.0, x); +} + +inline __host__ __device__ float clamp(const float x, const float min, const float max) { + return fminf(max, fmaxf(min, x)); +} + +inline __host__ __device__ void swapf(float& a, float& b) { + float c = a; a = b; b = c; +} + +inline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) { + const float mx = fmaxf(fabsf(x), fmaxf(fabs(y), fabs(z))); + int exponent; + frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, [2, 4) --> 2, ... + return fminf(max_cascade - 1, fmaxf(0, exponent)); +} + +inline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) { + const float mx = dt * H * 0.5; + int exponent; + frexpf(mx, &exponent); + return fminf(max_cascade - 1, fmaxf(0, exponent)); +} + +inline __host__ __device__ uint32_t __expand_bits(uint32_t v) +{ + v = (v * 0x00010001u) & 0xFF0000FFu; + v = (v * 0x00000101u) & 0x0F00F00Fu; + v = (v * 0x00000011u) & 0xC30C30C3u; + v = (v * 0x00000005u) & 0x49249249u; + return v; +} + +inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z) +{ + uint32_t xx = __expand_bits(x); + uint32_t yy = __expand_bits(y); + uint32_t zz = __expand_bits(z); + return xx | (yy << 1) | (zz << 2); +} + +inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x) +{ + x = x & 0x49249249; + x = (x | (x >> 2)) & 0xc30c30c3; + x = (x | (x >> 4)) & 0x0f00f00f; + x = (x | (x >> 8)) & 0xff0000ff; + x = (x | (x >> 16)) & 0x0000ffff; + return x; +} + + +//////////////////////////////////////////////////// +///////////// utils ///////////// +//////////////////////////////////////////////////// + +// rays_o/d: [N, 3] +// nears/fars: [N] +// scalar_t should always be float in use. +template +__global__ void kernel_near_far_from_aabb( + const scalar_t * __restrict__ rays_o, + const scalar_t * __restrict__ rays_d, + const scalar_t * __restrict__ aabb, + const uint32_t N, + const float min_near, + scalar_t * nears, scalar_t * fars +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + rays_o += n * 3; + rays_d += n * 3; + + const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; + const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; + const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; + + // get near far (assume cube scene) + float near = (aabb[0] - ox) * rdx; + float far = (aabb[3] - ox) * rdx; + if (near > far) swapf(near, far); + + float near_y = (aabb[1] - oy) * rdy; + float far_y = (aabb[4] - oy) * rdy; + if (near_y > far_y) swapf(near_y, far_y); + + if (near > far_y || near_y > far) { + nears[n] = fars[n] = std::numeric_limits::max(); + return; + } + + if (near_y > near) near = near_y; + if (far_y < far) far = far_y; + + float near_z = (aabb[2] - oz) * rdz; + float far_z = (aabb[5] - oz) * rdz; + if (near_z > far_z) swapf(near_z, far_z); + + if (near > far_z || near_z > far) { + nears[n] = fars[n] = std::numeric_limits::max(); + return; + } + + if (near_z > near) near = near_z; + if (far_z < far) far = far_z; + + if (near < min_near) near = min_near; + + nears[n] = near; + fars[n] = far; +} + + +void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rays_o.scalar_type(), "near_far_from_aabb", ([&] { + kernel_near_far_from_aabb<<>>(rays_o.data_ptr(), rays_d.data_ptr(), aabb.data_ptr(), N, min_near, nears.data_ptr(), fars.data_ptr()); + })); +} + + +// rays_o/d: [N, 3] +// radius: float +// coords: [N, 2] +template +__global__ void kernel_sph_from_ray( + const scalar_t * __restrict__ rays_o, + const scalar_t * __restrict__ rays_d, + const float radius, + const uint32_t N, + scalar_t * coords +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + rays_o += n * 3; + rays_d += n * 3; + coords += n * 2; + + const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; + const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; + const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; + + // solve t from || o + td || = radius + const float A = dx * dx + dy * dy + dz * dz; + const float B = ox * dx + oy * dy + oz * dz; // in fact B / 2 + const float C = ox * ox + oy * oy + oz * oz - radius * radius; + + const float t = (- B + sqrtf(B * B - A * C)) / A; // always use the larger solution (positive) + + // solve theta, phi (assume y is the up axis) + const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz; + const float theta = atan2(sqrtf(x * x + z * z), y); // [0, PI) + const float phi = atan2(z, x); // [-PI, PI) + + // normalize to [-1, 1] + coords[0] = 2 * theta * RPI() - 1; + coords[1] = phi * RPI(); +} + + +void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rays_o.scalar_type(), "sph_from_ray", ([&] { + kernel_sph_from_ray<<>>(rays_o.data_ptr(), rays_d.data_ptr(), radius, N, coords.data_ptr()); + })); +} + + +// coords: int32, [N, 3] +// indices: int32, [N] +__global__ void kernel_morton3D( + const int * __restrict__ coords, + const uint32_t N, + int * indices +) { + // parallel + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + coords += n * 3; + indices[n] = __morton3D(coords[0], coords[1], coords[2]); +} + + +void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices) { + static constexpr uint32_t N_THREAD = 128; + kernel_morton3D<<>>(coords.data_ptr(), N, indices.data_ptr()); +} + + +// indices: int32, [N] +// coords: int32, [N, 3] +__global__ void kernel_morton3D_invert( + const int * __restrict__ indices, + const uint32_t N, + int * coords +) { + // parallel + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + coords += n * 3; + + const int ind = indices[n]; + + coords[0] = __morton3D_invert(ind >> 0); + coords[1] = __morton3D_invert(ind >> 1); + coords[2] = __morton3D_invert(ind >> 2); +} + + +void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords) { + static constexpr uint32_t N_THREAD = 128; + kernel_morton3D_invert<<>>(indices.data_ptr(), N, coords.data_ptr()); +} + + +// grid: float, [C, H, H, H] +// N: int, C * H * H * H / 8 +// density_thresh: float +// bitfield: uint8, [N] +template +__global__ void kernel_packbits( + const scalar_t * __restrict__ grid, + const uint32_t N, + const float density_thresh, + uint8_t * bitfield +) { + // parallel per byte + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + grid += n * 8; + + uint8_t bits = 0; + + #pragma unroll + for (uint8_t i = 0; i < 8; i++) { + bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0; + } + + bitfield[n] = bits; +} + + +void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grid.scalar_type(), "packbits", ([&] { + kernel_packbits<<>>(grid.data_ptr(), N, density_thresh, bitfield.data_ptr()); + })); +} + + +// grid: float, [C, H, H, H] +__global__ void kernel_morton3D_dilation( + const float * __restrict__ grid, + const uint32_t C, + const uint32_t H, + float * __restrict__ grid_dilation +) { + // parallel per byte + const uint32_t H3 = H * H * H; + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= C * H3) return; + + // locate + const uint32_t c = n / H3; + const uint32_t ind = n - c * H3; + + const uint32_t x = __morton3D_invert(ind >> 0); + const uint32_t y = __morton3D_invert(ind >> 1); + const uint32_t z = __morton3D_invert(ind >> 2); + + // manual max pool + float res = grid[n]; + + if (x + 1 < H) res = fmaxf(res, grid[c * H3 + __morton3D(x + 1, y, z)]); + if (x > 0) res = fmaxf(res, grid[c * H3 + __morton3D(x - 1, y, z)]); + if (y + 1 < H) res = fmaxf(res, grid[c * H3 + __morton3D(x, y + 1, z)]); + if (y > 0) res = fmaxf(res, grid[c * H3 + __morton3D(x, y - 1, z)]); + if (z + 1 < H) res = fmaxf(res, grid[c * H3 + __morton3D(x, y, z + 1)]); + if (z > 0) res = fmaxf(res, grid[c * H3 + __morton3D(x, y, z - 1)]); + + // write + grid_dilation[n] = res; +} + +void morton3D_dilation(const at::Tensor grid, const uint32_t C, const uint32_t H, at::Tensor grid_dilation) { + static constexpr uint32_t N_THREAD = 128; + + kernel_morton3D_dilation<<>>(grid.data_ptr(), C, H, grid_dilation.data_ptr()); +} + +//////////////////////////////////////////////////// +///////////// training ///////////// +//////////////////////////////////////////////////// + +// rays_o/d: [N, 3] +// grid: [CHHH / 8] +// xyzs, dirs, deltas: [M, 3], [M, 3], [M, 2] +// dirs: [M, 3] +// rays: [N, 3], idx, offset, num_steps +template +__global__ void kernel_march_rays_train( + const scalar_t * __restrict__ rays_o, + const scalar_t * __restrict__ rays_d, + const uint8_t * __restrict__ grid, + const float bound, + const float dt_gamma, const uint32_t max_steps, + const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, + const scalar_t* __restrict__ nears, + const scalar_t* __restrict__ fars, + scalar_t * xyzs, scalar_t * dirs, scalar_t * deltas, + int * rays, + int * counter, + const scalar_t* __restrict__ noises +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + rays_o += n * 3; + rays_d += n * 3; + + // ray marching + const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; + const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; + const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; + const float rH = 1 / (float)H; + const float H3 = H * H * H; + + const float near = nears[n]; + const float far = fars[n]; + const float noise = noises[n]; + + const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H; + const float dt_min = fminf(dt_max, 2 * SQRT3() / max_steps); + + float t0 = near; + + // perturb + t0 += clamp(t0 * dt_gamma, dt_min, dt_max) * noise; + + // first pass: estimation of num_steps + float t = t0; + uint32_t num_steps = 0; + + //if (t < far) printf("valid ray %d t=%f near=%f far=%f \n", n, t, near, far); + + while (t < far && num_steps < max_steps) { + // current point + const float x = clamp(ox + t * dx, -bound, bound); + const float y = clamp(oy + t * dy, -bound, bound); + const float z = clamp(oz + t * dz, -bound, bound); + + const float dt = clamp(t * dt_gamma, dt_min, dt_max); + + // get mip level + const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] + + const float mip_bound = fminf(scalbnf(1.0f, level), bound); + const float mip_rbound = 1 / mip_bound; + + // convert to nearest grid position + const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + + const uint32_t index = level * H3 + __morton3D(nx, ny, nz); + const bool occ = grid[index / 8] & (1 << (index % 8)); + + // if occpuied, advance a small step, and write to output + //if (n == 0) printf("t=%f density=%f vs thresh=%f step=%d\n", t, density, density_thresh, num_steps); + + if (occ) { + num_steps++; + t += dt; + // else, skip a large step (basically skip a voxel grid) + } else { + // calc distance to next voxel + const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; + const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; + const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; + + const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); + // step until next voxel + do { + t += clamp(t * dt_gamma, dt_min, dt_max); + } while (t < tt); + } + } + + //printf("[n=%d] num_steps=%d, near=%f, far=%f, dt=%f, max_steps=%f\n", n, num_steps, near, far, dt_min, (far - near) / dt_min); + + // second pass: really locate and write points & dirs + uint32_t point_index = atomicAdd(counter, num_steps); + uint32_t ray_index = atomicAdd(counter + 1, 1); + + //printf("[n=%d] num_steps=%d, point_index=%d, ray_index=%d\n", n, num_steps, point_index, ray_index); + + // write rays + rays[ray_index * 3] = n; + rays[ray_index * 3 + 1] = point_index; + rays[ray_index * 3 + 2] = num_steps; + + if (num_steps == 0) return; + if (point_index + num_steps > M) return; + + xyzs += point_index * 3; + dirs += point_index * 3; + deltas += point_index * 2; + + t = t0; + uint32_t step = 0; + + while (t < far && step < num_steps) { + // current point + const float x = clamp(ox + t * dx, -bound, bound); + const float y = clamp(oy + t * dy, -bound, bound); + const float z = clamp(oz + t * dz, -bound, bound); + + const float dt = clamp(t * dt_gamma, dt_min, dt_max); + + // get mip level + const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] + + const float mip_bound = fminf(scalbnf(1.0f, level), bound); + const float mip_rbound = 1 / mip_bound; + + // convert to nearest grid position + const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + + // query grid + const uint32_t index = level * H3 + __morton3D(nx, ny, nz); + const bool occ = grid[index / 8] & (1 << (index % 8)); + + // if occpuied, advance a small step, and write to output + if (occ) { + // write step + xyzs[0] = x; + xyzs[1] = y; + xyzs[2] = z; + dirs[0] = dx; + dirs[1] = dy; + dirs[2] = dz; + t += dt; + deltas[0] = dt; + deltas[1] = t; // used to calc depth + xyzs += 3; + dirs += 3; + deltas += 2; + step++; + // else, skip a large step (basically skip a voxel grid) + } else { + // calc distance to next voxel + const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; + const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; + const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; + const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); + // step until next voxel + do { + t += clamp(t * dt_gamma, dt_min, dt_max); + } while (t < tt); + } + } +} + +void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rays_o.scalar_type(), "march_rays_train", ([&] { + kernel_march_rays_train<<>>(rays_o.data_ptr(), rays_d.data_ptr(), grid.data_ptr(), bound, dt_gamma, max_steps, N, C, H, M, nears.data_ptr(), fars.data_ptr(), xyzs.data_ptr(), dirs.data_ptr(), deltas.data_ptr(), rays.data_ptr(), counter.data_ptr(), noises.data_ptr()); + })); +} + + +// grad_xyzs/dirs: [M, 3] +// rays: [N, 3] +// deltas: [M, 2] +// grad_rays_o/d: [N, 3] +template +__global__ void kernel_march_rays_train_backward( + const scalar_t * __restrict__ grad_xyzs, + const scalar_t * __restrict__ grad_dirs, + const int * __restrict__ rays, + const scalar_t * __restrict__ deltas, + const uint32_t N, const uint32_t M, + scalar_t * grad_rays_o, + scalar_t * grad_rays_d +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + grad_rays_o += n * 3; + grad_rays_d += n * 3; + + uint32_t index = rays[n * 3]; + uint32_t offset = rays[n * 3 + 1]; + uint32_t num_steps = rays[n * 3 + 2]; + + // empty ray, or ray that exceed max step count. + if (num_steps == 0 || offset + num_steps > M) return; + + grad_xyzs += offset * 3; + grad_dirs += offset * 3; + deltas += offset * 2; + + // accumulate + uint32_t step = 0; + while (step < num_steps) { + + grad_rays_o[0] += grad_xyzs[0]; + grad_rays_o[1] += grad_xyzs[1]; + grad_rays_o[2] += grad_xyzs[2]; + + grad_rays_d[0] += grad_xyzs[0] * deltas[1] + grad_dirs[0]; + grad_rays_d[1] += grad_xyzs[1] * deltas[1] + grad_dirs[1]; + grad_rays_d[2] += grad_xyzs[2] * deltas[1] + grad_dirs[2]; + + // locate + grad_xyzs += 3; + grad_dirs += 3; + deltas += 2; + + step++; + } +} + +void march_rays_train_backward(const at::Tensor grad_xyzs, const at::Tensor grad_dirs, const at::Tensor rays, const at::Tensor deltas, const uint32_t N, const uint32_t M, at::Tensor grad_rays_o, at::Tensor grad_rays_d) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_xyzs.scalar_type(), "march_rays_train_backward", ([&] { + kernel_march_rays_train_backward<<>>(grad_xyzs.data_ptr(), grad_dirs.data_ptr(), rays.data_ptr(), deltas.data_ptr(), N, M, grad_rays_o.data_ptr(), grad_rays_d.data_ptr()); + })); +} + + +// sigmas: [M] +// rgbs: [M, 3] +// deltas: [M, 2] +// rays: [N, 3], idx, offset, num_steps +// weights_sum: [N], final pixel alpha +// depth: [N,] +// image: [N, 3] +template +__global__ void kernel_composite_rays_train_forward( + const scalar_t * __restrict__ sigmas, + const scalar_t * __restrict__ rgbs, + const scalar_t * __restrict__ ambient, + const scalar_t * __restrict__ deltas, + const int * __restrict__ rays, + const uint32_t M, const uint32_t N, const float T_thresh, + scalar_t * weights_sum, + scalar_t * ambient_sum, + scalar_t * depth, + scalar_t * image +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + uint32_t index = rays[n * 3]; + uint32_t offset = rays[n * 3 + 1]; + uint32_t num_steps = rays[n * 3 + 2]; + + // empty ray, or ray that exceed max step count. + if (num_steps == 0 || offset + num_steps > M) { + weights_sum[index] = 0; + ambient_sum[index] = 0; + depth[index] = 0; + image[index * 3] = 0; + image[index * 3 + 1] = 0; + image[index * 3 + 2] = 0; + return; + } + + sigmas += offset; + rgbs += offset * 3; + ambient += offset; + deltas += offset * 2; + + // accumulate + uint32_t step = 0; + + scalar_t T = 1.0f; + scalar_t r = 0, g = 0, b = 0, ws = 0, d = 0, amb = 0; + + while (step < num_steps) { + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + const scalar_t weight = alpha * T; + + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + + d += weight * deltas[1]; + + ws += weight; + + amb += ambient[0]; + + T *= 1.0f - alpha; + + // minimal remained transmittence + if (T < T_thresh) break; + + //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); + + // locate + sigmas++; + rgbs += 3; + ambient++; + deltas += 2; + + step++; + } + + //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); + + // write + weights_sum[index] = ws; // weights_sum + ambient_sum[index] = amb; + depth[index] = d; + image[index * 3] = r; + image[index * 3 + 1] = g; + image[index * 3 + 2] = b; +} + + +void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor ambient_sum, at::Tensor depth, at::Tensor image) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + sigmas.scalar_type(), "composite_rays_train_forward", ([&] { + kernel_composite_rays_train_forward<<>>(sigmas.data_ptr(), rgbs.data_ptr(), ambient.data_ptr(), deltas.data_ptr(), rays.data_ptr(), M, N, T_thresh, weights_sum.data_ptr(), ambient_sum.data_ptr(), depth.data_ptr(), image.data_ptr()); + })); +} + + +// grad_weights_sum: [N,] +// grad: [N, 3] +// sigmas: [M] +// rgbs: [M, 3] +// deltas: [M, 2] +// rays: [N, 3], idx, offset, num_steps +// weights_sum: [N,], weights_sum here +// image: [N, 3] +// grad_sigmas: [M] +// grad_rgbs: [M, 3] +template +__global__ void kernel_composite_rays_train_backward( + const scalar_t * __restrict__ grad_weights_sum, + const scalar_t * __restrict__ grad_ambient_sum, + const scalar_t * __restrict__ grad_image, + const scalar_t * __restrict__ sigmas, + const scalar_t * __restrict__ rgbs, + const scalar_t * __restrict__ ambient, + const scalar_t * __restrict__ deltas, + const int * __restrict__ rays, + const scalar_t * __restrict__ weights_sum, + const scalar_t * __restrict__ ambient_sum, + const scalar_t * __restrict__ image, + const uint32_t M, const uint32_t N, const float T_thresh, + scalar_t * grad_sigmas, + scalar_t * grad_rgbs, + scalar_t * grad_ambient +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + uint32_t index = rays[n * 3]; + uint32_t offset = rays[n * 3 + 1]; + uint32_t num_steps = rays[n * 3 + 2]; + + if (num_steps == 0 || offset + num_steps > M) return; + + grad_weights_sum += index; + grad_ambient_sum += index; + grad_image += index * 3; + weights_sum += index; + ambient_sum += index; + image += index * 3; + + sigmas += offset; + rgbs += offset * 3; + ambient += offset; + deltas += offset * 2; + + grad_sigmas += offset; + grad_rgbs += offset * 3; + grad_ambient += offset; + + // accumulate + uint32_t step = 0; + + scalar_t T = 1.0f; + const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0]; + scalar_t r = 0, g = 0, b = 0, ws = 0; + + while (step < num_steps) { + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + const scalar_t weight = alpha * T; + + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + // amb += weight * ambient[0]; + ws += weight; + + T *= 1.0f - alpha; + + // check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation. + // write grad_rgbs + grad_rgbs[0] = grad_image[0] * weight; + grad_rgbs[1] = grad_image[1] * weight; + grad_rgbs[2] = grad_image[2] * weight; + + // write grad_ambient + grad_ambient[0] = grad_ambient_sum[0]; + + // write grad_sigmas + grad_sigmas[0] = deltas[0] * ( + grad_image[0] * (T * rgbs[0] - (r_final - r)) + + grad_image[1] * (T * rgbs[1] - (g_final - g)) + + grad_image[2] * (T * rgbs[2] - (b_final - b)) + + // grad_ambient_sum[0] * (T * ambient[0] - (amb_final - amb)) + + grad_weights_sum[0] * (1 - ws_final) + ); + + //printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r); + // minimal remained transmittence + if (T < T_thresh) break; + + // locate + sigmas++; + rgbs += 3; + // ambient++; + deltas += 2; + grad_sigmas++; + grad_rgbs += 3; + grad_ambient++; + + step++; + } +} + + +void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_ambient_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor ambient_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_ambient) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_image.scalar_type(), "composite_rays_train_backward", ([&] { + kernel_composite_rays_train_backward<<>>(grad_weights_sum.data_ptr(), grad_ambient_sum.data_ptr(), grad_image.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), ambient.data_ptr(), deltas.data_ptr(), rays.data_ptr(), weights_sum.data_ptr(), ambient_sum.data_ptr(), image.data_ptr(), M, N, T_thresh, grad_sigmas.data_ptr(), grad_rgbs.data_ptr(), grad_ambient.data_ptr()); + })); +} + + +//////////////////////////////////////////////////// +///////////// infernce ///////////// +//////////////////////////////////////////////////// + +template +__global__ void kernel_march_rays( + const uint32_t n_alive, + const uint32_t n_step, + const int* __restrict__ rays_alive, + const scalar_t* __restrict__ rays_t, + const scalar_t* __restrict__ rays_o, + const scalar_t* __restrict__ rays_d, + const float bound, + const float dt_gamma, const uint32_t max_steps, + const uint32_t C, const uint32_t H, + const uint8_t * __restrict__ grid, + const scalar_t* __restrict__ nears, + const scalar_t* __restrict__ fars, + scalar_t* xyzs, scalar_t* dirs, scalar_t* deltas, + const scalar_t* __restrict__ noises +) { + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= n_alive) return; + + const int index = rays_alive[n]; // ray id + const float noise = noises[n]; + + // locate + rays_o += index * 3; + rays_d += index * 3; + xyzs += n * n_step * 3; + dirs += n * n_step * 3; + deltas += n * n_step * 2; + + const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; + const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; + const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; + const float rH = 1 / (float)H; + const float H3 = H * H * H; + + float t = rays_t[index]; // current ray's t + const float near = nears[index], far = fars[index]; + + const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H; + const float dt_min = fminf(dt_max, 2 * SQRT3() / max_steps); + + // march for n_step steps, record points + uint32_t step = 0; + + // introduce some randomness + t += clamp(t * dt_gamma, dt_min, dt_max) * noise; + + while (t < far && step < n_step) { + // current point + const float x = clamp(ox + t * dx, -bound, bound); + const float y = clamp(oy + t * dy, -bound, bound); + const float z = clamp(oz + t * dz, -bound, bound); + + const float dt = clamp(t * dt_gamma, dt_min, dt_max); + + // get mip level + const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] + + const float mip_bound = fminf(scalbnf(1, level), bound); + const float mip_rbound = 1 / mip_bound; + + // convert to nearest grid position + const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + + const uint32_t index = level * H3 + __morton3D(nx, ny, nz); + const bool occ = grid[index / 8] & (1 << (index % 8)); + + // if occpuied, advance a small step, and write to output + if (occ) { + // write step + xyzs[0] = x; + xyzs[1] = y; + xyzs[2] = z; + dirs[0] = dx; + dirs[1] = dy; + dirs[2] = dz; + // calc dt + t += dt; + deltas[0] = dt; + deltas[1] = t; // used to calc depth + // step + xyzs += 3; + dirs += 3; + deltas += 2; + step++; + + // else, skip a large step (basically skip a voxel grid) + } else { + // calc distance to next voxel + const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; + const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; + const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; + const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); + // step until next voxel + do { + t += clamp(t * dt_gamma, dt_min, dt_max); + } while (t < tt); + } + } +} + + +void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor near, const at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises) { + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rays_o.scalar_type(), "march_rays", ([&] { + kernel_march_rays<<>>(n_alive, n_step, rays_alive.data_ptr(), rays_t.data_ptr(), rays_o.data_ptr(), rays_d.data_ptr(), bound, dt_gamma, max_steps, C, H, grid.data_ptr(), near.data_ptr(), far.data_ptr(), xyzs.data_ptr(), dirs.data_ptr(), deltas.data_ptr(), noises.data_ptr()); + })); +} + + +template +__global__ void kernel_composite_rays( + const uint32_t n_alive, + const uint32_t n_step, + const float T_thresh, + int* rays_alive, + scalar_t* rays_t, + const scalar_t* __restrict__ sigmas, + const scalar_t* __restrict__ rgbs, + const scalar_t* __restrict__ deltas, + scalar_t* weights_sum, scalar_t* depth, scalar_t* image +) { + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= n_alive) return; + + const int index = rays_alive[n]; // ray id + + // locate + sigmas += n * n_step; + rgbs += n * n_step * 3; + deltas += n * n_step * 2; + + rays_t += index; + weights_sum += index; + depth += index; + image += index * 3; + + scalar_t t = rays_t[0]; // current ray's t + + scalar_t weight_sum = weights_sum[0]; + scalar_t d = depth[0]; + scalar_t r = image[0]; + scalar_t g = image[1]; + scalar_t b = image[2]; + + // accumulate + uint32_t step = 0; + while (step < n_step) { + + // ray is terminated if delta == 0 + if (deltas[0] == 0) break; + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + + /* + T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j) + w_i = alpha_i * T_i + --> + T_i = 1 - \sum_{j=0}^{i-1} w_j + */ + const scalar_t T = 1 - weight_sum; + const scalar_t weight = alpha * T; + weight_sum += weight; + + t = deltas[1]; + d += weight * t; + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + + //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); + + // ray is terminated if T is too small + // use a larger bound to further accelerate inference + if (T < T_thresh) break; + + // locate + sigmas++; + rgbs += 3; + deltas += 2; + step++; + } + + //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); + + // rays_alive = -1 means ray is terminated early. + if (step < n_step) { + rays_alive[n] = -1; + } else { + rays_t[0] = t; + } + + weights_sum[0] = weight_sum; // this is the thing I needed! + depth[0] = d; + image[0] = r; + image[1] = g; + image[2] = b; +} + + +void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights, at::Tensor depth, at::Tensor image) { + static constexpr uint32_t N_THREAD = 128; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + image.scalar_type(), "composite_rays", ([&] { + kernel_composite_rays<<>>(n_alive, n_step, T_thresh, rays_alive.data_ptr(), rays_t.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), weights.data_ptr(), depth.data_ptr(), image.data_ptr()); + })); +} \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/modules/raymarching/src/raymarching.h b/Geneface_main/GeneFace/docker/modules/raymarching/src/raymarching.h new file mode 100644 index 00000000..e7d9b219 --- /dev/null +++ b/Geneface_main/GeneFace/docker/modules/raymarching/src/raymarching.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + + +void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars); +void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords); +void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices); +void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords); +void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield); +void morton3D_dilation(const at::Tensor grid, const uint32_t C, const uint32_t H, at::Tensor grid_dilation); + +void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises); +void march_rays_train_backward(const at::Tensor grad_xyzs, const at::Tensor grad_dirs, const at::Tensor rays, const at::Tensor deltas, const uint32_t N, const uint32_t M, at::Tensor grad_rays_o, at::Tensor grad_rays_d); +void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor ambient_sum, at::Tensor depth, at::Tensor image); +void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_ambient_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor ambient_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_ambient); + +void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises); +void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); \ No newline at end of file diff --git a/Geneface_main/GeneFace/docker/requirements.txt b/Geneface_main/GeneFace/docker/requirements.txt new file mode 100644 index 00000000..c6b13bde --- /dev/null +++ b/Geneface_main/GeneFace/docker/requirements.txt @@ -0,0 +1,27 @@ +numpy==1.23.0 +pandas +transformers +scipy +scikit-learn +scikit-image +tensorboard +tensorboardX +python_speech_features +resampy +opencv_python +face_alignment +matplotlib +configargparse +librosa==0.9.2 +praat-parselmouth==0.4.3 +trimesh +kornia==0.5.0 +PyMCubes +lpips +setuptools # ==59.5.0 +ffmpeg-python +moviepy +dearpygui +ninja +pyaudio # for extract esperanto +mediapipe==0.8.11 \ No newline at end of file diff --git a/Geneface_main/GeneFace/docs/prepare_env/geneface_rtx2080.yaml b/Geneface_main/GeneFace/docs/prepare_env/geneface_rtx2080.yaml new file mode 100644 index 00000000..5284b4cf --- /dev/null +++ b/Geneface_main/GeneFace/docs/prepare_env/geneface_rtx2080.yaml @@ -0,0 +1,188 @@ +name: geneface +channels: + - pytorch3d + - bottler + - iopath + - pytorch + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - blas=1.0=mkl + - brotlipy=0.7.0=py39h27cfd23_1003 + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2023.01.10=h06a4308_0 + - certifi=2022.12.7=py39h06a4308_0 + - cffi=1.15.1=py39h5eee18b_3 + - charset-normalizer=2.0.4=pyhd3eb1b0_0 + - colorama=0.4.6=pyhd8ed1ab_0 + - cryptography=39.0.1=py39h9ce1e76_0 + - cudatoolkit=11.3.1=h2bc3f7f_2 + - ffmpeg=4.3=hf484d3e_0 + - flit-core=3.8.0=py39h06a4308_0 + - freetype=2.12.1=h4a9f257_0 + - fvcore=0.1.5.post20221221=pyhd8ed1ab_0 + - giflib=5.2.1=h5eee18b_3 + - gmp=6.2.1=h295c915_3 + - gnutls=3.6.15=he1e5248_0 + - idna=3.4=py39h06a4308_0 + - intel-openmp=2021.4.0=h06a4308_3561 + - iopath=0.1.9=py39 + - jpeg=9e=h5eee18b_1 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=3.0=h295c915_0 + - libdeflate=1.17=h5eee18b_0 + - libffi=3.4.2=h6a678d5_6 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libiconv=1.16=h7f8727e_2 + - libidn2=2.3.2=h7f8727e_0 + - libpng=1.6.39=h5eee18b_0 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtasn1=4.16.0=h27cfd23_0 + - libtiff=4.5.0=h6a678d5_2 + - libunistring=0.9.10=h27cfd23_0 + - libuv=1.44.2=h5eee18b_0 + - libwebp=1.2.4=h11a3e52_1 + - libwebp-base=1.2.4=h5eee18b_1 + - lz4-c=1.9.4=h6a678d5_0 + - mkl=2021.4.0=h06a4308_640 + - mkl-service=2.4.0=py39h7f8727e_0 + - mkl_fft=1.3.1=py39hd3c417c_0 + - mkl_random=1.2.2=py39h51133e4_0 + - ncurses=6.4=h6a678d5_0 + - nettle=3.7.3=hbbd107a_1 + - numpy=1.23.5=py39h14f4228_0 + - numpy-base=1.23.5=py39h31eccc5_0 + - nvidiacub=1.10.0=0 + - openh264=2.1.1=h4ff587b_0 + - openssl=1.1.1t=h7f8727e_0 + - pillow=9.4.0=py39h6a678d5_0 + - pip=23.0.1=py39h06a4308_0 + - portalocker=2.7.0=py39hf3d152e_0 + - pycparser=2.21=pyhd3eb1b0_0 + - pyopenssl=23.0.0=py39h06a4308_0 + - pysocks=1.7.1=py39h06a4308_0 + - python=3.9.16=h7a1cb2a_2 + - python_abi=3.9=2_cp39 + - pytorch=1.11.0=py3.9_cuda11.3_cudnn8.2.0_0 + - pytorch-mutex=1.0=cuda + - pytorch3d=0.7.2=py39_cu113_pyt1110 + - pyyaml=6.0=py39hb9d737c_4 + - readline=8.2=h5eee18b_0 + - requests=2.28.1=py39h06a4308_1 + - setuptools=65.6.3=py39h06a4308_0 + - six=1.16.0=pyhd3eb1b0_1 + - sqlite=3.41.1=h5eee18b_0 + - tabulate=0.9.0=pyhd8ed1ab_1 + - termcolor=2.2.0=pyhd8ed1ab_0 + - tk=8.6.12=h1ccaba5_0 + - torchaudio=0.11.0=py39_cu113 + - torchvision=0.12.0=py39_cu113 + - tqdm=4.65.0=pyhd8ed1ab_1 + - typing_extensions=4.4.0=py39h06a4308_0 + - tzdata=2022g=h04d1e81_0 + - urllib3=1.26.15=py39h06a4308_0 + - wheel=0.38.4=py39h06a4308_0 + - xz=5.2.10=h5eee18b_1 + - yacs=0.1.8=pyhd8ed1ab_0 + - yaml=0.2.5=h7f98852_2 + - zlib=1.2.13=h5eee18b_0 + - zstd=1.5.4=hc292b87_0 + - pip: + - absl-py==1.4.0 + - astunparse==1.6.3 + - attrs==22.2.0 + - audioread==3.0.0 + - cachetools==5.3.0 + - configargparse==1.5.3 + - contourpy==1.0.7 + - cycler==0.11.0 + - dearpygui==1.8.0 + - decorator==4.4.2 + - decord==0.6.0 + - face-alignment==1.3.5 + - ffmpeg-python==0.2.0 + - filelock==3.10.0 + - flatbuffers==23.3.3 + - fonttools==4.39.2 + - freqencoder==0.0.0 + - future==0.18.3 + - gast==0.4.0 + - google-auth==2.16.2 + - google-auth-oauthlib==0.4.6 + - google-pasta==0.2.0 + - gridencoder==0.0.0 + - grpcio==1.51.3 + - h5py==3.8.0 + - huggingface-hub==0.13.2 + - imageio==2.26.0 + - imageio-ffmpeg==0.4.8 + - importlib-metadata==6.1.0 + - importlib-resources==5.12.0 + - joblib==1.2.0 + - keras==2.11.0 + - kiwisolver==1.4.4 + - kornia==0.5.0 + - lazy-loader==0.1 + - libclang==15.0.6.1 + - librosa==0.9.2 + - llvmlite==0.39.1 + - lpips==0.1.4 + - markdown==3.4.1 + - markupsafe==2.1.2 + - matplotlib==3.7.1 + - mediapipe==0.9.1.0 + - moviepy==1.0.3 + - networkx==3.0 + - ninja==1.11.1 + - numba==0.56.4 + - oauthlib==3.2.2 + - opencv-contrib-python==4.7.0.72 + - opencv-python==4.7.0.72 + - opt-einsum==3.3.0 + - packaging==23.0 + - pandas==1.5.3 + - platformdirs==3.1.1 + - pooch==1.7.0 + - praat-parselmouth==0.4.3 + - proglog==0.1.10 + - protobuf==3.19.6 + - pyasn1==0.4.8 + - pyasn1-modules==0.2.8 + - pyaudio==0.2.13 + - pymcubes==0.1.4 + - pyparsing==3.0.9 + - python-dateutil==2.8.2 + - python-speech-features==0.6 + - pytz==2022.7.1 + - pywavelets==1.4.1 + - raymarching-face==0.0.0 + - regex==2022.10.31 + - requests-oauthlib==1.3.1 + - resampy==0.4.2 + - rsa==4.9 + - scikit-image==0.20.0 + - scikit-learn==1.2.2 + - scipy==1.9.1 + - shencoder==0.0.0 + - soundfile==0.12.1 + - tensorboard==2.11.2 + - tensorboard-data-server==0.6.1 + - tensorboard-plugin-wit==1.8.1 + - tensorboardx==2.6 + - tensorflow==2.11.0 + - tensorflow-estimator==2.11.0 + - tensorflow-io-gcs-filesystem==0.31.0 + - threadpoolctl==3.1.0 + - tifffile==2023.3.15 + - tokenizers==0.13.2 + - transformers==4.27.1 + - trimesh==3.20.2 + - werkzeug==2.2.3 + - wrapt==1.15.0 + - zipp==3.15.0 +prefix: /home/yezhenhui/anaconda3/envs/geneface diff --git a/Geneface_main/GeneFace/docs/prepare_env/geneface_rtx3090.yaml b/Geneface_main/GeneFace/docs/prepare_env/geneface_rtx3090.yaml new file mode 100644 index 00000000..58fffbb3 --- /dev/null +++ b/Geneface_main/GeneFace/docs/prepare_env/geneface_rtx3090.yaml @@ -0,0 +1,189 @@ +name: geneface +channels: + - pytorch3d + - bottler + - iopath + - pytorch + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - blas=1.0=mkl + - brotlipy=0.7.0=py39h27cfd23_1003 + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2023.01.10=h06a4308_0 + - certifi=2022.12.7=py39h06a4308_0 + - cffi=1.15.1=py39h5eee18b_3 + - charset-normalizer=2.0.4=pyhd3eb1b0_0 + - colorama=0.4.6=pyhd8ed1ab_0 + - cryptography=39.0.1=py39h9ce1e76_0 + - cudatoolkit=11.3.1=h2bc3f7f_2 + - ffmpeg=4.3=hf484d3e_0 + - flit-core=3.8.0=py39h06a4308_0 + - freetype=2.12.1=h4a9f257_0 + - fvcore=0.1.5.post20221221=pyhd8ed1ab_0 + - giflib=5.2.1=h5eee18b_3 + - gmp=6.2.1=h295c915_3 + - gnutls=3.6.15=he1e5248_0 + - idna=3.4=py39h06a4308_0 + - intel-openmp=2021.4.0=h06a4308_3561 + - iopath=0.1.9=py39 + - jpeg=9e=h5eee18b_1 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=3.0=h295c915_0 + - libdeflate=1.17=h5eee18b_0 + - libffi=3.4.2=h6a678d5_6 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libiconv=1.16=h7f8727e_2 + - libidn2=2.3.2=h7f8727e_0 + - libpng=1.6.39=h5eee18b_0 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtasn1=4.16.0=h27cfd23_0 + - libtiff=4.5.0=h6a678d5_2 + - libunistring=0.9.10=h27cfd23_0 + - libuv=1.44.2=h5eee18b_0 + - libwebp=1.2.4=h11a3e52_1 + - libwebp-base=1.2.4=h5eee18b_1 + - lz4-c=1.9.4=h6a678d5_0 + - mkl=2021.4.0=h06a4308_640 + - mkl-service=2.4.0=py39h7f8727e_0 + - mkl_fft=1.3.1=py39hd3c417c_0 + - mkl_random=1.2.2=py39h51133e4_0 + - ncurses=6.4=h6a678d5_0 + - nettle=3.7.3=hbbd107a_1 + - numpy=1.23.5=py39h14f4228_0 + - numpy-base=1.23.5=py39h31eccc5_0 + - nvidiacub=1.10.0=0 + - openh264=2.1.1=h4ff587b_0 + - openssl=1.1.1t=h7f8727e_0 + - pillow=9.4.0=py39h6a678d5_0 + - pip=23.0.1=py39h06a4308_0 + - portalocker=2.7.0=py39hf3d152e_0 + - pycparser=2.21=pyhd3eb1b0_0 + - pyopenssl=23.0.0=py39h06a4308_0 + - pysocks=1.7.1=py39h06a4308_0 + - python=3.9.16=h7a1cb2a_2 + - python_abi=3.9=2_cp39 + - pytorch=1.11.0=py3.9_cuda11.3_cudnn8.2.0_0 + - pytorch-mutex=1.0=cuda + - pytorch3d=0.7.2=py39_cu113_pyt1110 + - pyyaml=6.0=py39hb9d737c_4 + - readline=8.2=h5eee18b_0 + - requests=2.28.1=py39h06a4308_1 + - setuptools=65.6.3=py39h06a4308_0 + - six=1.16.0=pyhd3eb1b0_1 + - sqlite=3.41.1=h5eee18b_0 + - tabulate=0.9.0=pyhd8ed1ab_1 + - termcolor=2.2.0=pyhd8ed1ab_0 + - tk=8.6.12=h1ccaba5_0 + - torchaudio=0.11.0=py39_cu113 + - torchvision=0.12.0=py39_cu113 + - tqdm=4.65.0=pyhd8ed1ab_1 + - typing_extensions=4.4.0=py39h06a4308_0 + - tzdata=2022g=h04d1e81_0 + - urllib3=1.26.15=py39h06a4308_0 + - wheel=0.38.4=py39h06a4308_0 + - xz=5.2.10=h5eee18b_1 + - yacs=0.1.8=pyhd8ed1ab_0 + - yaml=0.2.5=h7f98852_2 + - zlib=1.2.13=h5eee18b_0 + - zstd=1.5.4=hc292b87_0 + - pip: + - absl-py==1.4.0 + - astunparse==1.6.3 + - attrs==22.2.0 + - audioread==3.0.0 + - cachetools==5.3.0 + - configargparse==1.5.3 + - contourpy==1.0.7 + - cycler==0.11.0 + - dearpygui==1.9.0 + - decorator==4.4.2 + - face-alignment==1.3.5 + - ffmpeg-python==0.2.0 + - filelock==3.10.7 + - flatbuffers==23.3.3 + - fonttools==4.39.3 + - freqencoder==0.0.0 + - future==0.18.3 + - gast==0.4.0 + - google-auth==2.17.1 + - google-auth-oauthlib==1.0.0 + - google-pasta==0.2.0 + - gridencoder==0.0.0 + - grpcio==1.53.0 + - h5py==3.8.0 + - huggingface-hub==0.13.3 + - imageio==2.27.0 + - imageio-ffmpeg==0.4.8 + - importlib-metadata==6.1.0 + - importlib-resources==5.12.0 + - jax==0.4.8 + - joblib==1.2.0 + - keras==2.12.0 + - kiwisolver==1.4.4 + - kornia==0.5.0 + - lazy-loader==0.2 + - libclang==16.0.0 + - librosa==0.9.2 + - llvmlite==0.39.1 + - lpips==0.1.4 + - markdown==3.4.3 + - markupsafe==2.1.2 + - matplotlib==3.7.1 + - mediapipe==0.9.2.1 + - ml-dtypes==0.0.4 + - moviepy==1.0.3 + - networkx==3.0 + - ninja==1.11.1 + - numba==0.56.4 + - oauthlib==3.2.2 + - opencv-contrib-python==4.7.0.72 + - opencv-python==4.7.0.72 + - opt-einsum==3.3.0 + - packaging==23.0 + - pandas==1.5.3 + - platformdirs==3.2.0 + - pooch==1.7.0 + - praat-parselmouth==0.4.3 + - proglog==0.1.10 + - protobuf==3.20.3 + - pyasn1==0.4.8 + - pyasn1-modules==0.2.8 + - pyaudio==0.2.13 + - pymcubes==0.1.4 + - pyparsing==3.0.9 + - python-dateutil==2.8.2 + - python-speech-features==0.6 + - pytz==2023.3 + - pywavelets==1.4.1 + - raymarching-face==0.0.0 + - regex==2023.3.23 + - requests-oauthlib==1.3.1 + - resampy==0.4.2 + - rsa==4.9 + - scikit-image==0.20.0 + - scikit-learn==1.2.2 + - scipy==1.9.1 + - shencoder==0.0.0 + - soundfile==0.12.1 + - tensorboard==2.12.1 + - tensorboard-data-server==0.7.0 + - tensorboard-plugin-wit==1.8.1 + - tensorboardx==2.6 + - tensorflow==2.12.0 + - tensorflow-estimator==2.12.0 + - tensorflow-io-gcs-filesystem==0.32.0 + - threadpoolctl==3.1.0 + - tifffile==2023.3.21 + - tokenizers==0.13.2 + - transformers==4.27.4 + - trimesh==3.21.3 + - werkzeug==2.2.3 + - wrapt==1.14.1 + - zipp==3.15.0 +prefix: /home/yezhenhui/anaconda3/envs/geneface diff --git a/Geneface_main/GeneFace/docs/prepare_env/install_ext.sh b/Geneface_main/GeneFace/docs/prepare_env/install_ext.sh new file mode 100644 index 00000000..3c4f13d4 --- /dev/null +++ b/Geneface_main/GeneFace/docs/prepare_env/install_ext.sh @@ -0,0 +1,4 @@ +pip install ./modules/radnerfs/encoders/freqencoder +pip install ./modules/radnerfs/encoders/shencoder +pip install ./modules/radnerfs/encoders/gridencoder +pip install ./modules/radnerfs/raymarching \ No newline at end of file diff --git a/Geneface_main/GeneFace/docs/prepare_env/install_guide-zh.md b/Geneface_main/GeneFace/docs/prepare_env/install_guide-zh.md new file mode 100644 index 00000000..951da26d --- /dev/null +++ b/Geneface_main/GeneFace/docs/prepare_env/install_guide-zh.md @@ -0,0 +1,87 @@ +# 搭建环境 + +本指南介绍了如何构建一个用于GeneFace的python环境。以下安装流程在RTX3090,Ubuntu 18.04得到验证。 + +# 1. 安装 CUDA + +我们使用来自[torch-ngp](https://github.com/ashawkey/torch-ngp)的CUDA扩展,您可能需要从[Nvidia官方页面](https://developer.nvidia.com/cuda-toolkit)手动安装CUDA。我们建议安装CUDA `11.3`(在各种类型的gpu中验证),但其他CUDA版本(如 `10.2`)也可以很好地工作。确保你的cuda路径(通常是 `/usr/local/cuda`)指向已安装的 `/usr/local/cuda-11.3` + +# 2. 安装Python库 + +``` +conda create -n geneface python=3.9.16 -y +conda activate geneface +# Install pytorch with cudatoolkit, note that the cudatoolkit version should equal to the CUDA version in step 1 +conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch +# Install pytorch3d +conda install -c fvcore -c iopath -c conda-forge fvcore iopath -y +conda install -c bottler nvidiacub -y +conda install pytorch3d -c pytorch3d -y # 0.7.2 recommended +# Install other dependencies, including tensorflow-gpu=2.x +sudo apt-get install libasound2-dev portaudio19-dev # dependency for pyaudio +pip install -r docs/prepare_env/requirements.txt +conda install ffmpeg # we need to install ffmpeg from anaconda to include the x264 encoder + +# Build customized cuda extensions from torch-ngp +# NOTE: you need to manually install CUDA with the same version of pytorch (in this case, CUDA v11.3) +# make sure your cuda path (typically /usr/local/cuda) points to a installed `/usr/local/cuda-11.3` +bash docs/prepare_env/install_ext.sh +``` + +如果你在上述安装过程中遇到兼容性问题,可以参考 `docs/prepare_env/geneface_*.yaml`文件,其中记录了我在不同型号GPU下安装成功的详细环境配置。 + +# 3. 准备 3DMM 模型 + +## 3.1 下载 3DMM model + +在[这个链接](https://faces.dmi.unibas.ch/bfm/index.php?nav=1-2&id=downloads)申请BFM2009 model. + +你能得到一个 `BaselFaceModel.tgz`,将其解压,解压后获得其中 `01_MorphableModel.mat`保存到 `./deep_3drecon/BFM/`文件夹 + +## 3.2 Download PCA Basis + +通过这个链接下载:[链接](https://drive.google.com/drive/folders/1iTopSpZucEmjWiWZIErLYiMBlZYwzil2?usp=share_link) + +获得其中的 `Exp_Pca.bin`存到 `./deep_3drecon/BFM` 路径 + +## 3.3 Download BFM Model Front + +通过这个链接下载:[链接](https://drive.google.com/drive/folders/1YCxXKJFfo1w01PzayhnxWSZZK5k7spSH?usp=share_link) + +获得其中的 ` BFM_model_front.mat` 存到 `./deep_3drecon/BFM` 路径 + +## 3.4 Download FaceRecon Model + +通过这个链接下载:[链接](https://drive.google.com/drive/folders/18VRcygXYOKPYvJWsl9lrF0J9PoFPk77y?usp=sharing) + +获得其中的 `epoch_20.pth` 存到 `./deep_3drecon/checkpoints/facerecon` 路径 + +## 3.5 生成face_tracking需要的文件 + +在GeneFace的root路径执行以下命令行: + +``` +cd data_util/face_tracking +conda activate geneface +python convert_BFM.py +``` + +这将在以下路径生成文件:`data_util/face_tracking/3DMM/3DMM_info.npy`. + +# 4. 验证安装成功 + +``` +# 跑通 deep_3drecon_pytorch 项目的原始example +cd +conda activate geneface +export PYTHONPATH=./ +CUDA_VISIBLE_DEVICES=0 python deep_3drecon/test.py + +# 验证与GeneFace之间的桥梁 +# 生成deep_3drecon的config文件(默认已生成) +python deep_3drecon/generate_reconstructor_opt_for_geneface.py +CUDA_VISIBLE_DEVICES=0 python +# 以下几行在python中执行 +> import deep_3drecon +> face_reconstructor = deep_3drecon.Reconstructor() +``` diff --git a/Geneface_main/GeneFace/docs/prepare_env/install_guide.md b/Geneface_main/GeneFace/docs/prepare_env/install_guide.md new file mode 100644 index 00000000..522cd633 --- /dev/null +++ b/Geneface_main/GeneFace/docs/prepare_env/install_guide.md @@ -0,0 +1,91 @@ +# Prepare the Environment + +[中文文档](./install_guide-zh.md) + +This guide is about building a python environment for GeneFace. + +The following installation process is verified in RTX3090, Ubuntu 18.04. + +# 1. Install CUDA + +We use CUDA extensions from [torch-ngp](https://github.com/ashawkey/torch-ngp), you may need to manually install CUDA from the [offcial page](https://developer.nvidia.com/cuda-toolkit). We recommend to install CUDA `11.3` (which is verified in various types of GPUs), but other CUDA versions (such as `10.2`) may also work well. Make sure your cuda path (typically `/usr/local/cuda`) points to a installed `/usr/local/cuda-11.3` + +# 2. Install Python Packages + +``` +conda create -n geneface python=3.9.16 -y +conda activate geneface +# Install pytorch with cudatoolkit, note that the cudatoolkit version should equal to the CUDA version in step 1 +conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch +# Install pytorch3d +conda install -c fvcore -c iopath -c conda-forge fvcore iopath -y +conda install -c bottler nvidiacub -y +conda install pytorch3d -c pytorch3d -y # 0.7.2 recommended +# Install other dependencies, including tensorflow-gpu=2.x +sudo apt-get install libasound2-dev portaudio19-dev # dependency for pyaudio +pip install -r docs/prepare_env/requirements.txt +conda install ffmpeg # we need to install ffmpeg from anaconda to include the x264 encoder + +# Build customized cuda extensions from torch-ngp +# NOTE: you need to manually install CUDA with the same version of pytorch (in this case, CUDA v11.3) +# make sure your cuda path (typically /usr/local/cuda) points to a installed `/usr/local/cuda-11.3` +bash docs/prepare_env/install_ext.sh +``` + +If you find any error in python package compatility, you can refer to `docs/prepare_env/geneface_*.yaml` for my tested specific package versions for different GPUs. + +# 3. Prepare the 3DMM model and other data + +## 3.1 Download 3DMM model + +Apply BFM2009 model in [this link](https://faces.dmi.unibas.ch/bfm/index.php?nav=1-2&id=downloads). + +You can obtain a file named `BaselFaceModel.tgz`, extract a file named `01_MorphableModel.mat` from it and save it into the directory `./deep_3drecon/BFM/` + +## 3.2 Download PCA Basis + +Download at [this link](https://drive.google.com/drive/folders/1iTopSpZucEmjWiWZIErLYiMBlZYwzil2?usp=share_link) + +Extract the `Exp_Pca.bin` and place it to the `./deep_3drecon/BFM` directory. + +## 3.3 Download BFM Model Front + +Download at [this link](https://drive.google.com/drive/folders/1YCxXKJFfo1w01PzayhnxWSZZK5k7spSH?usp=share_link) + +Extract the `BFM_model_front.mat` and place it to the `./deep_3drecon/BFM` directory. + +## 3.4 Download FaceRecon Model + +Download at [this link](https://drive.google.com/drive/folders/18VRcygXYOKPYvJWsl9lrF0J9PoFPk77y?usp=sharing) + +Extract the `epoch_20.pth` and place it to the `./deep_3drecon/checkpoints/facerecon` directory. + +## 3.5 generate files for face_tracking (used by all NeRFs) + +Then run the following commandlines: + +``` +cd data_util/face_tracking +conda activate geneface +python convert_BFM.py +``` + +This will generate `data_util/face_tracking/3DMM/3DMM_info.npy`. + +# 4. Verification of the Installation + +``` +# run the examples of deep_3drecon +cd +conda activate geneface +export PYTHONPATH=./ +CUDA_VISIBLE_DEVICES=0 python deep_3drecon/test.py + +# validate the bindings with GeneFace +# generate config file of deep_3drecon locally +python deep_3drecon/generate_reconstructor_opt_for_geneface.py +CUDA_VISIBLE_DEVICES=0 python +# below are run in python console +import deep_3drecon +face_reconstructor = deep_3drecon.Reconstructor() +``` diff --git a/Geneface_main/GeneFace/docs/prepare_env/requirements.txt b/Geneface_main/GeneFace/docs/prepare_env/requirements.txt new file mode 100644 index 00000000..4db97e18 --- /dev/null +++ b/Geneface_main/GeneFace/docs/prepare_env/requirements.txt @@ -0,0 +1,29 @@ +numpy==1.23.0 +pandas +transformers +scipy +scikit-learn +scikit-image +tensorflow # you can flexible it, this is gpu version +tensorboard +tensorboardX +python_speech_features +resampy +opencv_python +face_alignment +matplotlib +configargparse +librosa==0.9.2 +praat-parselmouth==0.4.3 +trimesh +kornia==0.5.0 +PyMCubes +lpips +setuptools # ==59.5.0 +ffmpeg-python +moviepy +dearpygui +ninja +pyaudio # for extract esperanto +mediapipe +psutil diff --git a/Geneface_main/GeneFace/docs/process_data/process_lrs3.md b/Geneface_main/GeneFace/docs/process_data/process_lrs3.md new file mode 100644 index 00000000..5c2b02ab --- /dev/null +++ b/Geneface_main/GeneFace/docs/process_data/process_lrs3.md @@ -0,0 +1,76 @@ +# Process the LRS3 dataset + +[中文文档](./zh/process_lrs3-zh.md) + +We use LRS3 dataset to learn a robust audio2motion generator. It is also required for training a post-net and syncnet. + +## Processed LRS3 Dataset available + +🔥 Note: Since we turn to a new 3DMM extractor, we update the provided lrs3 dataset. You may need download the newest dataset file to be compatible with the latest code. + +Since LRS3 is quite big (500 hours+), it is expensive to process this dataset. For your convenience, we provide the binarized LRS3 dataset file (about 26 GB) on Google Drive. If you use the processed dataset, you can skip `step 1-3` below and go directly to `step 4` to verify the installation. + +- Download Link on Google Drive: [Partition 1](https://drive.google.com/drive/folders/1QK_ikLKUzGYiqHBzvKz0s5zKWeH-sm3L?usp=share_link), [Partition 2](https://drive.google.com/drive/folders/1WbECLfpxAZ0D7PcrlZxV-fCObT-TnfD8?usp=share_link). +- Download Link on Baiduyun Disk: [link](https://pan.baidu.com/s/1JsvEz58c9ItSI73ls43tTw?pwd=lrs3), passward: `lrs3` +- How to use: + - step1. Integrate the segments `cat lrs3.zip.part_* > lrs3.zip` . + - step2. Unzip `processed_lrs3.zip` and place it into the `data/binary/lrs3` folder. + - step3. Move to `Step4. Verification` to verify installation. +- Disclaimer: the provided binarized dataset file only contains data-masked features (such as HuBERT for audio representations), so it does not viloate the copyright of LRS3. + +If you ue our processed lrs3 dataset, you can skip the first 3 steps, and directly go to `step4` + +## Step1. Apply and Download the LRS3-TED dataset + +🔥 Note: It seems that the raw dataset of LRS3 is no longer provided by the official website. + +Due to the License, we cannot provide a download link here. You can apply for LRS3-TED at [this link](https://www.robots.ox.ac.uk/~vgg/data/lip_reading/). + +## Step2. Process the LRS3 + +For process lrs3, you should first install the python env `geneface` following the docs in `dosc/prepare_env/install_guide.md` + +Then run these commandlines: (You may need to modify the directory name of raw lrs3 in the .py files) + +``` +conda activate geneface +CUDA_VISIBLE_DEVICES=0 python data_gen/process_lrs3/process_video_3dmm.py # extract 3dmm motion representations +CUDA_VISIBLE_DEVICES=0 python data_gen/process_lrs3/process_audio_mel_f0.py # extract mel spectrogram +CUDA_VISIBLE_DEVICES=0 python data_gen/process_lrs3/process_audio_hubert.py # extract hubert audio representations + +``` + +Since the LRS3-TED dataset is relatively big, you may need run multiple processes in several GPUs to accelerate the data preprocessing, for instance: + +``` +# run on two GPUs + +CUDA_VISIBLE_DEVICES=0 python data_gen/process_lrs3/process_video_3dmm.py --process_id=0 --total_process=2 \ +& CUDA_VISIBLE_DEVICES=1 python data_gen/process_lrs3/process_video_3dmm.py --process_id=1 --total_process=2 +``` + +## Step3. Binarize the dataset + +run the following commandline to binarize the dataset. (You may need to modify the directory name of raw lrs3 in the .py files) + +``` +conda activate procerss_lrs3 +python data_gen/process_lrs3/binarizer.py +``` + +# Step4. Verification + +Then you may find a directory at the path `data/binary/lrs3/` +After the above steps, the structure of your `data` directory should look like this: + +``` +> data + > binary + > lrs3 + sizes_train.npy + sizes_val.npy + spk_id2ispk_idx.npy + stats.npy + train.data + val.data +``` diff --git a/Geneface_main/GeneFace/docs/process_data/process_target_person_video.md b/Geneface_main/GeneFace/docs/process_data/process_target_person_video.md new file mode 100644 index 00000000..541dc979 --- /dev/null +++ b/Geneface_main/GeneFace/docs/process_data/process_target_person_video.md @@ -0,0 +1,17 @@ +# Process the Target Person Video + +[中文文档](./zh/process_target_person_video-zh.md) + +You need a about 3-minute-long videos of the target person to train the person-specific postnet and NeRF-based renderer. The video is the longer the better. + +We provide a example video at the path: `data/raw/videos/May.mp4` + +## Only 1 step: extract all required features and binarize it +``` +conda activate geneface +export PYTHONPATH=./ +export VIDEO_ID=May +CUDA_VISIBLE_DEVICES=0 data_gen/nerf/process_data.sh $VIDEO_ID +``` + +Then you can find a directory at the path `data/binary/videos/May` diff --git a/Geneface_main/GeneFace/docs/process_data/zh/process_lrs3-zh.md b/Geneface_main/GeneFace/docs/process_data/zh/process_lrs3-zh.md new file mode 100644 index 00000000..a3551c32 --- /dev/null +++ b/Geneface_main/GeneFace/docs/process_data/zh/process_lrs3-zh.md @@ -0,0 +1,63 @@ +# 处理LRS3-TED数据集 + +我们利用LRS3数据集来训练一个鲁棒的语音转动作的映射,这也是GeneFace能够实现高泛化能力的核心所在。除了audio2motion模型外,LRS3数据集还被用来训练postnet和syncnet。 + +## 处理完毕的LRS3数据集文件可用 + +🔥注意:由于我们转向新的3DMM提取器,我们更新了提供的lrs3数据集。您可能需要下载最新的数据集文件以与最新的代码兼容。 + +由于LRS3数据集数据量较大(500小时+),其处理过程非常消耗计算资源,因此我们在百度云盘提供了处理好的LRS3数据集(总共约26GB)。如果你使用我们处理好的数据集,可以跳过下面的步骤1-3,直接进入步骤4,验证安装成功。 + +- 谷歌网盘下载链接:[分区1](https://drive.google.com/drive/folders/1QK_ikLKUzGYiqHBzvKz0s5zKWeH-sm3L?usp=sharing),[分区2](https://drive.google.com/drive/folders/1WbECLfpxAZ0D7PcrlZxV-fCObT-TnfD8?usp=share_link)。 +- 百度云盘下载链接:[链接](https://pan.baidu.com/s/1JsvEz58c9ItSI73ls43tTw?pwd=lrs3),提取码:`lrs3` +- 如何使用: + + - 步骤1:将拆分的子文件还原成压缩包 `cat lrs3.zip.part_* > lrs3.zip` 。 + - 步骤2:将压缩包解压,并将其移动到 `data/binary/lrs3` 目录下。 +- 免责声明:我们提供的文件仅包含了经过数据脱敏处理的特征(比如HuBET作为音频的表征),没有侵犯LRS3中视频的版权。 + +## 步骤1. 申请并下载LRS3-TED数据集 + +由于License的原因,我们不能在这里提供下载链接。请您通过[这个链接](https://www.robots.ox.ac.uk/~vgg/data/lip_reading/lrs3.html)向LRS3-TED数据集的所有者提交申请。 + +## 步骤2. 处理LRS3数据集 + +在处理LRS3之前,请确保您按照 `dosc/prepare_env/install_guide.md` 的步骤正确安装了处理LRS3数据集的 `geneface`环境。 + +接着执行以下命令行(你可能需要修改下面 `.py`文件里面的路径名): + +``` +conda activate +CUDA_VISIBLE_DEVICES=0 python data_gen/process_lrs3/process_video.py # extract 3dmm motion representations +CUDA_VISIBLE_DEVICES=0 python data_gen/process_lrs3/process_audio_mel_f0.py # extract mel spectrogram +CUDA_VISIBLE_DEVICES=0 python data_gen/process_lrs3/process_audio_hubert.py # extract hubert audio representations +``` + +由于LRS3-TED数据集比较大,您可能需要同时开多个python进程,以利用多个gpu来加速数据预处理。 + +## 步骤3. 将数据集打包 + +执行以下命令行(你可能需要修改下面 `.py`文件里面的路径名) + +``` +conda activate procerss_lrs3 +python data_gen/process_lrs3/binarizer.py +``` + +## 步骤4. 验证 + +如果上述步骤都顺利完成的话,您将能在 `data/binary/lrs3`路径看到处理好的LRS3数据集。理想状态下,你的 `data`文件夹内部结构应该是类似这样的: + +``` +> data + > binary + > lrs3 + sizes_train.npy + sizes_val.npy + spk_id2ispk_idx.npy + stats.npy + train.data + train.idx + val.data + val.idx +``` diff --git a/Geneface_main/GeneFace/docs/process_data/zh/process_target_person_video-zh.md b/Geneface_main/GeneFace/docs/process_data/zh/process_target_person_video-zh.md new file mode 100644 index 00000000..f7e61386 --- /dev/null +++ b/Geneface_main/GeneFace/docs/process_data/zh/process_target_person_video-zh.md @@ -0,0 +1,18 @@ +# 处理说话人视频数据集 + +你需要一个大约3分钟的目标人物视频来训练特定人物的postnet和基于NeRF的渲染器。(视频长度越长越好) + +我们在 `data/raw/videos/May.mp4` 路径下提供了一个示例视频 + +## 提取所有所需的特征并打包。 + +运行如下命令行: + +``` +conda activate geneface +export PYTHONPATH=./ +export VIDEO_ID=May +CUDA_VISIBLE_DEVICES=0 data_gen/nerf/process_data.sh $VIDEO_ID +``` + +如果上面的步骤都顺利完成,你可以在 `data/binary/videos/May`路径下看到处理好的目标说话人视频的数据集。 diff --git a/Geneface_main/GeneFace/docs/train_models/train_models-zh.md b/Geneface_main/GeneFace/docs/train_models/train_models-zh.md new file mode 100644 index 00000000..1125bb2f --- /dev/null +++ b/Geneface_main/GeneFace/docs/train_models/train_models-zh.md @@ -0,0 +1,122 @@ +# 训练 GeneFace! + +GeneFace 包含三个模块:1)一个训练于LRS3数据集并通用于所有说话人的 `语音转动作`模块;2)一个适用于特定说话人的 `动作后处理`网络,它被训练于LRS3数据集和对应说话人的视频数据;3)一个适用于特定说话人的 `基于NeRF的渲染器`,它被训练于对应说话人的视频数据。 + +要训练GeneFace,请首先按照我们在 `docs/prepare_env`文档和 `docs/process_data`文档中的步骤,分别完成搭建环境和准备数据集。 + +在[这个链接](https://github.com/yerfor/GeneFace/releases/tag/v1.1.0)中,我们还准备了GeneFace的预训练模型,其中: + +* `lrs3.zip` 包含了在LRS3数据集上训练的模型 (包括一个 `lm3d_vae_sync`模型以实现语音转动作的变换,和一个 `syncnet`以实现对语音-嘴形对齐程度的衡量),这些模型是通用于所有说话人视频的。 +* `May.zip` 包含了我们在 `May.mp4`视频上训练的所有模型(包括一个 `postnet`以对 `lm3d_vae_sync`产生的3D landmark进行后处理,以及一个 `lm3d_radnerf`和 `lm3d_radnerf_torso`分别渲染说话人的头部和躯干部位。)对每个说话人视频,你都需要新训练这三个模型。 + +## 步骤1. 训练SyncNet模型 + +注意:我们在[这个链接](https://github.com/yerfor/GeneFace/releases/tag/v1.1.0)的 `lrs3.zip`文件中提供了预训练好的SyncNet,你可以将其下载并提取出其中的 `syncnet`文件夹,并将它放到 `checkpoints/lrs3/syncnet`路径中。 + +如果你想要从头训练SyncNet,请执行以下命令行(你需要首先准备好LRS3数据集): + +``` +conda activate geneface +export PYTHONPATH=./ +CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/lrs3/lm3d_syncnet.yaml --exp_name=lrs3/syncnet +``` + +注意SyncNet模型适用于所有说话人视频,所以你只需要训练它一次! + +## 步骤2. 训练Audio2Motion模型 + +注意:我们在[这个链接](https://github.com/yerfor/GeneFace/releases/tag/v1.1.0)的 `lrs3.zip`文件中提供了预训练好的audio2motion模型,你可以将其下载并提取出其中的 `lm3d_vae_sync`文件夹,并将它放到 `checkpoints/lrs3/lm3d_vae_sync`路径中。 + +如果你想要从头训练audio2motion模型,请执行以下命令行(你需要首先准备好LRS3数据集): + +``` +conda activate geneface +export PYTHONPATH=./ +CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/videos/lrs3/lm3d_vae_sync.yaml --exp_name=lrs3/lm3d_vae_sync +``` + +注意名为 `lm3d_vae_sync`的audio2motion模型适用于所有说话人视频,所以你只需要训练它一次! + +## 步骤3. 训练PostNet模型 + +注意:我们在[这个链接](https://github.com/yerfor/GeneFace/releases/tag/v1.1.0)的 `May.zip`文件中提供了专用于 `data/raw/videos/May.mp4`视频的预训练好的Postnet模型,你可以将其下载并提取出其中的 `postnet`文件夹,并将它放到 `checkpoints/May/postnet`路径中。 + +如果你想要从头训练postnet模型,请执行以下命令行(你需要首先准备好LRS3数据集和对应的说话人视频数据集): + +``` +conda activate geneface +export PYTHONPATH=./ +CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/videos/May/lm3d_postnet_sync.yaml --exp_name=May/postnet +``` + +注意postnet模型仅适用于对应的说话人视频,所以对每个新的说话人视频你都需要训练一个新的postnet。 + +#### 训练小tips:选择合适步数的checkpoint + +由于我们的postnet的训练属于对抗域适应(Adversarial Domain Adaptation)过程,而对抗训练的训练过程被广泛公认是不稳定的。例如当训练步数过多时,可能导致模型出现模式坍塌,比如postnet可能会学到将输入的任意表情都映射到同一个target person domain的表情(体现在validation sync/mse loss上升)。因此为了避免最终得到的人脸表情的lip-sync性能下降过大,我们应该early stop,即选择步数较小的checkpoint。但同时,当步数过小的时候,postnet可能还欠拟合,无法保证能够将各种各样的表情成功地映射到target person domain(体现在adversarial loss未收敛)。 + +因此,在实际操作中,我们一般根据三个原则来选择合适步数的checkpoint:(1)validation sync/mse loss越低越好;(2)adversarial loss达到收敛。(3)尽量选择步数较小的checkpoint。 + +下图我们展示了一个实例,它是训练 `May.mp4`时我们选择合适的postnet checkpoint的过程。我们发现6k步的时候,`val/mse`和 `val/sync`较小,并且 `tr/disc_neg_conf`和 `tr/disc_pos_conf`都约等于0.5(这意味着discriminator已经无法区分正样本和postnet产生的负样本之间的差异),因此我们选择6k步的checkpoint。 + +

+
+ +
+

+ +最后,为了快速验证选择的postnet checkpoint的lip-sync性能。我们还提供了一个3D landmark的可视化脚本。运行以下脚本(你可能需要修改以下 `.sh`和 `.py`文件内的路径名): + +``` +conda activate geneface +bash infer_postnet.sh # use the selected postnet checkpoint to predict the 3D landmark sequence. +python utils/visualization/lm_visualizer.py # visualize the 3D landmark sequence. +``` + +你能在输出路径中看到可视化的3d landmark视频。 + +## 步骤4. 训练基于RAD-NeRF的渲染器 + +RAD-NeRF利用instant-ngp对NeRF的训练效率和推理速度进行了巨大的提升,我们推荐使用RAD-NeRF作为NeRF的后端。RAD-NeRF的训练速度是原始NeRF的6倍,并且可以实现实时推理,并且渲染质量和原始NeRF十分接近。 + +注意:我们在[这个链接](https://github.com/yerfor/GeneFace/releases/tag/v1.1.0)的 `May.zip`文件中提供了专用于 `data/raw/videos/May.mp4`视频的预训练好的RAD-NeRF模型,你可以将其下载并提取出其中的 `lm3d_radnerf`和 `lm3d_radnerf_torso`文件夹,并将它放到 `checkpoints/May/lm3d_radnerf`和 `checkpoints/May/lm3d_radnerf_torso`路径中。 + +如果你想要从头训练RAD-NeRF模型,请执行以下命令行(你需要首先准备好LRS3数据集和对应的说话人视频数据集),在RTX3090上训练大约花费10小时。 + +``` +conda activate geneface +export PYTHONPATH=./ +# Train the head rad_nerf, it takes about 6 hours in one RTX3090Ti +CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/videos/May/lm3d_radnerf.yaml --exp_name=May/lm3d_radnerf +# Train the torso rad_nerf, it takes about 4 hours in one RTX3090Ti +CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/videos/May/lm3d_radnerf_torso.yaml --exp_name=May/lm3d_radnerf_torso +``` + +注意NeRF模型仅适用于对应的说话人视频,所以对每个新的说话人视频你都需要训练一个新的NeRF模型。 + +### 旧版本:训练基于原始NeRF的渲染器 + +尽管推荐使用RAD-NeRF,为了完整性,我们仍然保留了基于原始NeRF的渲染器。你可以利用下面的命令行对其进行训练。在RTX3090上训练大约花费60小时。 + +注意:如[这个issue](https://github.com/yerfor/GeneFace/issues/18)里面指出的,由于NeRF非常依赖于初始化参数,你可能需要重复执行几次训练命令,直到NeRF的loss得以正常下降。 + +``` +conda activate geneface +export PYTHONPATH=./ +# Train the head nerf, it takes about 30 hours on a RTX3090Ti +CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/videos/May/lm3d_nerf.yaml --exp_name=May/lm3d_nerf +# Train the torso nerf, it takes about 36 hours on a RTX3090Ti +CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/videos/May/lm3d_nerf_torso.yaml --exp_name=May/lm3d_nerf_torso +``` + +注意:基于原始NeRF的图像渲染器的推理过程相对较慢(使用RTX2080Ti渲染250帧512x512分辨率的图像需要大约2个小时)。可以通过将 `——n_samples_per_ray`和 `——n_samples_per_ray_fine`设置为较低的值来部分缓解这个问题。不过由于实现了对RAD-NeRF的支持,推理速度已经不再是问题。 + +## 步骤5. 使用GeneFace生成说话人视频 + +你可以执行以下命令行,以运行训练好的GeneFace生成说话人视频。 + +``` +# By default we use the data/raw/val_wavs/zozo.wav as the driving audio. +bash scripts/infer_postnet.sh +bash scripts/infer_lm3d_radnerf.sh +``` diff --git a/Geneface_main/GeneFace/docs/train_models/train_models.md b/Geneface_main/GeneFace/docs/train_models/train_models.md new file mode 100644 index 00000000..c5810d0d --- /dev/null +++ b/Geneface_main/GeneFace/docs/train_models/train_models.md @@ -0,0 +1,124 @@ +# Train GeneFace! + +[中文文档](./train_models-zh.md) + +GeneFace consists of three models: 1) an generic `audio2motion` model trained on LRS3 dataset; 2) a person-specific `postnet` trained on LRS3 and the target person video; 3) a person-specific `nerf` renderer trained on the target person video. + +To train GeneFace, please first follow the docs in `docs/prepare_env` and `docs/process_data` to build the environment and prepare the datasets, respectively. + +We also provide pre-trained models at [this link](https://github.com/yerfor/GeneFace/releases/tag/v1.1.0), in which: + +* `lrs3.zip` includes the models trained on LRS3-ted dataset (a `lm3d_vae_sync` to perform the audio2motion transform and a `syncnet` for measuring the lip-sync), which are generic for all possible target person videos. +* `May.zip` includes the models trained on the `May.mp4` target person video (a `postnet` for refining the predicted 3d landmark, a `lm3d_radnerf` for rendering the head image, and a `lm3d_radnerf_torso` for rendering the torso part). For each target person video, you need to train these three models. + +## Step1. Train the SyncNet Model + +NOTE: We provide the pre-trained SyncNet model in `lrs3.zip` at [this link](https://github.com/yerfor/GeneFace/releases/tag/v1.1.0), you can download it and extract the `syncnet` folder and place it into the path `checkpoints/lrs3/syncnet`. + +If you want to train a SyncNet from scratch, please run the following commandlines (The processed LRS3 dataset is required): + +``` +conda activate geneface +export PYTHONPATH=./ +CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/lrs3/lm3d_syncnet.yaml --exp_name=lrs3/syncnet +``` + +Note that SyncNet is a generic model for all possible target person videos, so you only to train it once! + +## Step2. Train the Audio2motion model + +NOTE: We provide the pre-trained Audio2motion model in `lrs3.zip` at [this link](https://github.com/yerfor/GeneFace/releases/tag/v1.1.0), you can download it and extract the `lm3d_vae_sync` folder and place it into the path `checkpoints/lrs3/lm3d_vae_sync`. + +If you want to train a audio2motion model from scratch, please run the following commandlines (The processed LRS3 dataset is required): + +``` +conda activate geneface +export PYTHONPATH=./ +CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/videos/lrs3/lm3d_vae_sync.yaml --exp_name=lrs3/lm3d_vae_sync +``` + +Note that the Audio2motion model named `lm3d_vae_sync` is a generic model for all possible target person videos, so you only to train it once! + +## Step3. Train the Postnet + +NOTE: We provide the pre-trained Post-net model for the target person video named `data/raw/videos/May.mp4` in `May.zip` at [this link](https://github.com/yerfor/GeneFace/releases/tag/v1.1.0), you can download it and extract the `postnet` folder and place it into the path `checkpoints/May/postnet`. + +If you want to train a postnet model from scratch, please run the following commandlines (The processed LRS3 dataset and the target person video is required): + +``` +conda activate geneface +export PYTHONPATH=./ +CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/videos/May/lm3d_postnet_sync.yaml --exp_name=May/postnet +``` + +Note that the Post-net is person-specific, so for each target person video, you need to train a new Post-net. + +#### tips: choosing the appropriate checkpoint + +Since our postnet belongs to **Adversarial** Domain Adaptation, whose training process is widely considered to be unstable. For example, training the model for too many steps may lead to model collapse. For example, when mode collapse occurs, the postnet may map abitrary input landmark into the same landmark in the target person domain (which results in rises in validation sync/mse loss). Therefore, to avoid degradation of the lip-sync performance, we should make an early stop, i.e., select a checkpoint trained with a small number of iterations. However, at the same time, if the number of iterations is too small, postnet may be underfitting and cannot successfully map the landmarks into the target person domain (which means the adversarial loss is not converged). + +Therefore, in practice, we choose the checkpoint with the appropriate number of iterations according to three principles: (1) validation sync/mse loss should be as low as possible; (2) the adversarial loss should be converged. (3) a small number of iterations is desirable. + +The following figure shows an example of the process of selecting the appropriate postnet checkpoint when training `May.mp4`. We found that `val/mse` and `val/sync` are relatively low at 6k steps. Besides, `tr/disc_neg_conf` and `tr/disc_pos_conf` are both about 0.5 (which means that the discriminator cannot distinguish between the (GT) positive samples and the (postnet-generated) negative samples), so we choose the checkpoint at 6k steps. + +

+
+ +
+

+ +Finally, to quickly verify the lip-sync performance of the selected postnet checkpoint, we also provide a script to visualize the predicted 3D landmark. Run the following script (you may need to modify the path names in the following `.sh` and `.py` files): + +``` +conda activate geneface +bash infer_postnet.sh # use the selected postnet checkpoint to predict the 3D landmark sequence. +python utils/visualization/lm_visualizer.py # visualize the 3D landmark sequence. +``` + +You can see the visualized 3d landmark video in the output path. + +## Step4. Train the RAD-NeRF-based Render + +RAD-NeRF uses instant-ngp to improve the training and inference speed of NeRF. We recommend using RAD-NeRF as the backend of NeRF-based renderer. RAD-NeRF is 6x faster than NeRF in training, and could infer in real-time, with similar rendering quality to the vanilla NeRF. + +NOTE: We provide the pre-trained RAD-NeRF model for the target person video named `data/raw/videos/May.mp4` in `May.zip` at [this link](https://github.com/yerfor/GeneFace/releases/tag/v1.1.0), you can download it and extract the `lm3d_radnerf` and `lm3d_radnerf_torso` folder, then place it into the path `checkpoints/May/lm3d_radnerf` and `checkpoints/May/lm3d_radnerf_torso`, respectively. + +If you want to train a RAD-NeRF model from scratch, please run the following commandlines (The processed target person video dataset is required). It takes about 10 hours on a RTX3090Ti. + +``` +conda activate geneface +export PYTHONPATH=./ +# Train the head rad_nerf, it takes about 6 hours in one RTX3090Ti +CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/videos/May/lm3d_radnerf.yaml --exp_name=May/lm3d_radnerf +# Train the torso rad_nerf, it takes about 4 hours in one RTX3090Ti +CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/videos/May/lm3d_radnerf_torso.yaml --exp_name=May/lm3d_radnerf_torso +``` + +Note that the NeRF-based renderer is person-specific, so for each target person video, you need to train a new NeRF-based renderer. + +### Legacy: Train the vanilla NeRF-based renderer in the GeneFace paper + +Although the use of RAD-NeRF is recommended, we still support the vanilla NeRF-based renderer for completeness. You can train it with the following command line. It takes about 60 hours on a RTX3090Ti. + +Note: As pointed out in [this issue](https://github.com/yerfor/GeneFace/issues/18), since NeRF requires good initilization, you may need to run the commandline several times, until the loss converges normally. + +``` +conda activate geneface +export PYTHONPATH=./ +# Train the head nerf, it takes about 30 hours on a RTX3090Ti +CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/videos/May/lm3d_nerf.yaml --exp_name=May/lm3d_nerf +# Train the torso nerf, it takes about 36 hours on a RTX3090Ti +CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/videos/May/lm3d_nerf_torso.yaml --exp_name=May/lm3d_nerf_torso +``` + +The inference process for the vanilla Nerf-based renderer is very slow (it takes about 2 hours to render 250 frames of 512x512 resolution images using RTX2080Ti). This problem can be partially mitigated by setting `--n_samples_per_ray` and `--n_samples_per_ray_fine` to lower values. However, with the implementation of RAD-NeRF, inference speed is no longer an issue for GeneFace. + +## Step5. Inference! + +You can infer the GeneFace with the following commandlines: + +``` +# By default we use the data/raw/val_wavs/zozo.wav as the driving audio. +bash scripts/infer_postnet.sh +bash scripts/infer_lm3d_radnerf.sh +``` diff --git a/Geneface_main/GeneFace/egs/datasets/lrs3/lm3d_syncnet.yaml b/Geneface_main/GeneFace/egs/datasets/lrs3/lm3d_syncnet.yaml new file mode 100644 index 00000000..3932bc3a --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/lrs3/lm3d_syncnet.yaml @@ -0,0 +1,2 @@ +base_config: + - egs/egs_bases/syncnet/base.yaml diff --git a/Geneface_main/GeneFace/egs/datasets/lrs3/lm3d_vae_sync.yaml b/Geneface_main/GeneFace/egs/datasets/lrs3/lm3d_vae_sync.yaml new file mode 100644 index 00000000..c6956983 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/lrs3/lm3d_vae_sync.yaml @@ -0,0 +1,6 @@ +base_config: + - egs/egs_bases/audio2motion/vae_sync.yaml + +syncnet_work_dir: checkpoints/lrs3/syncnet +syncnet_ckpt_steps: 1000 +lambda_kl: 0.4 diff --git a/Geneface_main/GeneFace/egs/datasets/lrs3/lm3d_vae_sync_pitch.yaml b/Geneface_main/GeneFace/egs/datasets/lrs3/lm3d_vae_sync_pitch.yaml new file mode 100644 index 00000000..96ae696a --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/lrs3/lm3d_vae_sync_pitch.yaml @@ -0,0 +1,7 @@ +base_config: + - ./lm3d_vae_sync.yaml + +lambda_kl: 0.4 +syncnet_work_dir: checkpoints/lrs3/syncnet +syncnet_ckpt_steps: 40000 +task_cls: tasks.audio2motion.lm3d_vae_sync_pitch.VAESyncAudio2MotionTask diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Lieu/adnerf.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/adnerf.yaml new file mode 100644 index 00000000..2e858d24 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/adnerf.yaml @@ -0,0 +1,5 @@ +base_config: + - egs/egs_bases/nerf/adnerf.yaml + +video_id: Lieu # the video file should be located at `data/raw/videos/.mp4` + diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Lieu/adnerf_torso.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/adnerf_torso.yaml new file mode 100644 index 00000000..b20f8c02 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/adnerf_torso.yaml @@ -0,0 +1,5 @@ +base_config: + - egs/egs_bases/nerf/adnerf_torso.yaml + +video_id: Lieu # the video file should be located at `data/raw/videos/.mp4` +head_model_dir: checkpoints/Lieu/adnerf # the path of the head_nerf model, e.g., `checkpoints//lm3d_nerf` diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Lieu/base.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/base.yaml new file mode 100644 index 00000000..b8e7e145 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/base.yaml @@ -0,0 +1,5 @@ +video_id: Lieu # the video file should be located at `data/raw/videos/.mp4` + +infer_eye_blink_ref_frames_start_idx: 0 # start index of the ref blink sequence in the GT dataset +infer_eye_blink_ref_frames_end_idx: 60 # end index of the ref blink sequence in the GT dataset +infer_sil_ref_frame_idx: 17 # index of the ref frame with a closed mouth in the GT dataset diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Lieu/lm3d_nerf.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/lm3d_nerf.yaml new file mode 100644 index 00000000..2bc412fa --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/lm3d_nerf.yaml @@ -0,0 +1,4 @@ +base_config: + - egs/egs_bases/nerf/lm3d_nerf.yaml + - ./base.yaml + diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Lieu/lm3d_nerf_torso.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/lm3d_nerf_torso.yaml new file mode 100644 index 00000000..6d06e7f5 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/lm3d_nerf_torso.yaml @@ -0,0 +1,5 @@ +base_config: + - egs/egs_bases/nerf/lm3d_nerf_torso.yaml + - ./base.yaml + +head_model_dir: checkpoints/Lieu/lm3d_nerf # the path of the head_nerf model, e.g., `checkpoints//lm3d_nerf` diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Lieu/lm3d_postnet_sync.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/lm3d_postnet_sync.yaml new file mode 100644 index 00000000..283d129a --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/lm3d_postnet_sync.yaml @@ -0,0 +1,5 @@ +base_config: + - egs/egs_bases/postnet/base.yaml + +video_id: Lieu # the video file should be located at `data/raw/videos/.mp4` + diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Lieu/lm3d_postnet_sync_pitch.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/lm3d_postnet_sync_pitch.yaml new file mode 100644 index 00000000..2c5e7720 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/lm3d_postnet_sync_pitch.yaml @@ -0,0 +1,9 @@ +base_config: + - egs/egs_bases/postnet/base.yaml + +video_id: Lieu + +task_cls: tasks.postnet.lm3d_postnet_adv_sync_pitch.PostnetAdvSyncTask +audio2motion_task_cls: tasks.audio2motion.lm3d_vae_sync_pitch.VAESyncAudio2MotionTask +audio2motion_work_dir: checkpoints/lrs3/lm3d_vae_sync_pitch/ +audio2motion_ckpt_steps: 40000 \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Lieu/lm3d_postnet_sync_pitch_reg.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/lm3d_postnet_sync_pitch_reg.yaml new file mode 100644 index 00000000..37ac762f --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/lm3d_postnet_sync_pitch_reg.yaml @@ -0,0 +1,11 @@ +base_config: + - egs/egs_bases/postnet/base.yaml + +video_id: Lieu + +task_cls: tasks.postnet.lm3d_postnet_adv_sync_pitch.PostnetAdvSyncTask +audio2motion_task_cls: tasks.audio2motion.lm3d_vae_sync_pitch.VAESyncAudio2MotionTask +audio2motion_work_dir: checkpoints/lrs3/lm3d_vae_sync_pitch/ +audio2motion_ckpt_steps: 40000 + +postnet_lambda_reg: 0.02 \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Lieu/lm3d_postnet_sync_pitch_reg_continuity.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/lm3d_postnet_sync_pitch_reg_continuity.yaml new file mode 100644 index 00000000..532eb438 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/lm3d_postnet_sync_pitch_reg_continuity.yaml @@ -0,0 +1,12 @@ +base_config: + - egs/egs_bases/postnet/base.yaml + +video_id: Lieu + +task_cls: tasks.postnet.lm3d_postnet_adv_sync_pitch.PostnetAdvSyncTask +audio2motion_task_cls: tasks.audio2motion.lm3d_vae_sync_pitch.VAESyncAudio2MotionTask +audio2motion_work_dir: checkpoints/lrs3/lm3d_vae_sync_pitch/ +audio2motion_ckpt_steps: 40000 + +postnet_lambda_reg: 0.02 +postnet_lambda_continuity: 0.10 \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Lieu/lm3d_radnerf.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/lm3d_radnerf.yaml new file mode 100644 index 00000000..0a72cd66 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/lm3d_radnerf.yaml @@ -0,0 +1,17 @@ +base_config: + - egs/egs_bases/radnerf/lm3d_radnerf.yaml + +video_id: Lieu # the video file should be located at `data/raw/videos/.mp4` +amp: true + +# to tune scale +# https://github.com/NVlabs/instant-ngp/blob/master/docs/nerf_dataset_tips.md +# https://github.com/ashawkey/torch-ngp/issues/112 +# The occupancy grid works fine in LEGO dataset. (~10x accelerated) +# In my experiment (and on my dataset), I found that occupancy grid sampling is vulnerable to scale. +# In specific scale range, the occ grid sampling works and accelerates rendering. +# But outside of that range, the acceleration gain disappears, or it fails to converge at all. +# (Without the occ grid sampling, the model has learned the scene in that scales.) +# I think this is reasonable because covering the the camera-viewed region with a predefined grid +# is easier to fail than sampling without grids. +# With an manual scale tuning, I can get the expected acceleration gain. \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Lieu/lm3d_radnerf_torso.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/lm3d_radnerf_torso.yaml new file mode 100644 index 00000000..0326d728 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/lm3d_radnerf_torso.yaml @@ -0,0 +1,8 @@ +base_config: + - egs/egs_bases/radnerf/lm3d_radnerf.yaml + +video_id: Lieu # the video file should be located at `data/raw/videos/.mp4` +task_cls: tasks.radnerfs.radnerf_torso.RADNeRFTorsoTask +head_model_dir: checkpoints/Lieu/lm3d_radnerf + +torso_train_mode: 1 diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Lieu/radnerf.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/radnerf.yaml new file mode 100644 index 00000000..27bd5f06 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/radnerf.yaml @@ -0,0 +1,18 @@ +base_config: + - egs/egs_bases/radnerf/radnerf.yaml + +video_id: Lieu # the video file should be located at `data/raw/videos/.mp4` +individual_embedding_num: 13000 # should be equal or larger than the number of frames +amp: true + +# to tune scale +# https://github.com/NVlabs/instant-ngp/blob/master/docs/nerf_dataset_tips.md +# https://github.com/ashawkey/torch-ngp/issues/112 +# The occupancy grid works fine in LEGO dataset. (~10x accelerated) +# In my experiment (and on my dataset), I found that occupancy grid sampling is vulnerable to scale. +# In specific scale range, the occ grid sampling works and accelerates rendering. +# But outside of that range, the acceleration gain disappears, or it fails to converge at all. +# (Without the occ grid sampling, the model has learned the scene in that scales.) +# I think this is reasonable because covering the the camera-viewed region with a predefined grid +# is easier to fail than sampling without grids. +# With an manual scale tuning, I can get the expected acceleration gain. \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Lieu/radnerf_torso.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/radnerf_torso.yaml new file mode 100644 index 00000000..eab0dd2c --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Lieu/radnerf_torso.yaml @@ -0,0 +1,8 @@ +base_config: + - egs/egs_bases/radnerf/radnerf.yaml + +video_id: Lieu # the video file should be located at `data/raw/videos/.mp4` +task_cls: tasks.radnerfs.radnerf_torso.RADNeRFTorsoTask +head_model_dir: checkpoints/Lieu/radnerf + +torso_train_mode: 1 diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Macron/adnerf.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Macron/adnerf.yaml new file mode 100644 index 00000000..cf3eb26e --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Macron/adnerf.yaml @@ -0,0 +1,5 @@ +base_config: + - egs/egs_bases/nerf/adnerf.yaml + +video_id: Macron # the video file should be located at `data/raw/videos/.mp4` + diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Macron/adnerf_torso.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Macron/adnerf_torso.yaml new file mode 100644 index 00000000..8f994e8d --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Macron/adnerf_torso.yaml @@ -0,0 +1,5 @@ +base_config: + - egs/egs_bases/nerf/adnerf_torso.yaml + +video_id: Macron # the video file should be located at `data/raw/videos/.mp4` +head_model_dir: checkpoints/Macron/adnerf # the path of the head_nerf model, e.g., `checkpoints//lm3d_nerf` diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Macron/base.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Macron/base.yaml new file mode 100644 index 00000000..19eb785f --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Macron/base.yaml @@ -0,0 +1,5 @@ +video_id: Macron # the video file should be located at `data/raw/videos/.mp4` + +infer_eye_blink_ref_frames_start_idx: 0 # start index of the ref blink sequence in the GT dataset +infer_eye_blink_ref_frames_end_idx: 60 # end index of the ref blink sequence in the GT dataset +infer_sil_ref_frame_idx: 17 # index of the ref frame with a closed mouth in the GT dataset diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Macron/lm3d_nerf.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Macron/lm3d_nerf.yaml new file mode 100644 index 00000000..2bc412fa --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Macron/lm3d_nerf.yaml @@ -0,0 +1,4 @@ +base_config: + - egs/egs_bases/nerf/lm3d_nerf.yaml + - ./base.yaml + diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Macron/lm3d_nerf_torso.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Macron/lm3d_nerf_torso.yaml new file mode 100644 index 00000000..e196a952 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Macron/lm3d_nerf_torso.yaml @@ -0,0 +1,5 @@ +base_config: + - egs/egs_bases/nerf/lm3d_nerf_torso.yaml + - ./base.yaml + +head_model_dir: checkpoints/Macron/lm3d_nerf # the path of the head_nerf model, e.g., `checkpoints//lm3d_nerf` diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Macron/lm3d_postnet_sync.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Macron/lm3d_postnet_sync.yaml new file mode 100644 index 00000000..a6f1cf48 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Macron/lm3d_postnet_sync.yaml @@ -0,0 +1,5 @@ +base_config: + - egs/egs_bases/postnet/base.yaml + +video_id: Macron # the video file should be located at `data/raw/videos/.mp4` + diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Macron/lm3d_postnet_sync_pitch.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Macron/lm3d_postnet_sync_pitch.yaml new file mode 100644 index 00000000..2ebaf476 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Macron/lm3d_postnet_sync_pitch.yaml @@ -0,0 +1,9 @@ +base_config: + - egs/egs_bases/postnet/base.yaml + +video_id: Macron + +task_cls: tasks.postnet.lm3d_postnet_adv_sync_pitch.PostnetAdvSyncTask +audio2motion_task_cls: tasks.audio2motion.lm3d_vae_sync_pitch.VAESyncAudio2MotionTask +audio2motion_work_dir: checkpoints/lrs3/lm3d_vae_sync_pitch/ +audio2motion_ckpt_steps: 40000 \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Macron/lm3d_postnet_sync_pitch_reg.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Macron/lm3d_postnet_sync_pitch_reg.yaml new file mode 100644 index 00000000..19d79af9 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Macron/lm3d_postnet_sync_pitch_reg.yaml @@ -0,0 +1,11 @@ +base_config: + - egs/egs_bases/postnet/base.yaml + +video_id: Macron + +task_cls: tasks.postnet.lm3d_postnet_adv_sync_pitch.PostnetAdvSyncTask +audio2motion_task_cls: tasks.audio2motion.lm3d_vae_sync_pitch.VAESyncAudio2MotionTask +audio2motion_work_dir: checkpoints/lrs3/lm3d_vae_sync_pitch/ +audio2motion_ckpt_steps: 40000 + +postnet_lambda_reg: 0.02 \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Macron/lm3d_postnet_sync_pitch_reg_continuity.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Macron/lm3d_postnet_sync_pitch_reg_continuity.yaml new file mode 100644 index 00000000..5927afd3 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Macron/lm3d_postnet_sync_pitch_reg_continuity.yaml @@ -0,0 +1,12 @@ +base_config: + - egs/egs_bases/postnet/base.yaml + +video_id: Macron + +task_cls: tasks.postnet.lm3d_postnet_adv_sync_pitch.PostnetAdvSyncTask +audio2motion_task_cls: tasks.audio2motion.lm3d_vae_sync_pitch.VAESyncAudio2MotionTask +audio2motion_work_dir: checkpoints/lrs3/lm3d_vae_sync_pitch/ +audio2motion_ckpt_steps: 40000 + +postnet_lambda_reg: 0.02 +postnet_lambda_continuity: 0.10 \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Macron/lm3d_radnerf.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Macron/lm3d_radnerf.yaml new file mode 100644 index 00000000..f5c79bd3 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Macron/lm3d_radnerf.yaml @@ -0,0 +1,17 @@ +base_config: + - egs/egs_bases/radnerf/lm3d_radnerf.yaml + +video_id: Macron # the video file should be located at `data/raw/videos/.mp4` +amp: true + +# to tune scale +# https://github.com/NVlabs/instant-ngp/blob/master/docs/nerf_dataset_tips.md +# https://github.com/ashawkey/torch-ngp/issues/112 +# The occupancy grid works fine in LEGO dataset. (~10x accelerated) +# In my experiment (and on my dataset), I found that occupancy grid sampling is vulnerable to scale. +# In specific scale range, the occ grid sampling works and accelerates rendering. +# But outside of that range, the acceleration gain disappears, or it fails to converge at all. +# (Without the occ grid sampling, the model has learned the scene in that scales.) +# I think this is reasonable because covering the the camera-viewed region with a predefined grid +# is easier to fail than sampling without grids. +# With an manual scale tuning, I can get the expected acceleration gain. \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Macron/lm3d_radnerf_torso.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Macron/lm3d_radnerf_torso.yaml new file mode 100644 index 00000000..5a09e76c --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Macron/lm3d_radnerf_torso.yaml @@ -0,0 +1,8 @@ +base_config: + - egs/egs_bases/radnerf/lm3d_radnerf.yaml + +video_id: Macron # the video file should be located at `data/raw/videos/.mp4` +task_cls: tasks.radnerfs.radnerf_torso.RADNeRFTorsoTask +head_model_dir: checkpoints/Macron/lm3d_radnerf + +torso_train_mode: 1 diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Macron/radnerf.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Macron/radnerf.yaml new file mode 100644 index 00000000..8be0fda6 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Macron/radnerf.yaml @@ -0,0 +1,18 @@ +base_config: + - egs/egs_bases/radnerf/radnerf.yaml + +video_id: Macron # the video file should be located at `data/raw/videos/.mp4` +individual_embedding_num: 13000 # should be equal or larger than the number of frames +amp: true + +# to tune scale +# https://github.com/NVlabs/instant-ngp/blob/master/docs/nerf_dataset_tips.md +# https://github.com/ashawkey/torch-ngp/issues/112 +# The occupancy grid works fine in LEGO dataset. (~10x accelerated) +# In my experiment (and on my dataset), I found that occupancy grid sampling is vulnerable to scale. +# In specific scale range, the occ grid sampling works and accelerates rendering. +# But outside of that range, the acceleration gain disappears, or it fails to converge at all. +# (Without the occ grid sampling, the model has learned the scene in that scales.) +# I think this is reasonable because covering the the camera-viewed region with a predefined grid +# is easier to fail than sampling without grids. +# With an manual scale tuning, I can get the expected acceleration gain. \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Macron/radnerf_torso.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Macron/radnerf_torso.yaml new file mode 100644 index 00000000..2f31d2f1 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Macron/radnerf_torso.yaml @@ -0,0 +1,8 @@ +base_config: + - egs/egs_bases/radnerf/radnerf.yaml + +video_id: Macron # the video file should be located at `data/raw/videos/.mp4` +task_cls: tasks.radnerfs.radnerf_torso.RADNeRFTorsoTask +head_model_dir: checkpoints/Macron/radnerf + +torso_train_mode: 1 diff --git a/Geneface_main/GeneFace/egs/datasets/videos/May/adnerf.yaml b/Geneface_main/GeneFace/egs/datasets/videos/May/adnerf.yaml new file mode 100644 index 00000000..889c8586 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/May/adnerf.yaml @@ -0,0 +1,5 @@ +base_config: + - egs/egs_bases/nerf/adnerf.yaml + +video_id: May # the video file should be located at `data/raw/videos/.mp4` + diff --git a/Geneface_main/GeneFace/egs/datasets/videos/May/adnerf_torso.yaml b/Geneface_main/GeneFace/egs/datasets/videos/May/adnerf_torso.yaml new file mode 100644 index 00000000..999ba653 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/May/adnerf_torso.yaml @@ -0,0 +1,5 @@ +base_config: + - egs/egs_bases/nerf/adnerf_torso.yaml + +video_id: May # the video file should be located at `data/raw/videos/.mp4` +head_model_dir: checkpoints/May/adnerf # the path of the head_nerf model, e.g., `checkpoints//lm3d_nerf` diff --git a/Geneface_main/GeneFace/egs/datasets/videos/May/audio2pose.yaml b/Geneface_main/GeneFace/egs/datasets/videos/May/audio2pose.yaml new file mode 100644 index 00000000..34689821 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/May/audio2pose.yaml @@ -0,0 +1,4 @@ +base_config: + - egs/egs_bases/audio2pose/base.yaml + +video_id: May # the video file should be located at `data/raw/videos/.mp4` diff --git a/Geneface_main/GeneFace/egs/datasets/videos/May/base.yaml b/Geneface_main/GeneFace/egs/datasets/videos/May/base.yaml new file mode 100644 index 00000000..5216def0 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/May/base.yaml @@ -0,0 +1,3 @@ +infer_eye_blink_ref_frames_start_idx: 0 # start index of the ref blink sequence in the GT dataset +infer_eye_blink_ref_frames_end_idx: 60 # end index of the ref blink sequence in the GT dataset +infer_sil_ref_frame_idx: 6072 # index of the ref frame with a closed mouth in the GT dataset \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_nerf.yaml b/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_nerf.yaml new file mode 100644 index 00000000..14db667f --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_nerf.yaml @@ -0,0 +1,5 @@ +base_config: + - egs/egs_bases/nerf/lm3d_nerf.yaml + - ./base.yaml + +video_id: May # the video file should be located at `data/raw/videos/.mp4` diff --git a/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_nerf_torso.yaml b/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_nerf_torso.yaml new file mode 100644 index 00000000..8ca18e0d --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_nerf_torso.yaml @@ -0,0 +1,6 @@ +base_config: + - egs/egs_bases/nerf/lm3d_nerf_torso.yaml + - ./base.yaml + +video_id: May # the video file should be located at `data/raw/videos/.mp4` +head_model_dir: checkpoints/May/lm3d_nerf # the path of the head_nerf model, e.g., `checkpoints//lm3d_nerf` diff --git a/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_postnet_sync.yaml b/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_postnet_sync.yaml new file mode 100644 index 00000000..507121e5 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_postnet_sync.yaml @@ -0,0 +1,7 @@ +base_config: + - egs/egs_bases/postnet/base.yaml + +video_id: May # the video file should be located at `data/raw/videos/.mp4` +audio2motion_task_cls: tasks.audio2motion.lm3d_vae_sync.VAESyncAudio2MotionTask +audio2motion_work_dir: checkpoints/lrs3/lm3d_vae_sync/ +audio2motion_ckpt_steps: 40000 diff --git a/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_postnet_sync_pitch.yaml b/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_postnet_sync_pitch.yaml new file mode 100644 index 00000000..0e1b2505 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_postnet_sync_pitch.yaml @@ -0,0 +1,9 @@ +base_config: + - egs/egs_bases/postnet/base.yaml + +video_id: May + +task_cls: tasks.postnet.lm3d_postnet_adv_sync_pitch.PostnetAdvSyncTask +audio2motion_task_cls: tasks.audio2motion.lm3d_vae_sync_pitch.VAESyncAudio2MotionTask +audio2motion_work_dir: checkpoints/lrs3/lm3d_vae_sync_pitch/ +audio2motion_ckpt_steps: 40000 \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_radnerf.yaml b/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_radnerf.yaml new file mode 100644 index 00000000..eadcaec7 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_radnerf.yaml @@ -0,0 +1,17 @@ +base_config: + - egs/egs_bases/radnerf/lm3d_radnerf.yaml + +video_id: May # the video file should be located at `data/raw/videos/.mp4` +amp: true + +# to tune scale +# https://github.com/NVlabs/instant-ngp/blob/master/docs/nerf_dataset_tips.md +# https://github.com/ashawkey/torch-ngp/issues/112 +# The occupancy grid works fine in LEGO dataset. (~10x accelerated) +# In my experiment (and on my dataset), I found that occupancy grid sampling is vulnerable to scale. +# In specific scale range, the occ grid sampling works and accelerates rendering. +# But outside of that range, the acceleration gain disappears, or it fails to converge at all. +# (Without the occ grid sampling, the model has learned the scene in that scales.) +# I think this is reasonable because covering the the camera-viewed region with a predefined grid +# is easier to fail than sampling without grids. +# With an manual scale tuning, I can get the expected acceleration gain. \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_radnerf_hash.yaml b/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_radnerf_hash.yaml new file mode 100644 index 00000000..8aa7b0a7 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_radnerf_hash.yaml @@ -0,0 +1,19 @@ +base_config: + - egs/egs_bases/radnerf/lm3d_radnerf.yaml + +video_id: May # the video file should be located at `data/raw/videos/.mp4` +individual_embedding_num: 10000 # should be equal or larger than the number of frames +amp: true + +grid_type: hashgrid +# to tune scale +# https://github.com/NVlabs/instant-ngp/blob/master/docs/nerf_dataset_tips.md +# https://github.com/ashawkey/torch-ngp/issues/112 +# The occupancy grid works fine in LEGO dataset. (~10x accelerated) +# In my experiment (and on my dataset), I found that occupancy grid sampling is vulnerable to scale. +# In specific scale range, the occ grid sampling works and accelerates rendering. +# But outside of that range, the acceleration gain disappears, or it fails to converge at all. +# (Without the occ grid sampling, the model has learned the scene in that scales.) +# I think this is reasonable because covering the the camera-viewed region with a predefined grid +# is easier to fail than sampling without grids. +# With an manual scale tuning, I can get the expected acceleration gain. \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_radnerf_hash_smoothstep.yaml b/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_radnerf_hash_smoothstep.yaml new file mode 100644 index 00000000..4e86aa5d --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_radnerf_hash_smoothstep.yaml @@ -0,0 +1,20 @@ +base_config: + - egs/egs_bases/radnerf/lm3d_radnerf.yaml + +video_id: May # the video file should be located at `data/raw/videos/.mp4` +individual_embedding_num: 10000 # should be equal or larger than the number of frames +amp: true + +grid_type: hashgrid +grid_interpolation_type: smoothstep +# to tune scale +# https://github.com/NVlabs/instant-ngp/blob/master/docs/nerf_dataset_tips.md +# https://github.com/ashawkey/torch-ngp/issues/112 +# The occupancy grid works fine in LEGO dataset. (~10x accelerated) +# In my experiment (and on my dataset), I found that occupancy grid sampling is vulnerable to scale. +# In specific scale range, the occ grid sampling works and accelerates rendering. +# But outside of that range, the acceleration gain disappears, or it fails to converge at all. +# (Without the occ grid sampling, the model has learned the scene in that scales.) +# I think this is reasonable because covering the the camera-viewed region with a predefined grid +# is easier to fail than sampling without grids. +# With an manual scale tuning, I can get the expected acceleration gain. \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_radnerf_smoothstep.yaml b/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_radnerf_smoothstep.yaml new file mode 100644 index 00000000..863c3ad1 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_radnerf_smoothstep.yaml @@ -0,0 +1,19 @@ +base_config: + - egs/egs_bases/radnerf/lm3d_radnerf.yaml + +video_id: May # the video file should be located at `data/raw/videos/.mp4` +individual_embedding_num: 10000 # should be equal or larger than the number of frames +amp: true + +grid_interpolation_type: smoothstep +# to tune scale +# https://github.com/NVlabs/instant-ngp/blob/master/docs/nerf_dataset_tips.md +# https://github.com/ashawkey/torch-ngp/issues/112 +# The occupancy grid works fine in LEGO dataset. (~10x accelerated) +# In my experiment (and on my dataset), I found that occupancy grid sampling is vulnerable to scale. +# In specific scale range, the occ grid sampling works and accelerates rendering. +# But outside of that range, the acceleration gain disappears, or it fails to converge at all. +# (Without the occ grid sampling, the model has learned the scene in that scales.) +# I think this is reasonable because covering the the camera-viewed region with a predefined grid +# is easier to fail than sampling without grids. +# With an manual scale tuning, I can get the expected acceleration gain. \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_radnerf_torso.yaml b/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_radnerf_torso.yaml new file mode 100644 index 00000000..95969ebd --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_radnerf_torso.yaml @@ -0,0 +1,8 @@ +base_config: + - egs/egs_bases/radnerf/lm3d_radnerf.yaml + +video_id: May # the video file should be located at `data/raw/videos/.mp4` +task_cls: tasks.radnerfs.radnerf_torso.RADNeRFTorsoTask +head_model_dir: checkpoints/May/lm3d_radnerf + +torso_train_mode: 1 diff --git a/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_radnerf_torso_head_aware.yaml b/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_radnerf_torso_head_aware.yaml new file mode 100644 index 00000000..252ccf8f --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/May/lm3d_radnerf_torso_head_aware.yaml @@ -0,0 +1,9 @@ +base_config: + - egs/egs_bases/radnerf/lm3d_radnerf.yaml + +video_id: May # the video file should be located at `data/raw/videos/.mp4` +task_cls: tasks.radnerfs.radnerf_torso.RADNeRFTorsoTask +head_model_dir: checkpoints/May/lm3d_radnerf + +torso_train_mode: 1 +torso_head_aware: true # head aware torso nerf to avoid head-torso separation artifacts! diff --git a/Geneface_main/GeneFace/egs/datasets/videos/May/radnerf.yaml b/Geneface_main/GeneFace/egs/datasets/videos/May/radnerf.yaml new file mode 100644 index 00000000..496a9b95 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/May/radnerf.yaml @@ -0,0 +1,17 @@ +base_config: + - egs/egs_bases/radnerf/radnerf.yaml + +video_id: May # the video file should be located at `data/raw/videos/.mp4` +amp: true + +# to tune scale +# https://github.com/NVlabs/instant-ngp/blob/master/docs/nerf_dataset_tips.md +# https://github.com/ashawkey/torch-ngp/issues/112 +# The occupancy grid works fine in LEGO dataset. (~10x accelerated) +# In my experiment (and on my dataset), I found that occupancy grid sampling is vulnerable to scale. +# In specific scale range, the occ grid sampling works and accelerates rendering. +# But outside of that range, the acceleration gain disappears, or it fails to converge at all. +# (Without the occ grid sampling, the model has learned the scene in that scales.) +# I think this is reasonable because covering the the camera-viewed region with a predefined grid +# is easier to fail than sampling without grids. +# With an manual scale tuning, I can get the expected acceleration gain. \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/May/radnerf_torso.yaml b/Geneface_main/GeneFace/egs/datasets/videos/May/radnerf_torso.yaml new file mode 100644 index 00000000..8c557ade --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/May/radnerf_torso.yaml @@ -0,0 +1,8 @@ +base_config: + - egs/egs_bases/radnerf/radnerf.yaml + +video_id: May # the video file should be located at `data/raw/videos/.mp4` +task_cls: tasks.radnerfs.radnerf_torso.RADNeRFTorsoTask +head_model_dir: checkpoints/May/radnerf + +torso_train_mode: 1 diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Obama/adnerf.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Obama/adnerf.yaml new file mode 100644 index 00000000..87fbaf33 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Obama/adnerf.yaml @@ -0,0 +1,5 @@ +base_config: + - egs/egs_bases/nerf/adnerf.yaml + +video_id: Obama # the video file should be located at `data/raw/videos/.mp4` + diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Obama/adnerf_torso.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Obama/adnerf_torso.yaml new file mode 100644 index 00000000..dfe23084 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Obama/adnerf_torso.yaml @@ -0,0 +1,5 @@ +base_config: + - egs/egs_bases/nerf/adnerf_torso.yaml + +video_id: Obama # the video file should be located at `data/raw/videos/.mp4` +head_model_dir: checkpoints/Obama/adnerf # the path of the head_nerf model, e.g., `checkpoints//lm3d_nerf` diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Obama/base.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Obama/base.yaml new file mode 100644 index 00000000..e85aa578 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Obama/base.yaml @@ -0,0 +1,4 @@ +video_id: Obama +infer_eye_blink_ref_frames_start_idx: 0 # start index of the ref blink sequence in the GT dataset +infer_eye_blink_ref_frames_end_idx: 60 # end index of the ref blink sequence in the GT dataset +infer_sil_ref_frame_idx: 6072 # index of the ref frame with a closed mouth in the GT dataset \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Obama/lm3d_nerf.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Obama/lm3d_nerf.yaml new file mode 100644 index 00000000..2bc412fa --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Obama/lm3d_nerf.yaml @@ -0,0 +1,4 @@ +base_config: + - egs/egs_bases/nerf/lm3d_nerf.yaml + - ./base.yaml + diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Obama/lm3d_nerf_torso.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Obama/lm3d_nerf_torso.yaml new file mode 100644 index 00000000..0b2332ca --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Obama/lm3d_nerf_torso.yaml @@ -0,0 +1,5 @@ +base_config: + - egs/egs_bases/nerf/lm3d_nerf_torso.yaml + - ./base.yaml + +head_model_dir: checkpoints/Obama/lm3d_nerf # the path of the head_nerf model, e.g., `checkpoints//lm3d_nerf` diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Obama/lm3d_postnet_sync.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Obama/lm3d_postnet_sync.yaml new file mode 100644 index 00000000..9a24edbb --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Obama/lm3d_postnet_sync.yaml @@ -0,0 +1,5 @@ +base_config: + - egs/egs_bases/postnet/base.yaml + +video_id: Obama # the video file should be located at `data/raw/videos/.mp4` + diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Obama/radnerf.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Obama/radnerf.yaml new file mode 100644 index 00000000..24e02007 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Obama/radnerf.yaml @@ -0,0 +1,12 @@ +base_config: + - egs/egs_bases/radnerf/radnerf.yaml + +video_id: Obama # the video file should be located at `data/raw/videos/.mp4` +individual_embedding_num: 10000 # should be equal or larger than the number of frames +amp: true + +# Q: How to adjust bound and scale? + +# A: You could start with a large bound (e.g., 16) +# or a small scale (e.g., 0.3) +# to make sure the object falls into the bounding box. \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Obama2/adnerf.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/adnerf.yaml new file mode 100644 index 00000000..2bd2996e --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/adnerf.yaml @@ -0,0 +1,5 @@ +base_config: + - egs/egs_bases/nerf/adnerf.yaml + +video_id: Obama2 # the video file should be located at `data/raw/videos/.mp4` + diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Obama2/adnerf_torso.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/adnerf_torso.yaml new file mode 100644 index 00000000..e6dc6534 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/adnerf_torso.yaml @@ -0,0 +1,5 @@ +base_config: + - egs/egs_bases/nerf/adnerf_torso.yaml + +video_id: Obama2 # the video file should be located at `data/raw/videos/.mp4` +head_model_dir: checkpoints/Obama2/adnerf # the path of the head_nerf model, e.g., `checkpoints//lm3d_nerf` diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Obama2/base.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/base.yaml new file mode 100644 index 00000000..b2d082ba --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/base.yaml @@ -0,0 +1,5 @@ +video_id: Obama2 # the video file should be located at `data/raw/videos/.mp4` + +infer_eye_blink_ref_frames_start_idx: 0 # start index of the ref blink sequence in the GT dataset +infer_eye_blink_ref_frames_end_idx: 60 # end index of the ref blink sequence in the GT dataset +infer_sil_ref_frame_idx: 17 # index of the ref frame with a closed mouth in the GT dataset diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Obama2/lm3d_nerf.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/lm3d_nerf.yaml new file mode 100644 index 00000000..2bc412fa --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/lm3d_nerf.yaml @@ -0,0 +1,4 @@ +base_config: + - egs/egs_bases/nerf/lm3d_nerf.yaml + - ./base.yaml + diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Obama2/lm3d_nerf_torso.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/lm3d_nerf_torso.yaml new file mode 100644 index 00000000..7d23e6c9 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/lm3d_nerf_torso.yaml @@ -0,0 +1,5 @@ +base_config: + - egs/egs_bases/nerf/lm3d_nerf_torso.yaml + - ./base.yaml + +head_model_dir: checkpoints/Obama2/lm3d_nerf # the path of the head_nerf model, e.g., `checkpoints//lm3d_nerf` diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Obama2/lm3d_postnet_sync.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/lm3d_postnet_sync.yaml new file mode 100644 index 00000000..2e98983b --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/lm3d_postnet_sync.yaml @@ -0,0 +1,5 @@ +base_config: + - egs/egs_bases/postnet/base.yaml + +video_id: Obama2 # the video file should be located at `data/raw/videos/.mp4` + diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Obama2/lm3d_postnet_sync_pitch.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/lm3d_postnet_sync_pitch.yaml new file mode 100644 index 00000000..831c92cf --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/lm3d_postnet_sync_pitch.yaml @@ -0,0 +1,9 @@ +base_config: + - egs/egs_bases/postnet/base.yaml + +video_id: Obama2 + +task_cls: tasks.postnet.lm3d_postnet_adv_sync_pitch.PostnetAdvSyncTask +audio2motion_task_cls: tasks.audio2motion.lm3d_vae_sync_pitch.VAESyncAudio2MotionTask +audio2motion_work_dir: checkpoints/lrs3/lm3d_vae_sync_pitch/ +audio2motion_ckpt_steps: 40000 \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Obama2/lm3d_postnet_sync_pitch_reg.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/lm3d_postnet_sync_pitch_reg.yaml new file mode 100644 index 00000000..249d9ab4 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/lm3d_postnet_sync_pitch_reg.yaml @@ -0,0 +1,11 @@ +base_config: + - egs/egs_bases/postnet/base.yaml + +video_id: Obama2 + +task_cls: tasks.postnet.lm3d_postnet_adv_sync_pitch.PostnetAdvSyncTask +audio2motion_task_cls: tasks.audio2motion.lm3d_vae_sync_pitch.VAESyncAudio2MotionTask +audio2motion_work_dir: checkpoints/lrs3/lm3d_vae_sync_pitch/ +audio2motion_ckpt_steps: 40000 + +postnet_lambda_reg: 0.02 \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Obama2/lm3d_postnet_sync_pitch_reg_continuity.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/lm3d_postnet_sync_pitch_reg_continuity.yaml new file mode 100644 index 00000000..1cef0436 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/lm3d_postnet_sync_pitch_reg_continuity.yaml @@ -0,0 +1,12 @@ +base_config: + - egs/egs_bases/postnet/base.yaml + +video_id: Obama2 + +task_cls: tasks.postnet.lm3d_postnet_adv_sync_pitch.PostnetAdvSyncTask +audio2motion_task_cls: tasks.audio2motion.lm3d_vae_sync_pitch.VAESyncAudio2MotionTask +audio2motion_work_dir: checkpoints/lrs3/lm3d_vae_sync_pitch/ +audio2motion_ckpt_steps: 40000 + +postnet_lambda_reg: 0.02 +postnet_lambda_continuity: 0.10 \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Obama2/lm3d_radnerf.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/lm3d_radnerf.yaml new file mode 100644 index 00000000..6acfa205 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/lm3d_radnerf.yaml @@ -0,0 +1,18 @@ +base_config: + - egs/egs_bases/radnerf/lm3d_radnerf.yaml + +video_id: Obama2 # the video file should be located at `data/raw/videos/.mp4` +individual_embedding_num: 13000 # should be equal or larger than the number of frames +amp: true + +# to tune scale +# https://github.com/NVlabs/instant-ngp/blob/master/docs/nerf_dataset_tips.md +# https://github.com/ashawkey/torch-ngp/issues/112 +# The occupancy grid works fine in LEGO dataset. (~10x accelerated) +# In my experiment (and on my dataset), I found that occupancy grid sampling is vulnerable to scale. +# In specific scale range, the occ grid sampling works and accelerates rendering. +# But outside of that range, the acceleration gain disappears, or it fails to converge at all. +# (Without the occ grid sampling, the model has learned the scene in that scales.) +# I think this is reasonable because covering the the camera-viewed region with a predefined grid +# is easier to fail than sampling without grids. +# With an manual scale tuning, I can get the expected acceleration gain. \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Obama2/lm3d_radnerf_torso.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/lm3d_radnerf_torso.yaml new file mode 100644 index 00000000..5739e70c --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/lm3d_radnerf_torso.yaml @@ -0,0 +1,8 @@ +base_config: + - egs/egs_bases/radnerf/lm3d_radnerf.yaml + +video_id: Obama2 # the video file should be located at `data/raw/videos/.mp4` +task_cls: tasks.radnerfs.radnerf_torso.RADNeRFTorsoTask +head_model_dir: checkpoints/Obama2/lm3d_radnerf + +torso_train_mode: 1 diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Obama2/radnerf.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/radnerf.yaml new file mode 100644 index 00000000..28cd80a0 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/radnerf.yaml @@ -0,0 +1,18 @@ +base_config: + - egs/egs_bases/radnerf/radnerf.yaml + +video_id: Obama2 # the video file should be located at `data/raw/videos/.mp4` +individual_embedding_num: 13000 # should be equal or larger than the number of frames +amp: true + +# to tune scale +# https://github.com/NVlabs/instant-ngp/blob/master/docs/nerf_dataset_tips.md +# https://github.com/ashawkey/torch-ngp/issues/112 +# The occupancy grid works fine in LEGO dataset. (~10x accelerated) +# In my experiment (and on my dataset), I found that occupancy grid sampling is vulnerable to scale. +# In specific scale range, the occ grid sampling works and accelerates rendering. +# But outside of that range, the acceleration gain disappears, or it fails to converge at all. +# (Without the occ grid sampling, the model has learned the scene in that scales.) +# I think this is reasonable because covering the the camera-viewed region with a predefined grid +# is easier to fail than sampling without grids. +# With an manual scale tuning, I can get the expected acceleration gain. \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Obama2/radnerf_torso.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/radnerf_torso.yaml new file mode 100644 index 00000000..46ac2933 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Obama2/radnerf_torso.yaml @@ -0,0 +1,8 @@ +base_config: + - egs/egs_bases/radnerf/radnerf.yaml + +video_id: Obama2 # the video file should be located at `data/raw/videos/.mp4` +task_cls: tasks.radnerfs.radnerf_torso.RADNeRFTorsoTask +head_model_dir: checkpoints/Obama2/radnerf + +torso_train_mode: 1 diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/adnerf.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/adnerf.yaml new file mode 100644 index 00000000..f9707f42 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/adnerf.yaml @@ -0,0 +1,5 @@ +base_config: + - egs/egs_bases/nerf/adnerf.yaml + +video_id: Zhang2 # the video file should be located at `data/raw/videos/.mp4` + diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/adnerf_torso.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/adnerf_torso.yaml new file mode 100644 index 00000000..758cf113 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/adnerf_torso.yaml @@ -0,0 +1,5 @@ +base_config: + - egs/egs_bases/nerf/adnerf_torso.yaml + +video_id: Zhang2 # the video file should be located at `data/raw/videos/.mp4` +head_model_dir: checkpoints/Zhang2/adnerf # the path of the head_nerf model, e.g., `checkpoints//lm3d_nerf` diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/audio2pose.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/audio2pose.yaml new file mode 100644 index 00000000..c35f92a6 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/audio2pose.yaml @@ -0,0 +1,4 @@ +base_config: + - egs/egs_bases/audio2pose/base.yaml + +video_id: Zhang2 # the video file should be located at `data/raw/videos/.mp4` diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/base.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/base.yaml new file mode 100644 index 00000000..12b9f62f --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/base.yaml @@ -0,0 +1,3 @@ +infer_eye_blink_ref_frames_start_idx: 40 # start index of the ref blink sequence in the GT dataset +infer_eye_blink_ref_frames_end_idx: 100 # end index of the ref blink sequence in the GT dataset +infer_sil_ref_frame_idx: 7780 # index of the ref frame with a closed mouth in the GT dataset \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/lm3d_nerf.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/lm3d_nerf.yaml new file mode 100644 index 00000000..1e868831 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/lm3d_nerf.yaml @@ -0,0 +1,5 @@ +base_config: + - egs/egs_bases/nerf/lm3d_nerf.yaml + - ./base.yaml + +video_id: Zhang2 diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/lm3d_nerf_torso.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/lm3d_nerf_torso.yaml new file mode 100644 index 00000000..1978b88b --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/lm3d_nerf_torso.yaml @@ -0,0 +1,6 @@ +base_config: + - egs/egs_bases/nerf/lm3d_nerf_torso.yaml + - ./base.yaml + +video_id: Zhang2 +head_model_dir: checkpoints/Zhang2/lm3d_nerf diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/lm3d_postnet_sync.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/lm3d_postnet_sync.yaml new file mode 100644 index 00000000..e7b8b62b --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/lm3d_postnet_sync.yaml @@ -0,0 +1,5 @@ +base_config: + - egs/egs_bases/postnet/base.yaml + +video_id: Zhang2 + diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/lm3d_postnet_sync_pitch.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/lm3d_postnet_sync_pitch.yaml new file mode 100644 index 00000000..fe58d559 --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/lm3d_postnet_sync_pitch.yaml @@ -0,0 +1,9 @@ +base_config: + - egs/egs_bases/postnet/base.yaml + +video_id: Zhang2 + +task_cls: tasks.postnet.lm3d_postnet_adv_sync_pitch.PostnetAdvSyncTask +audio2motion_task_cls: tasks.audio2motion.lm3d_vae_sync_pitch.VAESyncAudio2MotionTask +audio2motion_work_dir: checkpoints/lrs3/lm3d_vae_sync_pitch/ +audio2motion_ckpt_steps: 40000 \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/lm3d_postnet_sync_pitch_reg.yaml b/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/lm3d_postnet_sync_pitch_reg.yaml new file mode 100644 index 00000000..3ca0e5ee --- /dev/null +++ b/Geneface_main/GeneFace/egs/datasets/videos/Zhang2/lm3d_postnet_sync_pitch_reg.yaml @@ -0,0 +1,11 @@ +base_config: + - egs/egs_bases/postnet/base.yaml + +video_id: Zhang2 + +task_cls: tasks.postnet.lm3d_postnet_adv_sync_pitch.PostnetAdvSyncTask +audio2motion_task_cls: tasks.audio2motion.lm3d_vae_sync_pitch.VAESyncAudio2MotionTask +audio2motion_work_dir: checkpoints/lrs3/lm3d_vae_sync_pitch/ +audio2motion_ckpt_steps: 40000 + +postnet_lambda_reg: 0.02 \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/egs_bases/audio2motion/base.yaml b/Geneface_main/GeneFace/egs/egs_bases/audio2motion/base.yaml new file mode 100644 index 00000000..4fcc8819 --- /dev/null +++ b/Geneface_main/GeneFace/egs/egs_bases/audio2motion/base.yaml @@ -0,0 +1,51 @@ +# dataset-related +binary_data_dir: data/binary/lrs3 + +# project-related +work_dir: '' +load_ckpt: '' +tb_log_interval: 100 + +# testing related +gen_dir_name: '' +save_gt: true + +# training-scheme-related +num_ckpt_keep: 100 +val_check_interval: 2000 +valid_infer_interval: 2000 +max_updates: 4_0000 +seed: 9999 +lr: 0.0005 +scheduler: exponential # exponential|rsqrt|warmup|none|step_lr +warmup_updates: 1000 +optimizer_adam_beta1: 0.9 +optimizer_adam_beta2: 0.999 +weight_decay: 0 +accumulate_grad_batches: 1 +clip_grad_norm: 1 +clip_grad_value: 0 +num_sanity_val_steps: 5 +num_valid_plots: 1 +eval_max_batches: 10 # num_test_plots +print_nan_grads: false +resume_from_checkpoint: 0 # specify the step, 0 for latest +amp: false +valid_monitor_key: val_loss +valid_monitor_mode: min +save_best: false +debug: false +save_codes: +- tasks +- modules +- egs + +# model-related +hidden_size: 256 + +# infer-related +infer_audio_source_name: '' +infer_out_npy_name: '' +infer_ckpt_steps: 40000 + +load_db_to_memory: false # enable it for faster indexing \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/egs_bases/audio2motion/vae_sync.yaml b/Geneface_main/GeneFace/egs/egs_bases/audio2motion/vae_sync.yaml new file mode 100644 index 00000000..f707922b --- /dev/null +++ b/Geneface_main/GeneFace/egs/egs_bases/audio2motion/vae_sync.yaml @@ -0,0 +1,10 @@ +base_config: + - ./base.yaml + +# VAE related +task_cls: tasks.audio2motion.lm3d_vae_sync.VAESyncAudio2MotionTask +lambda_kl: 0.5 + +# SyncNet related +syncnet_work_dir: checkpoints/lrs3/syncnet +syncnet_ckpt_steps: 1000 diff --git a/Geneface_main/GeneFace/egs/egs_bases/audio2pose/base.yaml b/Geneface_main/GeneFace/egs/egs_bases/audio2pose/base.yaml new file mode 100644 index 00000000..e71cff9c --- /dev/null +++ b/Geneface_main/GeneFace/egs/egs_bases/audio2pose/base.yaml @@ -0,0 +1,47 @@ +# dataset-related +raw_data_dir: data/raw/videos +processed_data_dir: data/processed/videos +binary_data_dir: data/binary/videos +video_id: '' +task_cls: '' + +# project-related +work_dir: '' +load_ckpt: '' +tb_log_interval: 100 +val_check_interval: 1000 +valid_infer_interval: 1000 +num_sanity_val_steps: 5 +num_valid_plots: 1 +eval_max_batches: 10 # num_test_plots +print_nan_grads: false +resume_from_checkpoint: 0 # specify the step, 0 for latest +amp: false +valid_monitor_key: val_loss +valid_monitor_mode: min +save_best: true +debug: false +save_codes: +- tasks +- modules +- egs +accumulate_grad_batches: 1 +clip_grad_norm: 1. + +# training-scheme-related +task_cls: tasks.audio2pose.audio2pose.Audio2PoseTask +max_updates: 1_0000 +seed: 9999 +lr: 0.0005 +optimizer_adam_beta1: 0.9 +optimizer_adam_beta2: 0.999 +scheduler: exponential # exponential|rsqrt|warmup|none|step_lr +warmup_updates: 1000 + +valid_infer_interval: 1000 +val_check_interval: 1000 +num_ckpt_keep: 10 + +infer_audio_source_name: '' +infer_out_npy_name: '' +reception_field: 100 \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/egs_bases/nerf/adnerf.yaml b/Geneface_main/GeneFace/egs/egs_bases/nerf/adnerf.yaml new file mode 100644 index 00000000..49206120 --- /dev/null +++ b/Geneface_main/GeneFace/egs/egs_bases/nerf/adnerf.yaml @@ -0,0 +1,8 @@ +base_config: + - egs/egs_bases/nerf/base.yaml + +task_cls: tasks.nerfs.adnerf.ADNeRFTask +cond_type: deepspeech +no_smo_iterations: 20_0000 +cond_win_size: 16 +smo_win_size: 8 \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/egs_bases/nerf/adnerf_torso.yaml b/Geneface_main/GeneFace/egs/egs_bases/nerf/adnerf_torso.yaml new file mode 100644 index 00000000..222077e6 --- /dev/null +++ b/Geneface_main/GeneFace/egs/egs_bases/nerf/adnerf_torso.yaml @@ -0,0 +1,7 @@ +base_config: + - egs/egs_bases/nerf/adnerf.yaml + +task_cls: tasks.nerfs.adnerf_torso.ADNeRFTorsoTask +no_smo_iterations: 0 # nerf_torso use the fixed audatt_net from head_nerf +head_model_dir: '' +use_color: false diff --git a/Geneface_main/GeneFace/egs/egs_bases/nerf/base.yaml b/Geneface_main/GeneFace/egs/egs_bases/nerf/base.yaml new file mode 100644 index 00000000..921ba892 --- /dev/null +++ b/Geneface_main/GeneFace/egs/egs_bases/nerf/base.yaml @@ -0,0 +1,79 @@ +# dataset-related +raw_data_dir: data/raw/videos +processed_data_dir: data/processed/videos +binary_data_dir: data/binary/videos +video_id: '' +task_cls: '' + +# project-related +work_dir: '' +load_ckpt: '' +tb_log_interval: 100 +num_ckpt_keep: 1 +val_check_interval: 10000 +valid_infer_interval: 10000 +num_sanity_val_steps: 0 +num_valid_plots: 5 +eval_max_batches: 100 # num_test_plots +print_nan_grads: false +resume_from_checkpoint: 0 # specify the step, 0 for latest +amp: false +valid_monitor_key: val_loss +valid_monitor_mode: min +save_best: true +debug: false +save_codes: +- tasks +- modules +- egs + +# testing related +gen_dir_name: '' +save_gt: true + +# training-scheme-related +max_updates: 40_0000 +seed: 9999 +lr: 0.0005 +scheduler: exponential # exponential|rsqrt|warmup|none|step_lr +warmup_updates: 0 +optimizer_adam_beta1: 0.9 +optimizer_adam_beta2: 0.999 +weight_decay: 0 +clip_grad_norm: 0 # disable grad clipping +clip_grad_value: 0 # disable grad clipping +rays_sampler_type: uniform +in_rect_percent: 0.95 +accumulate_grad_batches: 1 + +# model-related +use_window_cond: true +with_att: true # only available when use win_cond, use a attention Net in AD-NeRF +cond_type: '' +cond_dim: 64 +hidden_size: 256 + +# NeRF-related +near: 0.3 +far: 0.9 +n_rays: 1600 # default 2048, 1600 for RTX2080Ti +n_samples_per_ray: 64 +n_samples_per_ray_fine: 128 +embedding_args: + multi_res_pos: 10 # log2+1 of max freq for positional encoding (3D location) + multi_res_views: 4 # log2+1 of max freq for positional encoding (2D direction) + +infer_cond_name: '' +infer_out_video_name: '' +infer_scale_factor: 1.0 +infer_smo_std: 0. +infer_audio_source_name: '' +infer_c2w_name: '' + +# postprocessing params +infer_lm3d_clamp_std: 2.0 +infer_lm3d_lle_percent: 0. # percent of lle fused feature to compose the processed lm3d +infer_lm3d_smooth_sigma: 0. # sigma of gaussian kernel to smooth the predicted lm3d +infer_pose_smooth_sigma: 2. + +load_imgs_to_memory: false # load uint8 training img to memory, which reduce io costs, at the expense of more memory occupation \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/egs_bases/nerf/lm3d_nerf.yaml b/Geneface_main/GeneFace/egs/egs_bases/nerf/lm3d_nerf.yaml new file mode 100644 index 00000000..87af5a4e --- /dev/null +++ b/Geneface_main/GeneFace/egs/egs_bases/nerf/lm3d_nerf.yaml @@ -0,0 +1,18 @@ +base_config: + - egs/egs_bases/nerf/base.yaml + +task_cls: tasks.nerfs.lm3d_nerf.Lm3dNeRFTask +cond_type: idexp_lm3d_normalized +no_smo_iterations: 20_0000 + +use_window_cond: true # the NeRF only takes the exp at current frame as condition +with_att: true # only available when use win_cond, use a attention Net in AD-NeRF +cond_win_size: 1 +smo_win_size: 5 + +infer_inject_eye_blink_mode: none # none|gt|period. `gt` uses the eye blink sequence from GT dataset, `period` use a ref blink sequence from GT dataset and repeat it to the final length +infer_eye_blink_ref_frames_start_idx: '' # start index of the ref blink sequence in the GT dataset +infer_eye_blink_ref_frames_end_idx: '' # end index of the ref blink sequence in the GT dataset + +infer_close_mouth_when_sil: False # detect sil frames, then set the mouth to close in these frames +infer_sil_ref_frame_idx: '' # index of the ref frame with a closed mouth in the GT dataset \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/egs_bases/nerf/lm3d_nerf_torso.yaml b/Geneface_main/GeneFace/egs/egs_bases/nerf/lm3d_nerf_torso.yaml new file mode 100644 index 00000000..ea05f8d2 --- /dev/null +++ b/Geneface_main/GeneFace/egs/egs_bases/nerf/lm3d_nerf_torso.yaml @@ -0,0 +1,9 @@ +base_config: + - egs/egs_bases/nerf/lm3d_nerf.yaml + +task_cls: tasks.nerfs.lm3d_nerf_torso.Lm3dNeRFTorsoTask + +no_smo_iterations: 0 # nerf_torso use the fixed audatt_net from head_nerf +use_color: true + +head_model_dir: '' diff --git a/Geneface_main/GeneFace/egs/egs_bases/postnet/base.yaml b/Geneface_main/GeneFace/egs/egs_bases/postnet/base.yaml new file mode 100644 index 00000000..ab67c94d --- /dev/null +++ b/Geneface_main/GeneFace/egs/egs_bases/postnet/base.yaml @@ -0,0 +1,38 @@ +base_config: + - egs/egs_bases/audio2motion/vae_sync.yaml + +task_cls: tasks.postnet.lm3d_postnet_adv_sync.PostnetAdvSyncTask +audio2motion_task_cls: tasks.audio2motion.lm3d_vae_sync.VAESyncAudio2MotionTask +person_binary_data_dir: data/binary/videos +# postnet training +postnet_lr: 0.0001 +postnet_lambda_adv: 0.85 +postnet_lambda_sync: 0.1 +postnet_lambda_mse: 0.05 + +# Discriminator +postnet_disc_lr: 0.0001 +discriminator_scheduler_params: + gamma: 0.5 + step_size: 40000 +postnet_disc_start_steps: 0 +postnet_disc_interval: 1 + +# Training Schedule +scheduler: none +num_ckpt_keep: 500 +val_check_interval: 1000 +valid_infer_interval: 1000 +max_updates: 20000 + +# Pretrained Ckpts +audio2motion_work_dir: checkpoints/lrs3/lm3d_vae_sync/ +audio2motion_ckpt_steps: 40000 +syncnet_work_dir: checkpoints/lrs3/syncnet +syncnet_ckpt_steps: 40000 + +infer_audio_source_name: data/raw/val_wavs/zozo.wav +infer_out_npy_name: infer_out/May/pred_lm3d/zozo.npy +infer_ckpt_steps: 6000 + +load_db_to_memory: false # enable it for faster indexing diff --git a/Geneface_main/GeneFace/egs/egs_bases/radnerf/base.yaml b/Geneface_main/GeneFace/egs/egs_bases/radnerf/base.yaml new file mode 100644 index 00000000..78e86fd2 --- /dev/null +++ b/Geneface_main/GeneFace/egs/egs_bases/radnerf/base.yaml @@ -0,0 +1,125 @@ +# dataset-related +raw_data_dir: data/raw/videos +processed_data_dir: data/processed/videos +binary_data_dir: data/binary/videos +video_id: '' +task_cls: '' + +# project-related +work_dir: '' +load_ckpt: '' +tb_log_interval: 100 +num_ckpt_keep: 1 +val_check_interval: 2000 +valid_infer_interval: 10000 +num_sanity_val_steps: 2 +num_valid_plots: 5 +eval_max_batches: 100 # num_test_plots +print_nan_grads: false +resume_from_checkpoint: 0 # specify the step, 0 for latest +amp: false +valid_monitor_key: val_loss +valid_monitor_mode: min +save_best: true +debug: false +save_codes: +- tasks +- modules +- egs + +# testing related +save_gt: true + +# training-scheme-related +seed: 9999 +lr: 0.0005 +scheduler: exponential # exponential|rsqrt|warmup|none|step_lr +warmup_updates: 0 +optimizer_adam_beta1: 0.9 +optimizer_adam_beta2: 0.999 +weight_decay: 0 +clip_grad_norm: 0 # disable grad clipping +clip_grad_value: 0 # disable grad clipping +accumulate_grad_batches: 1 + +# model-related +cond_type: '' # deepspeech, esperanto, idexp_lm3d + +# training +amp: true # use fp16 +load_imgs_to_memory: true # load uint8 training img to memory, which reduce io costs, at the expense of more memory occupation + +# NeRF-related +near: 0.3 +far: 0.9 +n_rays: 65536 # num rays sampled per image for each training step, default 4096*16 +cuda_ray: true # use CUDA raymarching instead of pytorch +max_steps: 16 # max num steps sampled per ray (only valid when using --cuda_ray) +num_steps: 16 # num steps sampled per ray (only valid when NOT using --cuda_ray) +upsample_steps: 0 # num steps up-sampled per ray (only valid when NOT using --cuda_ray) +update_extra_interval: 16 # iter interval to update extra status (only valid when using --cuda_ray) +max_ray_batch: 4096 # batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray) + + +max_updates: 25_0000 # 40_0000 for training the whole head, 5_0000 for finetuning the mouth +finetune_lips: true +finetune_lips_start_iter: 20_0000 +lambda_lpips_loss: 0.01 # auxiliary loss for finetune lips +lambda_weights_entropy: 0.0001 +lambda_ambient: 0.1 + +min_near: 0.05 # minimum near distance for camera +bound: 1 # assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching. +camera_scale: 4. # scale camera location into box[-bound, bound]^3 +camera_offset: [0, 0, 0] # offset of camera location +grid_size: 128 +desired_resolution: 2048 +log2_hashmap_size: 16 +dt_gamma: 0.00390625 # default 1/256, dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality) +density_thresh: 10 # threshold for density grid to be occupied (sigma) +density_thresh_torso: 0.01 # threshold for density grid to be occupied (alpha) +torso_shrink: 0.8 # shrink bg coords to allow more flexibility in deform + +smooth_lips: false + +# Network +grid_type: tiledgrid # tiledgrid or hashgrid +grid_interpolation_type: linear # smoothstep or linear +with_att: true +use_window_cond: true +torso_head_aware: false # head aware torso nerf to avoid head-torso separation artifacts! +num_layers_sigma: 3 +hidden_dim_sigma: 128 # 64 by radnerf is too small +geo_feat_dim: 128 # 64 by radnerf is too small +num_layers_color: 2 +hidden_dim_color: 128 # 64 by radnerf is too small +cond_out_dim: 64 +num_layers_ambient: 3 +hidden_dim_ambient: 128 # 64 by radnerf is too small +ambient_out_dim: 2 +individual_embedding_num: 13000 +individual_embedding_dim: 4 +torso_individual_embedding_dim: 8 + +# infer +infer_cond_name: '' +infer_out_video_name: '' +infer_scale_factor: 1.0 +infer_smo_std: 0. +infer_audio_source_name: '' +infer_c2w_name: '' +infer_lm3d_clamp_std: 2.5 +infer_lm3d_lle_percent: 0. # percent of lle fused feature to compose the processed lm3d +infer_lm3d_smooth_sigma: 0. # sigma of gaussian kernel to smooth the predicted lm3d +infer_bg_img_fname: '' # black, white, or a img fname +infer_smooth_camera_path: true +infer_smooth_camera_path_kernel_size: 7 + +# gui feat +gui_w: 512 +gui_h: 512 +gui_radius: 3.35 +gui_fovy: 21.24 +gui_max_spp: 1 # GUI rendering max sample per pixel + +load_imgs_to_memory: false # load uint8 training img to memory, which reduce io costs, at the expense of more memory occupation diff --git a/Geneface_main/GeneFace/egs/egs_bases/radnerf/lm3d_radnerf.yaml b/Geneface_main/GeneFace/egs/egs_bases/radnerf/lm3d_radnerf.yaml new file mode 100644 index 00000000..c90c391c --- /dev/null +++ b/Geneface_main/GeneFace/egs/egs_bases/radnerf/lm3d_radnerf.yaml @@ -0,0 +1,7 @@ +base_config: + - ./base.yaml + +task_cls: tasks.radnerfs.radnerf.RADNeRFTask +cond_type: idexp_lm3d_normalized +cond_win_size: 1 +smo_win_size: 5 diff --git a/Geneface_main/GeneFace/egs/egs_bases/radnerf/radnerf.yaml b/Geneface_main/GeneFace/egs/egs_bases/radnerf/radnerf.yaml new file mode 100644 index 00000000..cfdd7f56 --- /dev/null +++ b/Geneface_main/GeneFace/egs/egs_bases/radnerf/radnerf.yaml @@ -0,0 +1,7 @@ +base_config: + - ./base.yaml + +task_cls: tasks.radnerfs.radnerf.RADNeRFTask +cond_type: esperanto +cond_win_size: 16 +smo_win_size: 8 \ No newline at end of file diff --git a/Geneface_main/GeneFace/egs/egs_bases/syncnet/base.yaml b/Geneface_main/GeneFace/egs/egs_bases/syncnet/base.yaml new file mode 100644 index 00000000..4d48a493 --- /dev/null +++ b/Geneface_main/GeneFace/egs/egs_bases/syncnet/base.yaml @@ -0,0 +1,37 @@ +# dataset-related +binary_data_dir: data/binary/lrs3 + +# project-related +work_dir: '' +load_ckpt: '' +tb_log_interval: 100 +val_check_interval: 1000 +valid_infer_interval: 1000 +num_sanity_val_steps: 5 +num_valid_plots: 1 +eval_max_batches: 10 # num_test_plots +print_nan_grads: false +resume_from_checkpoint: 0 # specify the step, 0 for latest +amp: false +valid_monitor_key: val_loss +valid_monitor_mode: min +save_best: true +debug: false +save_codes: +- tasks +- modules +- egs +accumulate_grad_batches: 1 +clip_grad_norm: 1. + +# training-scheme-related +task_cls: tasks.syncnet.lm3d_syncnet.SyncNetTask +max_updates: 4_0000 +seed: 9999 +lr: 0.0005 +optimizer_adam_beta1: 0.9 +optimizer_adam_beta2: 0.999 +scheduler: none +num_ckpt_keep: 100 + +load_db_to_memory: false # enable it for faster indexing diff --git a/Geneface_main/GeneFace/inference/audio2motion/audio2motion_infer.py b/Geneface_main/GeneFace/inference/audio2motion/audio2motion_infer.py new file mode 100644 index 00000000..a104ed7f --- /dev/null +++ b/Geneface_main/GeneFace/inference/audio2motion/audio2motion_infer.py @@ -0,0 +1,138 @@ +import os +import torch +import librosa +import numpy as np +import importlib +import tqdm + +from utils.commons.tensor_utils import move_to_cuda +from utils.commons.ckpt_utils import load_ckpt, get_last_checkpoint +from utils.commons.hparams import hparams, set_hparams + + +class Audio2MotionInfer: + def __init__(self, hparams, device=None): + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.hparams = hparams + self.infer_max_length = hparams.get('infer_max_length', 500000) + self.device = device + self.audio2motion_task = self.build_audio2motion_task() + self.audio2motion_task.eval() + self.audio2motion_task.to(self.device) + + def build_audio2motion_task(self): + assert hparams['task_cls'] != '' + pkg = ".".join(hparams["task_cls"].split(".")[:-1]) + cls_name = hparams["task_cls"].split(".")[-1] + task_cls = getattr(importlib.import_module(pkg), cls_name) + task = task_cls() + task.build_model() + task.eval() + steps = hparams.get('infer_ckpt_steps', 40000) + load_ckpt(task.model, hparams['work_dir'], 'model', steps=steps) + ckpt, _ = get_last_checkpoint(hparams['work_dir'], steps=steps) + task.global_step = ckpt['global_step'] + return task + + def infer_once(self, inp): + self.inp = inp + samples = self.get_cond_from_input(inp) + out_name = self.forward_system(samples, inp) + print(f"The predicted 3D landmark sequence is saved at {out_name}") + + def get_cond_from_input(self, inp): + """ + :param inp: {'audio_source_name': (str)} + :return: a list that contains the condition feature of NeRF + """ + self.save_wav16k(inp) + from data_gen.process_lrs3.process_audio_hubert import get_hubert_from_16k_wav + hubert = get_hubert_from_16k_wav(self.wav16k_name).detach().numpy() + len_mel = hubert.shape[0] + x_multiply = 8 + if len_mel % x_multiply == 0: + num_to_pad = 0 + else: + num_to_pad = x_multiply - len_mel % x_multiply + hubert = np.pad(hubert, pad_width=((0,num_to_pad), (0,0))) + t_x = hubert.shape[0] + x_mask = torch.ones([1, t_x]).float() + y_mask = torch.ones([1, t_x//2]).float() + + from data_gen.process_lrs3.process_audio_mel_f0 import extract_mel_from_fname,extract_f0_from_wav_and_mel + wav, mel = extract_mel_from_fname(self.wav16k_name) + f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel) + f0 = f0.reshape([-1,1]) + if f0.shape[0] > len(hubert): + f0 = f0[:len(hubert)] + else: + num_to_pad = len(hubert) - len(f0) + f0 = np.pad(f0, pad_width=((0,num_to_pad), (0,0))) + f0 = f0.squeeze(-1) + + sample = { + 'hubert': torch.from_numpy(hubert).float().unsqueeze(0), + 'f0': torch.from_numpy(f0).float().unsqueeze(0), + 'x_mask': x_mask, + 'y_mask': y_mask, + } + return [sample] + + def forward_system(self, batches, inp): + out_dir = self._forward_audio2motion_task(batches, inp) + return out_dir + + def _forward_audio2motion_task(self, batches, inp): + with torch.no_grad(): + pred_lst = [] + for idx, batch in tqdm.tqdm(enumerate(batches), total=len(batches), + desc=f"Now VAE is predicting the action into {inp['out_npy_name']}"): + if self.device == 'cuda': + batch = move_to_cuda(batch) + + model_out = self.audio2motion_task.run_model(batch, infer=True) + pred = model_out['pred'].squeeze().cpu().numpy() + pred_lst.append(pred) + np.save(inp['out_npy_name'], pred_lst) + return inp['out_npy_name'] + + @classmethod + def example_run(cls, inp=None): + inp_tmp = { + 'audio_source_name': 'data/raw/val_wavs/zozo.wav', + 'out_npy_name': 'infer_out/lrs3/0.npy' + } + if inp is not None: + inp_tmp.update(inp) + inp = inp_tmp + if hparams.get("infer_audio_source_name", '') != '': + inp['audio_source_name'] = hparams['infer_audio_source_name'] + if hparams.get("infer_out_npy_name", '') != '': + inp['out_npy_name'] = hparams['infer_out_npy_name'] + out_dir = os.path.dirname(inp['out_npy_name']) + + os.makedirs(out_dir, exist_ok=True) + infer_ins = cls(hparams) + infer_ins.infer_once(inp) + + ############## + # IO-related + ############## + def save_wav16k(self, inp): + source_name = inp['audio_source_name'] + supported_types = ('.wav', '.mp3', '.mp4', '.avi') + assert source_name.endswith(supported_types), f"Now we only support {','.join(supported_types)} as audio source!" + wav16k_name = source_name[:-4] + '_16k.wav' + self.wav16k_name = wav16k_name + extract_wav_cmd = f"ffmpeg -i {source_name} -f wav -ar 16000 {wav16k_name} -y" + os.system(extract_wav_cmd) + print(f"I have extracted wav file (16khz) from {source_name} to {wav16k_name}.") + +if __name__ == '__main__': + set_hparams() + inp = { + 'audio_source_name': 'data/raw/val_wavs/zozo.wav', + 'out_npy_name': 'infer_outs/out.npy', + } + Audio2MotionInfer.example_run(inp) \ No newline at end of file diff --git a/Geneface_main/GeneFace/inference/audio2pose/audio2pose_infer.py b/Geneface_main/GeneFace/inference/audio2pose/audio2pose_infer.py new file mode 100644 index 00000000..dbb92c1f --- /dev/null +++ b/Geneface_main/GeneFace/inference/audio2pose/audio2pose_infer.py @@ -0,0 +1,153 @@ +import os +import torch +import librosa +import numpy as np +import importlib +import tqdm + +from utils.commons.tensor_utils import move_to_cuda +from utils.commons.ckpt_utils import load_ckpt, get_last_checkpoint +from utils.commons.hparams import hparams, set_hparams +from utils.commons.euler2rot import euler_trans_2_c2w + +from tasks.audio2pose.dataset_utils import Audio2PoseDataset + + +class Audio2PoseInfer: + def __init__(self, hparams, device=None): + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.hparams = hparams + self.infer_max_length = hparams.get('infer_max_length', 500000) + self.device = device + self.audio2pose_task = self.build_audio2pose_task() + self.audio2pose_task.eval() + self.audio2pose_task.to(self.device) + dataset = Audio2PoseDataset() + self.mean_trans = dataset.mean_trans.unsqueeze(0).numpy() + self.init_pose = torch.cat([dataset.euler_lst[0], dataset.trans_lst[0]], dim=0).numpy() + + def build_audio2pose_task(self): + assert hparams['task_cls'] != '' + pkg = ".".join(hparams["task_cls"].split(".")[:-1]) + cls_name = hparams["task_cls"].split(".")[-1] + task_cls = getattr(importlib.import_module(pkg), cls_name) + task = task_cls() + task.build_model() + task.eval() + # steps = hparams.get('infer_ckpt_steps', 5000) + steps = None + load_ckpt(task.model, hparams['work_dir'], 'model', steps=steps) + ckpt, _ = get_last_checkpoint(hparams['work_dir'], steps=steps) + task.global_step = ckpt['global_step'] + return task + + def infer_once(self, inp): + self.inp = inp + samples = self.get_cond_from_input(inp) + out_name = self.forward_system(samples, inp) + print(f"The predicted 3D landmark sequence is saved at {out_name}") + + def get_cond_from_input(self, inp): + """ + :param inp: {'audio_source_name': (str)} + :return: a list that contains the condition feature of NeRF + """ + + self.save_wav16k(inp) + # from data_gen.process_lrs3.process_audio_hubert import get_hubert_from_16k_wav + # hubert = get_hubert_from_16k_wav(self.wav16k_name).detach().numpy() + # len_mel = hubert.shape[0] + # x_multiply = 8 + # if len_mel % x_multiply == 0: + # num_to_pad = 0 + # else: + # num_to_pad = x_multiply - len_mel % x_multiply + # hubert = np.pad(hubert, pad_width=((0,num_to_pad), (0,0))) # [t_x, 1024] + # t_x = hubert.shape[0] + # hubert = hubert.reshape([t_x//2, 1024*2]) + # sample = { + # 'hubert': torch.from_numpy(hubert).float().unsqueeze(0), # [1, T, 2048] + # } + + # load the deepspeech features as the condition for lm3d torso nerf + wav16k_name = self.wav16k_name + deepspeech_name = wav16k_name[:-4] + '_deepspeech.npy' + if not os.path.exists(deepspeech_name): + print(f"Try to extract deepspeech from {wav16k_name}...") + # deepspeech_python = '/home/yezhenhui/anaconda3/envs/geneface/bin/python' # the path of your python interpreter that has installed DeepSpeech + # extract_deepspeech_cmd = f'{deepspeech_python} data_util/deepspeech_features/extract_ds_features.py --input={wav16k_name} --output={deepspeech_name}' + extract_deepspeech_cmd = f'python data_util/deepspeech_features/extract_ds_features.py --input={wav16k_name} --output={deepspeech_name}' + os.system(extract_deepspeech_cmd) + print(f"Saved deepspeech features of {wav16k_name} to {deepspeech_name}.") + else: + print(f"Try to load pre-extracted deepspeech from {deepspeech_name}...") + deepspeech_arr = np.load(deepspeech_name) # [T, w=16, c=29] + print(f"Loaded deepspeech features from {deepspeech_name}.") + # get window condition of deepspeech + sample = {} + # sample['deepspeech'] = torch.from_numpy(deepspeech_arr).float().reshape([-1, 16*29]) + sample['deepspeech'] = torch.from_numpy(deepspeech_arr[:, 7:9,:]).float().reshape([-1, 2*29]) + return [sample] + + def forward_system(self, batches, inp): + out_dir = self._forward_audio2pose_task(batches, inp) + return out_dir + + def _forward_audio2pose_task(self, batches, inp): + with torch.no_grad(): + pred_lst = [] + for idx, batch in tqdm.tqdm(enumerate(batches), total=len(batches), + desc=f"Now Audio2Pose model is predicting the head pose (camera2world matrix) into {inp['out_npy_name']}"): + if self.device == 'cuda': + batch = move_to_cuda(batch) + + # smo_pred_pose = self.audio2pose_task.model.autoregressive_infer(batch['hubert'].squeeze(), self.init_pose) + smo_pred_pose = self.audio2pose_task.model.autoregressive_infer(batch['deepspeech'].squeeze(), self.init_pose) + smo_pred_pose = smo_pred_pose.squeeze().cpu().numpy() + euler, trans = smo_pred_pose[:,:3], smo_pred_pose[:,3:6] + trans = trans + self.mean_trans + c2w = euler_trans_2_c2w(euler, trans).numpy() + pred_lst.append(c2w) + np.save(inp['out_npy_name'], pred_lst) + return inp['out_npy_name'] + + @classmethod + def example_run(cls, inp=None): + inp_tmp = { + 'audio_source_name': 'data/raw/val_wavs/zozo.wav', + 'out_npy_name': 'infer_outs/May/pred_c2w/zozo.npy' + } + if inp is not None: + inp_tmp.update(inp) + inp = inp_tmp + if hparams.get("infer_audio_source_name", '') != '': + inp['audio_source_name'] = hparams['infer_audio_source_name'] + if hparams.get("infer_out_npy_name", '') != '': + inp['out_npy_name'] = hparams['infer_out_npy_name'] + out_dir = os.path.dirname(inp['out_npy_name']) + + os.makedirs(out_dir, exist_ok=True) + infer_ins = cls(hparams) + infer_ins.infer_once(inp) + + ############## + # IO-related + ############## + def save_wav16k(self, inp): + source_name = inp['audio_source_name'] + supported_types = ('.wav', '.mp3', '.mp4', '.avi') + assert source_name.endswith(supported_types), f"Now we only support {','.join(supported_types)} as audio source!" + wav16k_name = source_name[:-4] + '_16k.wav' + self.wav16k_name = wav16k_name + extract_wav_cmd = f"ffmpeg -i {source_name} -f wav -ar 16000 -v quiet {wav16k_name} -y" + os.system(extract_wav_cmd) + print(f"Extracted wav file (16khz) from {source_name} to {wav16k_name}.") + +if __name__ == '__main__': + set_hparams() + inp = { + 'audio_source_name': 'data/raw/val_wavs/zozo.wav', + 'out_npy_name': 'infer_outs/May/pred_c2w/zozo.npy', + } + Audio2PoseInfer.example_run(inp) \ No newline at end of file diff --git a/Geneface_main/GeneFace/inference/nerfs/__pycache__/base_nerf_infer.cpython-39.pyc b/Geneface_main/GeneFace/inference/nerfs/__pycache__/base_nerf_infer.cpython-39.pyc new file mode 100644 index 00000000..04de0f7c Binary files /dev/null and b/Geneface_main/GeneFace/inference/nerfs/__pycache__/base_nerf_infer.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/inference/nerfs/__pycache__/lm3d_nerf_infer.cpython-39.pyc b/Geneface_main/GeneFace/inference/nerfs/__pycache__/lm3d_nerf_infer.cpython-39.pyc new file mode 100644 index 00000000..3bc61ded Binary files /dev/null and b/Geneface_main/GeneFace/inference/nerfs/__pycache__/lm3d_nerf_infer.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/inference/nerfs/adnerf_infer.py b/Geneface_main/GeneFace/inference/nerfs/adnerf_infer.py new file mode 100644 index 00000000..42a90293 --- /dev/null +++ b/Geneface_main/GeneFace/inference/nerfs/adnerf_infer.py @@ -0,0 +1,45 @@ +import os +import numpy as np +import torch +from inference.nerfs.base_nerf_infer import BaseNeRFInfer +from data_gen.nerf.binarizer import get_win_conds + + +class AdNeRFInfer(BaseNeRFInfer): + def get_cond_from_input(self, inp): + """ + :param inp: {'audio_source_name': (str), 'cond_name': (str, optional)} + :return: a list that contains the condition feature of NeRF + """ + self.save_wav16k(inp) + if inp.get('cond_name', None) is not None: + assert inp['cond_name'].endswith('.npy') + deepspeech_arr = np.load(inp['cond_name']) # [T, w=16, c=29] + print(f"I have Loaded pre-extracted deepspeech from {inp['cond_name']}!") + else: + wav16k_name = self.wav16k_name + print(f"Trying to extract deepspeech from {wav16k_name}...") + deepspeech_name = wav16k_name[:-4] + '_deepspeech.npy' + if not os.path.exists(deepspeech_name): + extract_deepspeech_cmd = f'python data_util/deepspeech_features/extract_ds_features.py --input={wav16k_name} --output={deepspeech_name}' + os.system(extract_deepspeech_cmd) + print(f"I have extracted deepspeech features from {wav16k_name} to {deepspeech_name}.") + else: + print(f"I have Loaded pre-extracted deepspeech from {deepspeech_name}!") + deepspeech_arr = np.load(deepspeech_name) # [T, w=16, c=29] + + num_samples = min(len(deepspeech_arr), self.infer_max_length) + samples = [{} for _ in range(num_samples)] + for idx, sample in enumerate(samples): + sample['cond_win'] = torch.from_numpy(deepspeech_arr[idx]).float().unsqueeze(0) # [B=1, w=16, C=29] + sample['cond_wins'] = torch.from_numpy(get_win_conds(deepspeech_arr, idx, smo_win_size=8)).float() #.unsqueeze(0) # [B=1,W=8, w=16, C=29] + return samples + +if __name__ == '__main__': + from utils.commons.hparams import set_hparams + from utils.commons.hparams import hparams as hp + inp = { + 'audio_source_name': 'data/raw/val_wavs/zozo.wav', + 'out_video_name': 'infer_outs/out.mp4', + } + AdNeRFInfer.example_run(inp) \ No newline at end of file diff --git a/Geneface_main/GeneFace/inference/nerfs/base_nerf_infer.py b/Geneface_main/GeneFace/inference/nerfs/base_nerf_infer.py new file mode 100644 index 00000000..ddf6120f --- /dev/null +++ b/Geneface_main/GeneFace/inference/nerfs/base_nerf_infer.py @@ -0,0 +1,317 @@ +import os +import sys +import cv2 +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import numpy as np +import importlib +import tqdm +import logging +import copy +import re +import random + +from utils.commons.ddp_utils import DDP +from utils.commons.hparams import hparams, set_hparams +from utils.commons.ckpt_utils import load_ckpt, get_last_checkpoint +from utils.commons.euler2rot import euler_trans_2_c2w, c2w_to_euler_trans +from utils.commons.tensor_utils import move_to_cpu, move_to_cuda, convert_to_tensor + +from tasks.nerfs.dataset_utils import NeRFDataset +from scipy.ndimage import gaussian_filter1d +from scipy.spatial.transform import Rotation + + +def smooth_camera_path(poses, kernel_size=7): + # smooth the camera trajectory (i.e., translation)... + # poses: [N, 4, 4], numpy array + N = poses.shape[0] + K = kernel_size // 2 + + trans = poses[:, :3, 3].copy() # [N, 3] + rots = poses[:, :3, :3].copy() # [N, 3, 3] + + for i in range(N): + start = max(0, i - K) + end = min(N, i + K + 1) + poses[i, :3, 3] = trans[start:end].mean(0) + try: + poses[i, :3, :3] = Rotation.from_matrix(rots[start:end]).mean().as_matrix() + except: + if i == 0: + poses[i, :3, :3] = rots[i] + else: + poses[i, :3, :3] = poses[i-1, :3, :3] + return poses + + +class BaseNeRFInfer: + def __init__(self, hparams, device=None): + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.hparams = hparams + self.infer_max_length = hparams.get('infer_max_length', 500000) # default render 10 seconds long + self.device = device + self.dataset_cls = NeRFDataset # the dataset only provides head pose + self.dataset = self.dataset_cls('trainval') + + assert hparams['task_cls'] != '' + pkg = ".".join(hparams["task_cls"].split(".")[:-1]) + cls_name = hparams["task_cls"].split(".")[-1] + self.task_cls = getattr(importlib.import_module(pkg), cls_name) + + self.all_gpu_ids = [int(x) for x in os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",") if x != ''] + self.num_gpus = len(self.all_gpu_ids) + self.on_gpu = self.num_gpus > 0 + self.root_gpu = 0 + logging.info(f'GPU available: {torch.cuda.is_available()}, GPU used: {self.all_gpu_ids}') + self.use_ddp = self.num_gpus > 1 + self.proc_rank = 0 + + def build_nerf_task(self): + task = self.task_cls() + task.build_model() + task.eval() + load_ckpt(task.model, hparams['work_dir'], 'model') + ckpt, _ = get_last_checkpoint(hparams['work_dir']) + task.global_step = ckpt['global_step'] + return task + + def _forward_nerf_task_single_process(self, batches): + tmp_imgs_dir = self.inp['tmp_imgs_dir'] + os.makedirs(tmp_imgs_dir, exist_ok=True) + H, W = batches[0]['H'], batches[0]['W'] + H = int(hparams['infer_scale_factor']*H) + W = int(hparams['infer_scale_factor']*W) + idx_batch_lst = [(idx, batch) for idx,batch in enumerate(batches)] + + print(f"The tmp imge dir is {tmp_imgs_dir}.") + with torch.no_grad(): + for (idx, batch) in tqdm.tqdm(idx_batch_lst, total=len(idx_batch_lst), + desc=f"NeRF is rendering frames..."): + torch.cuda.empty_cache() + if self.device == 'cuda': + batch = move_to_cuda(batch) + model_out = self.nerf_task.run_model(batch, infer=True) + pred_rgb = model_out['rgb_map'] * 255 + pred_img = pred_rgb.view([H, W, 3]).cpu().numpy().astype(np.uint8) + out_name = os.path.join(tmp_imgs_dir, format(idx, '05d')+".png") + bgr_img = cv2.cvtColor(pred_img, cv2.COLOR_RGB2BGR) + cv2.imwrite(out_name, bgr_img) + batches[idx] = move_to_cpu(batch) + for k in list(batch.keys()): + del batch[k] + torch.cuda.empty_cache() + return tmp_imgs_dir + + def init_ddp_connection(self, proc_rank, world_size): + root_node = '127.0.0.1' + root_node = self.resolve_root_node_address(root_node) + os.environ['MASTER_ADDR'] = root_node + os.environ['MASTER_PORT'] = '12345' + dist.init_process_group('nccl', rank=proc_rank, world_size=world_size) + + def resolve_root_node_address(self, root_node): + if '[' in root_node: + name = root_node.split('[')[0] + number = root_node.split(',')[0] + if '-' in number: + number = number.split('-')[0] + number = re.sub('[^0-9]', '', number) + root_node = name + number + return root_node + + def configure_ddp(self, task): + task = DDP(task, device_ids=[self.root_gpu], find_unused_parameters=True) + random.seed(self.hparams['seed']) + np.random.seed(self.hparams['seed']) + return task + + def _forward_nerf_task_ddp(self, gpu_idx, batches, hparams_): + hparams.update(hparams_) # the global hparams dict in the subprocess is empty, so inplace-update it! + self.proc_rank = gpu_idx + self.init_ddp_connection(self.proc_rank, self.num_gpus) + # if dist.get_rank() != 0: + # sys.stdout = open(os.devnull, "w") + # sys.stderr = open(os.devnull, "w") + tmp_imgs_dir = self.inp['tmp_imgs_dir'] + os.makedirs(tmp_imgs_dir, exist_ok=True) + torch.cuda.set_device(gpu_idx) + self.root_gpu = gpu_idx + self.nerf_task = self.build_nerf_task() + self.nerf_task.eval() + self.nerf_task.cuda() + self.nerf_task = self.configure_ddp(self.nerf_task) + dist.barrier() + nerf_task = self.nerf_task.module + self.dataset = self.dataset_cls('train') + + idx_batch_lst = [(idx, batch) for idx,batch in enumerate(batches)] + num_batchs_per_gpu = len(batches) // self.num_gpus + if self.proc_rank != self.num_gpus-1: + idx_batch_lst = idx_batch_lst[self.proc_rank*num_batchs_per_gpu:(self.proc_rank+1)*num_batchs_per_gpu] + else: + idx_batch_lst = idx_batch_lst[self.proc_rank*num_batchs_per_gpu:] + + H, W = batches[0]['H'], batches[0]['W'] + H = int(hparams['infer_scale_factor']*H) + W = int(hparams['infer_scale_factor']*W) + with torch.no_grad(): + if dist.get_rank() == 0: + print(f"The tmp imge dir is {tmp_imgs_dir}.") + for (idx, batch) in tqdm.tqdm(idx_batch_lst, total=len(idx_batch_lst), + desc=f"Process {self.proc_rank} : NeRF is rendering frames..."): + torch.cuda.empty_cache() + if self.device == 'cuda': + batch = move_to_cuda(batch, self.root_gpu) + model_out = nerf_task.run_model(batch, infer=True) + pred_rgb = model_out['rgb_map'] * 255 + pred_img = pred_rgb.view([H, W, 3]).cpu().numpy().astype(np.uint8) + out_name = os.path.join(tmp_imgs_dir, format(idx, '05d')+".png") + bgr_img = cv2.cvtColor(pred_img, cv2.COLOR_RGB2BGR) + cv2.imwrite(out_name, bgr_img) + batches[idx] = move_to_cpu(batch) + for k in list(batch.keys()): + del batch[k] + torch.cuda.empty_cache() + dist.barrier() + return tmp_imgs_dir + + def forward_system(self, batches): + if self.use_ddp: + del self.dataset + torch.multiprocessing.set_sharing_strategy('file_system') + batches = copy.deepcopy(batches) + mp.spawn(self._forward_nerf_task_ddp, nprocs=self.num_gpus, args=[batches, copy.deepcopy(hparams)]) + img_dir = self.inp['tmp_imgs_dir'] + else: + self.nerf_task = self.build_nerf_task() + self.nerf_task.eval() + self.nerf_task.to(self.device) + img_dir = self._forward_nerf_task_single_process(batches) + return img_dir + + def get_cond_from_input(self, inp): + """ + get the conditon features of NeRF + """ + raise NotImplementedError + + def get_pose_from_ds(self, samples): + """ + process the item into torch.tensor batch + """ + if self.use_pred_pose: + print(f"The head pose mode is: pred") + c2w_arr = np.load(self.inp['c2w_name'])[0] # [T, 3, 3] + print(f"Loaded head pose from {self.inp['c2w_name']}.") + assert len(samples) - len(c2w_arr) < 5 + if len(samples) > len(c2w_arr): + samples = samples[:len(c2w_arr)] + if len(samples) < len(c2w_arr): + c2w_arr = c2w_arr[:len(samples)] + else: + print(f"The head pose mode is: gt") + + for idx, sample in enumerate(samples): + if idx >= len(self.dataset.samples) and not self.use_pred_pose: + # since we use GT head pose from the dataset, the pred_samples cannot be longer than the GT samples + del samples[idx:] + break + sample['H'] = self.dataset.H + sample['W'] = self.dataset.W + sample['focal'] = self.dataset.focal + sample['cx'] = self.dataset.cx + sample['cy'] = self.dataset.cy + sample['near'] = hparams['near'] + sample['far'] = hparams['far'] + sample['bg_img'] = self.dataset.bg_img + + if self.use_pred_pose: + sample['c2w'] = torch.from_numpy(c2w_arr[idx]) + else: + sample['c2w'] = self.dataset.samples[idx]['c2w'][:3] + sample['c2w_t0'] = self.dataset.samples[0]['c2w'][:3] + + sample['t'] = torch.tensor([0,]).float() + euler, trans = c2w_to_euler_trans(sample['c2w']) + euler_t0, trans_t0 = c2w_to_euler_trans(sample['c2w_t0']) + sample['euler'] = torch.tensor(np.ascontiguousarray(euler)).float() + sample['trans'] = torch.tensor(np.ascontiguousarray(trans)).float() + sample['euler_t0'] = torch.tensor(np.ascontiguousarray(euler_t0)).float() + sample['trans_t0'] = torch.tensor(np.ascontiguousarray(trans_t0)).float() + + if hparams.get("infer_smo_head_pose", True) is True: + c2w_arr = torch.stack([s['c2w'] for s in samples]).numpy() + smo_c2w_arr = smooth_camera_path(c2w_arr) + for i, sample in enumerate(samples): + sample['c2w'] = convert_to_tensor(smo_c2w_arr[i]) + euler, trans = c2w_to_euler_trans(sample['c2w']) + sample['euler'] = convert_to_tensor(np.ascontiguousarray(euler)) + sample['trans'] = convert_to_tensor(np.ascontiguousarray(trans)) + return samples + + def postprocess_output(self, output): + tmp_imgs_dir = self.inp['tmp_imgs_dir'] + out_video_name = self.inp['out_video_name'] + self.save_mp4(tmp_imgs_dir, self.wav16k_name, out_video_name) + return out_video_name + + def infer_once(self, inp): + self.inp = inp + self.use_pred_pose = True if self.inp.get('c2w_name','') != '' else False + samples = self.get_cond_from_input(inp) + batches = self.get_pose_from_ds(samples) + image_dir = self.forward_system(batches) + if self.proc_rank == 0: + out_name = self.postprocess_output(image_dir) + print(f"The synthesized video is saved at {out_name}") + + @classmethod + def example_run(cls, inp=None): + from utils.commons.hparams import set_hparams + from utils.commons.hparams import hparams as hp + set_hparams() + inp_tmp = { + 'audio_source_name': 'data/raw/val_wavs/zozo.wav', + 'out_dir': 'infer_out', + 'out_video_name': 'infer_out/zozo.mp4' + } + if inp is not None: + inp_tmp.update(inp) + inp = inp_tmp + if hparams.get("infer_cond_name", '') != '': + inp['cond_name'] = hparams['infer_cond_name'] + if hparams.get("infer_audio_source_name", '') != '': + inp['audio_source_name'] = hparams['infer_audio_source_name'] + if hparams.get("infer_out_video_name", '') != '': + inp['out_video_name'] = hparams['infer_out_video_name'] + if hparams.get("infer_c2w_name", '') != '': + inp['c2w_name'] = hparams['infer_c2w_name'] + out_dir = os.path.dirname(inp['out_video_name']) + video_name = os.path.basename(inp['out_video_name'])[:-4] + tmp_imgs_dir = os.path.join(out_dir, "tmp_imgs", video_name) + inp['tmp_imgs_dir'] = tmp_imgs_dir + + os.makedirs(out_dir, exist_ok=True) + os.makedirs(tmp_imgs_dir, exist_ok=True) + infer_ins = cls(hp) + infer_ins.infer_once(inp) + + ############## + # IO-related + ############## + @classmethod + def save_mp4(self, img_dir, wav_name, out_name): + os.system(f"ffmpeg -i {img_dir}/%5d.png -i {wav_name} -shortest -v quiet -c:v mpeg4 -pix_fmt yuv420p -b:v 2000k -r 25 -strict -2 -y {out_name}") + + def save_wav16k(self, inp): + source_name = inp['audio_source_name'] + supported_types = ('.wav', '.mp3', '.mp4', '.avi') + assert source_name.endswith(supported_types), f"Now we only support {','.join(supported_types)} as audio source!" + wav16k_name = source_name[:-4] + '_16k.wav' + self.wav16k_name = wav16k_name + extract_wav_cmd = f"ffmpeg -i {source_name} -v quiet -f wav -ar 16000 {wav16k_name} -y" + os.system(extract_wav_cmd) + print(f"Saved 16khz wav file to {wav16k_name}.") diff --git a/Geneface_main/GeneFace/inference/nerfs/lm3d_nerf_infer.py b/Geneface_main/GeneFace/inference/nerfs/lm3d_nerf_infer.py new file mode 100644 index 00000000..3649edb6 --- /dev/null +++ b/Geneface_main/GeneFace/inference/nerfs/lm3d_nerf_infer.py @@ -0,0 +1,153 @@ +import os +import numpy as np +import torch +import tqdm +import cv2 +import importlib +import math +from scipy.ndimage import gaussian_filter1d + +from inference.nerfs.base_nerf_infer import BaseNeRFInfer +from data_util.extract_mel import get_mel_from_fname +from utils.commons.ckpt_utils import load_ckpt +from utils.commons.hparams import hparams, set_hparams +from utils.commons.tensor_utils import move_to_cuda, convert_to_tensor, convert_to_np +from utils.commons.euler2rot import euler_trans_2_c2w, c2w_to_euler_trans +from modules.postnet.lle import compute_LLE_projection, find_k_nearest_neighbors + + +class LM3dNeRFInfer(BaseNeRFInfer): + + def get_cond_from_input(self, inp): + """ + :param inp: {'audio_source_name': (str), 'cond_name': (str, optional)} + :return: a list that contains the condition feature of NeRF + """ + self.save_wav16k(inp) + + # load the lm3d as the condition for lm3d head nerf + assert inp['cond_name'].endswith('.npy') + lm3d_arr = np.load(inp['cond_name'])[0] # [T, w=16, c=29] + idexp_lm3d = torch.from_numpy(lm3d_arr).float() + print(f"Loaded pre-extracted 3D landmark sequence from {inp['cond_name']}!") + + # load the deepspeech features as the condition for lm3d torso nerf + wav16k_name = self.wav16k_name + deepspeech_name = wav16k_name[:-4] + '_deepspeech.npy' + if not os.path.exists(deepspeech_name): + print(f"Try to extract deepspeech from {wav16k_name}...") + # deepspeech_python = '/home/yezhenhui/anaconda3/envs/geneface/bin/python' # the path of your python interpreter that has installed DeepSpeech + # extract_deepspeech_cmd = f'{deepspeech_python} data_util/deepspeech_features/extract_ds_features.py --input={wav16k_name} --output={deepspeech_name}' + extract_deepspeech_cmd = f'python data_util/deepspeech_features/extract_ds_features.py --input={wav16k_name} --output={deepspeech_name}' + os.system(extract_deepspeech_cmd) + print(f"Saved deepspeech features of {wav16k_name} to {deepspeech_name}.") + else: + print(f"Try to load pre-extracted deepspeech from {deepspeech_name}...") + deepspeech_arr = np.load(deepspeech_name) # [T, w=16, c=29] + print(f"Loaded deepspeech features from {deepspeech_name}.") + # get window condition of deepspeech + from data_gen.nerf.binarizer import get_win_conds + num_samples = min(len(lm3d_arr), len(deepspeech_arr), self.infer_max_length) + samples = [{} for _ in range(num_samples)] + for idx, sample in enumerate(samples): + sample['deepspeech_win'] = torch.from_numpy(deepspeech_arr[idx]).float().unsqueeze(0) # [B=1, w=16, C=29] + sample['deepspeech_wins'] = torch.from_numpy(get_win_conds(deepspeech_arr, idx, smo_win_size=8)).float() # [W=8, w=16, C=29] + + idexp_lm3d_mean = self.dataset.idexp_lm3d_mean + idexp_lm3d_std = self.dataset.idexp_lm3d_std + idexp_lm3d_normalized = (idexp_lm3d.reshape([-1,68,3]) - idexp_lm3d_mean)/idexp_lm3d_std + + # step1. clamp the lm3d, to regularize apparent outliers + lm3d_clamp_std = hparams['infer_lm3d_clamp_std'] + idexp_lm3d_normalized[:,0:17] = torch.clamp(idexp_lm3d_normalized[:,0:17], -lm3d_clamp_std, lm3d_clamp_std) # yaw_x_y_z + idexp_lm3d_normalized[:,17:27,0:2] = torch.clamp(idexp_lm3d_normalized[:,17:27,0:2], -lm3d_clamp_std/2, lm3d_clamp_std/2) # brow_x_y + idexp_lm3d_normalized[:,17:27,2] = torch.clamp(idexp_lm3d_normalized[:,17:27,2], -lm3d_clamp_std, lm3d_clamp_std) # brow_z + idexp_lm3d_normalized[:,27:36] = torch.clamp(idexp_lm3d_normalized[:,27:36], -lm3d_clamp_std, lm3d_clamp_std) # nose + idexp_lm3d_normalized[:,36:48,0:2] = torch.clamp(idexp_lm3d_normalized[:,36:48,0:2], -lm3d_clamp_std/2, lm3d_clamp_std/2) # eye_x_y + idexp_lm3d_normalized[:,36:48,2] = torch.clamp(idexp_lm3d_normalized[:,36:48,2], -lm3d_clamp_std, lm3d_clamp_std) # eye_z + idexp_lm3d_normalized[:,48:68] = torch.clamp(idexp_lm3d_normalized[:,48:68], -lm3d_clamp_std, lm3d_clamp_std) # mouth + idexp_lm3d_normalized = idexp_lm3d_normalized.reshape([-1,68*3]) + + # step2. LLE projection to drag the predicted lm3d closer to the GT lm3d + LLE_percent = hparams['infer_lm3d_lle_percent'] + if LLE_percent > 0: + idexp_lm3d_normalized_database = torch.stack([s['idexp_lm3d_normalized'] for s in self.dataset.samples]).reshape([-1, 68*3]) + feat_fuse, _, _ = compute_LLE_projection(feats=idexp_lm3d_normalized[:, :48*3], feat_database=idexp_lm3d_normalized_database[:, :48*3], K=10) + idexp_lm3d_normalized[:, :48*3] = LLE_percent * feat_fuse + (1-LLE_percent) * idexp_lm3d_normalized[:,:48*3] + + # step3. inject eye blink + inject_eye_blink_mode = hparams.get("infer_inject_eye_blink_mode", "none") + print(f"The eye blink mode is: {inject_eye_blink_mode}") + if inject_eye_blink_mode == 'none': + pass + elif inject_eye_blink_mode == 'period': + # get a eye blink period (~40 frames) from the gt data + # then repeat it to the whole sequence length + blink_ref_frames_start_idx = hparams["infer_eye_blink_ref_frames_start_idx"] # the index of start frame of a blink period, + blink_ref_frames_end_idx = hparams["infer_eye_blink_ref_frames_end_idx"] # the index of end frame of a blink period, + assert blink_ref_frames_start_idx != '' or blink_ref_frames_end_idx != '', "If you want to use `period` eye blink mode, please find a eye blink period in your GT frames, then set `infer_eye_blink_pattern_start_idx` in your config file" + idexp_lm3d_normalized_database = torch.stack([s['idexp_lm3d_normalized'] for s in self.dataset.samples]).reshape([-1, 68*3]) + blink_eye_pattern = idexp_lm3d_normalized_database[blink_ref_frames_start_idx:blink_ref_frames_end_idx+1, 17*3:48*3].clone() + repeated_blink_eye_pattern = blink_eye_pattern.repeat([len(idexp_lm3d_normalized)//len(blink_eye_pattern)+1,1])[:len(idexp_lm3d_normalized)] + idexp_lm3d_normalized = idexp_lm3d_normalized.reshape([-1, 68*3]) + idexp_lm3d_normalized[:, 17*3:48*3] = repeated_blink_eye_pattern + idexp_lm3d_normalized = idexp_lm3d_normalized.reshape([-1, 68,3]) + elif inject_eye_blink_mode == 'gt': + # use the eye blink sequence from the gt data + idexp_lm3d_normalized_database = torch.stack([s['idexp_lm3d_normalized'] for s in self.dataset.samples]).reshape([-1, 68*3]) + blink_eye_pattern = idexp_lm3d_normalized_database[:, 17*3:48*3].clone() + repeated_blink_eye_pattern = blink_eye_pattern.repeat([len(idexp_lm3d_normalized)//len(blink_eye_pattern)+1,1])[:len(idexp_lm3d_normalized)] + idexp_lm3d_normalized = idexp_lm3d_normalized.reshape([-1, 68*3]) + idexp_lm3d_normalized[:, 17*3:48*3] = repeated_blink_eye_pattern + idexp_lm3d_normalized = idexp_lm3d_normalized.reshape([-1, 68,3]) + + else: + raise NotImplementedError() + + # step4. close the mouth in silent frames + # todo: remove `infer_sil_ref_frame_idx`, close the mouth using the current frame instead. + if hparams.get('infer_close_mouth_when_sil', False): + idexp_lm3d_normalized = idexp_lm3d_normalized.reshape([-1, 68*3]) + mel, energy = get_mel_from_fname(self.wav16k_name, return_energy=True) + energy = energy.reshape([-1]) + if len(energy) < 2*len(idexp_lm3d_normalized): + energy = np.concatenate([energy] + [energy[-1:]]*(2*len(idexp_lm3d_normalized)-len(energy))) + energy = energy[:2*len(idexp_lm3d_normalized)] + energy = energy.reshape([-1,2]).max(axis=1) # downsample with max_pool + is_sil_mask = energy < 1e-5 + sil_index = np.where(is_sil_mask)[0] + sil_ref_frame_idx = hparams['infer_sil_ref_frame_idx'] + assert sil_ref_frame_idx != '', "Please set `infer_sil_ref_frame_idx` to the index of a frame with closed mouth in the GT dataset" + idexp_lm3d_normalized_database = torch.stack([s['idexp_lm3d_normalized'] for s in self.dataset.samples]).reshape([-1, 68*3]) + sil_mouth_pattern = idexp_lm3d_normalized_database[sil_ref_frame_idx, 48*3:68*3].clone() + repeated_sil_mouth_pattern = sil_mouth_pattern.unsqueeze(0).repeat([len(sil_index),1]) + idexp_lm3d_normalized[sil_index, 48*3:68*3] = repeated_sil_mouth_pattern + + # step5. gaussian filter to smooth the whole sequence + lm3d_smooth_sigma = hparams['infer_lm3d_smooth_sigma'] + if lm3d_smooth_sigma > 0: + idexp_lm3d_normalized[:, :48*3] = convert_to_tensor(gaussian_filter1d(idexp_lm3d_normalized[:, :48*3].numpy(), sigma=lm3d_smooth_sigma)) + # idexp_lm3d_normalized = convert_to_tensor(gaussian_filter1d(idexp_lm3d_normalized.numpy(), sigma=lm3d_smooth_sigma)) + + idexp_lm3d_normalized_numpy = idexp_lm3d_normalized.cpu().numpy() + idexp_lm3d_normalized_win_numpy = np.stack([get_win_conds(idexp_lm3d_normalized_numpy, i, smo_win_size=hparams['cond_win_size'], pad_option='edge') for i in range(idexp_lm3d_normalized_numpy.shape[0])]) + idexp_lm3d_normalized_win = torch.from_numpy(idexp_lm3d_normalized_win_numpy) + + for idx, sample in enumerate(samples): + sample['cond'] = idexp_lm3d_normalized[idx].unsqueeze(0) + if hparams['use_window_cond']: + sample['cond_win'] = idexp_lm3d_normalized_win[idx] + sample['cond_wins'] = torch.from_numpy(get_win_conds(idexp_lm3d_normalized_win_numpy, idx, hparams['smo_win_size'], 'edge')) + return samples + + +if __name__ == '__main__': + from utils.commons.hparams import set_hparams + from utils.commons.hparams import hparams as hp + inp = { + 'audio_source_name': 'data/raw/val_wavs/zozo.wav', + 'cond_name': 'infer_out/May/pred_lm3d/zozo.npy', + 'out_video_name': 'infer_out/May/pred_video/zozo.mp4', + } + + LM3dNeRFInfer.example_run(inp) \ No newline at end of file diff --git a/Geneface_main/GeneFace/inference/nerfs/lm3d_radnerf_infer.py b/Geneface_main/GeneFace/inference/nerfs/lm3d_radnerf_infer.py new file mode 100644 index 00000000..a2db8557 --- /dev/null +++ b/Geneface_main/GeneFace/inference/nerfs/lm3d_radnerf_infer.py @@ -0,0 +1,99 @@ +import torch +import numpy as np + +from utils.commons.hparams import hparams + +from tasks.radnerfs.dataset_utils import RADNeRFDataset +from inference.nerfs.lm3d_nerf_infer import LM3dNeRFInfer +from data_util.face3d_helper import Face3DHelper + + +class LM3d_RADNeRFInfer(LM3dNeRFInfer): + def __init__(self, hparams, device=None): + super().__init__(hparams, device) + self.dataset_cls = RADNeRFDataset # the dataset only provides head pose + self.dataset = self.dataset_cls('trainval', training=False) + self.face3d_helper = Face3DHelper() + + def get_pose_from_ds(self, samples): + """ + process the item into torch.tensor batch + """ + for i, sample in enumerate(samples): + ds_sample = self.dataset[i] + sample['rays_o'] = ds_sample['rays_o'] + sample['rays_d'] = ds_sample['rays_d'] + sample['bg_coords'] = ds_sample['bg_coords'] + sample['pose'] = ds_sample['pose'] + sample['idx'] = ds_sample['idx'] + sample['bg_img'] = ds_sample['bg_img'] + sample['H'] = ds_sample['H'] + sample['W'] = ds_sample['W'] + return samples + + def get_cond_from_input(self, inp): + """ + :param inp: {'audio_source_name': (str), 'cond_name': (str, optional)} + :return: a list that contains the condition feature of NeRF + """ + self.save_wav16k(inp) + + # load the lm3d as the condition for lm3d head nerf + assert inp['cond_name'].endswith('.npy') + lm3d_arr = np.load(inp['cond_name'])[0] # [T, w=16, c=29] + idexp_lm3d = torch.from_numpy(lm3d_arr).float() + print(f"Loaded pre-extracted 3D landmark sequence from {inp['cond_name']}!") + # idexp_lm3d = self.face3d_helper.close_eyes_for_idexp_lm3d(idexp_lm3d) + # idexp_lm3d = self.face3d_helper.close_mouth_for_idexp_lm3d(idexp_lm3d) + + idexp_lm3d_mean = self.dataset.idexp_lm3d_mean + idexp_lm3d_std = self.dataset.idexp_lm3d_std + idexp_lm3d_normalized = (idexp_lm3d.reshape([-1,68,3]) - idexp_lm3d_mean)/idexp_lm3d_std + + # step3. clamp the lm3d, to regularize apparent outliers + lm3d_clamp_std = hparams['infer_lm3d_clamp_std'] + idexp_lm3d_normalized[:,0:17] = torch.clamp(idexp_lm3d_normalized[:,0:17], -lm3d_clamp_std, lm3d_clamp_std) # yaw_x_y_z + idexp_lm3d_normalized[:,17:27,0:2] = torch.clamp(idexp_lm3d_normalized[:,17:27,0:2], -lm3d_clamp_std/2, lm3d_clamp_std/2) # brow_x_y + idexp_lm3d_normalized[:,17:27,2] = torch.clamp(idexp_lm3d_normalized[:,17:27,2], -lm3d_clamp_std, lm3d_clamp_std) # brow_z + idexp_lm3d_normalized[:,27:36] = torch.clamp(idexp_lm3d_normalized[:,27:36], -lm3d_clamp_std, lm3d_clamp_std) # nose + idexp_lm3d_normalized[:,36:48,0:2] = torch.clamp(idexp_lm3d_normalized[:,36:48,0:2], -lm3d_clamp_std/2, lm3d_clamp_std/2) # eye_x_y + idexp_lm3d_normalized[:,36:48,2] = torch.clamp(idexp_lm3d_normalized[:,36:48,2], -lm3d_clamp_std, lm3d_clamp_std) # eye_z + idexp_lm3d_normalized[:,48:68] = torch.clamp(idexp_lm3d_normalized[:,48:68], -lm3d_clamp_std, lm3d_clamp_std) # mouth + + _lambda_other = 0.2 + _lambda_lip = 0.2 + moving_lm = idexp_lm3d_normalized[0].clone() + for i in range(len(idexp_lm3d_normalized)): + idexp_lm3d_normalized[i,0:17] = _lambda_other * moving_lm[0:17] + (1 - _lambda_other) * idexp_lm3d_normalized[i,0:17] # yaw + idexp_lm3d_normalized[i,17:27] = _lambda_other * moving_lm[17:27] + (1 - _lambda_other) * idexp_lm3d_normalized[i,17:27] # brow + idexp_lm3d_normalized[i,27:36] = _lambda_other * moving_lm[27:36] + (1 - _lambda_other) * idexp_lm3d_normalized[i,27:36] # nose + idexp_lm3d_normalized[i,36:48] = _lambda_other * moving_lm[36:48] + (1 - _lambda_other) * idexp_lm3d_normalized[i,36:48] # eye + idexp_lm3d_normalized[i,48:68] = _lambda_lip * moving_lm[48:68] + (1 - _lambda_lip) * idexp_lm3d_normalized[i,48:68] + moving_lm.data = idexp_lm3d_normalized[i].data + + idexp_lm3d_normalized = idexp_lm3d_normalized.reshape([-1,68*3]) + from data_gen.nerf.binarizer import get_win_conds + idexp_lm3d_normalized_numpy = idexp_lm3d_normalized.cpu().numpy() + idexp_lm3d_normalized_win_numpy = np.stack([get_win_conds(idexp_lm3d_normalized_numpy, i, smo_win_size=hparams['cond_win_size'], pad_option='edge') for i in range(idexp_lm3d_normalized_numpy.shape[0])]) + idexp_lm3d_normalized_win = torch.from_numpy(idexp_lm3d_normalized_win_numpy) + + samples = [{} for _ in range(len(idexp_lm3d_normalized))] + for idx, sample in enumerate(samples): + sample['cond'] = idexp_lm3d_normalized[idx].unsqueeze(0) + if hparams['use_window_cond']: + sample['cond_win'] = idexp_lm3d_normalized_win[idx] + sample['cond_wins'] = torch.from_numpy(get_win_conds(idexp_lm3d_normalized_win_numpy, idx, hparams['smo_win_size'], 'edge')) + return samples + + + +if __name__ == '__main__': + from utils.commons.hparams import set_hparams + from utils.commons.hparams import hparams as hp + inp = { + 'audio_source_name': 'data/raw/val_wavs/zozo.wav', + 'cond_name': 'infer_out/May/pred_lm3d/zozo.npy', + 'out_video_name': 'infer_out/May/pred_video/zozo.mp4', + } + + LM3d_RADNeRFInfer.example_run(inp) \ No newline at end of file diff --git a/Geneface_main/GeneFace/inference/nerfs/radnerf_gui.py b/Geneface_main/GeneFace/inference/nerfs/radnerf_gui.py new file mode 100644 index 00000000..dc4ca037 --- /dev/null +++ b/Geneface_main/GeneFace/inference/nerfs/radnerf_gui.py @@ -0,0 +1,608 @@ +import math +import os +import importlib +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader + +import numpy as np +import dearpygui.dearpygui as dpg +from scipy.spatial.transform import Rotation as R + +from modules.radnerfs.utils import get_audio_features +from tasks.radnerfs.dataset_utils import RADNeRFDataset +from utils.commons.tensor_utils import move_to_cuda +# from .asr import ASR + + +class OrbitCamera: + def __init__(self, W, H, r=2, fovy=60): + self.W = W + self.H = H + self.radius = r # camera distance from center + self.fovy = fovy # in degree + self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point + self.rot = R.from_matrix([[0, -1, 0], [0, 0, -1], [1, 0, 0]]) # init camera matrix: [[1, 0, 0], [0, -1, 0], [0, 0, 1]] (to suit ngp convention) + self.up = np.array([1, 0, 0], dtype=np.float32) # need to be normalized! + + # pose + @property + def pose(self): + # first move camera to radius + res = np.eye(4, dtype=np.float32) + res[2, 3] -= self.radius + # rotate + rot = np.eye(4, dtype=np.float32) + rot[:3, :3] = self.rot.as_matrix() + res = rot @ res + # translate + res[:3, 3] -= self.center + return res + + def update_pose(self, pose): + # pose: [4, 4] numpy array + # assert self.center is 0 + self.radius = np.linalg.norm(pose[:3, 3]) + T = np.eye(4) + T[2, 3] = -self.radius + rot = pose @ np.linalg.inv(T) + self.rot = R.from_matrix(rot[:3, :3]) + + def update_intrinsics(self, intrinsics): + fl_x, fl_y, cx, cy = intrinsics + self.W = int(cx * 2) + self.H = int(cy * 2) + self.fovy = np.rad2deg(2 * np.arctan2(self.H, 2 * fl_y)) + + # intrinsics + @property + def intrinsics(self): + focal = self.H / (2 * np.tan(np.deg2rad(self.fovy) / 2)) + return np.array([focal, focal, self.W // 2, self.H // 2]) + + def orbit(self, dx, dy): + # rotate along camera up/side axis! + side = self.rot.as_matrix()[:3, 0] # why this is side --> ? # already normalized. + rotvec_x = self.up * np.radians(-0.01 * dx) + rotvec_y = side * np.radians(-0.01 * dy) + self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot + + def scale(self, delta): + self.radius *= 1.1 ** (-delta) + + def pan(self, dx, dy, dz=0): + # pan in camera coordinate system (careful on the sensitivity!) + self.center += 0.0001 * self.rot.as_matrix()[:3, :3] @ np.array([dx, dy, dz]) + + +class NeRFGUI: + def __init__(self, hparams, task, dataset, debug=True): + self.hparams_bak = hparams # shared with the trainer's opt to support in-place modification of rendering parameters. + self.hparams = copy.deepcopy(hparams) # shared with the trainer's opt to support in-place modification of rendering parameters. + self.W = self.hparams['gui_w'] + self.H = self.hparams['gui_h'] + self.cam = OrbitCamera(self.W, self.H, r=self.hparams['gui_radius'], fovy=self.hparams['gui_fovy']) + self.debug = debug + self.training = False + self.step = 0 # training step + + self.task = task + self.dataset = dataset + + # override with dataloader's intrinsics + self.W = dataset.W + self.H = dataset.H + self.cam.update_intrinsics(dataset.intrinsics) + # use dataloader's pose + pose_init = dataset.poses[0] + self.cam.update_pose(pose_init) + self.cam_t0 = copy.deepcopy(self.cam) + + # use dataloader's bg + bg_img = dataset.bg_img #.view(1, -1, 3) + if self.H != bg_img.shape[0] or self.W != bg_img.shape[1]: + bg_img = F.interpolate(bg_img.permute(2, 0, 1).unsqueeze(0).contiguous(), (self.H, self.W), mode='bilinear').squeeze(0).permute(1, 2, 0).contiguous() + self.bg_color = bg_img.view(1, -1, 3) + + # audio features (from dataloader, only used in non-playing mode) + self.cond_features = dataset.conds # [N, 29, 16] + self.cond_idx = 0 + + # control eye + # self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item() + + # playing seq from dataloader, or pause. + self.playing = False + self.loader = iter(DataLoader(self.dataset, batch_size=1, collate_fn=self.dataset.collater, shuffle=False)) + + self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32) + self.need_update = True # camera moved, should reset accumulation + self.spp = 1 # sample per pixel + self.mode = 'image' # choose from ['image', 'depth'] + + self.dynamic_resolution = False # assert False! + self.downscale = 1 + self.train_steps = 16 + + self.ind_index = 0 + self.ind_num = self.task.model.individual_embedding_num + + # build asr + # if self.opt.asr: + # self.asr = ASR(opt) + # print(" ") + + dpg.create_context() + self.register_dpg() + self.test_step() + + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + # if self.opt.asr: + # self.asr.stop() + dpg.destroy_context() + + # def train_step(self): + + # starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + # starter.record() + + # outputs = self.trainer.train_gui(self.data_loader, step=self.train_steps) + + # ender.record() + # torch.cuda.synchronize() + # t = starter.elapsed_time(ender) + + # self.step += self.train_steps + # self.need_update = True + + # dpg.set_value("_log_train_time", f'{t:.4f}ms ({int(1000/t)} FPS)') + # dpg.set_value("_log_train_log", f'step = {self.step: 5d} (+{self.train_steps: 2d}), loss = {outputs["loss"]:.4f}, lr = {outputs["lr"]:.5f}') + + # # dynamic train steps + # # max allowed train time per-frame is 500 ms + # full_t = t / self.train_steps * 16 + # train_steps = min(16, max(4, int(16 * 500 / full_t))) + # if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8: + # self.train_steps = train_steps + + def prepare_buffer(self, outputs): + if self.mode == 'image': + return outputs['image'] + else: + return np.expand_dims(outputs['depth'], -1).repeat(3, -1) + + def test_step(self): + + if self.need_update or self.spp < hparams['gui_max_spp']: + + starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + starter.record() + + if self.playing: + try: + data = next(self.loader) + except StopIteration: + self.loader = iter(DataLoader(self.dataset, batch_size=1, collate_fn=self.dataset.collater, shuffle=False)) + data = next(self.loader) + + # if self.opt.asr: + # use the live audio stream + # data['auds'] = self.asr.get_next_feat() + move_to_cuda(data) + outputs = self.task.test_gui_with_data(data, self.W, self.H) + + # sync local camera pose + self.cam.update_pose(data['pose_matrix'][0].detach().cpu().numpy()) + + else: + if self.cond_features is not None: + auds = get_audio_features(self.cond_features, 2, self.cond_idx) + else: + auds = None + outputs = self.task.test_gui_with_editable_data(self.cam.pose, self.cam.intrinsics, self.W, self.H, auds, self.ind_index, self.bg_color, self.spp, self.downscale) + + ender.record() + torch.cuda.synchronize() + t = starter.elapsed_time(ender) + + # update dynamic resolution + if self.dynamic_resolution: + # max allowed infer time per-frame is 200 ms + full_t = t / (self.downscale ** 2) + downscale = min(1, max(1/4, math.sqrt(200 / full_t))) + if downscale > self.downscale * 1.2 or downscale < self.downscale * 0.8: + self.downscale = downscale + + if self.need_update: + self.render_buffer = self.prepare_buffer(outputs) + self.spp = 1 + self.need_update = False + else: + self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1) + self.spp += 1 + + if self.playing: + self.need_update = True + + dpg.set_value("_log_infer_time", f'{t:.4f}ms ({int(1000/t)} FPS)') + dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}') + dpg.set_value("_log_spp", self.spp) + dpg.set_value("_texture", self.render_buffer) + + + def register_dpg(self): + + ### register texture + + with dpg.texture_registry(show=False): + dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture") + + ### register window + + # the rendered image, as the primary window + with dpg.window(tag="_primary_window", width=self.W, height=self.H): + + # add the texture + dpg.add_image("_texture") + + # dpg.set_primary_window("_primary_window", True) + + dpg.show_tool(dpg.mvTool_Metrics) + + # control window + with dpg.window(label="Control", tag="_control_window", width=400, height=300): + + # button theme + with dpg.theme() as theme_button: + with dpg.theme_component(dpg.mvButton): + dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18)) + dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47)) + dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83)) + dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5) + dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3) + + # time + # if not self.opt.test: + # with dpg.group(horizontal=True): + # dpg.add_text("Train time: ") + # dpg.add_text("no data", tag="_log_train_time") + + with dpg.group(horizontal=True): + dpg.add_text("Infer time: ") + dpg.add_text("no data", tag="_log_infer_time") + + with dpg.group(horizontal=True): + dpg.add_text("SPP: ") + dpg.add_text("1", tag="_log_spp") + + # train button + # if not self.opt.test: + # with dpg.collapsing_header(label="Train", default_open=True): + + # # train / stop + # with dpg.group(horizontal=True): + # dpg.add_text("Train: ") + + # def callback_train(sender, app_data): + # if self.training: + # self.training = False + # dpg.configure_item("_button_train", label="start") + # else: + # self.training = True + # dpg.configure_item("_button_train", label="stop") + + # dpg.add_button(label="start", tag="_button_train", callback=callback_train) + # dpg.bind_item_theme("_button_train", theme_button) + + # def callback_reset(sender, app_data): + # @torch.no_grad() + # def weight_reset(m: nn.Module): + # reset_parameters = getattr(m, "reset_parameters", None) + # if callable(reset_parameters): + # m.reset_parameters() + # self.trainer.model.apply(fn=weight_reset) + # self.trainer.model.reset_extra_state() # for cuda_ray density_grid and step_counter + # self.need_update = True + + # dpg.add_button(label="reset", tag="_button_reset", callback=callback_reset) + # dpg.bind_item_theme("_button_reset", theme_button) + + # # save ckpt + # with dpg.group(horizontal=True): + # dpg.add_text("Checkpoint: ") + + # def callback_save(sender, app_data): + # self.trainer.save_checkpoint(full=True, best=False) + # dpg.set_value("_log_ckpt", "saved " + os.path.basename(self.trainer.stats["checkpoints"][-1])) + # self.trainer.epoch += 1 # use epoch to indicate different calls. + + # dpg.add_button(label="save", tag="_button_save", callback=callback_save) + # dpg.bind_item_theme("_button_save", theme_button) + + # dpg.add_text("", tag="_log_ckpt") + + # # save mesh + # with dpg.group(horizontal=True): + # dpg.add_text("Marching Cubes: ") + + # def callback_mesh(sender, app_data): + # self.trainer.save_mesh(resolution=256, threshold=10) + # dpg.set_value("_log_mesh", "saved " + f'{self.trainer.name}_{self.trainer.epoch}.ply') + # self.trainer.epoch += 1 # use epoch to indicate different calls. + + # dpg.add_button(label="mesh", tag="_button_mesh", callback=callback_mesh) + # dpg.bind_item_theme("_button_mesh", theme_button) + + # dpg.add_text("", tag="_log_mesh") + + # with dpg.group(horizontal=True): + # dpg.add_text("", tag="_log_train_log") + + # debug info + if self.debug: + with dpg.collapsing_header(label="Debug"): + # pose + dpg.add_separator() + dpg.add_text("Camera Pose:") + dpg.add_text(str(self.cam.pose), tag="_log_pose") + + # rendering options + with dpg.collapsing_header(label="Options", default_open=True): + + # playing + with dpg.group(horizontal=True): + dpg.add_text("Play: ") + + def callback_play(sender, app_data): + + if self.playing: + self.playing = False + dpg.configure_item("_button_play", label="start") + else: + self.playing = True + dpg.configure_item("_button_play", label="stop") + # if self.opt.asr: + # self.asr.warm_up() + self.need_update = True + + dpg.add_button(label="start", tag="_button_play", callback=callback_play) + dpg.bind_item_theme("_button_play", theme_button) + + def callback_reset_pose(sender, app_data): + self.cam = copy.deepcopy(self.cam_t0) + dpg.set_value("_log_pose", str(self.cam.pose)) + self.need_update = True + + dpg.add_button(label="reset pose", tag="_button_reset_pose", callback=callback_reset_pose) + dpg.bind_item_theme("_button_reset_pose", theme_button) + # set asr + # if self.opt.asr: + + # # clear queue button + # def callback_clear_queue(sender, app_data): + + # self.asr.clear_queue() + # self.need_update = True + + # dpg.add_button(label="clear", tag="_button_clear_queue", callback=callback_clear_queue) + # dpg.bind_item_theme("_button_clear_queue", theme_button) + + # dynamic rendering resolution + with dpg.group(horizontal=True): + + def callback_set_dynamic_resolution(sender, app_data): + if self.dynamic_resolution: + self.dynamic_resolution = False + self.downscale = 1 + else: + self.dynamic_resolution = True + self.need_update = True + + # Disable dynamic resolution for face. + # dpg.add_checkbox(label="dynamic resolution", default_value=self.dynamic_resolution, callback=callback_set_dynamic_resolution) + dpg.add_text(f"{self.W}x{self.H}", tag="_log_resolution") + + # mode combo + def callback_change_mode(sender, app_data): + self.mode = app_data + self.need_update = True + + dpg.add_combo(('image', 'depth'), label='mode', default_value=self.mode, callback=callback_change_mode) + + + # bg_color picker + def callback_change_bg(sender, app_data): + self.bg_color = torch.tensor(app_data[:3], dtype=torch.float32) # only need RGB in [0, 1] + self.need_update = True + + dpg.add_color_edit((255, 255, 255), label="Background Color", width=200, tag="_color_editor", no_alpha=True, callback=callback_change_bg) + + # audio index slider + # if not self.opt.asr: + def callback_set_audio_index(sender, app_data): + self.cond_idx = app_data + self.need_update = True + + dpg.add_slider_int(label="Audio", min_value=0, max_value=self.cond_features.shape[0] - 1, format="%d", default_value=self.cond_idx, callback=callback_set_audio_index) + + # ind code index slider + if hparams['individual_embedding_num'] > 0: + def callback_set_individual_code(sender, app_data): + self.ind_index = app_data + self.need_update = True + + dpg.add_slider_int(label="Individual", min_value=0, max_value=self.ind_num - 1, format="%d", default_value=self.ind_index, callback=callback_set_individual_code) + + + # eye area slider + # if self.opt.exp_eye: + # def callback_set_eye(sender, app_data): + # self.eye_area = app_data + # self.need_update = True + + # dpg.add_slider_float(label="eye area", min_value=0, max_value=0.5, format="%.2f percent", default_value=self.eye_area, callback=callback_set_eye) + + # fov slider + def callback_set_fovy(sender, app_data): + self.cam.fovy = app_data + self.need_update = True + + dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy) + + # dt_gamma slider + def callback_set_dt_gamma(sender, app_data): + self.hparams['dt_gamma'] = app_data + self.need_update = True + + dpg.add_slider_float(label="dt_gamma", min_value=0, max_value=0.1, format="%.5f", default_value=self.hparams['dt_gamma'], callback=callback_set_dt_gamma) + + # max_steps slider + def callback_set_max_steps(sender, app_data): + self.hparams['max_steps'] = app_data + self.need_update = True + + dpg.add_slider_int(label="max steps", min_value=1, max_value=1024, format="%d", default_value=self.hparams['max_steps'], callback=callback_set_max_steps) + + # aabb slider + def callback_set_aabb(sender, app_data, user_data): + # user_data is the dimension for aabb (xmin, ymin, zmin, xmax, ymax, zmax) + self.task.model.aabb_infer[user_data] = app_data + + # also change train aabb ? [better not...] + #self.trainer.model.aabb_train[user_data] = app_data + + self.need_update = True + + dpg.add_separator() + dpg.add_text("Axis-aligned bounding box:") + + with dpg.group(horizontal=True): + dpg.add_slider_float(label="x", width=150, min_value=-self.hparams['bound'], max_value=0, format="%.2f", default_value=-self.hparams['bound'], callback=callback_set_aabb, user_data=0) + dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.hparams['bound'], format="%.2f", default_value=self.hparams['bound'], callback=callback_set_aabb, user_data=3) + + with dpg.group(horizontal=True): + dpg.add_slider_float(label="y", width=150, min_value=-self.hparams['bound'], max_value=0, format="%.2f", default_value=-self.hparams['bound'], callback=callback_set_aabb, user_data=1) + dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.hparams['bound'], format="%.2f", default_value=self.hparams['bound'], callback=callback_set_aabb, user_data=4) + + with dpg.group(horizontal=True): + dpg.add_slider_float(label="z", width=150, min_value=-self.hparams['bound'], max_value=0, format="%.2f", default_value=-self.hparams['bound'], callback=callback_set_aabb, user_data=2) + dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.hparams['bound'], format="%.2f", default_value=self.hparams['bound'], callback=callback_set_aabb, user_data=5) + + + ### register camera handler + + def callback_camera_drag_rotate(sender, app_data): + + if not dpg.is_item_focused("_primary_window"): + return + + dx = app_data[1] + dy = app_data[2] + + self.cam.orbit(dx, dy) + self.need_update = True + + if self.debug: + dpg.set_value("_log_pose", str(self.cam.pose)) + + + def callback_camera_wheel_scale(sender, app_data): + + if not dpg.is_item_focused("_primary_window"): + return + + delta = app_data + + self.cam.scale(delta) + self.need_update = True + + if self.debug: + dpg.set_value("_log_pose", str(self.cam.pose)) + + + def callback_camera_drag_pan(sender, app_data): + + if not dpg.is_item_focused("_primary_window"): + return + + dx = app_data[1] + dy = app_data[2] + + self.cam.pan(dx, dy) + self.need_update = True + + if self.debug: + dpg.set_value("_log_pose", str(self.cam.pose)) + + + with dpg.handler_registry(): + dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate) + dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) + dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan) + + + dpg.create_viewport(title='GeneFace-S', width=1080, height=720, resizable=True) + + ### global theme + with dpg.theme() as theme_no_padding: + with dpg.theme_component(dpg.mvAll): + # set all padding to 0 to avoid scroll bar + dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core) + dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core) + dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core) + + dpg.bind_item_theme("_primary_window", theme_no_padding) + + dpg.setup_dearpygui() + + #dpg.show_metrics() + + dpg.show_viewport() + + + def render(self): + + while dpg.is_dearpygui_running(): + # update texture every frame + # if self.training: + # self.train_step() + # audio stream thread... + # if self.opt.asr and self.playing: + # run 2 ASR steps (audio is at 50FPS, video is at 25FPS) + # for _ in range(2): + # self.asr.run_step() + self.test_step() + dpg.render_dearpygui_frame() + + +if __name__ == '__main__': + from utils.commons.hparams import set_hparams, hparams + from inference.nerfs.lm3d_radnerf_infer import LM3d_RADNeRFInfer + set_hparams() + + inferencer = LM3d_RADNeRFInfer(hparams) + nerf_task = inferencer.build_nerf_task().cuda() + dataset = inferencer.dataset + if hparams['infer_cond_name'] != '': + cond = np.load(hparams['infer_cond_name'])[0] # [T, w=16, c=29] + assert hparams['cond_type'] == 'idexp_lm3d_normalized' + idexp_lm3d = torch.from_numpy(cond).float() + idexp_lm3d_mean = dataset.idexp_lm3d_mean + idexp_lm3d_std = dataset.idexp_lm3d_std + idexp_lm3d_normalized = (idexp_lm3d.reshape([-1,68,3]) - idexp_lm3d_mean)/idexp_lm3d_std + cond_win = idexp_lm3d_normalized.reshape([-1, 1, 204]) # [T, 1, 204] + dataset.conds = cond_win + dataset.samples = dataset.samples[:len(cond_win)] + + if hparams['amp']: + nerf_task = nerf_task.half() + with NeRFGUI(hparams, nerf_task, dataset) as gui: + gui.render() + print("GoodBye!") \ No newline at end of file diff --git a/Geneface_main/GeneFace/inference/postnet/postnet_infer.py b/Geneface_main/GeneFace/inference/postnet/postnet_infer.py new file mode 100644 index 00000000..2bc362dd --- /dev/null +++ b/Geneface_main/GeneFace/inference/postnet/postnet_infer.py @@ -0,0 +1,138 @@ +import os +import torch +import librosa +import numpy as np +import importlib +import tqdm + +from utils.commons.tensor_utils import move_to_cuda +from utils.commons.ckpt_utils import load_ckpt, get_last_checkpoint +from utils.commons.hparams import hparams, set_hparams + + +class PostnetInfer: + def __init__(self, hparams, device=None): + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.hparams = hparams + self.infer_max_length = hparams.get('infer_max_length', 500000) + self.device = device + self.postnet_task = self.build_postnet_task() + self.postnet_task.eval() + self.postnet_task.to(self.device) + + def build_postnet_task(self): + assert hparams['task_cls'] != '' + pkg = ".".join(hparams["task_cls"].split(".")[:-1]) + cls_name = hparams["task_cls"].split(".")[-1] + task_cls = getattr(importlib.import_module(pkg), cls_name) + task = task_cls() + task.build_model() + task.eval() + steps = hparams.get('infer_ckpt_steps', 12000) + load_ckpt(task.model, hparams['work_dir'], 'model', steps=steps) + load_ckpt(task.audio2motion_task, hparams['work_dir'], 'audio2motion_task', steps=steps) + load_ckpt(task.syncnet_task, hparams['work_dir'], 'syncnet_task', steps=steps) + task.global_step = steps + return task + + def infer_once(self, inp): + self.inp = inp + samples = self.get_cond_from_input(inp) + out_name = self.forward_system(samples, inp) + print(f"The predicted and refined 3D landmark sequence is saved at {out_name}") + + def get_cond_from_input(self, inp): + """ + :param inp: {'audio_source_name': (str)} + :return: a list that contains the condition feature of NeRF + """ + self.save_wav16k(inp) + from data_gen.process_lrs3.process_audio_hubert import get_hubert_from_16k_wav + hubert = get_hubert_from_16k_wav(self.wav16k_name).detach().numpy() + len_mel = hubert.shape[0] + x_multiply = 8 + if len_mel % x_multiply == 0: + num_to_pad = 0 + else: + num_to_pad = x_multiply - len_mel % x_multiply + hubert = np.pad(hubert, pad_width=((0,num_to_pad), (0,0))) + + from data_gen.process_lrs3.process_audio_mel_f0 import extract_mel_from_fname,extract_f0_from_wav_and_mel + wav, mel = extract_mel_from_fname(self.wav16k_name) + f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel) + f0 = f0.reshape([-1,1]) + if f0.shape[0] > len(hubert): + f0 = f0[:len(hubert)] + else: + num_to_pad = len(hubert) - len(f0) + f0 = np.pad(f0, pad_width=((0,num_to_pad), (0,0))) + f0 = f0.squeeze(-1) + + t_x = hubert.shape[0] + x_mask = torch.ones([1, t_x]).float() + y_mask = torch.ones([1, t_x//2]).float() + sample = { + 'hubert': torch.from_numpy(hubert).float().unsqueeze(0), + 'f0': torch.from_numpy(f0).float().unsqueeze(0), + 'x_mask': x_mask, + 'y_mask': y_mask, + } + return [sample] + + def forward_system(self, batches, inp): + out_dir = self._forward_postnet_task(batches, inp) + return out_dir + + def _forward_postnet_task(self, batches, inp): + with torch.no_grad(): + pred_lst = [] + for idx, batch in tqdm.tqdm(enumerate(batches), total=len(batches), + desc=f"Now VAE is predicting the action into {inp['out_npy_name']}"): + if self.device == 'cuda': + batch = move_to_cuda(batch) + model_out = self.postnet_task.run_model(batch, infer=True, temperature=1.) + pred = model_out['refine_lm3d'].squeeze().cpu().numpy() + pred_lst.append(pred) + np.save(inp['out_npy_name'], pred_lst) + return inp['out_npy_name'] + + @classmethod + def example_run(cls, inp=None): + inp_tmp = { + 'audio_source_name': 'data/raw/val_wavs/zozo.wav', + 'out_npy_name': 'infer_out/lrs3/0.npy' + } + if inp is not None: + inp_tmp.update(inp) + inp = inp_tmp + if hparams.get("infer_audio_source_name", '') != '': + inp['audio_source_name'] = hparams['infer_audio_source_name'] + if hparams.get("infer_out_npy_name", '') != '': + inp['out_npy_name'] = hparams['infer_out_npy_name'] + out_dir = os.path.dirname(inp['out_npy_name']) + + os.makedirs(out_dir, exist_ok=True) + infer_ins = cls(hparams) + infer_ins.infer_once(inp) + + ############## + # IO-related + ############## + def save_wav16k(self, inp): + source_name = inp['audio_source_name'] + supported_types = ('.wav', '.mp3', '.mp4', '.avi') + assert source_name.endswith(supported_types), f"Now we only support {','.join(supported_types)} as audio source!" + wav16k_name = source_name[:-4] + '_16k.wav' + self.wav16k_name = wav16k_name + extract_wav_cmd = f"ffmpeg -i {source_name} -f wav -ar 16000 {wav16k_name} -y" + os.system(extract_wav_cmd) + print(f"Extracted wav file (16khz) from {source_name} to {wav16k_name}.") + +if __name__ == '__main__': + set_hparams() + inp = { + 'audio_source_name': 'data/raw/val_wavs/zozo.wav', + 'out_npy_name': 'infer_out/May/pred_lm3d/zozo.npy', + } + PostnetInfer.example_run(inp) \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/audio2motion/__pycache__/flow_base.cpython-39.pyc b/Geneface_main/GeneFace/modules/audio2motion/__pycache__/flow_base.cpython-39.pyc new file mode 100644 index 00000000..b2bf7ccf Binary files /dev/null and b/Geneface_main/GeneFace/modules/audio2motion/__pycache__/flow_base.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/modules/audio2motion/__pycache__/transformer_base.cpython-39.pyc b/Geneface_main/GeneFace/modules/audio2motion/__pycache__/transformer_base.cpython-39.pyc new file mode 100644 index 00000000..6b15632b Binary files /dev/null and b/Geneface_main/GeneFace/modules/audio2motion/__pycache__/transformer_base.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/modules/audio2motion/__pycache__/transformer_models.cpython-39.pyc b/Geneface_main/GeneFace/modules/audio2motion/__pycache__/transformer_models.cpython-39.pyc new file mode 100644 index 00000000..f0ac8862 Binary files /dev/null and b/Geneface_main/GeneFace/modules/audio2motion/__pycache__/transformer_models.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/modules/audio2motion/__pycache__/utils.cpython-39.pyc b/Geneface_main/GeneFace/modules/audio2motion/__pycache__/utils.cpython-39.pyc new file mode 100644 index 00000000..eee28a97 Binary files /dev/null and b/Geneface_main/GeneFace/modules/audio2motion/__pycache__/utils.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/modules/audio2motion/__pycache__/vae.cpython-39.pyc b/Geneface_main/GeneFace/modules/audio2motion/__pycache__/vae.cpython-39.pyc new file mode 100644 index 00000000..6672964e Binary files /dev/null and b/Geneface_main/GeneFace/modules/audio2motion/__pycache__/vae.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/modules/audio2motion/cnn_models.py b/Geneface_main/GeneFace/modules/audio2motion/cnn_models.py new file mode 100644 index 00000000..b58e8c47 --- /dev/null +++ b/Geneface_main/GeneFace/modules/audio2motion/cnn_models.py @@ -0,0 +1,359 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def init_weights_func(m): + classname = m.__class__.__name__ + if classname.find("Conv1d") != -1: + torch.nn.init.xavier_uniform_(m.weight) + + +class LambdaLayer(nn.Module): + def __init__(self, lambd): + super(LambdaLayer, self).__init__() + self.lambd = lambd + + def forward(self, x): + return self.lambd(x) + + +class LayerNorm(torch.nn.LayerNorm): + """Layer normalization module. + :param int nout: output dim size + :param int dim: dimension to be normalized + """ + + def __init__(self, nout, dim=-1, eps=1e-5): + """Construct an LayerNorm object.""" + super(LayerNorm, self).__init__(nout, eps=eps) + self.dim = dim + + def forward(self, x): + """Apply layer normalization. + :param torch.Tensor x: input tensor + :return: layer normalized tensor + :rtype torch.Tensor + """ + if self.dim == -1: + return super(LayerNorm, self).forward(x) + return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1) + + + +class ResidualBlock(nn.Module): + """Implements conv->PReLU->norm n-times""" + + def __init__(self, channels, kernel_size, dilation, n=2, norm_type='bn', dropout=0.0, + c_multiple=2, ln_eps=1e-12, bias=False): + super(ResidualBlock, self).__init__() + + if norm_type == 'bn': + norm_builder = lambda: nn.BatchNorm1d(channels) + elif norm_type == 'in': + norm_builder = lambda: nn.InstanceNorm1d(channels, affine=True) + elif norm_type == 'gn': + norm_builder = lambda: nn.GroupNorm(8, channels) + elif norm_type == 'ln': + norm_builder = lambda: LayerNorm(channels, dim=1, eps=ln_eps) + else: + norm_builder = lambda: nn.Identity() + + self.blocks = [ + nn.Sequential( + norm_builder(), + nn.Conv1d(channels, c_multiple * channels, kernel_size, dilation=dilation, + padding=(dilation * (kernel_size - 1)) // 2, bias=bias), + LambdaLayer(lambda x: x * kernel_size ** -0.5), + nn.GELU(), + nn.Conv1d(c_multiple * channels, channels, 1, dilation=dilation, bias=bias), + ) + for _ in range(n) + ] + + self.blocks = nn.ModuleList(self.blocks) + self.dropout = dropout + + def forward(self, x): + nonpadding = (x.abs().sum(1) > 0).float()[:, None, :] + for b in self.blocks: + x_ = b(x) + if self.dropout > 0 and self.training: + x_ = F.dropout(x_, self.dropout, training=self.training) + x = x + x_ + x = x * nonpadding + return x + + +class ConvBlocks(nn.Module): + """Decodes the expanded phoneme encoding into spectrograms""" + + def __init__(self, channels, out_dims, dilations, kernel_size, + norm_type='ln', layers_in_block=2, c_multiple=2, + dropout=0.0, ln_eps=1e-5, init_weights=True, is_BTC=True, bias=False): + super(ConvBlocks, self).__init__() + self.is_BTC = is_BTC + self.res_blocks = nn.Sequential( + *[ResidualBlock(channels, kernel_size, d, + n=layers_in_block, norm_type=norm_type, c_multiple=c_multiple, + dropout=dropout, ln_eps=ln_eps, bias=bias) + for d in dilations], + ) + if norm_type == 'bn': + norm = nn.BatchNorm1d(channels) + elif norm_type == 'in': + norm = nn.InstanceNorm1d(channels, affine=True) + elif norm_type == 'gn': + norm = nn.GroupNorm(8, channels) + elif norm_type == 'ln': + norm = LayerNorm(channels, dim=1, eps=ln_eps) + self.last_norm = norm + self.post_net1 = nn.Conv1d(channels, out_dims, kernel_size=3, padding=1, bias=bias) + if init_weights: + self.apply(init_weights_func) + + def forward(self, x): + """ + + :param x: [B, T, H] + :return: [B, T, H] + """ + if self.is_BTC: + x = x.transpose(1, 2) # [B, C, T] + nonpadding = (x.abs().sum(1) > 0).float()[:, None, :] + x = self.res_blocks(x) * nonpadding + x = self.last_norm(x) * nonpadding + x = self.post_net1(x) * nonpadding + if self.is_BTC: + x = x.transpose(1, 2) + return x + + +class SeqLevelConvolutionalModel(nn.Module): + def __init__(self, out_dim=64, dropout=0.5, audio_feat_type='ppg', backbone_type='unet', norm_type='bn'): + nn.Module.__init__(self) + self.audio_feat_type = audio_feat_type + if audio_feat_type == 'ppg': + self.audio_encoder = nn.Sequential(*[ + nn.Conv1d(29, 48, 3, 1, 1, bias=False), + nn.BatchNorm1d(48) if norm_type=='bn' else LayerNorm(48, dim=1), + nn.GELU(), + nn.Conv1d(48, 48, 3, 1, 1, bias=False) + ]) + self.energy_encoder = nn.Sequential(*[ + nn.Conv1d(1, 16, 3, 1, 1, bias=False), + nn.BatchNorm1d(16) if norm_type=='bn' else LayerNorm(16, dim=1), + nn.GELU(), + nn.Conv1d(16, 16, 3, 1, 1, bias=False) + ]) + elif audio_feat_type == 'mel': + self.mel_encoder = nn.Sequential(*[ + nn.Conv1d(80, 64, 3, 1, 1, bias=False), + nn.BatchNorm1d(64) if norm_type=='bn' else LayerNorm(64, dim=1), + nn.GELU(), + nn.Conv1d(64, 64, 3, 1, 1, bias=False) + ]) + else: + raise NotImplementedError("now only ppg or mel are supported!") + + self.style_encoder = nn.Sequential(*[ + nn.Linear(135, 256), + nn.GELU(), + nn.Linear(256, 256) + ]) + + if backbone_type == 'resnet': + self.backbone = ResNetBackbone() + elif backbone_type == 'unet': + self.backbone = UNetBackbone() + elif backbone_type == 'resblocks': + self.backbone = ResBlocksBackbone() + else: + raise NotImplementedError("Now only resnet and unet are supported!") + + self.out_layer = nn.Sequential( + nn.BatchNorm1d(512) if norm_type=='bn' else LayerNorm(512, dim=1), + nn.Conv1d(512, 64, 3, 1, 1, bias=False), + nn.PReLU(), + nn.Conv1d(64, out_dim, 3, 1, 1, bias=False) + ) + self.feat_dropout = nn.Dropout(p=dropout) + + @property + def device(self): + return self.backbone.parameters().__next__().device + + def forward(self, batch, ret, log_dict=None): + style, x_mask = batch['style'].to(self.device), batch['x_mask'].to(self.device) + style_feat = self.style_encoder(style) # [B,C=135] => [B,C=128] + + if self.audio_feat_type == 'ppg': + audio, energy = batch['audio'].to(self.device), batch['energy'].to(self.device) + audio_feat = self.audio_encoder(audio.transpose(1,2)).transpose(1,2) * x_mask.unsqueeze(2) # [B,T,C=29] => [B,T,C=48] + energy_feat = self.energy_encoder(energy.transpose(1,2)).transpose(1,2) * x_mask.unsqueeze(2) # [B,T,C=1] => [B,T,C=16] + feat = torch.cat([audio_feat, energy_feat], dim=2) # [B,T,C=48+16] + elif self.audio_feat_type == 'mel': + mel = batch['mel'].to(self.device) + feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2) * x_mask.unsqueeze(2) # [B,T,C=64] + + feat, x_mask = self.backbone(x=feat, sty=style_feat, x_mask=x_mask) + + out = self.out_layer(feat.transpose(1,2)).transpose(1,2) * x_mask.unsqueeze(2) # [B,T//2,C=256] => [B,T//2,C=64] + + ret['pred'] = out + ret['mask'] = x_mask + return out + + +class ResBlocksBackbone(nn.Module): + def __init__(self, in_dim=64, out_dim=512, p_dropout=0.5, norm_type='bn'): + super(ResBlocksBackbone,self).__init__() + self.resblocks_0 = ConvBlocks(channels=in_dim, out_dims=64, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) + self.resblocks_1 = ConvBlocks(channels=64, out_dims=128, dilations=[1]*4, kernel_size=3, norm_type=norm_type, is_BTC=False) + self.resblocks_2 = ConvBlocks(channels=128, out_dims=256, dilations=[1]*14, kernel_size=3, norm_type=norm_type, is_BTC=False) + self.resblocks_3 = ConvBlocks(channels=512, out_dims=512, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) + self.resblocks_4 = ConvBlocks(channels=512, out_dims=out_dim, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) + + self.downsampler = LambdaLayer(lambda x: F.interpolate(x, scale_factor=0.5, mode='linear')) + self.upsampler = LambdaLayer(lambda x: F.interpolate(x, scale_factor=4, mode='linear')) + + self.dropout = nn.Dropout(p=p_dropout) + + def forward(self, x, sty, x_mask=1.): + """ + x: [B, T, C] + sty: [B, C=256] + x_mask: [B, T] + ret: [B, T/2, C] + """ + x = x.transpose(1, 2) # [B, C, T] + x_mask = x_mask[:, None, :] # [B, 1, T] + + x = self.resblocks_0(x) * x_mask # [B, C, T] + + x_mask = self.downsampler(x_mask) # [B, 1, T/2] + x = self.downsampler(x) * x_mask # [B, C, T/2] + x = self.resblocks_1(x) * x_mask # [B, C, T/2] + x = self.resblocks_2(x) * x_mask # [B, C, T/2] + + x = self.dropout(x.transpose(1,2)).transpose(1,2) + sty = sty[:, :, None].repeat([1,1,x_mask.shape[2]]) # [B,C=256,T/2] + x = torch.cat([x, sty], dim=1) # [B, C=256+256, T/2] + + x = self.resblocks_3(x) * x_mask # [B, C, T/2] + x = self.resblocks_4(x) * x_mask # [B, C, T/2] + + x = x.transpose(1,2) + x_mask = x_mask.squeeze(1) + return x, x_mask + + + +class ResNetBackbone(nn.Module): + def __init__(self, in_dim=64, out_dim=512, p_dropout=0.5, norm_type='bn'): + super(ResNetBackbone,self).__init__() + self.resblocks_0 = ConvBlocks(channels=in_dim, out_dims=64, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) + self.resblocks_1 = ConvBlocks(channels=64, out_dims=128, dilations=[1]*4, kernel_size=3, norm_type=norm_type, is_BTC=False) + self.resblocks_2 = ConvBlocks(channels=128, out_dims=256, dilations=[1]*14, kernel_size=3, norm_type=norm_type, is_BTC=False) + self.resblocks_3 = ConvBlocks(channels=512, out_dims=512, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) + self.resblocks_4 = ConvBlocks(channels=512, out_dims=out_dim, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) + + self.downsampler = LambdaLayer(lambda x: F.interpolate(x, scale_factor=0.5, mode='linear')) + self.upsampler = LambdaLayer(lambda x: F.interpolate(x, scale_factor=4, mode='linear')) + + self.dropout = nn.Dropout(p=p_dropout) + + def forward(self, x, sty, x_mask=1.): + """ + x: [B, T, C] + sty: [B, C=256] + x_mask: [B, T] + ret: [B, T/2, C] + """ + x = x.transpose(1, 2) # [B, C, T] + x_mask = x_mask[:, None, :] # [B, 1, T] + + x = self.resblocks_0(x) * x_mask # [B, C, T] + + x_mask = self.downsampler(x_mask) # [B, 1, T/2] + x = self.downsampler(x) * x_mask # [B, C, T/2] + x = self.resblocks_1(x) * x_mask # [B, C, T/2] + + x_mask = self.downsampler(x_mask) # [B, 1, T/4] + x = self.downsampler(x) * x_mask # [B, C, T/4] + x = self.resblocks_2(x) * x_mask # [B, C, T/4] + + x_mask = self.downsampler(x_mask) # [B, 1, T/8] + x = self.downsampler(x) * x_mask # [B, C, T/8] + x = self.dropout(x.transpose(1,2)).transpose(1,2) + sty = sty[:, :, None].repeat([1,1,x_mask.shape[2]]) # [B,C=256,T/8] + x = torch.cat([x, sty], dim=1) # [B, C=256+256, T/8] + x = self.resblocks_3(x) * x_mask # [B, C, T/8] + + x_mask = self.upsampler(x_mask) # [B, 1, T/2] + x = self.upsampler(x) * x_mask # [B, C, T/2] + x = self.resblocks_4(x) * x_mask # [B, C, T/2] + + x = x.transpose(1,2) + x_mask = x_mask.squeeze(1) + return x, x_mask + + +class UNetBackbone(nn.Module): + def __init__(self, in_dim=64, out_dim=512, p_dropout=0.5, norm_type='bn'): + super(UNetBackbone, self).__init__() + self.resblocks_0 = ConvBlocks(channels=in_dim, out_dims=64, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) + self.resblocks_1 = ConvBlocks(channels=64, out_dims=128, dilations=[1]*4, kernel_size=3, norm_type=norm_type, is_BTC=False) + self.resblocks_2 = ConvBlocks(channels=128, out_dims=256, dilations=[1]*8, kernel_size=3, norm_type=norm_type, is_BTC=False) + self.resblocks_3 = ConvBlocks(channels=512, out_dims=512, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) + self.resblocks_4 = ConvBlocks(channels=768, out_dims=512, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) # [768 = c3(512) + c2(256)] + self.resblocks_5 = ConvBlocks(channels=640, out_dims=out_dim, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) # [640 = c4(512) + c1(128)] + + self.downsampler = nn.Upsample(scale_factor=0.5, mode='linear') + self.upsampler = nn.Upsample(scale_factor=2, mode='linear') + self.dropout = nn.Dropout(p=p_dropout) + + def forward(self, x, sty, x_mask=1.): + """ + x: [B, T, C] + sty: [B, C=256] + x_mask: [B, T] + ret: [B, T/2, C] + """ + x = x.transpose(1, 2) # [B, C, T] + x_mask = x_mask[:, None, :] # [B, 1, T] + + x0 = self.resblocks_0(x) * x_mask # [B, C, T] + + x_mask = self.downsampler(x_mask) # [B, 1, T/2] + x = self.downsampler(x0) * x_mask # [B, C, T/2] + x1 = self.resblocks_1(x) * x_mask # [B, C, T/2] + + x_mask = self.downsampler(x_mask) # [B, 1, T/4] + x = self.downsampler(x1) * x_mask # [B, C, T/4] + x2 = self.resblocks_2(x) * x_mask # [B, C, T/4] + + x_mask = self.downsampler(x_mask) # [B, 1, T/8] + x = self.downsampler(x2) * x_mask # [B, C, T/8] + x = self.dropout(x.transpose(1,2)).transpose(1,2) + sty = sty[:, :, None].repeat([1,1,x_mask.shape[2]]) # [B,C=256,T/8] + x = torch.cat([x, sty], dim=1) # [B, C=256+256, T/8] + x3 = self.resblocks_3(x) * x_mask # [B, C, T/8] + + x_mask = self.upsampler(x_mask) # [B, 1, T/4] + x = self.upsampler(x3) * x_mask # [B, C, T/4] + x = torch.cat([x, self.dropout(x2.transpose(1,2)).transpose(1,2)], dim=1) # + x4 = self.resblocks_4(x) * x_mask # [B, C, T/4] + + x_mask = self.upsampler(x_mask) # [B, 1, T/2] + x = self.upsampler(x4) * x_mask # [B, C, T/2] + x = torch.cat([x, self.dropout(x1.transpose(1,2)).transpose(1,2)], dim=1) + x5 = self.resblocks_5(x) * x_mask # [B, C, T/2] + + x = x5.transpose(1,2) + x_mask = x_mask.squeeze(1) + return x, x_mask + + +if __name__ == '__main__': + pass diff --git a/Geneface_main/GeneFace/modules/audio2motion/flow_base.py b/Geneface_main/GeneFace/modules/audio2motion/flow_base.py new file mode 100644 index 00000000..d2ff1c62 --- /dev/null +++ b/Geneface_main/GeneFace/modules/audio2motion/flow_base.py @@ -0,0 +1,838 @@ +import scipy +from scipy import linalg +from torch.nn import functional as F +import torch +from torch import nn +import numpy as np + +import modules.audio2motion.utils as utils +from modules.audio2motion.transformer_models import FFTBlocks +from utils.commons.hparams import hparams + + +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + +class WN(torch.nn.Module): + def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, + p_dropout=0, share_cond_layers=False): + super(WN, self).__init__() + assert (kernel_size % 2 == 1) + assert (hidden_channels % 2 == 0) + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + self.share_cond_layers = share_cond_layers + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + + self.drop = nn.Dropout(p_dropout) + + self.use_adapters = hparams.get("use_adapters", False) + if self.use_adapters: + self.adapter_layers = torch.nn.ModuleList() + + if gin_channels != 0 and not share_cond_layers: + cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) + self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') + + for i in range(n_layers): + dilation = dilation_rate ** i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, + dilation=dilation, padding=padding) + in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') + self.res_skip_layers.append(res_skip_layer) + + if self.use_adapters: + adapter_layer = MlpAdapter(in_out_dim=res_skip_channels, hid_dim=res_skip_channels//4) + self.adapter_layers.append(adapter_layer) + + def forward(self, x, x_mask=None, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + if g is not None and not self.share_cond_layers: + g = self.cond_layer(g) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + x_in = self.drop(x_in) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :] + else: + g_l = torch.zeros_like(x_in) + + acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) + + res_skip_acts = self.res_skip_layers[i](acts) + if self.use_adapters: + res_skip_acts = self.adapter_layers[i](res_skip_acts.transpose(1,2)).transpose(1,2) + if i < self.n_layers - 1: + x = (x + res_skip_acts[:, :self.hidden_channels, :]) * x_mask + output = output + res_skip_acts[:, self.hidden_channels:, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + def remove_weight_norm(m): + try: + nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(remove_weight_norm) + + def enable_adapters(self): + if not self.use_adapters: + return + for adapter_layer in self.adapter_layers: + adapter_layer.enable() + + def disable_adapters(self): + if not self.use_adapters: + return + for adapter_layer in self.adapter_layers: + adapter_layer.disable() + +class Permute(nn.Module): + def __init__(self, *args): + super(Permute, self).__init__() + self.args = args + + def forward(self, x): + return x.permute(self.args) + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-4): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + n_dims = len(x.shape) + mean = torch.mean(x, 1, keepdim=True) + variance = torch.mean((x - mean) ** 2, 1, keepdim=True) + + x = (x - mean) * torch.rsqrt(variance + self.eps) + + shape = [1, -1] + [1] * (n_dims - 2) + x = x * self.gamma.view(*shape) + self.beta.view(*shape) + return x + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential( + nn.ReLU(), + nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + + +class ActNorm(nn.Module): + def __init__(self, channels, ddi=False, **kwargs): + super().__init__() + self.channels = channels + self.initialized = not ddi + + self.logs = nn.Parameter(torch.zeros(1, channels, 1)) + self.bias = nn.Parameter(torch.zeros(1, channels, 1)) + + def forward(self, x, x_mask=None, reverse=False, **kwargs): + if x_mask is None: + x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype) + x_len = torch.sum(x_mask, [1, 2]) + if not self.initialized: + self.initialize(x, x_mask) + self.initialized = True + + if reverse: + z = (x - self.bias) * torch.exp(-self.logs) * x_mask + logdet = torch.sum(-self.logs) * x_len + else: + z = (self.bias + torch.exp(self.logs) * x) * x_mask + logdet = torch.sum(self.logs) * x_len # [b] + return z, logdet + + def store_inverse(self): + pass + + def set_ddi(self, ddi): + self.initialized = not ddi + + def initialize(self, x, x_mask): + with torch.no_grad(): + denom = torch.sum(x_mask, [0, 2]) + m = torch.sum(x * x_mask, [0, 2]) / denom + m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom + v = m_sq - (m ** 2) + logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) + + bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype) + logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype) + + self.bias.data.copy_(bias_init) + self.logs.data.copy_(logs_init) + + +class InvConvNear(nn.Module): + def __init__(self, channels, n_split=4, no_jacobian=False, lu=True, n_sqz=2, **kwargs): + super().__init__() + assert (n_split % 2 == 0) + self.channels = channels + self.n_split = n_split + self.n_sqz = n_sqz + self.no_jacobian = no_jacobian + + w_init = torch.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0] + if torch.det(w_init) < 0: + w_init[:, 0] = -1 * w_init[:, 0] + self.lu = lu + if lu: + # LU decomposition can slightly speed up the inverse + np_p, np_l, np_u = linalg.lu(w_init) + np_s = np.diag(np_u) + np_sign_s = np.sign(np_s) + np_log_s = np.log(np.abs(np_s)) + np_u = np.triu(np_u, k=1) + l_mask = np.tril(np.ones(w_init.shape, dtype=float), -1) + eye = np.eye(*w_init.shape, dtype=float) + + self.register_buffer('p', torch.Tensor(np_p.astype(float))) + self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(float))) + self.l = nn.Parameter(torch.Tensor(np_l.astype(float)), requires_grad=True) + self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(float)), requires_grad=True) + self.u = nn.Parameter(torch.Tensor(np_u.astype(float)), requires_grad=True) + self.register_buffer('l_mask', torch.Tensor(l_mask)) + self.register_buffer('eye', torch.Tensor(eye)) + else: + self.weight = nn.Parameter(w_init) + + def forward(self, x, x_mask=None, reverse=False, **kwargs): + b, c, t = x.size() + assert (c % self.n_split == 0) + if x_mask is None: + x_mask = 1 + x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t + else: + x_len = torch.sum(x_mask, [1, 2]) + + x = x.view(b, self.n_sqz, c // self.n_split, self.n_split // self.n_sqz, t) + x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t) + + if self.lu: + self.weight, log_s = self._get_weight() + logdet = log_s.sum() + logdet = logdet * (c / self.n_split) * x_len + else: + logdet = torch.logdet(self.weight) * (c / self.n_split) * x_len # [b] + + if reverse: + if hasattr(self, "weight_inv"): + weight = self.weight_inv + else: + weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype) + logdet = -logdet + else: + weight = self.weight + if self.no_jacobian: + logdet = 0 + + weight = weight.view(self.n_split, self.n_split, 1, 1) + z = F.conv2d(x, weight) + + z = z.view(b, self.n_sqz, self.n_split // self.n_sqz, c // self.n_split, t) + z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask + return z, logdet + + def _get_weight(self): + l, log_s, u = self.l, self.log_s, self.u + l = l * self.l_mask + self.eye + u = u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(log_s)) + weight = torch.matmul(self.p, torch.matmul(l, u)) + return weight, log_s + + def store_inverse(self): + weight, _ = self._get_weight() + self.weight_inv = torch.inverse(weight.float()).to(next(self.parameters()).device) + + +class InvConv(nn.Module): + def __init__(self, channels, no_jacobian=False, lu=True, **kwargs): + super().__init__() + w_shape = [channels, channels] + w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(float) + LU_decomposed = lu + if not LU_decomposed: + # Sample a random orthogonal matrix: + self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init))) + else: + np_p, np_l, np_u = linalg.lu(w_init) + np_s = np.diag(np_u) + np_sign_s = np.sign(np_s) + np_log_s = np.log(np.abs(np_s)) + np_u = np.triu(np_u, k=1) + l_mask = np.tril(np.ones(w_shape, dtype=float), -1) + eye = np.eye(*w_shape, dtype=float) + + self.register_buffer('p', torch.Tensor(np_p.astype(float))) + self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(float))) + self.l = nn.Parameter(torch.Tensor(np_l.astype(float))) + self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(float))) + self.u = nn.Parameter(torch.Tensor(np_u.astype(float))) + self.l_mask = torch.Tensor(l_mask) + self.eye = torch.Tensor(eye) + self.w_shape = w_shape + self.LU = LU_decomposed + self.weight = None + + def get_weight(self, device, reverse): + w_shape = self.w_shape + self.p = self.p.to(device) + self.sign_s = self.sign_s.to(device) + self.l_mask = self.l_mask.to(device) + self.eye = self.eye.to(device) + l = self.l * self.l_mask + self.eye + u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s)) + dlogdet = self.log_s.sum() + if not reverse: + w = torch.matmul(self.p, torch.matmul(l, u)) + else: + l = torch.inverse(l.double()).float() + u = torch.inverse(u.double()).float() + w = torch.matmul(u, torch.matmul(l, self.p.inverse())) + return w.view(w_shape[0], w_shape[1], 1), dlogdet + + def forward(self, x, x_mask=None, reverse=False, **kwargs): + """ + log-det = log|abs(|W|)| * pixels + """ + b, c, t = x.size() + if x_mask is None: + x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t + else: + x_len = torch.sum(x_mask, [1, 2]) + logdet = 0 + if not reverse: + weight, dlogdet = self.get_weight(x.device, reverse) + z = F.conv1d(x, weight) + if logdet is not None: + logdet = logdet + dlogdet * x_len + return z, logdet + else: + if self.weight is None: + weight, dlogdet = self.get_weight(x.device, reverse) + else: + weight, dlogdet = self.weight, self.dlogdet + z = F.conv1d(x, weight) + if logdet is not None: + logdet = logdet - dlogdet * x_len + return z, logdet + + def store_inverse(self): + self.weight, self.dlogdet = self.get_weight('cuda', reverse=True) + + +class Flip(nn.Module): + def forward(self, x, *args, reverse=False, **kwargs): + x = torch.flip(x, [1]) + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + + def store_inverse(self): + pass + + +class CouplingBlock(nn.Module): + def __init__(self, in_channels, hidden_channels, kernel_size, dilation_rate, n_layers, + gin_channels=0, p_dropout=0, sigmoid_scale=False, + share_cond_layers=False, wn=None): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + self.sigmoid_scale = sigmoid_scale + + start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1) + start = torch.nn.utils.weight_norm(start) + self.start = start + # Initializing last layer to 0 makes the affine coupling layers + # do nothing at first. This helps with training stability + end = torch.nn.Conv1d(hidden_channels, in_channels, 1) + end.weight.data.zero_() + end.bias.data.zero_() + self.end = end + self.wn = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels, + p_dropout, share_cond_layers) + if wn is not None: + self.wn.in_layers = wn.in_layers + self.wn.res_skip_layers = wn.res_skip_layers + + def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): + if x_mask is None: + x_mask = 1 + x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:] + + x = self.start(x_0) * x_mask + x = self.wn(x, x_mask, g) + out = self.end(x) + + z_0 = x_0 + m = out[:, :self.in_channels // 2, :] + logs = out[:, self.in_channels // 2:, :] + if self.sigmoid_scale: + logs = torch.log(1e-6 + torch.sigmoid(logs + 2)) + if reverse: + z_1 = (x_1 - m) * torch.exp(-logs) * x_mask + logdet = torch.sum(-logs * x_mask, [1, 2]) + else: + z_1 = (m + torch.exp(logs) * x_1) * x_mask + logdet = torch.sum(logs * x_mask, [1, 2]) + z = torch.cat([z_0, z_1], 1) + return z, logdet + + def store_inverse(self): + self.wn.remove_weight_norm() + + +class GlowFFTBlocks(FFTBlocks): + def __init__(self, hidden_size=128, gin_channels=256, num_layers=2, ffn_kernel_size=5, + dropout=None, num_heads=4, use_pos_embed=True, use_last_norm=True, + norm='ln', use_pos_embed_alpha=True): + super().__init__(hidden_size, num_layers, ffn_kernel_size, dropout, num_heads, use_pos_embed, + use_last_norm, norm, use_pos_embed_alpha) + self.inp_proj = nn.Conv1d(hidden_size + gin_channels, hidden_size, 1) + + def forward(self, x, x_mask=None, g=None): + """ + :param x: [B, C_x, T] + :param x_mask: [B, 1, T] + :param g: [B, C_g, T] + :return: [B, C_x, T] + """ + if g is not None: + x = self.inp_proj(torch.cat([x, g], 1)) + x = x.transpose(1, 2) + x = super(GlowFFTBlocks, self).forward(x, x_mask[:, 0] == 0) + x = x.transpose(1, 2) + return x + + +class TransformerCouplingBlock(nn.Module): + def __init__(self, in_channels, hidden_channels, n_layers, + gin_channels=0, p_dropout=0, sigmoid_scale=False): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + self.sigmoid_scale = sigmoid_scale + + start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1) + self.start = start + # Initializing last layer to 0 makes the affine coupling layers + # do nothing at first. This helps with training stability + end = torch.nn.Conv1d(hidden_channels, in_channels, 1) + end.weight.data.zero_() + end.bias.data.zero_() + self.end = end + self.fft_blocks = GlowFFTBlocks( + hidden_size=hidden_channels, + ffn_kernel_size=3, + gin_channels=gin_channels, + num_layers=n_layers) + + def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): + if x_mask is None: + x_mask = 1 + x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:] + + x = self.start(x_0) * x_mask + x = self.fft_blocks(x, x_mask, g) + out = self.end(x) + + z_0 = x_0 + m = out[:, :self.in_channels // 2, :] + logs = out[:, self.in_channels // 2:, :] + if self.sigmoid_scale: + logs = torch.log(1e-6 + torch.sigmoid(logs + 2)) + if reverse: + z_1 = (x_1 - m) * torch.exp(-logs) * x_mask + logdet = torch.sum(-logs * x_mask, [1, 2]) + else: + z_1 = (m + torch.exp(logs) * x_1) * x_mask + logdet = torch.sum(logs * x_mask, [1, 2]) + z = torch.cat([z_0, z_1], 1) + return z, logdet + + def store_inverse(self): + pass + + +class FreqFFTCouplingBlock(nn.Module): + def __init__(self, in_channels, hidden_channels, n_layers, + gin_channels=0, p_dropout=0, sigmoid_scale=False): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + self.sigmoid_scale = sigmoid_scale + + hs = hidden_channels + stride = 8 + self.start = torch.nn.Conv2d(3, hs, kernel_size=stride * 2, + stride=stride, padding=stride // 2) + end = nn.ConvTranspose2d(hs, 2, kernel_size=stride, stride=stride) + end.weight.data.zero_() + end.bias.data.zero_() + self.end = nn.Sequential( + nn.Conv2d(hs * 3, hs, 3, 1, 1), + nn.ReLU(), + nn.GroupNorm(4, hs), + nn.Conv2d(hs, hs, 3, 1, 1), + end + ) + self.fft_v = FFTBlocks(hidden_size=hs, ffn_kernel_size=1, num_layers=n_layers) + self.fft_h = nn.Sequential( + nn.Conv1d(hs, hs, 3, 1, 1), + nn.ReLU(), + nn.Conv1d(hs, hs, 3, 1, 1), + ) + self.fft_g = nn.Sequential( + nn.Conv1d( + gin_channels - 160, hs, kernel_size=stride * 2, stride=stride, padding=stride // 2), + Permute(0, 2, 1), + FFTBlocks(hidden_size=hs, ffn_kernel_size=1, num_layers=n_layers), + Permute(0, 2, 1), + ) + + def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): + g_, _ = utils.unsqueeze(g) + g_mel = g_[:, :80] + g_txt = g_[:, 80:] + g_mel, _ = utils.squeeze(g_mel) + g_txt, _ = utils.squeeze(g_txt) # [B, C, T] + + if x_mask is None: + x_mask = 1 + x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:] + x = torch.stack([x_0, g_mel[:, :80], g_mel[:, 80:]], 1) + x = self.start(x) # [B, C, N_bins, T] + B, C, N_bins, T = x.shape + + x_v = self.fft_v(x.permute(0, 3, 2, 1).reshape(B * T, N_bins, C)) + x_v = x_v.reshape(B, T, N_bins, -1).permute(0, 3, 2, 1) + # x_v = x + + x_h = self.fft_h(x.permute(0, 2, 1, 3).reshape(B * N_bins, C, T)) + x_h = x_h.reshape(B, N_bins, -1, T).permute(0, 2, 1, 3) + # x_h = x + + x_g = self.fft_g(g_txt)[:, :, None, :].repeat(1, 1, 10, 1) + x = torch.cat([x_v, x_h, x_g], 1) + out = self.end(x) + + z_0 = x_0 + m = out[:, 0] + logs = out[:, 1] + if self.sigmoid_scale: + logs = torch.log(1e-6 + torch.sigmoid(logs + 2)) + if reverse: + z_1 = (x_1 - m) * torch.exp(-logs) * x_mask + logdet = torch.sum(-logs * x_mask, [1, 2]) + else: + z_1 = (m + torch.exp(logs) * x_1) * x_mask + logdet = torch.sum(logs * x_mask, [1, 2]) + z = torch.cat([z_0, z_1], 1) + return z, logdet + + def store_inverse(self): + pass + + + +class ResidualCouplingLayer(nn.Module): + def __init__(self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=0, + gin_channels=0, + mean_only=False, + nn_type='wn'): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + if nn_type == 'wn': + self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, + gin_channels=gin_channels) + # elif nn_type == 'conv': + # self.enc = ConditionalConvBlocks( + # hidden_channels, gin_channels, hidden_channels, [1] * n_layers, kernel_size, + # layers_in_block=1, is_BTC=False) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask=x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = -torch.sum(logs, [1, 2]) + return x, logdet + + +class ResidualCouplingBlock(nn.Module): + def __init__(self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=0, + nn_type='wn'): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append( + ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, + gin_channels=gin_channels, mean_only=True, nn_type=nn_type)) + self.flows.append(Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x, _ = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class Glow(nn.Module): + def __init__(self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_blocks, + n_layers, + p_dropout=0., + n_split=4, + n_sqz=2, + sigmoid_scale=False, + gin_channels=0, + inv_conv_type='near', + share_cond_layers=False, + share_wn_layers=0, + ): + super().__init__() + """ + Note that regularization likes weight decay can leads to Nan error! + """ + + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_blocks = n_blocks + self.n_layers = n_layers + self.p_dropout = p_dropout + self.n_split = n_split + self.n_sqz = n_sqz + self.sigmoid_scale = sigmoid_scale + self.gin_channels = gin_channels + self.share_cond_layers = share_cond_layers + if gin_channels != 0 and share_cond_layers: + cond_layer = torch.nn.Conv1d(gin_channels * n_sqz, 2 * hidden_channels * n_layers, 1) + self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') + wn = None + self.flows = nn.ModuleList() + for b in range(n_blocks): + self.flows.append(ActNorm(channels=in_channels * n_sqz)) + if inv_conv_type == 'near': + self.flows.append(InvConvNear(channels=in_channels * n_sqz, n_split=n_split, n_sqz=n_sqz)) + if inv_conv_type == 'invconv': + self.flows.append(InvConv(channels=in_channels * n_sqz)) + if share_wn_layers > 0: + if b % share_wn_layers == 0: + wn = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels * n_sqz, + p_dropout, share_cond_layers) + self.flows.append( + CouplingBlock( + in_channels * n_sqz, + hidden_channels, + kernel_size=kernel_size, + dilation_rate=dilation_rate, + n_layers=n_layers, + gin_channels=gin_channels * n_sqz, + p_dropout=p_dropout, + sigmoid_scale=sigmoid_scale, + share_cond_layers=share_cond_layers, + wn=wn + )) + + def forward(self, x, x_mask=None, g=None, reverse=False, return_hiddens=False): + """ + x: [B,T,C] + x_mask: [B,T] + g: [B,T,C] + """ + x = x.transpose(1,2) + x_mask = x_mask.unsqueeze(1) + if g is not None: + g = g.transpose(1,2) + + logdet_tot = 0 + if not reverse: + flows = self.flows + else: + flows = reversed(self.flows) + if return_hiddens: + hs = [] + if self.n_sqz > 1: + x, x_mask_ = utils.squeeze(x, x_mask, self.n_sqz) + if g is not None: + g, _ = utils.squeeze(g, x_mask, self.n_sqz) + x_mask = x_mask_ + if self.share_cond_layers and g is not None: + g = self.cond_layer(g) + for f in flows: + x, logdet = f(x, x_mask, g=g, reverse=reverse) + if return_hiddens: + hs.append(x) + logdet_tot += logdet + if self.n_sqz > 1: + x, x_mask = utils.unsqueeze(x, x_mask, self.n_sqz) + + x = x.transpose(1,2) + if return_hiddens: + return x, logdet_tot, hs + return x, logdet_tot + + def store_inverse(self): + def remove_weight_norm(m): + try: + nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(remove_weight_norm) + for f in self.flows: + f.store_inverse() + + +if __name__ == '__main__': + model = Glow(in_channels=64, + hidden_channels=128, + kernel_size=5, + dilation_rate=1, + n_blocks=12, + n_layers=4, + p_dropout=0.0, + n_split=4, + n_sqz=2, + sigmoid_scale=False, + gin_channels=80 + ) + exp = torch.rand([1,1440,64]) + mel = torch.rand([1,1440,80]) + x_mask = torch.ones([1,1440],dtype=torch.float32) + y, logdet = model(exp, x_mask,g=mel, reverse=False) + pred_exp, logdet = model(y, x_mask,g=mel, reverse=False) + # y: [b, t,c=64] + print(" ") \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/audio2motion/multi_length_disc.py b/Geneface_main/GeneFace/modules/audio2motion/multi_length_disc.py new file mode 100644 index 00000000..4a57df2c --- /dev/null +++ b/Geneface_main/GeneFace/modules/audio2motion/multi_length_disc.py @@ -0,0 +1,340 @@ +import numpy as np +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +from modules.audio2motion.cnn_models import LambdaLayer + + +class Discriminator1DFactory(nn.Module): + def __init__(self, time_length, kernel_size=3, in_dim=1, hidden_size=128, norm_type='bn'): + super(Discriminator1DFactory, self).__init__() + padding = kernel_size // 2 + + def discriminator_block(in_filters, out_filters, first=False): + """ + Input: (B, c, T) + Output:(B, c, T//2) + """ + conv = nn.Conv1d(in_filters, out_filters, kernel_size, 2, padding) + block = [ + conv, # padding = kernel//2 + nn.LeakyReLU(0.2, inplace=True), + nn.Dropout2d(0.25) + ] + if norm_type == 'bn' and not first: + block.append(nn.BatchNorm1d(out_filters, 0.8)) + if norm_type == 'in' and not first: + block.append(nn.InstanceNorm1d(out_filters, affine=True)) + block = nn.Sequential(*block) + return block + + if time_length >= 8: + self.model = nn.ModuleList([ + discriminator_block(in_dim, hidden_size, first=True), + discriminator_block(hidden_size, hidden_size), + discriminator_block(hidden_size, hidden_size), + ]) + ds_size = time_length // (2 ** 3) + elif time_length == 3: + self.model = nn.ModuleList([ + nn.Sequential(*[ + nn.Conv1d(in_dim, hidden_size, 3, 1, 0), + nn.LeakyReLU(0.2, inplace=True), + nn.Dropout2d(0.25), + nn.Conv1d(hidden_size, hidden_size, 1, 1, 0), + nn.LeakyReLU(0.2, inplace=True), + nn.Dropout2d(0.25), + nn.BatchNorm1d(hidden_size, 0.8), + nn.Conv1d(hidden_size, hidden_size, 1, 1, 0), + nn.LeakyReLU(0.2, inplace=True), + nn.Dropout2d(0.25), + nn.BatchNorm1d(hidden_size, 0.8) + ]) + ]) + ds_size = 1 + elif time_length == 1: + self.model = nn.ModuleList([ + nn.Sequential(*[ + nn.Linear(in_dim, hidden_size), + nn.LeakyReLU(0.2, inplace=True), + nn.Dropout2d(0.25), + nn.Linear(hidden_size, hidden_size), + nn.LeakyReLU(0.2, inplace=True), + nn.Dropout2d(0.25), + ]) + ]) + ds_size = 1 + + self.adv_layer = nn.Linear(hidden_size * ds_size, 1) + + def forward(self, x): + """ + + :param x: [B, C, T] + :return: validity: [B, 1], h: List of hiddens + """ + h = [] + if x.shape[-1] == 1: + x = x.squeeze(-1) + for l in self.model: + x = l(x) + h.append(x) + if x.ndim == 2: + b, ct = x.shape + use_sigmoid = True + else: + b, c, t = x.shape + ct = c * t + use_sigmoid = False + x = x.view(b, ct) + validity = self.adv_layer(x) # [B, 1] + if use_sigmoid: + validity = torch.sigmoid(validity) + return validity, h + + +class CosineDiscriminator1DFactory(nn.Module): + def __init__(self, time_length, kernel_size=3, in_dim=1, hidden_size=128, norm_type='bn'): + super().__init__() + padding = kernel_size // 2 + + def discriminator_block(in_filters, out_filters, first=False): + """ + Input: (B, c, T) + Output:(B, c, T//2) + """ + conv = nn.Conv1d(in_filters, out_filters, kernel_size, 2, padding) + block = [ + conv, # padding = kernel//2 + nn.LeakyReLU(0.2, inplace=True), + nn.Dropout2d(0.25) + ] + if norm_type == 'bn' and not first: + block.append(nn.BatchNorm1d(out_filters, 0.8)) + if norm_type == 'in' and not first: + block.append(nn.InstanceNorm1d(out_filters, affine=True)) + block = nn.Sequential(*block) + return block + + self.model1 = nn.ModuleList([ + discriminator_block(in_dim, hidden_size, first=True), + discriminator_block(hidden_size, hidden_size), + discriminator_block(hidden_size, hidden_size), + ]) + + self.model2 = nn.ModuleList([ + discriminator_block(in_dim, hidden_size, first=True), + discriminator_block(hidden_size, hidden_size), + discriminator_block(hidden_size, hidden_size), + ]) + + self.relu = nn.ReLU() + def forward(self, x1, x2): + """ + + :param x1: [B, C, T] + :param x2: [B, C, T] + :return: validity: [B, 1], h: List of hiddens + """ + h1, h2 = [], [] + for l in self.model1: + x1 = l(x1) + h1.append(x1) + for l in self.model2: + x2 = l(x2) + h2.append(x1) + b,c,t = x1.shape + x1 = x1.view(b, c*t) + x2 = x2.view(b, c*t) + x1 = self.relu(x1) + x2 = self.relu(x2) + # x1 = F.normalize(x1, p=2, dim=1) + # x2 = F.normalize(x2, p=2, dim=1) + validity = F.cosine_similarity(x1, x2) + return validity, [h1,h2] + + +class MultiWindowDiscriminator(nn.Module): + def __init__(self, time_lengths, cond_dim=80, in_dim=64, kernel_size=3, hidden_size=128, disc_type='standard', norm_type='bn', reduction='sum'): + super(MultiWindowDiscriminator, self).__init__() + self.win_lengths = time_lengths + self.reduction = reduction + self.disc_type = disc_type + + if cond_dim > 0: + self.use_cond = True + self.cond_proj_layers = nn.ModuleList() + self.in_proj_layers = nn.ModuleList() + else: + self.use_cond = False + + self.conv_layers = nn.ModuleList() + for time_length in time_lengths: + conv_layer = [ + Discriminator1DFactory( + time_length, kernel_size, in_dim=64, hidden_size=hidden_size, + norm_type=norm_type) if self.disc_type == 'standard' + else CosineDiscriminator1DFactory(time_length, kernel_size, in_dim=64, + hidden_size=hidden_size,norm_type=norm_type) + ] + self.conv_layers += conv_layer + if self.use_cond: + self.cond_proj_layers.append(nn.Linear(cond_dim, 64)) + self.in_proj_layers.append(nn.Linear(in_dim, 64)) + + def clip(self, x, cond, x_len, win_length, start_frames=None): + '''Ramdom clip x to win_length. + Args: + x (tensor) : (B, T, C). + cond (tensor) : (B, T, H). + x_len (tensor) : (B,). + win_length (int): target clip length + + Returns: + (tensor) : (B, c_in, win_length, n_bins). + + ''' + clip_from_same_frame = start_frames is None + T_start = 0 + # T_end = x_len.max() - win_length + T_end = x_len.min() - win_length + if T_end < 0: + return None, None, start_frames + T_end = T_end.item() + if start_frames is None: + start_frame = np.random.randint(low=T_start, high=T_end + 1) + start_frames = [start_frame] * x.size(0) + else: + start_frame = start_frames[0] + + + if clip_from_same_frame: + x_batch = x[:, start_frame: start_frame + win_length, :] + c_batch = cond[:, start_frame: start_frame + win_length, :] if cond is not None else None + else: + x_lst = [] + c_lst = [] + for i, start_frame in enumerate(start_frames): + x_lst.append(x[i, start_frame: start_frame + win_length, :]) + if cond is not None: + c_lst.append(cond[i, start_frame: start_frame + win_length, :]) + x_batch = torch.stack(x_lst, dim=0) + if cond is None: + c_batch = None + else: + c_batch = torch.stack(c_lst, dim=0) + return x_batch, c_batch, start_frames + + def forward(self, x, x_len, cond=None, start_frames_wins=None): + ''' + Args: + x (tensor): input mel, (B, T, C). + x_length (tensor): len of per mel. (B,). + + Returns: + tensor : (B). + ''' + validity = [] + if start_frames_wins is None: + start_frames_wins = [None] * len(self.conv_layers) + h = [] + for i, start_frames in zip(range(len(self.conv_layers)), start_frames_wins): + x_clip, c_clip, start_frames = self.clip( + x, cond, x_len, self.win_lengths[i], start_frames) # (B, win_length, C) + start_frames_wins[i] = start_frames + if x_clip is None: + continue + if self.disc_type == 'standard': + if self.use_cond: + x_clip = self.in_proj_layers[i](x_clip) # (B, T, C) + c_clip = self.cond_proj_layers[i](c_clip) + x_clip = x_clip + c_clip + validity_pred, h_ = self.conv_layers[i](x_clip.transpose(1,2)) + elif self.disc_type == 'cosine': + assert self.use_cond is True + x_clip = self.in_proj_layers[i](x_clip) # (B, T, C) + c_clip = self.cond_proj_layers[i](c_clip) + validity_pred, h_ = self.conv_layers[i](x_clip.transpose(1,2), c_clip.transpose(1,2)) + else: + raise NotImplementedError + + h += h_ + validity.append(validity_pred) + if len(validity) != len(self.conv_layers): + return None, start_frames_wins, h + if self.reduction == 'sum': + validity = sum(validity) # [B] + elif self.reduction == 'stack': + validity = torch.stack(validity, -1) # [B, W_L] + return validity, start_frames_wins, h + + +class Discriminator(nn.Module): + def __init__(self, x_dim=80, y_dim=64, disc_type='standard', + uncond_disc=False, kernel_size=3, hidden_size=128, norm_type='bn', reduction='sum', time_lengths=(8,16,32)): + """_summary_ + + Args: + time_lengths (list, optional): the list of window size. Defaults to [32, 64, 128]. + x_dim (int, optional): the dim of audio features. Defaults to 80, corresponding to mel-spec. + y_dim (int, optional): the dim of facial coeff. Defaults to 64, correspond to exp; other options can be 7(pose) or 71(exp+pose). + kernel (tuple, optional): _description_. Defaults to (3, 3). + c_in (int, optional): _description_. Defaults to 1. + hidden_size (int, optional): _description_. Defaults to 128. + norm_type (str, optional): _description_. Defaults to 'bn'. + reduction (str, optional): _description_. Defaults to 'sum'. + uncond_disc (bool, optional): _description_. Defaults to False. + """ + super(Discriminator, self).__init__() + self.time_lengths = time_lengths + self.x_dim, self.y_dim = x_dim, y_dim + self.disc_type = disc_type + self.reduction = reduction + self.uncond_disc = uncond_disc + + if uncond_disc: + self.x_dim = 0 + cond_dim = 0 + + else: + cond_dim = 64 + self.mel_encoder = nn.Sequential(*[ + nn.Conv1d(self.x_dim, 64, 3, 1, 1, bias=False), + nn.BatchNorm1d(64), + nn.GELU(), + nn.Conv1d(64, cond_dim, 3, 1, 1, bias=False) + ]) + + self.disc = MultiWindowDiscriminator( + time_lengths=self.time_lengths, + in_dim=self.y_dim, + cond_dim=cond_dim, + kernel_size=kernel_size, + hidden_size=hidden_size, norm_type=norm_type, + reduction=reduction, + disc_type=disc_type + ) + self.downsampler = LambdaLayer(lambda x: F.interpolate(x.transpose(1,2), scale_factor=0.5, mode='nearest').transpose(1,2)) + + @property + def device(self): + return self.disc.parameters().__next__().device + + def forward(self,x, batch, start_frames_wins=None): + """ + + :param x: [B, T, C] + :param cond: [B, T, cond_size] + :return: + """ + x = x.to(self.device) + if not self.uncond_disc: + mel = self.downsampler(batch['mel'].to(self.device)) + mel_feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2) + else: + mel_feat = None + x_len = x.sum(-1).ne(0).int().sum([1]) + disc_confidence, start_frames_wins, h = self.disc(x, x_len, mel_feat, start_frames_wins=start_frames_wins) + return disc_confidence + diff --git a/Geneface_main/GeneFace/modules/audio2motion/transformer_base.py b/Geneface_main/GeneFace/modules/audio2motion/transformer_base.py new file mode 100644 index 00000000..39bbe007 --- /dev/null +++ b/Geneface_main/GeneFace/modules/audio2motion/transformer_base.py @@ -0,0 +1,988 @@ +import math +import torch +from torch import nn +from torch.nn import Parameter +import torch.onnx.operators +import torch.nn.functional as F +from collections import defaultdict + + +def make_positions(tensor, padding_idx): + """Replace non-padding symbols with their position numbers. + + Position numbers begin at padding_idx+1. Padding symbols are ignored. + """ + # The series of casts and type-conversions here are carefully + # balanced to both work with ONNX export and XLA. In particular XLA + # prefers ints, cumsum defaults to output longs, and ONNX doesn't know + # how to handle the dtype kwarg in cumsum. + mask = tensor.ne(padding_idx).int() + return ( + torch.cumsum(mask, dim=1).type_as(mask) * mask + ).long() + padding_idx + + +def softmax(x, dim): + return F.softmax(x, dim=dim, dtype=torch.float32) + + +INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0) + +def _get_full_incremental_state_key(module_instance, key): + module_name = module_instance.__class__.__name__ + + # assign a unique ID to each module instance, so that incremental state is + # not shared across module instances + if not hasattr(module_instance, '_instance_id'): + INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1 + module_instance._instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name] + + return '{}.{}.{}'.format(module_name, module_instance._instance_id, key) + + + +def get_incremental_state(module, incremental_state, key): + """Helper for getting incremental state for an nn.Module.""" + full_key = _get_full_incremental_state_key(module, key) + if incremental_state is None or full_key not in incremental_state: + return None + return incremental_state[full_key] + + +def set_incremental_state(module, incremental_state, key, value): + """Helper for setting incremental state for an nn.Module.""" + if incremental_state is not None: + full_key = _get_full_incremental_state_key(module, key) + incremental_state[full_key] = value + + + +class Reshape(nn.Module): + def __init__(self, *args): + super(Reshape, self).__init__() + self.shape = args + + def forward(self, x): + return x.view(self.shape) + + +class Permute(nn.Module): + def __init__(self, *args): + super(Permute, self).__init__() + self.args = args + + def forward(self, x): + return x.permute(self.args) + + +class LinearNorm(torch.nn.Module): + def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): + super(LinearNorm, self).__init__() + self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) + + torch.nn.init.xavier_uniform_( + self.linear_layer.weight, + gain=torch.nn.init.calculate_gain(w_init_gain)) + + def forward(self, x): + return self.linear_layer(x) + + +class ConvNorm(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, + padding=None, dilation=1, bias=True, w_init_gain='linear'): + super(ConvNorm, self).__init__() + if padding is None: + assert (kernel_size % 2 == 1) + padding = int(dilation * (kernel_size - 1) / 2) + + self.conv = torch.nn.Conv1d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, + bias=bias) + + torch.nn.init.xavier_uniform_( + self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) + + def forward(self, signal): + conv_signal = self.conv(signal) + return conv_signal + + +def Embedding(num_embeddings, embedding_dim, padding_idx=None): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) + if padding_idx is not None: + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +class GroupNorm1DTBC(nn.GroupNorm): + def forward(self, input): + return super(GroupNorm1DTBC, self).forward(input.permute(1, 2, 0)).permute(2, 0, 1) + + +def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): + if not export and torch.cuda.is_available(): + try: + from apex.normalization import FusedLayerNorm + return FusedLayerNorm(normalized_shape, eps, elementwise_affine) + except ImportError: + pass + return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.) + return m + + +class SinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length. + + Padding symbols are ignored. + """ + + def __init__(self, embedding_dim, padding_idx, init_size=1024): + super().__init__() + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.weights = SinusoidalPositionalEmbedding.get_embedding( + init_size, + embedding_dim, + padding_idx, + ) + self.register_buffer('_float_tensor', torch.FloatTensor(1)) + + @staticmethod + def get_embedding(num_embeddings, embedding_dim, padding_idx=None): + """Build sinusoidal embeddings. + + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) + emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + return emb + + def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs): + """Input is expected to be of size [bsz x seqlen].""" + bsz, seq_len = input.shape[:2] + max_pos = self.padding_idx + 1 + seq_len + if self.weights is None or max_pos > self.weights.size(0): + # recompute/expand embeddings if needed + self.weights = SinusoidalPositionalEmbedding.get_embedding( + max_pos, + self.embedding_dim, + self.padding_idx, + ) + self.weights = self.weights.to(self._float_tensor) + + if incremental_state is not None: + # positions is the same for every token when decoding a single step + pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len + return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) + + positions = make_positions(input, self.padding_idx) if positions is None else positions + return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() + + def max_positions(self): + """Maximum number of supported positions.""" + return int(1e5) # an arbitrary large number + + +class ConvTBC(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, padding=0): + super(ConvTBC, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.padding = padding + + self.weight = torch.nn.Parameter(torch.Tensor( + self.kernel_size, in_channels, out_channels)) + self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) + + def forward(self, input): + return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding) + + +class MultiheadAttention(nn.Module): + def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True, + add_bias_kv=False, add_zero_attn=False, self_attention=False, + encoder_decoder_attention=False): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \ + 'value to be of the same size' + + if self.qkv_same_dim: + self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim)) + else: + self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) + self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) + self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) + + if bias: + self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim)) + else: + self.register_parameter('in_proj_bias', None) + + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.reset_parameters() + + self.enable_torch_version = False + if hasattr(F, "multi_head_attention_forward"): + self.enable_torch_version = True + else: + self.enable_torch_version = False + self.last_attn_probs = None + + def reset_parameters(self): + if self.qkv_same_dim: + nn.init.xavier_uniform_(self.in_proj_weight) + else: + nn.init.xavier_uniform_(self.k_proj_weight) + nn.init.xavier_uniform_(self.v_proj_weight) + nn.init.xavier_uniform_(self.q_proj_weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.in_proj_bias is not None: + nn.init.constant_(self.in_proj_bias, 0.) + nn.init.constant_(self.out_proj.bias, 0.) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + + def forward( + self, + query, key, value, + key_padding_mask=None, + incremental_state=None, + need_weights=True, + static_kv=False, + attn_mask=None, + before_softmax=False, + need_head_weights=False, + enc_dec_attn_constraint_mask=None, + reset_attn_weight=None + ): + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None: + if self.qkv_same_dim: + return F.multi_head_attention_forward(query, key, value, + self.embed_dim, self.num_heads, + self.in_proj_weight, + self.in_proj_bias, self.bias_k, self.bias_v, + self.add_zero_attn, self.dropout, + self.out_proj.weight, self.out_proj.bias, + self.training, key_padding_mask, need_weights, + attn_mask) + else: + return F.multi_head_attention_forward(query, key, value, + self.embed_dim, self.num_heads, + torch.empty([0]), + self.in_proj_bias, self.bias_k, self.bias_v, + self.add_zero_attn, self.dropout, + self.out_proj.weight, self.out_proj.bias, + self.training, key_padding_mask, need_weights, + attn_mask, use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight) + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if 'prev_key' in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + # self-attention + q, k, v = self.in_proj_qkv(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.in_proj_q(query) + if key is None: + assert value is None + k = v = None + else: + k = self.in_proj_k(key) + v = self.in_proj_v(key) + + else: + q = self.in_proj_q(query) + k = self.in_proj_k(key) + v = self.in_proj_v(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1) + + q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if 'prev_key' in saved_state: + prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + k = torch.cat((prev_key, k), dim=1) + if 'prev_value' in saved_state: + prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + v = torch.cat((prev_value, v), dim=1) + if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None: + prev_key_padding_mask = saved_state['prev_key_padding_mask'] + if static_kv: + key_padding_mask = prev_key_padding_mask + else: + key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1) + + saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state['prev_key_padding_mask'] = key_padding_mask + + self._set_input_buffer(incremental_state, saved_state) + + src_len = k.size(1) + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]): + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + if len(attn_mask.shape) == 2: + attn_mask = attn_mask.unsqueeze(0) + elif len(attn_mask.shape) == 3: + attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape( + bsz * self.num_heads, tgt_len, src_len) + attn_weights = attn_weights + attn_mask + + if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.masked_fill( + enc_dec_attn_constraint_mask.unsqueeze(2).bool(), + -1e8, + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + -1e8, + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v + + attn_weights_float = softmax(attn_weights, dim=-1) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) + + if reset_attn_weight is not None: + if reset_attn_weight: + self.last_attn_probs = attn_probs.detach() + else: + assert self.last_attn_probs is not None + attn_probs = self.last_attn_probs + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + + if need_weights: + attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + else: + attn_weights = None + + return attn, (attn_weights, attn_logits) + + def in_proj_qkv(self, query): + return self._in_proj(query).chunk(3, dim=-1) + + def in_proj_q(self, query): + if self.qkv_same_dim: + return self._in_proj(query, end=self.embed_dim) + else: + bias = self.in_proj_bias + if bias is not None: + bias = bias[:self.embed_dim] + return F.linear(query, self.q_proj_weight, bias) + + def in_proj_k(self, key): + if self.qkv_same_dim: + return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim) + else: + weight = self.k_proj_weight + bias = self.in_proj_bias + if bias is not None: + bias = bias[self.embed_dim:2 * self.embed_dim] + return F.linear(key, weight, bias) + + def in_proj_v(self, value): + if self.qkv_same_dim: + return self._in_proj(value, start=2 * self.embed_dim) + else: + weight = self.v_proj_weight + bias = self.in_proj_bias + if bias is not None: + bias = bias[2 * self.embed_dim:] + return F.linear(value, weight, bias) + + def _in_proj(self, input, start=0, end=None): + weight = self.in_proj_weight + bias = self.in_proj_bias + weight = weight[start:end, :] + if bias is not None: + bias = bias[start:end] + return F.linear(input, weight, bias) + + def _get_input_buffer(self, incremental_state): + return get_incremental_state( + self, + incremental_state, + 'attn_state', + ) or {} + + def _set_input_buffer(self, incremental_state, buffer): + set_incremental_state( + self, + incremental_state, + 'attn_state', + buffer, + ) + + def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz): + return attn_weights + + def clear_buffer(self, incremental_state=None): + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if 'prev_key' in saved_state: + del saved_state['prev_key'] + if 'prev_value' in saved_state: + del saved_state['prev_value'] + self._set_input_buffer(incremental_state, saved_state) + + +class Swish(torch.autograd.Function): + @staticmethod + def forward(ctx, i): + result = i * torch.sigmoid(i) + ctx.save_for_backward(i) + return result + + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_variables[0] + sigmoid_i = torch.sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + + +class CustomSwish(nn.Module): + def forward(self, input_tensor): + return Swish.apply(input_tensor) + + +class TransformerFFNLayer(nn.Module): + def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'): + super().__init__() + self.kernel_size = kernel_size + self.dropout = dropout + self.act = act + if padding == 'SAME': + self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2) + elif padding == 'LEFT': + self.ffn_1 = nn.Sequential( + nn.ConstantPad1d((kernel_size - 1, 0), 0.0), + nn.Conv1d(hidden_size, filter_size, kernel_size) + ) + self.ffn_2 = Linear(filter_size, hidden_size) + if self.act == 'swish': + self.swish_fn = CustomSwish() + + def forward(self, x, incremental_state=None): + # x: T x B x C + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if 'prev_input' in saved_state: + prev_input = saved_state['prev_input'] + x = torch.cat((prev_input, x), dim=0) + x = x[-self.kernel_size:] + saved_state['prev_input'] = x + self._set_input_buffer(incremental_state, saved_state) + + x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1) + x = x * self.kernel_size ** -0.5 + + if incremental_state is not None: + x = x[-1:] + if self.act == 'gelu': + x = F.gelu(x) + if self.act == 'relu': + x = F.relu(x) + if self.act == 'swish': + x = self.swish_fn(x) + x = F.dropout(x, self.dropout, training=self.training) + x = self.ffn_2(x) + return x + + def _get_input_buffer(self, incremental_state): + return get_incremental_state( + self, + incremental_state, + 'f', + ) or {} + + def _set_input_buffer(self, incremental_state, buffer): + set_incremental_state( + self, + incremental_state, + 'f', + buffer, + ) + + def clear_buffer(self, incremental_state): + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if 'prev_input' in saved_state: + del saved_state['prev_input'] + self._set_input_buffer(incremental_state, saved_state) + + +class BatchNorm1dTBC(nn.Module): + def __init__(self, c): + super(BatchNorm1dTBC, self).__init__() + self.bn = nn.BatchNorm1d(c) + + def forward(self, x): + """ + + :param x: [T, B, C] + :return: [T, B, C] + """ + x = x.permute(1, 2, 0) # [B, C, T] + x = self.bn(x) # [B, C, T] + x = x.permute(2, 0, 1) # [T, B, C] + return x + + +class EncSALayer(nn.Module): + def __init__(self, c, num_heads, dropout, attention_dropout=0.1, + relu_dropout=0.1, kernel_size=9, padding='SAME', norm='ln', act='gelu'): + super().__init__() + self.c = c + self.dropout = dropout + self.num_heads = num_heads + if num_heads > 0: + if norm == 'ln': + self.layer_norm1 = LayerNorm(c) + elif norm == 'bn': + self.layer_norm1 = BatchNorm1dTBC(c) + elif norm == 'gn': + self.layer_norm1 = GroupNorm1DTBC(8, c) + self.self_attn = MultiheadAttention( + self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False) + if norm == 'ln': + self.layer_norm2 = LayerNorm(c) + elif norm == 'bn': + self.layer_norm2 = BatchNorm1dTBC(c) + elif norm == 'gn': + self.layer_norm2 = GroupNorm1DTBC(8, c) + self.ffn = TransformerFFNLayer( + c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act) + + def forward(self, x, encoder_padding_mask=None, **kwargs): + layer_norm_training = kwargs.get('layer_norm_training', None) + if layer_norm_training is not None: + self.layer_norm1.training = layer_norm_training + self.layer_norm2.training = layer_norm_training + if self.num_heads > 0: + residual = x + x = self.layer_norm1(x) + x, _, = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=encoder_padding_mask + ) + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None] + + residual = x + x = self.layer_norm2(x) + x = self.ffn(x) + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None] + return x + + +class DecSALayer(nn.Module): + def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, + kernel_size=9, act='gelu', norm='ln'): + super().__init__() + self.c = c + self.dropout = dropout + if norm == 'ln': + self.layer_norm1 = LayerNorm(c) + elif norm == 'gn': + self.layer_norm1 = GroupNorm1DTBC(8, c) + self.self_attn = MultiheadAttention( + c, num_heads, self_attention=True, dropout=attention_dropout, bias=False + ) + if norm == 'ln': + self.layer_norm2 = LayerNorm(c) + elif norm == 'gn': + self.layer_norm2 = GroupNorm1DTBC(8, c) + self.encoder_attn = MultiheadAttention( + c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False, + ) + if norm == 'ln': + self.layer_norm3 = LayerNorm(c) + elif norm == 'gn': + self.layer_norm3 = GroupNorm1DTBC(8, c) + self.ffn = TransformerFFNLayer( + c, 4 * c, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act) + + def forward( + self, + x, + encoder_out=None, + encoder_padding_mask=None, + incremental_state=None, + self_attn_mask=None, + self_attn_padding_mask=None, + attn_out=None, + reset_attn_weight=None, + **kwargs, + ): + layer_norm_training = kwargs.get('layer_norm_training', None) + if layer_norm_training is not None: + self.layer_norm1.training = layer_norm_training + self.layer_norm2.training = layer_norm_training + self.layer_norm3.training = layer_norm_training + residual = x + x = self.layer_norm1(x) + x, _ = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + incremental_state=incremental_state, + attn_mask=self_attn_mask + ) + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + + attn_logits = None + if encoder_out is not None or attn_out is not None: + residual = x + x = self.layer_norm2(x) + if encoder_out is not None: + x, attn = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + incremental_state=incremental_state, + static_kv=True, + enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state, + 'enc_dec_attn_constraint_mask'), + reset_attn_weight=reset_attn_weight + ) + attn_logits = attn[1] + elif attn_out is not None: + x = self.encoder_attn.in_proj_v(attn_out) + if encoder_out is not None or attn_out is not None: + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + + residual = x + x = self.layer_norm3(x) + x = self.ffn(x, incremental_state=incremental_state) + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + return x, attn_logits + + def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None): + self.encoder_attn.clear_buffer(incremental_state) + self.ffn.clear_buffer(incremental_state) + + def set_buffer(self, name, tensor, incremental_state): + return set_incremental_state(self, incremental_state, name, tensor) + + +class ConvBlock(nn.Module): + def __init__(self, idim=80, n_chans=256, kernel_size=3, stride=1, norm='gn', dropout=0): + super().__init__() + self.conv = ConvNorm(idim, n_chans, kernel_size, stride=stride) + self.norm = norm + if self.norm == 'bn': + self.norm = nn.BatchNorm1d(n_chans) + elif self.norm == 'in': + self.norm = nn.InstanceNorm1d(n_chans, affine=True) + elif self.norm == 'gn': + self.norm = nn.GroupNorm(n_chans // 16, n_chans) + elif self.norm == 'ln': + self.norm = LayerNorm(n_chans // 16, n_chans) + elif self.norm == 'wn': + self.conv = torch.nn.utils.weight_norm(self.conv.conv) + self.dropout = nn.Dropout(dropout) + self.relu = nn.ReLU() + + def forward(self, x): + """ + + :param x: [B, C, T] + :return: [B, C, T] + """ + x = self.conv(x) + if not isinstance(self.norm, str): + if self.norm == 'none': + pass + elif self.norm == 'ln': + x = self.norm(x.transpose(1, 2)).transpose(1, 2) + else: + x = self.norm(x) + x = self.relu(x) + x = self.dropout(x) + return x + + +class ConvStacks(nn.Module): + def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn', + dropout=0, strides=None, res=True): + super().__init__() + self.conv = torch.nn.ModuleList() + self.kernel_size = kernel_size + self.res = res + self.in_proj = Linear(idim, n_chans) + if strides is None: + strides = [1] * n_layers + else: + assert len(strides) == n_layers + for idx in range(n_layers): + self.conv.append(ConvBlock( + n_chans, n_chans, kernel_size, stride=strides[idx], norm=norm, dropout=dropout)) + self.out_proj = Linear(n_chans, odim) + + def forward(self, x, return_hiddens=False): + """ + + :param x: [B, T, H] + :return: [B, T, H] + """ + x = self.in_proj(x) + x = x.transpose(1, -1) # (B, idim, Tmax) + hiddens = [] + for f in self.conv: + x_ = f(x) + x = x + x_ if self.res else x_ # (B, C, Tmax) + hiddens.append(x) + x = x.transpose(1, -1) + x = self.out_proj(x) # (B, Tmax, H) + if return_hiddens: + hiddens = torch.stack(hiddens, 1) # [B, L, C, T] + return x, hiddens + return x + + +class ConvGlobalStacks(nn.Module): + def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn', dropout=0, + strides=[2, 2, 2, 2, 2]): + super().__init__() + self.conv = torch.nn.ModuleList() + self.pooling = torch.nn.ModuleList() + self.kernel_size = kernel_size + self.in_proj = Linear(idim, n_chans) + for idx in range(n_layers): + self.conv.append(ConvBlock(n_chans, n_chans, kernel_size, stride=strides[idx], + norm=norm, dropout=dropout)) + self.pooling.append(nn.MaxPool1d(strides[idx])) + self.out_proj = Linear(n_chans, odim) + + def forward(self, x): + """ + + :param x: [B, T, H] + :return: [B, T, H] + """ + x = self.in_proj(x) + x = x.transpose(1, -1) # (B, idim, Tmax) + for f, p in zip(self.conv, self.pooling): + x = f(x) # (B, C, T) + x = x.transpose(1, -1) + x = self.out_proj(x.mean(1)) # (B, H) + return x + + +class ConvDecoder(nn.Module): + def __init__(self, c, dropout, kernel_size=9, act='gelu'): + super().__init__() + self.c = c + self.dropout = dropout + + self.pre_convs = nn.ModuleList() + self.pre_lns = nn.ModuleList() + for i in range(2): + self.pre_convs.append(TransformerFFNLayer( + c, c * 2, padding='LEFT', kernel_size=kernel_size, dropout=dropout, act=act)) + self.pre_lns.append(LayerNorm(c)) + + self.layer_norm_attn = LayerNorm(c) + self.encoder_attn = MultiheadAttention(c, 1, encoder_decoder_attention=True, bias=False) + + self.post_convs = nn.ModuleList() + self.post_lns = nn.ModuleList() + for i in range(8): + self.post_convs.append(TransformerFFNLayer( + c, c * 2, padding='LEFT', kernel_size=kernel_size, dropout=dropout, act=act)) + self.post_lns.append(LayerNorm(c)) + + def forward( + self, + x, + encoder_out=None, + encoder_padding_mask=None, + incremental_state=None, + **kwargs, + ): + attn_logits = None + for conv, ln in zip(self.pre_convs, self.pre_lns): + residual = x + x = ln(x) + x = conv(x) + residual + if encoder_out is not None: + residual = x + x = self.layer_norm_attn(x) + x, attn = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + incremental_state=incremental_state, + static_kv=True, + enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state, + 'enc_dec_attn_constraint_mask'), + ) + attn_logits = attn[1] + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + for conv, ln in zip(self.post_convs, self.post_lns): + residual = x + x = ln(x) + x = conv(x) + residual + return x, attn_logits + + def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None): + self.encoder_attn.clear_buffer(incremental_state) + self.ffn.clear_buffer(incremental_state) + + def set_buffer(self, name, tensor, incremental_state): + return set_incremental_state(self, incremental_state, name, tensor) diff --git a/Geneface_main/GeneFace/modules/audio2motion/transformer_models.py b/Geneface_main/GeneFace/modules/audio2motion/transformer_models.py new file mode 100644 index 00000000..70603980 --- /dev/null +++ b/Geneface_main/GeneFace/modules/audio2motion/transformer_models.py @@ -0,0 +1,209 @@ +from tkinter.tix import X_REGION +from numpy import isin +import torch +import torch.nn as nn +from modules.audio2motion.transformer_base import * + +DEFAULT_MAX_SOURCE_POSITIONS = 2000 +DEFAULT_MAX_TARGET_POSITIONS = 2000 + + +class TransformerEncoderLayer(nn.Module): + def __init__(self, hidden_size, dropout, kernel_size=None, num_heads=2, norm='ln'): + super().__init__() + self.hidden_size = hidden_size + self.dropout = dropout + self.num_heads = num_heads + self.op = EncSALayer( + hidden_size, num_heads, dropout=dropout, + attention_dropout=0.0, relu_dropout=dropout, + kernel_size=kernel_size + if kernel_size is not None else 9, + padding='SAME', + norm=norm, act='gelu' + ) + + def forward(self, x, **kwargs): + return self.op(x, **kwargs) + + +###################### +# fastspeech modules +###################### +class LayerNorm(torch.nn.LayerNorm): + """Layer normalization module. + :param int nout: output dim size + :param int dim: dimension to be normalized + """ + + def __init__(self, nout, dim=-1, eps=1e-5): + """Construct an LayerNorm object.""" + super(LayerNorm, self).__init__(nout, eps=eps) + self.dim = dim + + def forward(self, x): + """Apply layer normalization. + :param torch.Tensor x: input tensor + :return: layer normalized tensor + :rtype torch.Tensor + """ + if self.dim == -1: + return super(LayerNorm, self).forward(x) + return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1) + + +class FFTBlocks(nn.Module): + def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=None, + num_heads=2, use_pos_embed=True, use_last_norm=True, norm='ln', + use_pos_embed_alpha=True): + super().__init__() + self.num_layers = num_layers + embed_dim = self.hidden_size = hidden_size + self.dropout = dropout if dropout is not None else 0.1 + self.use_pos_embed = use_pos_embed + self.use_last_norm = use_last_norm + if use_pos_embed: + self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS + self.padding_idx = 0 + self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1 + self.embed_positions = SinusoidalPositionalEmbedding( + embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS, + ) + + self.layers = nn.ModuleList([]) + self.layers.extend([ + TransformerEncoderLayer(self.hidden_size, self.dropout, + kernel_size=ffn_kernel_size, num_heads=num_heads, + norm=norm) + for _ in range(self.num_layers) + ]) + if self.use_last_norm: + if norm == 'ln': + self.layer_norm = nn.LayerNorm(embed_dim) + elif norm == 'bn': + self.layer_norm = BatchNorm1dTBC(embed_dim) + elif norm == 'gn': + self.layer_norm = GroupNorm1DTBC(8, embed_dim) + else: + self.layer_norm = None + + def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False): + """ + :param x: [B, T, C] + :param padding_mask: [B, T] + :return: [B, T, C] or [L, B, T, C] + """ + padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask + nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1] + if self.use_pos_embed: + positions = self.pos_embed_alpha * self.embed_positions(x[..., 0]) + x = x + positions + x = F.dropout(x, p=self.dropout, training=self.training) + # B x T x C -> T x B x C + x = x.transpose(0, 1) * nonpadding_mask_TB + hiddens = [] + for layer in self.layers: + x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB + hiddens.append(x) + if self.use_last_norm: + x = self.layer_norm(x) * nonpadding_mask_TB + if return_hiddens: + x = torch.stack(hiddens, 0) # [L, T, B, C] + x = x.transpose(1, 2) # [L, B, T, C] + else: + x = x.transpose(0, 1) # [B, T, C] + return x + +class SequentialSA(nn.Module): + def __init__(self,layers): + super(SequentialSA,self).__init__() + self.layers = nn.ModuleList(layers) + + def forward(self,x,x_mask): + """ + x: [batch, T, H] + x_mask: [batch, T] + """ + pad_mask = 1. - x_mask + for layer in self.layers: + if isinstance(layer, EncSALayer): + x = x.permute(1,0,2) + x = layer(x,pad_mask) + x = x.permute(1,0,2) + elif isinstance(layer, nn.Linear): + x = layer(x) * x_mask.unsqueeze(2) + elif isinstance(layer, nn.AvgPool1d): + x = x.permute(0,2,1) + x = layer(x) + x = x.permute(0,2,1) + elif isinstance(layer, nn.PReLU): + bs, t, hid = x.shape + x = x.reshape([bs*t,hid]) + x = layer(x) + x = x.reshape([bs, t, hid]) + else: # Relu + x = layer(x) + + return x + +class TransformerStyleFusionModel(nn.Module): + def __init__(self, num_heads=4, dropout = 0.1, out_dim = 64): + super(TransformerStyleFusionModel, self).__init__() + self.audio_layer = SequentialSA([ + nn.Linear(29, 48), + nn.ReLU(48), + nn.Linear(48, 128), + ]) + + self.energy_layer = SequentialSA([ + nn.Linear(1, 16), + nn.ReLU(16), + nn.Linear(16, 64), + ]) + + self.backbone1 = FFTBlocks(hidden_size=192,num_layers=3) + + self.sty_encoder = nn.Sequential(*[ + nn.Linear(135, 64), + nn.ReLU(), + nn.Linear(64, 128) + ]) + + self.backbone2 = FFTBlocks(hidden_size=320,num_layers=3) + + self.out_layer = SequentialSA([ + nn.AvgPool1d(kernel_size=2,stride=2,padding=0), #[b,hid,t_audio]=>[b,hid,t_audio//2] + nn.Linear(320,out_dim), + nn.PReLU(out_dim), + nn.Linear(out_dim,out_dim), + ]) + + self.dropout = nn.Dropout(p = dropout) + + def forward(self, audio, energy, style, x_mask, y_mask): + pad_mask = 1. - x_mask + audio_feat = self.audio_layer(audio, x_mask) + energy_feat = self.energy_layer(energy, x_mask) + feat = torch.cat((audio_feat, energy_feat), dim=-1) # [batch, T, H=48+16] + feat = self.backbone1(feat, pad_mask) + feat = self.dropout(feat) + + sty_feat = self.sty_encoder(style) # [batch,135]=>[batch, H=64] + sty_feat = sty_feat.unsqueeze(1).repeat(1, feat.shape[1], 1) # [batch, T, H=64] + + feat = torch.cat([feat, sty_feat], dim=-1) # [batch, T, H=64+64] + feat = self.backbone2(feat, pad_mask) # [batch, T, H=128] + out = self.out_layer(feat, y_mask) # [batch, T//2, H=out_dim] + + return out + + +if __name__ == '__main__': + model = TransformerStyleFusionModel() + audio = torch.rand(4,200,29) # [B,T,H] + energy = torch.rand(4,200,1) # [B,T,H] + style = torch.ones(4,135) # [B,T] + x_mask = torch.ones(4,200) # [B,T] + x_mask[3,10:] = 0 + ret = model(audio,energy,style, x_mask) + print(" ") \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/audio2motion/utils.py b/Geneface_main/GeneFace/modules/audio2motion/utils.py new file mode 100644 index 00000000..7eb56ec5 --- /dev/null +++ b/Geneface_main/GeneFace/modules/audio2motion/utils.py @@ -0,0 +1,29 @@ +import torch + + +def squeeze(x, x_mask=None, n_sqz=2): + b, c, t = x.size() + + t = (t // n_sqz) * n_sqz + x = x[:, :, :t] + x_sqz = x.view(b, c, t // n_sqz, n_sqz) + x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz) + + if x_mask is not None: + x_mask = x_mask[:, :, n_sqz - 1::n_sqz] + else: + x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype) + return x_sqz * x_mask, x_mask + + +def unsqueeze(x, x_mask=None, n_sqz=2): + b, c, t = x.size() + + x_unsqz = x.view(b, n_sqz, c // n_sqz, t) + x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz) + + if x_mask is not None: + x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz) + else: + x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype) + return x_unsqz * x_mask, x_mask diff --git a/Geneface_main/GeneFace/modules/audio2motion/vae.py b/Geneface_main/GeneFace/modules/audio2motion/vae.py new file mode 100644 index 00000000..830f627a --- /dev/null +++ b/Geneface_main/GeneFace/modules/audio2motion/vae.py @@ -0,0 +1,432 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F +import torch.distributions as dist +import numpy as np + +from modules.audio2motion.flow_base import Glow, WN, ResidualCouplingBlock +from modules.audio2motion.transformer_base import Embedding + +from utils.commons.pitch_utils import f0_to_coarse + + +class LambdaLayer(nn.Module): + def __init__(self, lambd): + super(LambdaLayer, self).__init__() + self.lambd = lambd + + def forward(self, x): + return self.lambd(x) + + +def make_positions(tensor, padding_idx): + """Replace non-padding symbols with their position numbers. + + Position numbers begin at padding_idx+1. Padding symbols are ignored. + """ + # The series of casts and type-conversions here are carefully + # balanced to both work with ONNX export and XLA. In particular XLA + # prefers ints, cumsum defaults to output longs, and ONNX doesn't know + # how to handle the dtype kwarg in cumsum. + mask = tensor.ne(padding_idx).int() + return ( + torch.cumsum(mask, dim=1).type_as(mask) * mask + ).long() + padding_idx + +class SinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length. + + Padding symbols are ignored. + """ + + def __init__(self, embedding_dim, padding_idx, init_size=1024): + super().__init__() + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.weights = SinusoidalPositionalEmbedding.get_embedding( + init_size, + embedding_dim, + padding_idx, + ) + self.register_buffer('_float_tensor', torch.FloatTensor(1)) + + @staticmethod + def get_embedding(num_embeddings, embedding_dim, padding_idx=None): + """Build sinusoidal embeddings. + + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) + emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + return emb + + def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs): + """Input is expected to be of size [bsz x seqlen].""" + bsz, seq_len = input.shape[:2] + max_pos = self.padding_idx + 1 + seq_len + if self.weights is None or max_pos > self.weights.size(0): + # recompute/expand embeddings if needed + self.weights = SinusoidalPositionalEmbedding.get_embedding( + max_pos, + self.embedding_dim, + self.padding_idx, + ) + self.weights = self.weights.to(self._float_tensor) + + if incremental_state is not None: + # positions is the same for every token when decoding a single step + pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len + return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) + + positions = make_positions(input, self.padding_idx) if positions is None else positions + return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() + + def max_positions(self): + """Maximum number of supported positions.""" + return int(1e4) # an arbitrary large number + +class FVAEEncoder(nn.Module): + def __init__(self, in_channels, hidden_channels, latent_channels, kernel_size, + n_layers, gin_channels=0, p_dropout=0, strides=[4]): + super().__init__() + self.strides = strides + self.hidden_size = hidden_channels + self.pre_net = nn.Sequential(*[ + nn.Conv1d(in_channels, hidden_channels, kernel_size=s * 2, stride=s, padding=s // 2) + if i == 0 else + nn.Conv1d(hidden_channels, hidden_channels, kernel_size=s * 2, stride=s, padding=s // 2) + for i, s in enumerate(strides) + ]) + self.wn = WN(hidden_channels, kernel_size, 1, n_layers, gin_channels, p_dropout) + self.out_proj = nn.Conv1d(hidden_channels, latent_channels * 2, 1) + + self.latent_channels = latent_channels + + def forward(self, x, x_mask, g): + x = self.pre_net(x) + x_mask = x_mask[:, :, ::np.prod(self.strides)][:, :, :x.shape[-1]] + x = x * x_mask + x = self.wn(x, x_mask, g) * x_mask + x = self.out_proj(x) + m, logs = torch.split(x, self.latent_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) + return z, m, logs, x_mask + + +class FVAEDecoder(nn.Module): + def __init__(self, latent_channels, hidden_channels, out_channels, kernel_size, + n_layers, gin_channels=0, p_dropout=0, + strides=[4]): + super().__init__() + self.strides = strides + self.hidden_size = hidden_channels + self.pre_net = nn.Sequential(*[ + nn.ConvTranspose1d(latent_channels, hidden_channels, kernel_size=s, stride=s) + if i == 0 else + nn.ConvTranspose1d(hidden_channels, hidden_channels, kernel_size=s, stride=s) + for i, s in enumerate(strides) + ]) + self.wn = WN(hidden_channels, kernel_size, 1, n_layers, gin_channels, p_dropout) + self.out_proj = nn.Conv1d(hidden_channels, out_channels, 1) + + def forward(self, x, x_mask, g): + x = self.pre_net(x) + x = x * x_mask + x = self.wn(x, x_mask, g) * x_mask + x = self.out_proj(x) + return x + +class FVAE(nn.Module): + def __init__(self, + in_out_channels=64, hidden_channels=256, latent_size=16, + kernel_size=3, enc_n_layers=5, dec_n_layers=5, gin_channels=80, strides=[4,], + use_prior_glow=True, glow_hidden=256, glow_kernel_size=3, glow_n_blocks=5, + sqz_prior=False, use_pos_emb=False): + super(FVAE, self).__init__() + self.in_out_channels = in_out_channels + self.strides = strides + self.hidden_size = hidden_channels + self.latent_size = latent_size + self.use_prior_glow = use_prior_glow + self.sqz_prior = sqz_prior + self.g_pre_net = nn.Sequential(*[ + nn.Conv1d(gin_channels, gin_channels, kernel_size=s * 2, stride=s, padding=s // 2) + for i, s in enumerate(strides) + ]) + self.encoder = FVAEEncoder(in_out_channels, hidden_channels, latent_size, kernel_size, + enc_n_layers, gin_channels, strides=strides) + if use_prior_glow: + self.prior_flow = ResidualCouplingBlock( + latent_size, glow_hidden, glow_kernel_size, 1, glow_n_blocks, 4, gin_channels=gin_channels) + self.use_pos_embed = use_pos_emb + if sqz_prior: + self.query_proj = nn.Linear(latent_size, latent_size) + self.key_proj = nn.Linear(latent_size, latent_size) + self.value_proj = nn.Linear(latent_size, hidden_channels) + if self.in_out_channels in [7, 64]: + self.decoder = FVAEDecoder(hidden_channels, hidden_channels, in_out_channels, kernel_size, + dec_n_layers, gin_channels, strides=strides) + elif self.in_out_channels == 71: + self.exp_decoder = FVAEDecoder(hidden_channels, hidden_channels, 64, kernel_size, + dec_n_layers, gin_channels, strides=strides) + self.pose_decoder = FVAEDecoder(hidden_channels, hidden_channels, 7, kernel_size, + dec_n_layers, gin_channels, strides=strides) + if self.use_pos_embed: + self.embed_positions = SinusoidalPositionalEmbedding(self.latent_size, 0,init_size=2000+1,) + else: + self.decoder = FVAEDecoder(latent_size, hidden_channels, in_out_channels, kernel_size, + dec_n_layers, gin_channels, strides=strides) + + self.prior_dist = dist.Normal(0, 1) + + def forward(self, x=None, x_mask=None, g=None, infer=False, temperature=1. , **kwargs): + """ + + :param x: [B, T, C_in_out] + :param x_mask: [B, T] + :param g: [B, T, C_g] + :return: + """ + x_mask = x_mask[:, None, :] # [B, 1, T] + g = g.transpose(1,2) # [B, C_g, T] + g_for_sqz = g + + g_sqz = self.g_pre_net(g_for_sqz) + + if not infer: + x = x.transpose(1,2) # [B, C, T] + z_q, m_q, logs_q, x_mask_sqz = self.encoder(x, x_mask, g_sqz) + if self.sqz_prior: + z = z_q + if self.use_pos_embed: + position = self.embed_positions(z.transpose(1,2).abs().sum(-1)).transpose(1,2) + z = z + position + q = self.query_proj(z.mean(dim=-1,keepdim=True).transpose(1,2)) # [B, 1, C=16] + k = self.key_proj(z.transpose(1,2)) # [B, T, C=16] + v = self.value_proj(z.transpose(1,2)) # [B, T, C=256] + attn = torch.bmm(q,k.transpose(1,2)) # [B, 1, T] + attn = F.softmax(attn, dim=-1) + out = torch.bmm(attn, v) # [B, 1, C=256] + style_encoding = out.repeat([1,z_q.shape[-1],1]).transpose(1,2) # [B, C=256, T] + if self.in_out_channels == 71: + x_recon = torch.cat([self.exp_decoder(style_encoding, x_mask, g), self.pose_decoder(style_encoding, x_mask, g)], dim=1) + else: + x_recon = self.decoder(style_encoding, x_mask, g) + else: + if self.in_out_channels == 71: + x_recon = torch.cat([self.exp_decoder(z_q, x_mask, g), self.pose_decoder(z_q, x_mask, g)], dim=1) + else: + x_recon = self.decoder(z_q, x_mask, g) + q_dist = dist.Normal(m_q, logs_q.exp()) + if self.use_prior_glow: + logqx = q_dist.log_prob(z_q) + z_p = self.prior_flow(z_q, x_mask_sqz, g_sqz) + logpx = self.prior_dist.log_prob(z_p) + loss_kl = ((logqx - logpx) * x_mask_sqz).sum() / x_mask_sqz.sum() / logqx.shape[1] + else: + loss_kl = torch.distributions.kl_divergence(q_dist, self.prior_dist) + loss_kl = (loss_kl * x_mask_sqz).sum() / x_mask_sqz.sum() / z_q.shape[1] + z_p = z_q + return x_recon.transpose(1,2), loss_kl, z_p.transpose(1,2), m_q.transpose(1,2), logs_q.transpose(1,2) + else: + latent_shape = [g_sqz.shape[0], self.latent_size, g_sqz.shape[2]] + z_p = self.prior_dist.sample(latent_shape).to(g.device) * temperature # [B, latent_size, T_sqz] + if self.use_prior_glow: + z_p = self.prior_flow(z_p, 1, g_sqz, reverse=True) + if self.sqz_prior: + z = z_p + if self.use_pos_embed: + position = self.embed_positions(z.abs().sum(-1)) + z += position + q = self.query_proj(z.mean(dim=-1,keepdim=True).transpose(1,2)) # [B, 1, C=16] + k = self.key_proj(z.transpose(1,2)) # [B, T, C=16] + v = self.value_proj(z.transpose(1,2)) # [B, T, C=256] + attn = torch.bmm(q,k.transpose(1,2)) # [B, 1, T] + attn = F.softmax(attn, dim=-1) + out = torch.bmm(attn, v) # [B, 1, C=256] + style_encoding = out.repeat([1,z_p.shape[-1],1]).transpose(1,2) # [B, C=256, T] + x_recon = self.decoder(style_encoding, 1, g) + if self.in_out_channels == 71: + x_recon = torch.cat([self.exp_decoder(style_encoding, 1, g), self.pose_decoder(style_encoding, 1, g)], dim=1) + else: + x_recon = self.decoder(style_encoding, 1, g) + else: + if self.in_out_channels == 71: + x_recon = torch.cat([self.exp_decoder(z_p, 1, g), self.pose_decoder(z_p, 1, g)], dim=1) + else: + x_recon = self.decoder(z_p, 1, g) + return x_recon.transpose(1,2), z_p.transpose(1,2) + + +class VAEModel(nn.Module): + def __init__(self, in_out_dim=64, sqz_prior=False, cond_drop=False, use_prior_flow=True): + super().__init__() + mel_feat_dim = 64 + mel_in_dim = 1024 # hubert + + cond_dim = mel_feat_dim + self.mel_encoder = nn.Sequential(*[ + nn.Conv1d(mel_in_dim, 64, 3, 1, 1, bias=False), + nn.BatchNorm1d(64), + nn.GELU(), + nn.Conv1d(64, mel_feat_dim, 3, 1, 1, bias=False) + ]) + self.cond_drop = cond_drop + if self.cond_drop: + self.dropout = nn.Dropout(0.5) + + self.in_dim, self.out_dim = in_out_dim, in_out_dim + self.sqz_prior = sqz_prior + self.use_prior_flow = use_prior_flow + self.vae = FVAE(in_out_channels=in_out_dim, hidden_channels=256, latent_size=16, kernel_size=5, + enc_n_layers=8, dec_n_layers=4, gin_channels=cond_dim, strides=[4,], + use_prior_glow=self.use_prior_flow, glow_hidden=64, glow_kernel_size=3, glow_n_blocks=4,sqz_prior=sqz_prior) + self.downsampler = LambdaLayer(lambda x: F.interpolate(x.transpose(1,2), scale_factor=0.5, mode='nearest').transpose(1,2)) + + def num_params(self, model, print_out=True, model_name="model"): + parameters = filter(lambda p: p.requires_grad, model.parameters()) + parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 + if print_out: + print(f'| {model_name} Trainable Parameters: %.3fM' % parameters) + return parameters + + @property + def device(self): + return self.vae.parameters().__next__().device + + def forward(self, batch, ret, train=True, return_latent=False, temperature=1.): + infer = not train + mask = batch['y_mask'].to(self.device) + mel = batch['hubert'].to(self.device) + mel = self.downsampler(mel) + cond_feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2) + + if self.cond_drop: + cond_feat = self.dropout(cond_feat) + + if not infer: + exp = batch['y'].to(self.device) + x = exp + x_recon, loss_kl, z_p, m_q, logs_q = self.vae(x=x, x_mask=mask, g=cond_feat, infer=False) + x_recon = x_recon * mask.unsqueeze(-1) + ret['pred'] = x_recon + ret['mask'] = mask + ret['loss_kl'] = loss_kl + if return_latent: + ret['m_q'] = m_q + ret['z_p'] = z_p + return x_recon, loss_kl, m_q, logs_q + else: + x_recon, z_p = self.vae(x=None, x_mask=mask, g=cond_feat, infer=True, temperature=temperature) + x_recon = x_recon * mask.unsqueeze(-1) + ret['pred'] = x_recon + ret['mask'] = mask + + return x_recon + + +class PitchContourVAEModel(nn.Module): + def __init__(self, in_out_dim=64, sqz_prior=False, cond_drop=False, use_prior_flow=True): + super().__init__() + mel_feat_dim = 64 + mel_in_dim = 1024 # hubert + + cond_dim = mel_feat_dim + self.mel_encoder = nn.Sequential(*[ + nn.Conv1d(mel_in_dim, 64, 3, 1, 1, bias=False), + nn.BatchNorm1d(64), + nn.GELU(), + nn.Conv1d(64, mel_feat_dim, 3, 1, 1, bias=False) + ]) + + self.pitch_embed = Embedding(300, mel_feat_dim, None) + self.pitch_encoder = nn.Sequential(*[ + nn.Conv1d(mel_feat_dim, 64, 3, 1, 1, bias=False), + nn.BatchNorm1d(64), + nn.GELU(), + nn.Conv1d(64, 32, 3, 1, 1, bias=False) + ]) + cond_dim += 32 + + self.cond_drop = cond_drop + if self.cond_drop: + self.dropout = nn.Dropout(0.5) + + self.in_dim, self.out_dim = in_out_dim, in_out_dim + self.sqz_prior = sqz_prior + self.use_prior_flow = use_prior_flow + self.vae = FVAE(in_out_channels=in_out_dim, hidden_channels=256, latent_size=16, kernel_size=5, + enc_n_layers=8, dec_n_layers=4, gin_channels=cond_dim, strides=[4,], + use_prior_glow=self.use_prior_flow, glow_hidden=64, glow_kernel_size=3, glow_n_blocks=4,sqz_prior=sqz_prior) + self.downsampler = LambdaLayer(lambda x: F.interpolate(x.transpose(1,2), scale_factor=0.5, mode='nearest').transpose(1,2)) + + def num_params(self, model, print_out=True, model_name="model"): + parameters = filter(lambda p: p.requires_grad, model.parameters()) + parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 + if print_out: + print(f'| {model_name} Trainable Parameters: %.3fM' % parameters) + return parameters + + @property + def device(self): + return self.vae.parameters().__next__().device + + def forward(self, batch, ret, train=True, return_latent=False, temperature=1.): + infer = not train + mask = batch['y_mask'].to(self.device) + mel = batch['hubert'].to(self.device) + f0 = batch['f0'].to(self.device) # [b,t] + mel = self.downsampler(mel) + f0 = self.downsampler(f0.unsqueeze(-1)).squeeze(-1) + f0_coarse = f0_to_coarse(f0) + pitch_emb = self.pitch_embed(f0_coarse) + cond_feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2) + pitch_feat = self.pitch_encoder(pitch_emb.transpose(1,2)).transpose(1,2) + cond_feat = torch.cat([cond_feat, pitch_feat], dim=-1) + + if self.cond_drop: + cond_feat = self.dropout(cond_feat) + + if not infer: + exp = batch['y'].to(self.device) + x = exp + x_recon, loss_kl, z_p, m_q, logs_q = self.vae(x=x, x_mask=mask, g=cond_feat, infer=False) + x_recon = x_recon * mask.unsqueeze(-1) + ret['pred'] = x_recon + ret['mask'] = mask + ret['loss_kl'] = loss_kl + if return_latent: + ret['m_q'] = m_q + ret['z_p'] = z_p + return x_recon, loss_kl, m_q, logs_q + else: + x_recon, z_p = self.vae(x=None, x_mask=mask, g=cond_feat, infer=True, temperature=temperature) + x_recon = x_recon * mask.unsqueeze(-1) + ret['pred'] = x_recon + ret['mask'] = mask + + return x_recon + + +if __name__ == '__main__': + model = FVAE(in_out_channels=64, hidden_channels=128, latent_size=32,kernel_size=3, enc_n_layers=6, dec_n_layers=2, + gin_channels=80, strides=[4], use_prior_glow=False, glow_hidden=128, glow_kernel_size=3, glow_n_blocks=3) + x = torch.rand([8, 64, 1000]) + x_mask = torch.ones([8,1,1000]) + g = torch.rand([8, 80, 1000]) + train_out = model(x,x_mask,g,infer=False) + x_recon, loss_kl, z_p, m_q, logs_q = train_out + print(" ") + infer_out = model(x,x_mask,g,infer=True) + x_recon, z_p = infer_out + print(" ") diff --git a/Geneface_main/GeneFace/modules/audio2motion/vqvae.py b/Geneface_main/GeneFace/modules/audio2motion/vqvae.py new file mode 100644 index 00000000..310ffc7b --- /dev/null +++ b/Geneface_main/GeneFace/modules/audio2motion/vqvae.py @@ -0,0 +1,200 @@ +import scipy +from scipy import linalg +from torch.nn import functional as F +import torch +from torch import nn +import numpy as np +from modules.audio2motion.transformer_models import FFTBlocks +import modules.audio2motion.utils as utils +from modules.audio2motion.flow_base import Glow, WN, ResidualCouplingBlock +import torch.distributions as dist +from modules.audio2motion.cnn_models import LambdaLayer, LayerNorm + +from vector_quantize_pytorch import VectorQuantize + + +class FVAEEncoder(nn.Module): + def __init__(self, in_channels, hidden_channels, latent_channels, kernel_size, + n_layers, gin_channels=0, p_dropout=0, strides=[4]): + super().__init__() + self.strides = strides + self.hidden_size = hidden_channels + self.pre_net = nn.Sequential(*[ + nn.Conv1d(in_channels, hidden_channels, kernel_size=s * 2, stride=s, padding=s // 2) + if i == 0 else + nn.Conv1d(hidden_channels, hidden_channels, kernel_size=s * 2, stride=s, padding=s // 2) + for i, s in enumerate(strides) + ]) + self.wn = WN(hidden_channels, kernel_size, 1, n_layers, gin_channels, p_dropout) + self.out_proj = nn.Conv1d(hidden_channels, latent_channels * 2, 1) + self.latent_channels = latent_channels + + def forward(self, x, x_mask, g): + x = self.pre_net(x) + x_mask = x_mask[:, :, ::np.prod(self.strides)][:, :, :x.shape[-1]] + x = x * x_mask + x = self.wn(x, x_mask, g) * x_mask + x = self.out_proj(x) + m, logs = torch.split(x, self.latent_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) + return z, m, logs, x_mask + + +class FVAEDecoder(nn.Module): + def __init__(self, latent_channels, hidden_channels, out_channels, kernel_size, + n_layers, gin_channels=0, p_dropout=0, + strides=[4]): + super().__init__() + self.strides = strides + self.hidden_size = hidden_channels + self.pre_net = nn.Sequential(*[ + nn.ConvTranspose1d(latent_channels, hidden_channels, kernel_size=s, stride=s) + if i == 0 else + nn.ConvTranspose1d(hidden_channels, hidden_channels, kernel_size=s, stride=s) + for i, s in enumerate(strides) + ]) + self.wn = WN(hidden_channels, kernel_size, 1, n_layers, gin_channels, p_dropout) + self.out_proj = nn.Conv1d(hidden_channels, out_channels, 1) + + def forward(self, x, x_mask, g): + x = self.pre_net(x) + x = x * x_mask + x = self.wn(x, x_mask, g) * x_mask + x = self.out_proj(x) + return x + + +class VQVAE(nn.Module): + def __init__(self, + in_out_channels=64, hidden_channels=256, latent_size=16, + kernel_size=3, enc_n_layers=5, dec_n_layers=5, gin_channels=80, strides=[4,], + sqz_prior=False): + super().__init__() + self.in_out_channels = in_out_channels + self.strides = strides + self.hidden_size = hidden_channels + self.latent_size = latent_size + self.g_pre_net = nn.Sequential(*[ + nn.Conv1d(gin_channels, gin_channels, kernel_size=s * 2, stride=s, padding=s // 2) + for i, s in enumerate(strides) + ]) + self.encoder = FVAEEncoder(in_out_channels, hidden_channels, hidden_channels, kernel_size, + enc_n_layers, gin_channels, strides=strides) + # if use_prior_glow: + # self.prior_flow = ResidualCouplingBlock( + # latent_size, glow_hidden, glow_kernel_size, 1, glow_n_blocks, 4, gin_channels=gin_channels) + self.vq = VectorQuantize(dim=hidden_channels, codebook_size=256, codebook_dim=16) + + self.decoder = FVAEDecoder(hidden_channels, hidden_channels, in_out_channels, kernel_size, + dec_n_layers, gin_channels, strides=strides) + self.prior_dist = dist.Normal(0, 1) + self.sqz_prior = sqz_prior + + def forward(self, x=None, x_mask=None, g=None, infer=False, **kwargs): + """ + + :param x: [B, T, C_in_out] + :param x_mask: [B, T] + :param g: [B, T, C_g] + :return: + """ + x_mask = x_mask[:, None, :] # [B, 1, T] + g = g.transpose(1,2) # [B, C_g, T] + g_for_sqz = g + + g_sqz = self.g_pre_net(g_for_sqz) + + if not infer: + x = x.transpose(1,2) # [B, C, T] + z_q, m_q, logs_q, x_mask_sqz = self.encoder(x, x_mask, g_sqz) + if self.sqz_prior: + z_q = F.interpolate(z_q, scale_factor=1/8) + z_p, idx, commit_loss = self.vq(z_q.transpose(1,2)) + if self.sqz_prior: + z_p = F.interpolate(z_p.transpose(1,2),scale_factor=8).transpose(1,2) + + x_recon = self.decoder(z_p.transpose(1,2), x_mask, g) + return x_recon.transpose(1,2), commit_loss, z_p.transpose(1,2), m_q.transpose(1,2), logs_q.transpose(1,2) + else: + bs, t = g_sqz.shape[0], g_sqz.shape[2] + if self.sqz_prior: + t = t // 8 + latent_shape = [int(bs * t)] + latent_idx = torch.randint(0,256,latent_shape).to(self.vq.codebook.device) + # latent_idx = torch.ones_like(latent_idx, dtype=torch.long) + # z_p = torch.gather(self.vq.codebook, 0, latent_idx)# self.vq.codebook[latent_idx] + z_p = self.vq.codebook[latent_idx] + z_p = z_p.reshape([bs, t, -1]) + z_p = self.vq.project_out(z_p) + if self.sqz_prior: + z_p = F.interpolate(z_p.transpose(1,2),scale_factor=8).transpose(1,2) + + x_recon = self.decoder(z_p.transpose(1,2), 1, g) + return x_recon.transpose(1,2), z_p.transpose(1,2) + + +class VQVAEModel(nn.Module): + def __init__(self, in_out_dim=71, sqz_prior=False, enc_no_cond=False): + super().__init__() + self.mel_encoder = nn.Sequential(*[ + nn.Conv1d(80, 64, 3, 1, 1, bias=False), + nn.BatchNorm1d(64), + nn.GELU(), + nn.Conv1d(64, 64, 3, 1, 1, bias=False) + ]) + self.in_dim, self.out_dim = in_out_dim, in_out_dim + self.sqz_prior = sqz_prior + self.enc_no_cond = enc_no_cond + self.vae = VQVAE(in_out_channels=in_out_dim, hidden_channels=256, latent_size=16, kernel_size=5, + enc_n_layers=8, dec_n_layers=4, gin_channels=64, strides=[4,], sqz_prior=sqz_prior) + self.downsampler = LambdaLayer(lambda x: F.interpolate(x.transpose(1,2), scale_factor=0.5, mode='nearest').transpose(1,2)) + + @property + def device(self): + return self.vae.parameters().__next__().device + + def forward(self, batch, ret, log_dict=None, train=True): + infer = not train + mask = batch['y_mask'].to(self.device) + mel = batch['mel'].to(self.device) + mel = self.downsampler(mel) + + mel_feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2) + if not infer: + exp = batch['exp'].to(self.device) + pose = batch['pose'].to(self.device) + if self.in_dim == 71: + x = torch.cat([exp, pose], dim=-1) # [B, T, C=64 + 7] + elif self.in_dim == 64: + x = exp + elif self.in_dim == 7: + x = pose + if self.enc_no_cond: + x_recon, loss_commit, z_p, m_q, logs_q = self.vae(x=x, x_mask=mask, g=torch.zeros_like(mel_feat), infer=False) + else: + x_recon, loss_commit, z_p, m_q, logs_q = self.vae(x=x, x_mask=mask, g=mel_feat, infer=False) + loss_commit = loss_commit.reshape([]) + ret['pred'] = x_recon + ret['mask'] = mask + ret['loss_commit'] = loss_commit + return x_recon, loss_commit, m_q, logs_q + else: + x_recon, z_p = self.vae(x=None, x_mask=mask, g=mel_feat, infer=True) + return x_recon + + # def __get_feat(self, exp, pose): + # diff_exp = exp[:-1, :] - exp[1:, :] + # exp_std = (np.std(exp, axis = 0) - self.exp_std_mean) / self.exp_std_std + # diff_exp_std = (np.std(diff_exp, axis = 0) - self.exp_diff_std_mean) / self.exp_diff_std_std + + # diff_pose = pose[:-1, :] - pose[1:, :] + # diff_pose_std = (np.std(diff_pose, axis = 0) - self.pose_diff_std_mean) / self.pose_diff_std_std + + # return np.concatenate((exp_std, diff_exp_std, diff_pose_std)) + + def num_params(self, model, print_out=True, model_name="model"): + parameters = filter(lambda p: p.requires_grad, model.parameters()) + parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 + if print_out: + print(f'| {model_name} Trainable Parameters: %.3fM' % parameters) + return parameters diff --git a/Geneface_main/GeneFace/modules/audio2pose/gmm_utils.py b/Geneface_main/GeneFace/modules/audio2pose/gmm_utils.py new file mode 100644 index 00000000..ef935df2 --- /dev/null +++ b/Geneface_main/GeneFace/modules/audio2pose/gmm_utils.py @@ -0,0 +1,103 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + + +class GMMLogLoss(nn.Module): + ''' compute the GMM loss between model output and the groundtruth data. + Args: + ncenter: numbers of gaussian distribution + ndim: dimension of each gaussian distribution + sigma_min: current we do not use it. + ''' + def __init__(self, ncenter, ndim, sigma_min=0.03): + super(GMMLogLoss,self).__init__() + self.ncenter = ncenter + self.ndim = ndim + self.sigma_min = sigma_min + + def forward(self, output, target): + ''' + Args: + output: [b, T, ncenter + ncenter * ndim * 2]: + [:, :, : ncenter] shows each gaussian probability + [:, :, ncenter : ncenter + ndim * ncenter] shows the average values of each dimension of each gaussian + [: ,:, ncenter + ndim * ncenter : ncenter + ndim * 2 * ncenter] show the negative log sigma of each dimension of each gaussian + target: [b, T, ndim], the ground truth target landmark data is shown here + To maximize the log-likelihood equals to minimize the negative log-likelihood. + NOTE: It is unstable to directly compute the log results of sigma, e.g. ln(-0.1) as we need to clip the sigma results + into positive. Hence here we predict the negative log sigma results to avoid numerical instablility, which mean: + `` sigma = 1/exp(predict), predict = -ln(sigma) `` + Also, it will be just the 'B' term below! + Currently we only implement single gaussian distribution, hence the first values of pred are meaningless. + For single gaussian distribution: + L(mu, sigma) = -n/2 * ln(2pi * sigma^2) - 1 / (2 x sigma^2) * sum^n (x_i - mu)^2 (n for prediction times, n=1 for one frame, x_i for gt) + = -1/2 * ln(2pi) - 1/2 * ln(sigma^2) - 1/(2 x sigma^2) * (x - mu)^2 + == min -L(mu, sgima) = 0.5 x ln(2pi) + 0.5 x ln(sigma^2) + 1/(2 x sigma^2) * (x - mu)^2 + = 0.5 x ln_2PI + ln(sigma) + 0.5 x (MU_DIFF/sigma)^2 + = A - B + C + In batch and Time sample, b and T are summed and averaged. + ''' + b, T, _ = target.shape + # read prediction paras + mus = output[:, :, self.ncenter : (self.ncenter + self.ncenter * self.ndim)].view(b, T, self.ncenter, self.ndim) # [b, T, ncenter, ndim] + + # apply min sigma + neg_log_sigmas_out = output[:, :, (self.ncenter + self.ncenter * self.ndim):].view(b, T, self.ncenter, self.ndim) # [b, T, ncenter, ndim] + inv_sigmas_min = torch.ones(neg_log_sigmas_out.size()).cuda() * (1. / self.sigma_min) + inv_sigmas_min_log = torch.log(inv_sigmas_min) + neg_log_sigmas = torch.min(neg_log_sigmas_out, inv_sigmas_min_log) + + inv_sigmas = torch.exp(neg_log_sigmas) + # replicate the target of ncenter to minus mu + target_rep = target.unsqueeze(2).expand(b, T, self.ncenter, self.ndim) # [b, T, ncenter, ndim] + + MU_DIFF = target_rep - mus # [b, T, ncenter, ndim] + # sigma process + # A = 0.5 * math.log(2 * math.pi) # 0.9189385332046727 + # B = neg_log_sigmas # [b, T, ncenter, ndim] + # C = 0.5 * (MU_DIFF * inv_sigmas)**2 # [b, T, ncenter, ndim] + # negative_loglikelihood = A - B + C # [b, T, ncenter, ndim] + + # return negative_loglikelihood.mean() + return (MU_DIFF**2).mean() + +def Sample_GMM(gmm_params, ncenter, ndim, weight_smooth=0.0, sigma_scale=0.0): + '''Sample values from a given a GMM distribution. + Args: + gmm_params: [b, target_length, (2 * ndim + 1) * ncenter], including the + distribution weights, average and sigma + ncenter: numbers of gaussian distribution + ndim: dimension of each gaussian distribution + weight_smooth: float, smooth the gaussian distribution weights + sigma_scale: float, adjust the gaussian scale, larger for sharper prediction, + 0 for zero sigma which always return average values + Returns: + current_sample: [b,t,c=ndim] + ''' + b, T, _ = gmm_params.shape + gmm_params_cpu = gmm_params.cpu().view(-1, (2 * ndim + 1) * ncenter) + # compute each distrubution probability + prob = F.softmax(gmm_params_cpu[:, : ncenter] * (1 + weight_smooth), dim=1) + # select the gaussian distribution according to their weights + selected_idx = torch.multinomial(prob, num_samples=1, replacement=True) + + mu = gmm_params_cpu[:, ncenter : ncenter + ncenter * ndim] + # please note that we use -logsigma as output, hence here we need to take the negative + sigma = torch.exp(-gmm_params_cpu[:, ncenter + ncenter * ndim:]) * sigma_scale + # print('sigma average:', sigma.mean()) + + selected_sigma = torch.empty(b*T, ndim).float() + selected_mu = torch.empty(b*T, ndim).float() + current_sample = torch.randn(b*T, ndim).float() + + for i in range(b*T): + idx = selected_idx[i, 0] + selected_sigma[i, :] = sigma[i, idx * ndim:(idx + 1) * ndim] + selected_mu[i, :] = mu[i, idx * ndim:(idx + 1) * ndim] + + # sample with sel sigma and sel mean + current_sample = current_sample * selected_sigma + selected_mu + + return current_sample.reshape(b, T, -1) diff --git a/Geneface_main/GeneFace/modules/audio2pose/models.py b/Geneface_main/GeneFace/modules/audio2pose/models.py new file mode 100644 index 00000000..1a774ad6 --- /dev/null +++ b/Geneface_main/GeneFace/modules/audio2pose/models.py @@ -0,0 +1,320 @@ +import torch +import numpy as np +import torch.nn as nn +from torchvision import models +from torch.nn import functional as F +from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence +import tqdm +from modules.audio2pose.gmm_utils import Sample_GMM +from utils.commons.tensor_utils import convert_to_tensor + +class Audio2PoseModel(nn.Module): + def __init__(self, recept_field=100): + super(Audio2PoseModel, self).__init__() + self.audio_encoder = nn.Sequential( + # nn.Linear(in_features=1024*2, out_features=256), + nn.Linear(in_features=2*29, out_features=256), + nn.LeakyReLU(0.2), + nn.Linear(256, 256) + ) + self.backbone = WaveNet() + # self.recept_field = 30 + self.recept_field = recept_field + + def forward(self, audio, history_pose_velocity): + """ + audio: a fixed window of audio representation, [b, t=30, c=512] + history_pose_velocity: [b, t=30, c=12] + pred_pose_velocity_params: the GMM params of pose_and_velocity at t+1 steps, [b, c=12*2+1] + """ + audio = self.audio_encoder(audio) + ret = self.backbone(history_pose_velocity, audio) # [b, t, c] + # pred_pose_velocity_params = ret[:, -1, :] # [b, c] + # return pred_pose_velocity_params + return ret + + def autoregressive_infer(self, long_audio, init_pose=None): + """ + long_audio: [T, c=512] + init_pose: euler_trans, [6,], note that trans is subtracted by mean_trans! + """ + n_frames = len(long_audio) + pred_pose_and_velocity_lst = [] + + audio_insert = long_audio[0:1].repeat([self.recept_field-1,1]) + long_audio = torch.cat([audio_insert, long_audio], dim=0) + history_pose_and_velocity = torch.zeros([self.recept_field, 12]).float().to(long_audio.device) + if init_pose is not None: + init_pose = convert_to_tensor(init_pose).float().to(long_audio.device).unsqueeze(0).repeat([self.recept_field, 1]) # [self.recept_field, 6] + history_pose_and_velocity[:,:6] = init_pose + + with torch.no_grad(): + for i in tqdm.tqdm(range(n_frames), desc='generating headpose'): + audio_window = long_audio[i: i+self.recept_field].unsqueeze(0) # [b=1, t=30, c=512] + history_info = history_pose_and_velocity.unsqueeze(0) # [b=1, t=30, c=12] + pred_pose_and_velocity_gmm_params = self.forward(audio_window, history_info)[:,-1,:] # [b=1, c=12*2+1] + pred_pose_and_velocity = Sample_GMM(pred_pose_and_velocity_gmm_params.unsqueeze(1),ncenter=1,ndim=12,sigma_scale=0.0).to(long_audio.device) # [b=1,t=1,c=12] + pred_pose_and_velocity_lst.append(pred_pose_and_velocity.cpu().squeeze()) # [c=12] + history_pose_and_velocity = torch.cat([history_pose_and_velocity[1:,:], pred_pose_and_velocity.squeeze(0)],dim=0) # [29,c=12] + [1, c=12] ==> [30, c=12] + pred_pose_and_velocity = torch.stack(pred_pose_and_velocity_lst) # [T, c=12] + pred_pose = pred_pose_and_velocity[:,:6] + return pred_pose + + +class WaveNet(nn.Module): + ''' + We use WaveNet as the backbone of Audio2Pose model. + + Args: + batch_size: number of batch size + residual_layers: number of layers in each residual blocks + residual_blocks: number of residual blocks + dilation_channels: number of channels for the dilated convolution + residual_channels: number of channels for the residual connections + skip_channels: number of channels for the skip connections + end_channels: number of channels for the end convolution + classes: Number of possible values each sample can have as output + kernel_size: size of dilation convolution kernel + output_length(int): Number of samples that are generated for each input + use_bias: whether bias is used in each layer. + cond(bool): whether condition information are applied. if cond == True: + cond_channels: channel number of condition information + `` loss(str): GMM loss is adopted. `` + ''' + def __init__(self, + # residual_layers = 7, + residual_layers = 3, + residual_blocks = 2, + dilation_channels = 128, + residual_channels = 128, + skip_channels = 256, + kernel_size = 2, + use_bias = True, + cond = True, + input_channels = 12, + ncenter = 1, + ndim = 12, + output_channels = (2*12+1)*1, + cond_channels = 256, + activation = 'leakyrelu'): + super(WaveNet, self).__init__() + + self.layers = residual_layers + self.blocks = residual_blocks + self.dilation_channels = dilation_channels + self.residual_channels = residual_channels + self.skip_channels = skip_channels + self.input_channels = input_channels + self.ncenter = ncenter + self.ndim = ndim + self.output_channels = output_channels + self.kernel_size = kernel_size + # self.output_length = output_length + self.bias = use_bias + self.cond = cond + self.cond_channels = cond_channels + + # build modules + self.dilations = [] + self.dilation_queues = [] + residual_blocks = [] + self.receptive_field = 1 + + # 1x1 convolution to create channels + self.start_conv1 = nn.Conv1d(in_channels=self.input_channels, + out_channels=self.residual_channels, + kernel_size=1, + bias=True) + self.start_conv2 = nn.Conv1d(in_channels=self.residual_channels, + out_channels=self.residual_channels, + kernel_size=1, + bias=True) + if activation == 'relu': + self.activation = nn.ReLU(inplace = True) + elif activation == 'leakyrelu': + self.activation = nn.LeakyReLU(0.2) + self.drop_out2D = nn.Dropout2d(p=0.5) + + # build residual blocks + for b in range(self.blocks): + new_dilation = 1 + additional_scope = kernel_size - 1 + for i in range(self.layers): + # create current residual block + residual_blocks.append(residual_block(dilation = new_dilation, + dilation_channels = self.dilation_channels, + residual_channels = self.residual_channels, + skip_channels = self.skip_channels, + kernel_size = self.kernel_size, + use_bias = self.bias, + cond = self.cond, + cond_channels = self.cond_channels)) + new_dilation *= 2 + + self.receptive_field += additional_scope + additional_scope *= 2 + + self.residual_blocks = nn.ModuleList(residual_blocks) + # end convolutions + + self.end_conv_1 = nn.Conv1d(in_channels = self.skip_channels, + out_channels = self.output_channels, + kernel_size = 1, + bias = True) + self.end_conv_2 = nn.Conv1d(in_channels = self.output_channels, + out_channels = self.output_channels, + kernel_size = 1, + bias = True) + + + def parameter_count(self): + par = list(self.parameters()) + s = sum([np.prod(list(d.size())) for d in par]) + return s + + def forward(self, inp, cond=None): + ''' + Args: + inp: [b, T, ndim] + cond: [b, T, nfeature] + Returns: + res: [b, T, ndim] + ''' + inp = inp.transpose(1, 2) + if cond is not None: + cond = cond.transpose(1, 2) + # dropout + x = self.drop_out2D(inp) + + # preprocess + x = self.activation(self.start_conv1(x)) + x = self.activation(self.start_conv2(x)) + skip = 0 + for i, dilation_block in enumerate(self.residual_blocks): + x, current_skip = self.residual_blocks[i](x, cond) + skip += current_skip + + # postprocess + res = self.end_conv_1(self.activation(skip)) + res = self.end_conv_2(self.activation(res)) + + # cut the output size + # res = res[:, :, -self.output_length:] # [b, ndim, T] + res = res.transpose(1, 2) # [b, T, ndim] + return res + + +class residual_block(nn.Module): + ''' + This is the implementation of a residual block in wavenet model. Every + residual block takes previous block's output as input. The forward pass of + each residual block can be illusatrated as below: + + ######################### Current Residual Block ########################## + # |-----------------------*residual*--------------------| # + # | | # + # | |-- dilated conv -- tanh --| | # + # -> -|-- pad--| * ---- |-- 1x1 -- + --> *input* # + # |-- dilated conv -- sigm --| | # + # 1x1 # + # | # + # ---------------------------------------------> + -------------> *skip* # + ########################################################################### + As shown above, each residual block returns two value: 'input' and 'skip': + 'input' is indeed this block's output and also is the next block's input. + 'skip' is the skip data which will be added finally to compute the prediction. + The input args own the same meaning in the WaveNet class. + + ''' + def __init__(self, + dilation, + dilation_channels = 32, + residual_channels = 32, + skip_channels = 256, + kernel_size = 2, + use_bias = False, + cond = True, + cond_channels = 128): + super(residual_block, self).__init__() + + self.dilation = dilation + self.dilation_channels = dilation_channels + self.residual_channels = residual_channels + self.skip_channels = skip_channels + self.kernel_size = kernel_size + self.bias = use_bias + self.cond = cond + self.cond_channels = cond_channels + # zero padding to the left of the sequence. + self.padding = (int((self.kernel_size - 1) * self.dilation), 0) + + # dilated convolutions + self.filter_conv= nn.Conv1d(in_channels = self.residual_channels, + out_channels = self.dilation_channels, + kernel_size = self.kernel_size, + dilation = self.dilation, + bias = self.bias) + + self.gate_conv = nn.Conv1d(in_channels = self.residual_channels, + out_channels = self.dilation_channels, + kernel_size = self.kernel_size, + dilation = self.dilation, + bias = self.bias) + + # 1x1 convolution for residual connections + self.residual_conv = nn.Conv1d(in_channels = self.dilation_channels, + out_channels = self.residual_channels, + kernel_size = 1, + bias = self.bias) + + # 1x1 convolution for skip connections + self.skip_conv = nn.Conv1d(in_channels = self.dilation_channels, + out_channels = self.skip_channels, + kernel_size = 1, + bias = self.bias) + + # condition conv, no dilation + if self.cond == True: + self.cond_filter_conv = nn.Conv1d(in_channels = self.cond_channels, + out_channels = self.dilation_channels, + kernel_size = 1, + bias = True) + self.cond_gate_conv = nn.Conv1d(in_channels = self.cond_channels, + out_channels = self.dilation_channels, + kernel_size = 1, + bias = True) + + + def forward(self, inp, cond=None): + if self.cond is True and cond is None: + raise RuntimeError("set using condition to true, but no cond tensor inputed") + + x_pad = F.pad(inp, self.padding) + # filter + filt = self.filter_conv(x_pad) + # gate + gate = self.gate_conv(x_pad) + + if self.cond == True and cond is not None: + filter_cond = self.cond_filter_conv(cond) + gate_cond = self.cond_gate_conv(cond) + # add cond results + filt = filt + filter_cond + gate = gate + gate_cond + + # element-wise multiple + filt = torch.tanh(filt) + gate = torch.sigmoid(gate) + x = filt * gate + + # residual and skip + residual = self.residual_conv(x) + inp + skip = self.skip_conv(x) + return residual, skip + + +if __name__ == '__main__': + audio2pose_model = Audio2PoseModel() + audio = torch.rand([128, 512]) + pred_pose = audio2pose_model.autoregressive_infer(audio) + print(pred_pose.shape) \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/nerfs/adnerf/adnerf.py b/Geneface_main/GeneFace/modules/nerfs/adnerf/adnerf.py new file mode 100644 index 00000000..643c7f33 --- /dev/null +++ b/Geneface_main/GeneFace/modules/nerfs/adnerf/adnerf.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modules.nerfs.commons.embedders import FreqEmbedder +from modules.nerfs.adnerf.backbone import NeRFBackbone, AudioNet, AudioAttNet + + +class ADNeRF(nn.Module): + def __init__(self, hparams=None): + super().__init__() + self.hparams = hparams + self.pos_embedder = FreqEmbedder(in_dim=3, multi_res=10, use_log_bands=True, include_input=True) + self.view_embedder = FreqEmbedder(in_dim=3, multi_res=4, use_log_bands=True, include_input=True) + pos_dim = self.pos_embedder.out_dim + view_dim = self.view_embedder.out_dim + self.cond_dim = hparams['cond_dim'] + self.model_coarse = NeRFBackbone(pos_dim=pos_dim, cond_dim=self.cond_dim, view_dim=view_dim, hid_dim=hparams['hidden_size'], num_density_linears=8, num_color_linears=3, skip_layer_indices=[4]) + self.model_fine = NeRFBackbone(pos_dim=pos_dim, cond_dim=self.cond_dim, view_dim=view_dim, hid_dim=hparams['hidden_size'], num_density_linears=8, num_color_linears=3, skip_layer_indices=[4]) + + self.deepspeech_win_size = 16 + self.smo_win_size = 8 + self.aud_net = AudioNet(in_dim=29, out_dim=self.cond_dim, win_size=self.deepspeech_win_size) + self.audatt_net = AudioAttNet(in_out_dim=self.cond_dim, seq_len=self.smo_win_size) + + def forward(self, pos, cond_feat, view, run_model_fine=True, **kwargs): + out = {} + pos_embed = self.pos_embedder(pos) + view_embed = self.view_embedder(view) + if run_model_fine: + rgb_sigma = self.model_fine(pos_embed, cond_feat, view_embed) + else: + rgb_sigma = self.model_coarse(pos_embed, cond_feat, view_embed) + out['rgb_sigma'] = rgb_sigma + return out + + def cal_cond_feat(self, cond, with_att=False): + cond_feat = self.aud_net(cond) + if with_att: + cond_feat = self.audatt_net(cond_feat) + return cond_feat + + + \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/nerfs/adnerf/adnerf_torso.py b/Geneface_main/GeneFace/modules/nerfs/adnerf/adnerf_torso.py new file mode 100644 index 00000000..8052359f --- /dev/null +++ b/Geneface_main/GeneFace/modules/nerfs/adnerf/adnerf_torso.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modules.nerfs.commons.embedders import FreqEmbedder +from modules.nerfs.adnerf.backbone import NeRFBackbone, AudioNet, AudioAttNet + + +class ADNeRFTorso(nn.Module): + def __init__(self, hparams=None): + super().__init__() + self.hparams = hparams + self.pos_embedder = FreqEmbedder(in_dim=3, multi_res=10, use_log_bands=True, include_input=True) + self.view_embedder = FreqEmbedder(in_dim=3, multi_res=4, use_log_bands=True, include_input=True) + self.euler_embedder = FreqEmbedder(in_dim=3, multi_res=6, use_log_bands=True, include_input=True) + self.trans_embedder = FreqEmbedder(in_dim=3, multi_res=6, use_log_bands=True, include_input=True) + + pos_dim = self.pos_embedder.out_dim + view_dim = self.view_embedder.out_dim + nerf_in_cond_dim = hparams['cond_dim'] + self.euler_embedder.out_dim + self.trans_embedder.out_dim + + if hparams.get("use_color", False): + # pixel-level head color condition to prevent head-torso-separation artifacts + color_cond_dim = 16 + self.color_encoder = nn.Sequential(*[ + nn.Linear(3, 16, bias=True), + nn.LeakyReLU(0.02, True), + nn.Linear(16, 32, bias=True), + nn.LeakyReLU(0.02, True), + nn.Linear(32, color_cond_dim, bias=True), + ]) + nerf_in_cond_dim += color_cond_dim + audnet_out_dim = hparams['cond_dim'] + + self.model_coarse = NeRFBackbone(pos_dim=pos_dim, cond_dim=nerf_in_cond_dim, view_dim=view_dim, hid_dim=hparams['hidden_size'], num_density_linears=8, num_color_linears=3, skip_layer_indices=[4]) + self.model_fine = NeRFBackbone(pos_dim=pos_dim, cond_dim=nerf_in_cond_dim, view_dim=view_dim, hid_dim=hparams['hidden_size'], num_density_linears=8, num_color_linears=3, skip_layer_indices=[4]) + + self.deepspeech_win_size = 16 + self.smo_win_size = 8 + self.aud_net = AudioNet(in_dim=29, out_dim=audnet_out_dim, win_size=self.deepspeech_win_size) + self.audatt_net = AudioAttNet(in_out_dim=audnet_out_dim, seq_len=self.smo_win_size) + + def forward(self, pos, cond_feat, view, run_model_fine=True, **kwargs): + out = {} + pos_embed = self.pos_embedder(pos) + view_embed = self.view_embedder(view) + if run_model_fine: + rgb_sigma = self.model_fine(pos_embed, cond_feat, view_embed) + else: + rgb_sigma = self.model_coarse(pos_embed, cond_feat, view_embed) + out['rgb_sigma'] = rgb_sigma + return out + + def cal_cond_feat(self, cond, with_att=False, **kwargs): + cond_feat = self.aud_net(cond) + if with_att: + cond_feat = self.audatt_net(cond_feat) + if cond_feat.ndim == 1: + cond_feat = cond_feat.unsqueeze(0) + euler_embedding = self.euler_embedder(kwargs['euler']).unsqueeze(0).repeat([cond_feat.shape[0],1]) + trans_embedding = self.trans_embedder(kwargs['trans']).unsqueeze(0).repeat([cond_feat.shape[0],1]) + cond_feat = torch.cat([cond_feat, euler_embedding, trans_embedding], dim=-1) + + if self.hparams.get("use_color", False): + color = kwargs['color'] + color_feat = self.color_encoder(color) + cond_feat = cond_feat.reshape([1, -1]) + cond_feat = cond_feat.repeat([color_feat.shape[0], 1]) + cond_feat = torch.cat([cond_feat, color_feat], dim=-1) + + return cond_feat + + + \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/nerfs/adnerf/backbone.py b/Geneface_main/GeneFace/modules/nerfs/adnerf/backbone.py new file mode 100644 index 00000000..803eb025 --- /dev/null +++ b/Geneface_main/GeneFace/modules/nerfs/adnerf/backbone.py @@ -0,0 +1,137 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class AudioNet(nn.Module): + # Audio feature extractor in AD-NeRF + def __init__(self, in_dim=29, out_dim=64, win_size=16): + super(AudioNet, self).__init__() + self.win_size = win_size + self.dim_aud = out_dim + self.encoder_conv = nn.Sequential( # n x 29 x 16 + nn.Conv1d(in_dim, 32, kernel_size=3, stride=2, + padding=1, bias=True), # n x 32 x 8 + nn.LeakyReLU(0.02, True), + nn.Conv1d(32, 32, kernel_size=3, stride=2, + padding=1, bias=True), # n x 32 x 4 + nn.LeakyReLU(0.02, True), + nn.Conv1d(32, 64, kernel_size=3, stride=2, + padding=1, bias=True), # n x 64 x 2 + nn.LeakyReLU(0.02, True), + nn.Conv1d(64, 64, kernel_size=3, stride=2, + padding=1, bias=True), # n x 64 x 1 + nn.LeakyReLU(0.02, True), + ) + self.encoder_fc1 = nn.Sequential( + nn.Linear(64, 64), + nn.LeakyReLU(0.02, True), + nn.Linear(64, out_dim), + ) + + def forward(self, x): + """ + x: [batch, win=16, hid=29] + return: + [batch, out_dim=76] + """ + half_w = int(self.win_size/2) + x = x[:, 8-half_w:8+half_w, :].permute(0, 2, 1) # [b,t=16,c]=>[b,c,t=16] + x = self.encoder_conv(x).squeeze(-1) # [b, c=64, 1] => [b, c] + x = self.encoder_fc1(x).squeeze() # [b,out_dim=76] + return x + + +class AudioAttNet(nn.Module): + # Audio feature attention-based smoother in AD-NeRF + def __init__(self, in_out_dim=64, seq_len=8): + super(AudioAttNet, self).__init__() + self.seq_len = seq_len + self.in_out_dim = in_out_dim + self.attentionConvNet = nn.Sequential( # b x subspace_dim x seq_len + nn.Conv1d(self.in_out_dim, 16, kernel_size=3, + stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True), + nn.Conv1d(16, 8, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True), + nn.Conv1d(8, 4, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True), + nn.Conv1d(4, 2, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True), + nn.Conv1d(2, 1, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True) + ) + self.attentionNet = nn.Sequential( + nn.Linear(in_features=self.seq_len, out_features=self.seq_len, bias=True), + nn.Softmax(dim=1) + ) + + def forward(self, x): + """ + x: [b=8, c] + return: + [c] + """ + y = x[:, :self.in_out_dim].permute(1, 0).unsqueeze(0) # [b, c] => [1, c, b] + y = self.attentionConvNet(y) # [1,1,b] + y = self.attentionNet(y.view(1, self.seq_len)).view(self.seq_len, 1) # [8, 1] + smoothed_y = torch.sum(y*x, dim=0) # [8,1]*[8,c]=>[8,c]=>[c,] + return smoothed_y + + +class NeRFBackbone(nn.Module): + def __init__(self, pos_dim=3, cond_dim=64, view_dim=3, hid_dim=128, num_density_linears=8, num_color_linears=3, skip_layer_indices=[4]): + super(NeRFBackbone, self).__init__() + self.pos_dim = pos_dim + self.view_dim = view_dim + self.cond_dim = cond_dim + self.hid_dim = hid_dim + self.out_dim = 4 # rgb+sigma + + self.num_density_linears = num_density_linears + self.num_color_linears = num_color_linears + self.skip_layer_indices = skip_layer_indices # specify which layer in density_linears could get the raw input by skip connection + + density_input_dim = pos_dim + cond_dim + self.density_linears = nn.ModuleList( + [nn.Linear(density_input_dim, hid_dim)] + + [nn.Linear(hid_dim, hid_dim) if i not in self.skip_layer_indices else nn.Linear(hid_dim + density_input_dim, hid_dim) for i in range(num_density_linears-1)]) + self.density_out_linear = nn.Linear(hid_dim, 1) + + color_input_dim = view_dim + hid_dim + self.color_linears = nn.ModuleList( + [nn.Linear(color_input_dim, hid_dim//2)] + + [nn.Linear(hid_dim//2, hid_dim//2) for _ in range(num_color_linears-1)]) + self.color_out_linear = nn.Linear(hid_dim//2, 3) + + def forward(self, pos, cond, view): + """ + pos: [bs, n_sample, pos_dim]; the encoding of xyz + cond: [cond_dim,]; condition features + view: [bs, view_dim]; the encoding of view direction + """ + bs, n_sample, _ = pos.shape + if cond.ndim == 1: # [cond_dim] + cond = cond.squeeze()[None, None, :].expand([bs, n_sample, self.cond_dim]) + elif cond.ndim == 2: # [batch, cond_dim] + cond = cond[:, None, :].expand([bs, n_sample, self.cond_dim]) + view = view[:, None, :].expand([bs, n_sample, self.view_dim]) + density_linear_input = torch.cat([pos, cond], dim=-1) + h = density_linear_input + for i in range(len(self.density_linears)): + h = self.density_linears[i](h) + h = F.relu(h) + if i in self.skip_layer_indices: + h = torch.cat([density_linear_input, h], -1) + sigma = self.density_out_linear(h) # [..., 1] + + h = torch.cat([h, view], -1) + for i in range(len(self.color_linears)): + h = self.color_linears[i](h) + h = F.relu(h) + rgb = self.color_out_linear(h) # [..., 3] + + outputs = torch.cat([rgb, sigma], -1) # [..., 4] + return outputs + + diff --git a/Geneface_main/GeneFace/modules/nerfs/commons/embedders.py b/Geneface_main/GeneFace/modules/nerfs/commons/embedders.py new file mode 100644 index 00000000..91dcd638 --- /dev/null +++ b/Geneface_main/GeneFace/modules/nerfs/commons/embedders.py @@ -0,0 +1,45 @@ +import torch +import torch.nn as nn + + +class FreqEmbedder(nn.Module): + # Generate Positional Encoding in NeRF (section 5.1) + def __init__(self, in_dim=3, multi_res=10, use_log_bands=True, include_input=True): + super().__init__() + self.in_dim = in_dim + self.num_freqs = multi_res + self.max_freq_log2 = multi_res - 1 + self.use_log_bands = use_log_bands + self.periodic_fns = [torch.sin, torch.cos] + self.include_input = include_input + + self.embed_fns = None + self.out_dim = None + self.num_embed_fns = None + self.create_embedding_fn() + + def create_embedding_fn(self): + self.embed_fns = [] + self.out_dim = self.num_freqs * len(self.periodic_fns) * self.in_dim + if self.include_input: + self.embed_fns.append(lambda x: x) + self.out_dim += self.in_dim + + if self.use_log_bands: + freq_bands = 2. ** torch.linspace(0., self.max_freq_log2 , steps=self.num_freqs) + else: + freq_bands = torch.linspace(2.**0, 2. ** self.max_freq_log2, steps=self.num_freqs) + + for freq in freq_bands: + for p_fn in self.periodic_fns: + self.embed_fns.append(lambda x, p_fn=p_fn,freq=freq: p_fn(x * freq)) # e.g., torch.cos(x*(2^5)) + self.num_embed_fns = len(self.embed_fns) + + def forward(self, x): + """ + x: [..., in_dim]; xyz or view direction + embedding: [..., out_dim]; the corresponding frequency encoding + """ + embed_lst = [embed_fn(x) for embed_fn in self.embed_fns] # [list of [..., in_dim]] + embedding = torch.cat(embed_lst, dim=-1) # [..., out_dim] + return embedding diff --git a/Geneface_main/GeneFace/modules/nerfs/commons/ray_samplers.py b/Geneface_main/GeneFace/modules/nerfs/commons/ray_samplers.py new file mode 100644 index 00000000..7ef4f5a9 --- /dev/null +++ b/Geneface_main/GeneFace/modules/nerfs/commons/ray_samplers.py @@ -0,0 +1,309 @@ +import torch +import numpy as np +import math +import random +from utils.commons.hparams import hparams + + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + +def get_rays(H, W, focal, c2w, cx=None, cy=None): + """ + Get the rays emitted from camera to all pixels. + The ray is represented in world coordinate + input: + H: height of the image (in pixel) + W: width of the image (in pixel) + focal: focal length of the camera (in pixel) + c2w: a 3x4 camera-to-world matrix, it should be something like this: + [[r11, r12, r13, t1], + [r21, r22, r23, t2], + [r31, r32, r33, t3],] + cx: center of camera in Width axis + cy: center of camera in Height axis + return: + rays_o: the start point of the ray + rays_d: the direction of the ray. so you can sample the point in the ray with: + xyz = rays_o + rays_d * z_val, where z_val is the distance. + """ + # pytorch's meshgrid has indexing='ij' + i_pixels, j_pixels = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) + i_pixels = i_pixels.t().to(device) + j_pixels = j_pixels.t().to(device) + if cx is None: + cx = W*.5 + if cy is None: + cy = H*.5 + directions = torch.stack([(i_pixels-cx)/focal, -(j_pixels-cy)/focal, -torch.ones_like(i_pixels)], dim=-1) + # Rotate ray directions from camera frame to the world frame + # dot product, equals to: [c2w.dot(dir) for dir in dirs] + rays_d = torch.sum(directions[..., None, :] * c2w[:3, :3], dim=-1) # rays_delta + # origin point of all ray, the camera center in the world coordinate + rays_o = c2w[:3, -1].expand(rays_d.shape) + return rays_o, rays_d + + +class BaseRaySampler: + def __init__(self, N_rays): + super(BaseRaySampler, self).__init__() + self.N_rays = N_rays + + def __call__(self, H, W, focal, c2w): + rays_o, rays_d = get_rays(H, W, focal, c2w) + select_coords = self.sample_rays(H, W).to(device) + rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) + rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) + return rays_o, rays_d, select_coords + + def sample_rays(self, H, W, **kwargs): + raise NotImplementedError + + +class UniformRaySampler(BaseRaySampler): + def __init__(self, N_rays=None): + """ + Uniform RaySampler in vanilla NeRF and AD-NeRF + We use it in the photo reconstruction training (i.e., calculating MSE). + """ + super().__init__(N_rays=N_rays) + + def sample_rays(self, H, W, n_rays=None, rect=None, in_rect_percent=0.9, **kwargs): + """ + rect: [w1, h1, delta_w, delta_h] + """ + if n_rays is None: + n_rays = self.N_rays + coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2) + coords = torch.reshape(coords, [-1, 2]).to(device) # (H * W, 2) + if rect is None: + # uniformly sample the whole image + select_inds = np.random.choice(coords.shape[0], size=[n_rays], replace=False) # (n_rays, ) + select_coords = coords[select_inds].long().to(device) + else: + # uniformly sample from the rect reigon and out-rect, respectively. + w1, h1, delta_w, delta_h = rect + w2 = w1 + delta_w + h2 = h1 + delta_h + rect_inds = (coords[:, 0] >= h1) & (coords[:, 0] <= h2) & (coords[:, 1] >= w1) & ( coords[:, 1] <= w2) # (H*W), boolean mask + + coords_rect = coords[rect_inds] # [num_idx_in_mask, 2] + coords_norect = coords[~rect_inds] + num_rays_in_rect = int(n_rays * in_rect_percent) + num_rays_out_rect = n_rays - num_rays_in_rect + select_inds_rect = np.random.choice(coords_rect.shape[0], size=[num_rays_in_rect], replace=False) # (num_rays_in_rect,) + select_coords_rect = coords_rect[select_inds_rect].long() # (num_rays_in_rect, 2) + select_inds_norect = np.random.choice(coords_norect.shape[0], size=[num_rays_out_rect], replace=False) # (num_rays_out_rect,) + select_coords_norect = coords_norect[select_inds_norect].long() # (num_rays_in_rect) + select_coords = torch.cat((select_coords_rect, select_coords_norect), dim=0) + return select_coords # (n_rays, 2) + + def __call__(self, H, W, focal, c2w, n_rays=None, select_coords=None, rect=None, in_rect_percent=0.9, **kwargs): + rays_o, rays_d = get_rays(H, W, focal, c2w) # [H, W, 3] + if select_coords is None: + select_coords = self.sample_rays(H, W, n_rays, rect, in_rect_percent) # [N_rand, 2] + rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) + rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) + return rays_o, rays_d, select_coords + + def sample_pixels_from_img_with_select_coords(self, img, select_coords): + """ + img: [H*W, 3] + """ + return img[select_coords[:, 0], select_coords[:, 1]] + + +class TorsoUniformRaySampler(BaseRaySampler): + def __init__(self, N_rays=None): + """ + Uniform RaySampler for Torso + We use it in the photo reconstruction training (i.e., calculating MSE). + """ + super().__init__(N_rays=N_rays) + + def sample_rays(self, H, W, n_rays=None, rect=None, in_rect_percent=0.9, **kwargs): + """ + rect: [w1, h1, delta_w, delta_h] + """ + if n_rays is None: + n_rays = self.N_rays + coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2) + coords = torch.reshape(coords, [-1, 2]).to(device) # (H * W, 2) + if rect is None: + rect = [0, H/2, W, H/2] + # uniformly sample from the rect reigon and out-rect, respectively. + w1, h1, delta_w, delta_h = rect + w2 = w1 + delta_w + h2 = h1 + delta_h + rect_inds = (coords[:, 0] >= h1) & (coords[:, 0] <= h2) & (coords[:, 1] >= w1) & ( coords[:, 1] <= w2) # (H*W), boolean mask + + coords_rect = coords[rect_inds] # [num_idx_in_mask, 2] + coords_norect = coords[~rect_inds] + num_rays_in_rect = int(n_rays * in_rect_percent) + num_rays_out_rect = n_rays - num_rays_in_rect + select_inds_rect = np.random.choice(coords_rect.shape[0], size=[num_rays_in_rect], replace=False) # (num_rays_in_rect,) + select_coords_rect = coords_rect[select_inds_rect].long() # (num_rays_in_rect, 2) + select_inds_norect = np.random.choice(coords_norect.shape[0], size=[num_rays_out_rect], replace=False) # (num_rays_out_rect,) + select_coords_norect = coords_norect[select_inds_norect].long() # (num_rays_in_rect) + select_coords = torch.cat((select_coords_rect, select_coords_norect), dim=0) + + return select_coords # (n_rays, 2) + + def __call__(self, H, W, focal, c2w, n_rays=None, select_coords=None, rect=None, in_rect_percent=0.9, **kwargs): + rays_o, rays_d = get_rays(H, W, focal, c2w) # [H, W, 3] + if select_coords is None: + select_coords = self.sample_rays(H, W, n_rays, rect, in_rect_percent, **kwargs) # [N_rand, 2] + rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) + rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) + return rays_o, rays_d, select_coords + + def sample_image_with_select_coords(self, img, select_coords): + """ + img: [H*W, 3] + """ + return img[select_coords[:, 0], select_coords[:, 1]] + + +class FullRaySampler(BaseRaySampler): + def __init__(self, **kwargs): + """ + This sampler directly get all rays to render the whole image. + We only use it in the inference phase. + """ + super().__init__(N_rays=None) + + def sample_rays(self, H, W, **kwargs): + num_h_points = int(H*hparams['infer_scale_factor']) + num_w_points = int(W*hparams['infer_scale_factor']) + h, w = torch.meshgrid([torch.linspace(0,H-1,num_h_points), torch.linspace(0,W-1,num_w_points)]) + h = h.reshape([-1,1]).long() + w = w.reshape([-1,1]).long() + select_coords = torch.cat([h, w], dim=-1)# (n_rays, 2) + return select_coords + + def __call__(self, H, W, focal, c2w): + rays_o, rays_d = get_rays(H, W, focal, c2w) + select_coords = self.sample_rays(H, W).to(device) + rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) + rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) + return rays_o, rays_d, select_coords + + +class PatchRaySampler(BaseRaySampler): + def __init__(self, N_rays, min_scale=0.2, max_scale=1., scale_anneal=-0.1): + """ + A modified version of FlexGridSampler in GRAF + We utilize this sampler to get patch-wise rays for adversarial training + """ + self.N_rays_sqrt = int(math.sqrt(N_rays)) + super(PatchRaySampler, self).__init__(self.N_rays_sqrt**2) + + self.min_scale = min_scale + self.max_scale = max_scale + + # nn.functional.grid_sample grid value range in [-1,1] + self.w, self.h = torch.meshgrid([torch.linspace(-1,1,self.N_rays_sqrt), + torch.linspace(-1,1,self.N_rays_sqrt)]) + self.h = self.h.unsqueeze(2) # [n_rays_sqrt, n_rays_sqrt, 1] + self.w = self.w.unsqueeze(2) # [n_rays_sqrt, n_rays_sqrt, 1] + + # directly return grid for grid_sample + self.return_indices = False + + self.scale_anneal = scale_anneal + + def sample_rays(self, H, W, iterations, n_rays=None, rect=None, **kwargs): + """ + get the uv coordinates of the rays that belongs to a image patch + return: a uv mesh grid of [H, W, 2], + note that different from index-based sampler which + generates grid with values in [0,H] and [0,W] + to utilize F.grid_search, we generate grid with values in [-1,1] + """ + # get the unit mesh grid first, + # 0 denotes the center and -1~1 denotes the boundary of the whole image + if n_rays is None: + sqrt_n_rays = self.N_rays_sqrt + unit_w = self.w # value in [-1,1] + unit_h = self.h # value in [-1,1] + else: + sqrt_n_rays = int(math.sqrt(n_rays)) + unit_w, unit_h = torch.meshgrid([torch.linspace(-1, 1, sqrt_n_rays), + torch.linspace(-1,1,sqrt_n_rays)]) + unit_w = unit_w.unsqueeze(2) # [n_rays_sqrt, n_rays_sqrt, 1] + unit_h = unit_h.unsqueeze(2) + + # then get the scale factor of the unit mesh grid + # k_iter = iterations // 1000 * 3 * self.scale_anneal + # min_scale = max(self.min_scale, self.max_scale * math.exp(k_iter)) + # min_scale = min(0.9, min_scale) + min_scale = self.min_scale + scale_factor = torch.Tensor(1).uniform_(min_scale, self.max_scale) + w = unit_w * scale_factor + h = unit_h * scale_factor + + if rect is None: + max_offset_w = 1-scale_factor.item() + max_offset_h = 1-scale_factor.item() + h_offset = torch.Tensor(1).uniform_(0, max_offset_h) * (torch.randint(high=2,size=(1,)).float()-0.5)*2 + w_offset = torch.Tensor(1).uniform_(0, max_offset_w) * (torch.randint(high=2,size=(1,)).float()-0.5)*2 + else: + w1, h1, delta_w, delta_h = rect + w2 = w1 + delta_w + h2 = h1 + delta_h + + # rule1. the edge of patch shall not out of the image + # rule2. the center of patch shall not out of the rect + min_offset_w = max(scale_factor.item()-1, (w1-W//2)/(W//2)) + min_offset_h = max(scale_factor.item()-1, (h1-H//2)/(H//2)) + max_offset_w = min(1-scale_factor.item(), (w2-W//2)/(W//2)) + max_offset_h = min(1-scale_factor.item(), (h2-H//2)/(H//2)) + h_offset = torch.Tensor(1).uniform_(min_offset_h, max_offset_h) + w_offset = torch.Tensor(1).uniform_(min_offset_w, max_offset_w) + h += h_offset + w += w_offset + hw = torch.cat([h,w], dim=2) # [H, W, 2], it is necessary to keep the H,W axis for grid_sampling + select_coords = hw + return select_coords + + def __call__(self, H, W, focal, c2w, iterations, n_rays=None, rect=None, **kwargs): + rays_o, rays_d = get_rays(H, W, focal, c2w) + select_coords = self.sample_rays(H, W, iterations, n_rays, rect).to(rays_o.device) + # instead of int index, this is float index, and we need to sample rays with interpolation + rays_o = torch.nn.functional.grid_sample(rays_o.permute(2,0,1).unsqueeze(0), + select_coords.unsqueeze(0), mode='bilinear', align_corners=True)[0] + rays_d = torch.nn.functional.grid_sample(rays_d.permute(2,0,1).unsqueeze(0), + select_coords.unsqueeze(0), mode='bilinear', align_corners=True)[0] + rays_o = rays_o.permute(1,2,0).view(-1, 3) + rays_d = rays_d.permute(1,2,0).view(-1, 3) + return rays_o, rays_d, select_coords + + def sample_image_with_select_coords(self, img, select_coords): + """ + img: [H, W, 3] + """ + img = img.permute(2, 0, 1) # [3, H, W] + sampled_patch = torch.nn.functional.grid_sample(img.unsqueeze(0), + select_coords.unsqueeze(0), mode='bilinear', align_corners=True)[0] + sampled_patch = sampled_patch.permute(1, 2, 0) # [H, W, 3] + return sampled_patch + + +if __name__ == '__main__': + H, W = 800, 800 + focal = 1200 + rect = [0,0,100,100] + c2w = torch.FloatTensor([[1,1,1,1],[1,1,1,1],[1,1,1,1]]) + ray_sampler = BaseRaySampler(10000) + uniform_ray_sampler = UniformRaySampler(10000) + full_ray_sampler = FullRaySampler() + flex_ray_sampler = PatchRaySampler(10000) + uniform_ray_sampler(H,W,focal,c2w,10000,rect=rect) + import cv2 + img = cv2.imread("experimiental_yerfor/r_0.png") + img = torch.from_numpy(img).permute([2,0,1]).float() + flex_ray_sampler.iterations = 50000 + rays_o, rays_d, select_coords = flex_ray_sampler(H,W,focal, c2w,rect=[0,0,800,800], iterations=1000) + # gt_patch = flex_ray_sampler.sample_image_with_select_coords(select_coords, img) + # gt_patch = gt_patch.permute([1,2,0]).numpy().astype(np.uint8) + # cv2.imwrite("experimiental_yerfor/sampled_iteration=50000.png", gt_patch) \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/nerfs/commons/volume_rendering.py b/Geneface_main/GeneFace/modules/nerfs/commons/volume_rendering.py new file mode 100644 index 00000000..c1d53910 --- /dev/null +++ b/Geneface_main/GeneFace/modules/nerfs/commons/volume_rendering.py @@ -0,0 +1,286 @@ +import torch +import torch.nn.functional as F +import numpy as np + +from modules.nerfs.commons.ray_samplers import get_rays + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + +def raw2outputs(raw, z_vals, rays_d, bc_rgb, raw_noise_std=0, white_bkgd=False): + """Transforms model's predictions to semantically meaningful values. + Args: + raw: [num_rays, num_samples along ray, 4]. Prediction from model, rgb+sigma + z_vals: [num_rays, num_samples along ray]. Integration time. + rays_d: [num_rays, 3]. Direction of each ray. + Returns: + rgb_map: [num_rays, 3]. Estimated RGB color of a ray. + disp_map: [num_rays]. Disparity map. Inverse of depth map. + acc_map: [num_rays]. Sum of weights along each ray. + weights: [num_rays, num_samples]. Weights assigned to each sampled color. + depth_map: [num_rays]. Estimated distance to object. + """ + def raw2alpha(raw, dists, act_fn=F.relu): + """ + Args: + raw: predicted sigma; [N_rays, N_samples] + dists: delta distance; [N_rays, N_samples] + return: + alpha: normalized volume density (occulusion degree) in the delta distance; [N_rays, N_samples] + """ + return 1. - torch.exp(-(act_fn(raw)+1e-6)*dists) + + dists = z_vals[..., 1:] - z_vals[..., :-1] # delta_z_vals ,[N_rays, N_samples-1] + dists = torch.cat([dists, torch.Tensor([1e10]).to(device).expand(dists[..., :1].shape)], -1) # add infintely far, [N_rays, N_samples] + + dists = dists * torch.norm(rays_d[..., None, :], dim=-1) # [N_rays, N_samples] + + rgb = torch.sigmoid(raw[..., :3]) # [N_rays, N_samples, 3] + rgb = torch.cat((rgb[:, :-1, :], bc_rgb.unsqueeze(1)), dim=1) # replace the last sample point with background color + noise = 0. + if raw_noise_std > 0.: + noise = torch.randn(raw[..., 3].shape) * raw_noise_std + + alpha = raw2alpha(raw[..., 3] + noise, dists) # [N_rays, N_samples] + importance_weights = alpha * \ + torch.cumprod( + torch.cat([torch.ones((alpha.shape[0], 1)).to(device), 1.-alpha + 1e-10], -1), -1)[:, :-1] + rgb_map = torch.sum(importance_weights[..., None] * rgb, -2) # [N_rays, 3] + + rgb_map_fg = torch.sum(importance_weights[:, :-1, None]*rgb[:, :-1, :], -2) + + depth_map = torch.sum(importance_weights * z_vals, -1) + disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), + depth_map / torch.sum(importance_weights, -1)) + accu_map = torch.sum(importance_weights, -1) + + if white_bkgd: + rgb_map = rgb_map + (1.-accu_map[..., None]) + + return rgb_map, disp_map, accu_map, importance_weights, depth_map, rgb_map_fg + + +def sample_pdf(bins, weights, N_samples, det=False): + # Get pdf + weights = weights + 1e-5 # prevent nans + pdf = weights / torch.sum(weights, -1, keepdim=True) + cdf = torch.cumsum(pdf, -1) + # (batch, len(bins)) + cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) + + # Take uniform samples + if det: + u = torch.linspace(0., 1., steps=N_samples) + u = u.expand(list(cdf.shape[:-1]) + [N_samples]) + else: + u = torch.rand(list(cdf.shape[:-1]) + [N_samples]) + + + # Invert CDF + u = u.to(device).contiguous() + inds = torch.searchsorted(cdf, u, right=True) + below = torch.max(torch.zeros_like(inds-1), inds-1) + above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds) + inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) + + # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) + # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) + matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] + cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) + bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) + + denom = (cdf_g[..., 1]-cdf_g[..., 0]) + denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) + t = (u-cdf_g[..., 0])/denom + samples = bins_g[..., 0] + t * (bins_g[..., 1]-bins_g[..., 0]) + + return samples + +def render_rays(ray_batch, + bc_rgb, + cond, + network_fn, + N_samples, + return_raw=False, + linear_disp=False, + perturb=1., + N_importance=0, + white_bkgd=False, + raw_noise_std=0., + **kwargs + ): + """Volumetric rendering. + Args: + ray_batch: array of shape [batch_size, ...]. All information necessary + for sampling along a ray, including: ray origin, ray direction, min + dist, max dist, and unit-magnitude viewing direction. + network_fn: function. Model for predicting RGB and density at each point + in space. + network_query_fn: function used for passing queries to network_fn. + N_samples: int. Number of different times to sample along each ray. + retraw: bool. If True, include model's raw, unprocessed predictions. + lindisp: bool. If True, sample linearly in inverse depth rather than in depth. + perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified + random points in time. + N_importance: int. Number of additional times to sample along each ray. + These samples are only passed to network_fine. + network_fine: "fine" network with same spec as network_fn. + white_bkgd: bool. If True, assume a white background. + raw_noise_std: ... + verbose: bool. If True, print more debugging info. + Returns: + rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model. + disp_map: [num_rays]. Disparity map. 1 / depth. + acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model. + raw: [num_rays, num_samples, 4]. Raw predictions from model. + rgb0: See rgb_map. Output for coarse model. + disp0: See disp_map. Output for coarse model. + acc0: See acc_map. Output for coarse model. + z_std: [num_rays]. Standard deviation of distances along ray for each + sample. + """ + N_rays = ray_batch.shape[0] + rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6] # [N_rays, 3] each + viewdirs = ray_batch[:, -3:] if ray_batch.shape[-1] > 8 else None + bounds = torch.reshape(ray_batch[..., 6:8], [-1, 1, 2]) # [near, far] + near, far = bounds[..., 0], bounds[..., 1] # [-1,1] + + t_vals = torch.linspace(0., 1., steps=N_samples).to(device) # [64,] + if not linear_disp: + z_vals = near * (1.-t_vals) + far * (t_vals) + else: + z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals)) + + z_vals = z_vals.expand([N_rays, N_samples]) # [1024, 64] + + if perturb > 0.: # default 1., set it to add noise in z_vals + # get intervals between samples + mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) # [1024, 63] + upper = torch.cat([mids, z_vals[..., -1:]], -1) # [1024, 64] + lower = torch.cat([z_vals[..., :1], mids], -1) # [1024, 64] + # stratified samples in those intervals + t_rand = torch.rand(z_vals.shape).to(device) + + t_rand[..., -1] = 1.0 + z_vals = lower + (upper - lower) * t_rand + pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None] # [N_rays, N_samples, 3] + + # coarse_network_out = network_fn.forward(pts, cond, viewdirs, run_model_fine=False) # [N_rays, N_samples, rgbd] + coarse_network_out = network_fn.forward(pts, cond, viewdirs,run_model_fine=False, **kwargs) # [N_rays, N_samples, rgbd] + raw = coarse_network_out['rgb_sigma'] + + rgb_map, disp_map, acc_map, weights, depth_map, rgb_map_fg = raw2outputs( + raw, z_vals, rays_d, bc_rgb, raw_noise_std, white_bkgd) + + if N_importance > 0: # default 128 + # additionally run on fine model + rgb_map_0, disp_map_0, acc_map_0, last_weight_0, rgb_map_fg_0 = rgb_map, disp_map, acc_map, weights[..., -1], rgb_map_fg + + z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) + z_samples = sample_pdf( + z_vals_mid, weights[..., 1:-1], N_importance, det=(perturb == 0.)) + z_samples = z_samples.detach() + + z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1) + pts = rays_o[..., None, :] + rays_d[..., None, :] * \ + z_vals[..., :, None] # [N_rays, N_samples + N_importance, 3] + + # fine_network_out = network_fn.forward(pts, cond, viewdirs, run_model_fine=True) + fine_network_out = network_fn.forward(pts, cond, viewdirs, run_model_fine=True, **kwargs) # [N_rays, N_samples, rgbd] + + raw = fine_network_out['rgb_sigma'] + + rgb_map, disp_map, acc_map, weights, depth_map, rgb_map_fg = raw2outputs( + raw, z_vals, rays_d, bc_rgb, raw_noise_std, white_bkgd) + + ret = {'rgb_map': rgb_map, 'disp_map': disp_map, 'acc_map': acc_map, 'rgb_map_fg': rgb_map_fg} + if return_raw: + ret['raw'] = raw + if N_importance > 0: + ret['rgb_map_coarse'] = rgb_map_0 + ret['disp_map_coarse'] = disp_map_0 + ret['accu_map_coarse'] = acc_map_0 + ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays] + ret['last_weight'] = weights[..., -1] + ret['last_weight0'] = last_weight_0 + ret['rgb_map_fg0'] = rgb_map_fg_0 + + for k in ret: + if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()): + print(f"! [Numerical Error] {k} contains nan or inf.") + return ret + + +def batchify_render_rays(rays_flat, bc_rgb, cond, chunk, network_fn, N_samples, N_importance, **kwargs): + """Render rays in smaller minibatches to avoid OOM. + """ + all_ret = {} + for i in range(0, rays_flat.shape[0], chunk): + if cond.squeeze().ndim == 1: + ret = render_rays(rays_flat[i:i+chunk], bc_rgb[i:i+chunk], + cond, network_fn, N_samples, N_importance=N_importance, **kwargs) + elif cond.squeeze().ndim == 2: + ret = render_rays(rays_flat[i:i+chunk], bc_rgb[i:i+chunk], + cond[i:i+chunk], network_fn, N_samples, N_importance=N_importance, **kwargs) + + for k in ret: + if k not in all_ret: + all_ret[k] = [] + all_ret[k].append(ret[k]) + + all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret} + return all_ret + + +def render_dynamic_face(H, W, focal, cx, cy, chunk=1024, rays_o=None, rays_d=None, bc_rgb=None, cond=None, + c2w=None, near=0., far=1., use_viewdirs=True, c2w_staticcam=None, + network_fn=None, N_samples=None, N_importance=None, + **kwargs): + """ + bc_rgb: [H,W,3] + """ + if bc_rgb is not None: + bc_rgb = bc_rgb.reshape(-1, 3) # [H*W, 3] + + if c2w is not None: + # special case to render full image + rays_o, rays_d = get_rays(H, W, focal, c2w, cx, cy) + else: + # use provided ray batch + rays_o, rays_d = rays_o, rays_d + + if use_viewdirs: + # provide ray directions as input + viewdirs = rays_d + if c2w_staticcam is not None: + # special case to visualize effect of viewdirs + rays_o, rays_d = get_rays(H, W, focal, c2w_staticcam, cx, cy) + viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) + viewdirs = torch.reshape(viewdirs, [-1, 3]).float() + + sh = rays_d.shape # [..., 3] + + # Create ray batch + rays_o = torch.reshape(rays_o, [-1, 3]).float() + rays_d = torch.reshape(rays_d, [-1, 3]).float() + + near, far = near * \ + torch.ones_like(rays_d[..., :1]), far * \ + torch.ones_like(rays_d[..., :1]) + rays = torch.cat([rays_o, rays_d, near, far], -1) # [N, 8] + if use_viewdirs: + rays = torch.cat([rays, viewdirs], -1) # [N,11=rays_o3+rays_d3+nearfar2+viewdir3] + + # Render and reshape + all_ret = batchify_render_rays(rays, bc_rgb, cond, chunk, network_fn=network_fn, N_samples=N_samples, N_importance=N_importance, **kwargs) + for k in all_ret: + k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:]) + all_ret[k] = torch.reshape(all_ret[k], k_sh) + + k_extract = ['rgb_map', 'disp_map', 'acc_map', 'last_weight', 'rgb_map_fg'] + ret_list = [all_ret[k] for k in k_extract] + ret_dict = {k: all_ret[k] for k in all_ret if k not in k_extract} + return ret_list + [ret_dict] + + +if __name__ == '__main__': + get_rays(450, 450, 1200, ) \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/nerfs/lm3d_nerf/cond_encoder.py b/Geneface_main/GeneFace/modules/nerfs/lm3d_nerf/cond_encoder.py new file mode 100644 index 00000000..24c2c438 --- /dev/null +++ b/Geneface_main/GeneFace/modules/nerfs/lm3d_nerf/cond_encoder.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class AudioNet(nn.Module): + # Audio feature extractor in AD-NeRF + def __init__(self, in_dim=29, out_dim=64, win_size=16): + super(AudioNet, self).__init__() + self.win_size = win_size + if win_size == 1: + strides = [1,1,1,1] + elif win_size == 2: + strides = [2,1,1,1] + elif win_size in [3, 4]: + strides = [2,2,1,1] + elif win_size in [5, 8]: + strides = [2,2,2,1] + elif win_size == 16: + strides = [2,2,2,2] + else: + raise ValueError("unsupported win_size") + self.dim_aud = out_dim + self.encoder_conv = nn.Sequential( # n x 29 x 16 + nn.Conv1d(in_dim, 32, kernel_size=3, stride=strides[0], + padding=1, bias=True), # n x 32 x 8 + nn.LeakyReLU(0.02, True), + nn.Conv1d(32, 32, kernel_size=3, stride=strides[1], + padding=1, bias=True), # n x 32 x 4 + nn.LeakyReLU(0.02, True), + nn.Conv1d(32, 64, kernel_size=3, stride=strides[2], + padding=1, bias=True), # n x 64 x 2 + nn.LeakyReLU(0.02, True), + nn.Conv1d(64, 64, kernel_size=3, stride=strides[3], + padding=1, bias=True), # n x 64 x 1 + nn.LeakyReLU(0.02, True), + ) + self.encoder_fc1 = nn.Sequential( + nn.Linear(64, 64), + nn.LeakyReLU(0.02, True), + nn.Linear(64, out_dim), + ) + + def forward(self, x): + """ + x: [batch, win=16, hid=29] + return: + [batch, out_dim=76] + """ + half_w = int(self.win_size/2) + x = x.permute(0, 2, 1) # [b,t=16,c]=>[b,c,t=16] + x = self.encoder_conv(x).squeeze(-1) # [b, c=64, 1] => [b, c] + x = self.encoder_fc1(x).squeeze() # [b,out_dim=76] + return x + +class AudioAttNet(nn.Module): + # Audio feature attention-based smoother in AD-NeRF + def __init__(self, in_out_dim=64, seq_len=8): + super(AudioAttNet, self).__init__() + self.seq_len = seq_len + self.in_out_dim = in_out_dim + self.attentionConvNet = nn.Sequential( # b x subspace_dim x seq_len + nn.Conv1d(self.in_out_dim, 16, kernel_size=3, + stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True), + nn.Conv1d(16, 8, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True), + nn.Conv1d(8, 4, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True), + nn.Conv1d(4, 2, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True), + nn.Conv1d(2, 1, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True) + ) + self.attentionNet = nn.Sequential( + nn.Linear(in_features=self.seq_len, out_features=self.seq_len, bias=True), + nn.Softmax(dim=1) + ) + + def forward(self, x): + """ + x: [b=8, c] + return: + [c] + """ + y = x[:, :self.in_out_dim].permute(1, 0).unsqueeze(0) # [b, c] => [1, c, b] + y = self.attentionConvNet(y) # [1,1,b] + y = self.attentionNet(y.view(1, self.seq_len)).view(self.seq_len, 1) # [8, 1] + smoothed_y = torch.sum(y*x, dim=0) # [8,1]*[8,c]=>[8,c]=>[c,] + return smoothed_y + diff --git a/Geneface_main/GeneFace/modules/nerfs/lm3d_nerf/lm3d_nerf.py b/Geneface_main/GeneFace/modules/nerfs/lm3d_nerf/lm3d_nerf.py new file mode 100644 index 00000000..28636a01 --- /dev/null +++ b/Geneface_main/GeneFace/modules/nerfs/lm3d_nerf/lm3d_nerf.py @@ -0,0 +1,127 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modules.nerfs.commons.embedders import FreqEmbedder +from modules.nerfs.adnerf.backbone import NeRFBackbone +from modules.nerfs.lm3d_nerf.cond_encoder import AudioNet, AudioAttNet +from modules.nerfs.commons.volume_rendering import render_dynamic_face + +from utils.commons.hparams import hparams + + +class Lm3dNeRF(nn.Module): + def __init__(self, hparams=None): + super().__init__() + self.hparams = hparams + self.pos_embedder = FreqEmbedder(in_dim=3, multi_res=10, use_log_bands=True, include_input=True) + self.view_embedder = FreqEmbedder(in_dim=3, multi_res=4, use_log_bands=True, include_input=True) + pos_dim = self.pos_embedder.out_dim + view_dim = self.view_embedder.out_dim + nerf_cond_dim = lm3d_out_dim = hparams['cond_dim'] + self.model_coarse = NeRFBackbone(pos_dim=pos_dim, cond_dim=nerf_cond_dim, view_dim=view_dim, hid_dim=hparams['hidden_size'], num_density_linears=8, num_color_linears=3, skip_layer_indices=[4]) + self.model_fine = NeRFBackbone(pos_dim=pos_dim, cond_dim=nerf_cond_dim, view_dim=view_dim, hid_dim=hparams['hidden_size'], num_density_linears=8, num_color_linears=3, skip_layer_indices=[4]) + + cond_in_dim = 68 * 3 + if hparams['use_window_cond']: + self.lm3d_win_size = hparams['cond_win_size'] + self.smo_win_size = hparams['smo_win_size'] + self.lm_encoder = AudioNet(in_dim=cond_in_dim, out_dim=lm3d_out_dim, win_size=self.lm3d_win_size) + if hparams['with_att']: + self.lmatt_encoder = AudioAttNet(in_out_dim=lm3d_out_dim, seq_len=self.smo_win_size) + else: + self.lm_encoder = nn.Sequential(*[ + nn.Linear(cond_in_dim, 32, bias=True), + nn.LeakyReLU(0.02, True), + nn.Linear(32, 32, bias=True), + nn.LeakyReLU(0.02, True), + nn.Linear(32, 64, bias=True), + nn.LeakyReLU(0.02, True), + nn.Linear(64, lm3d_out_dim, bias=True), + ]) + + def forward(self, pos, cond_feat, view, run_model_fine=True, **kwargs): + out = {} + pos_embed = self.pos_embedder(pos) + view_embed = self.view_embedder(view) + if run_model_fine: + rgb_sigma = self.model_fine(pos_embed, cond_feat, view_embed) + else: + rgb_sigma = self.model_coarse(pos_embed, cond_feat, view_embed) + out['rgb_sigma'] = rgb_sigma + return out + + def cal_cond_feat(self, cond, with_att=False): + cond_feat = self.lm_encoder(cond) + if with_att: + cond_feat = self.lmatt_encoder(cond_feat) + return cond_feat + + ########################## + # forward the model + ########################## + def run_model(self, sample, infer=False): + """ + render or train on a single-frame + :param sample: a batch of data + :param infer: bool, run in infer mode + :return: + if not infer: + return losses, model_out + if infer: + return model_out + """ + cond = sample['cond_win'] if hparams['use_window_cond'] else sample['cond'] + cond_wins = sample['cond_wins'] + H = sample['H'] + W = sample['W'] + focal = sample['focal'] + cx = sample['cx'] + cy = sample['cy'] + near = sample['near'] + far = sample['far'] + bg_img = sample['bg_img'] + c2w = sample['c2w'] + c2w_t0 = sample['c2w_t0'] + t = sample['t'] + + with_att = hparams['with_att'] and (self.global_step >= self.no_smo_iterations) + if with_att: + cond_feat = self.model.cal_cond_feat(cond_wins, with_att=True) + else: + cond_feat = self.model.cal_cond_feat(cond, with_att=False) + + if infer: + rays_o, rays_d, _ = self.full_rays_sampler(H, W, focal, c2w) + rgb_pred, disp, acc, _, _, extras = render_dynamic_face(H, W, focal, cx, cy, rays_o=rays_o, rays_d=rays_d, + bc_rgb=bg_img, + chunk=2048, + c2w=None, cond=cond_feat, near=near, far=far, + network_fn=self.model, N_samples=self.n_samples_per_ray, N_importance=self.n_samples_per_ray_fine, + c2w_t=c2w, c2w_t0=c2w_t0,t=t, + ) + model_out = { + "rgb_map" : rgb_pred + } + return model_out + else: + rays_o, rays_d, select_coords = self.rays_sampler(H, W, focal, c2w, n_rays=None, rect=sample['rect'], in_rect_percent=hparams['in_rect_percent'], iterations=self.global_step) + target = sample['head_img'] + rgb_gt = self.rays_sampler.sample_pixels_from_img_with_select_coords(target, select_coords) + rgb_bc = self.rays_sampler.sample_pixels_from_img_with_select_coords(sample['bg_img'], select_coords) + + rgb_pred, disp, acc, _, _, extras = render_dynamic_face(H, W, focal, cx, cy, rays_o=rays_o, rays_d=rays_d, + bc_rgb=rgb_bc,chunk=self.chunk, c2w=None, cond=cond_feat, near=near, far=far, + network_fn=self.model, N_samples=self.n_samples_per_ray, N_importance=self.n_samples_per_ray_fine, + c2w_t=c2w, c2w_t0=c2w_t0,t=t,) + losses_out = {} + losses_out['mse_loss'] = torch.mean((rgb_pred - rgb_gt) ** 2) + if 'rgb_map_coarse' in extras: + losses_out['mse_loss_coarse'] = torch.mean((extras['rgb_map_coarse'] - rgb_gt) ** 2) + model_out = { + "rgb_map": rgb_pred + } + return losses_out, model_out + + + \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/postnet/__pycache__/lle.cpython-39.pyc b/Geneface_main/GeneFace/modules/postnet/__pycache__/lle.cpython-39.pyc new file mode 100644 index 00000000..c7936239 Binary files /dev/null and b/Geneface_main/GeneFace/modules/postnet/__pycache__/lle.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/modules/postnet/__pycache__/models.cpython-39.pyc b/Geneface_main/GeneFace/modules/postnet/__pycache__/models.cpython-39.pyc new file mode 100644 index 00000000..c5c4ca41 Binary files /dev/null and b/Geneface_main/GeneFace/modules/postnet/__pycache__/models.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/modules/postnet/lle.py b/Geneface_main/GeneFace/modules/postnet/lle.py new file mode 100644 index 00000000..89b09c41 --- /dev/null +++ b/Geneface_main/GeneFace/modules/postnet/lle.py @@ -0,0 +1,106 @@ +import torch +import numpy as np +from numpy.linalg import solve + +from utils.commons.tensor_utils import convert_to_tensor + + +def find_k_nearest_neighbors(feats, feat_database, K=10): + """ + KNN (K-nearest neighbor), return the index of k-nearest neighbors in the feat_database + args: + feats: [N_sample_in_batch, C] + feats_database: [N_sample_in_dataset, C] + K: the number of topK nearest neighbors + return: + ind: [N_sample_in_batch, K=10] the index of K nearest neighbors in the database, CPU tensor + """ + feats = convert_to_tensor(feats) + feat_database = convert_to_tensor(feat_database) + # Training + feat_base_norm = (feat_database ** 2).sum(-1) # [N_sample_in_database,] + # start computing KNN ... + feats_norm = (feats ** 2).sum(-1) # [N_sample_in_batch,] + # calculate distance via : (x-y)^2 = x^2 + y^2 - 2xy + distance_mat = (feats_norm.view(-1, 1) + feat_base_norm.view(1, -1) - 2 * feats @ feat_database.t()) # [N_sample_in_batch, N_sample_in_database] + # get the index of k nearest neighbors + ind = distance_mat.topk(K, dim=1, largest=False).indices + return ind + +def solve_LLE_projection_batch(feat, feat_base): + """ + Find LLE projection weights given feat base and target feat + Project a batch of feat vector into a linear combination of feat_base + TODO: perform this process in a mini-batch. + ======================================= + We need to solve the following function + ``` + min|| feat - \sum_0^k{w_i} * feat_base_i ||, s.t. \sum_0^k{w_i}=1 + ``` + equals to: + ft = w1*f1 + w2*f2 + ... + wk*fk, s.t. w1+w2+...+wk=1 + = (1-w2-...-wk)*f1 + w2*f2 + ... + wk*fk + ft-f1 = w2*(f2-f1) + w3*(f3-f1) + ... + wk*(fk-f1) + ft-f1 = (f2-f1, f3-f1, ..., fk-f1) dot (w2, w3, ..., wk).T + B = A dot w_, here, B: [ndim,] A: [ndim, k-1], w_: [k-1,] + Finally, + ft' = (1-w2-..wk, w2, ..., wk) dot (f1, f2, ..., fk) + ======================================= + args: + feat: [N_sample_in_batch, C], the feat to be preocessed + feat_base: [N_sample_in_batch, K, C], the base vectors to represent the feat + return: + weights: [N, K], the linear weights of K base vectors, sums to 1 + fear_fuse: [N, C], the processed feat + """ + feat = convert_to_tensor(feat) + feat_base = convert_to_tensor(feat_base) + N, K, C = feat_base.shape + if K == 1: + weights = torch.ones([N, 1]) + feat_fuse = feat_base[:, 0, ] + errors = None + else: + weights = torch.zeros(N, K) + B = feat - feat_base[:, 0, :] # [N, C] + A = (feat_base[:, 1:, :] - feat_base[:, 0:1, :]).transpose(1,2) # [N, C, K-1] + AT = A.transpose(1,2) # [N, K-1, C] + # solve the AX=B with Least square method + # where X [N, K-1] is the weights[1:] we want to learn + # AT*A*X=AT*B ==> X = inv(ATA)*AT*B + ATA = torch.bmm(AT, A) # [N, K-1, K-1] + inv_ATA = torch.inverse(ATA) # [N, K-1, K-1] + X = torch.bmm(torch.bmm(inv_ATA, AT), B.unsqueeze(2)).squeeze() # [N, K-1] + weights[:, 1:] = X + weights[:, 0] = torch.ones_like(weights[:, 0]) - X.sum(dim=1) + feat_fuse = torch.bmm(weights.unsqueeze(1), feat_base).squeeze(1) # [N,1,K] @ [N,K,C] ==> [N,1,C] ==> [N, C] + errors = (torch.bmm(A,X.unsqueeze(-1)).squeeze() - B).abs().mean(dim=-1) # [N,] + return feat_fuse, errors, weights + +def compute_LLE_projection(feats, feat_database, K=10): + """ + Project the feat into a linear combination of K base vectors in feat_database + args: + feat: [N_sample_in_batch, C], the feat to be processed + feat_database: [N_sample_in_batch, C], all feat datapoints in dataset + K: int, number of K neighbors + return: + weights: [N_sample_K, ] + """ + index_of_K_neighbors_in_database = find_k_nearest_neighbors(feats, feat_database, K) # [N_sample_in_batch, K=10] + feat_base = feat_database[index_of_K_neighbors_in_database] + # print("performing LLE projection ...") + feat_fuse, errors, weights = solve_LLE_projection_batch(feats, feat_base) + # print("LLE projection Done.") + return feat_fuse, errors, weights + + +if __name__ == '__main__': + audio_feats = torch.randn(1000, 64).numpy() + feat_database = torch.randn(10000, 64).numpy() + Knear = 10 + # LLE_percent =1. + ind = find_k_nearest_neighbors(audio_feats, feat_database, K=Knear) + weights, feat_fuse = compute_LLE_projection(audio_feats, feat_database, K=10) + # audio_feats = audio_feats * (1-LLE_percent) + feat_fuse * LLE_percent + print(" ") \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/postnet/models.py b/Geneface_main/GeneFace/modules/postnet/models.py new file mode 100644 index 00000000..5c1a5a42 --- /dev/null +++ b/Geneface_main/GeneFace/modules/postnet/models.py @@ -0,0 +1,106 @@ +import torch +import torch.nn as nn + +class Conv1d(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.Conv1d(cin, cout, kernel_size, stride, padding), + nn.BatchNorm1d(cout) + ) + self.act = nn.LeakyReLU(0.2, inplace=True) + self.residual = residual + + def forward(self, x): + out = self.conv_block(x) + if self.residual: + out += x + return self.act(out) + +class CNNPostNet(nn.Module): + def __init__(self, in_out_dim=64): + super().__init__() + self.in_out_dim = in_out_dim + self.block1 = nn.Sequential(*[ + Conv1d(in_out_dim, 128, 3, 1, 1, False), # [B, T=9, C=] + Conv1d(128, 128, 3, 1, 1, True), + Conv1d(128, 128, 3, 1, 1, True), + ]) + self.block2 = nn.Sequential(*[ + Conv1d(128, 256, 3, 1, 1, False), # [B, T=9, C=] + Conv1d(256, 256, 3, 1, 1, True), + Conv1d(256, 256, 3, 1, 1, True), + ]) + self.block3 = nn.Sequential(*[ + Conv1d(256, 128, 3, 1, 1, residual=False), + nn.Conv1d(128, in_out_dim, 1, 1, 0), + ]) + + def forward(self, x): + nopadding_mask = ~ x.abs().sum(-1).eq(0).data # [B, T] + diff_x = self.block1(x.transpose(1, 2)).transpose(1,2) * nopadding_mask.unsqueeze(2) + diff_x = self.block2(diff_x.transpose(1, 2)).transpose(1,2) * nopadding_mask.unsqueeze(2) + diff_x = self.block3(diff_x.transpose(1, 2)).transpose(1,2) * nopadding_mask.unsqueeze(2) + refine_x = x + diff_x + return refine_x + + +class PitchContourCNNPostNet(nn.Module): + def __init__(self, in_out_dim=64, pitch_dim=32): + super().__init__() + self.in_out_dim = in_out_dim + self.block1 = nn.Sequential(*[ + Conv1d(in_out_dim+pitch_dim, 128, 3, 1, 1, False), # [B, T=9, C=] + Conv1d(128, 128, 3, 1, 1, True), + Conv1d(128, 128, 3, 1, 1, True), + ]) + self.block2 = nn.Sequential(*[ + Conv1d(128, 256, 3, 1, 1, False), # [B, T=9, C=] + Conv1d(256, 256, 3, 1, 1, True), + Conv1d(256, 256, 3, 1, 1, True), + ]) + self.block3 = nn.Sequential(*[ + Conv1d(256, 128, 3, 1, 1, residual=False), + nn.Conv1d(128, in_out_dim, 1, 1, 0), + ]) + + def forward(self, x, pitch): + nopadding_mask = ~ x.abs().sum(-1).eq(0).data # [B, T] + inp = torch.cat([x,pitch],dim=-1) + diff_x = self.block1(inp.transpose(1, 2)).transpose(1,2) * nopadding_mask.unsqueeze(2) + diff_x = self.block2(diff_x.transpose(1, 2)).transpose(1,2) * nopadding_mask.unsqueeze(2) + diff_x = self.block3(diff_x.transpose(1, 2)).transpose(1,2) * nopadding_mask.unsqueeze(2) + refine_x = x + diff_x + return refine_x + + +class MLPDiscriminator(nn.Module): + def __init__(self, in_dim=64): + super().__init__() + self.in_dim = in_dim + self.backbone = nn.Sequential(*[ + nn.Linear(in_dim, 128), + nn.LeakyReLU(0.2, inplace=True), + nn.Dropout(0.25), + nn.Linear(128, 256), + nn.LeakyReLU(0.25, inplace=True), + nn.Dropout(0.25), + nn.Linear(256, 256), + nn.LeakyReLU(0.25, inplace=True), + nn.Dropout(0.25), + nn.Linear(256, 128), + nn.LeakyReLU(0.25, inplace=True), + nn.Dropout(0.25), + nn.Linear(128, 1, bias=False) + ]) + def forward(self, x): + x_mask = x.sum(-1).ne(0) # [b, T] + x_flatten = x[x_mask.unsqueeze(2).repeat([1,1,self.in_dim])].reshape([-1,self.in_dim]) + validity = self.backbone(x_flatten) + return [validity] + +if __name__ == '__main__': + net = CNNPostNet() + x = torch.rand(2, 9, 64) + y = net(x) + print(y.shape) \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/__pycache__/cond_encoder.cpython-39.pyc b/Geneface_main/GeneFace/modules/radnerfs/__pycache__/cond_encoder.cpython-39.pyc new file mode 100644 index 00000000..8d78149f Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/__pycache__/cond_encoder.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/__pycache__/radnerf.cpython-39.pyc b/Geneface_main/GeneFace/modules/radnerfs/__pycache__/radnerf.cpython-39.pyc new file mode 100644 index 00000000..47b5e319 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/__pycache__/radnerf.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/__pycache__/radnerf_torso.cpython-39.pyc b/Geneface_main/GeneFace/modules/radnerfs/__pycache__/radnerf_torso.cpython-39.pyc new file mode 100644 index 00000000..37138e4c Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/__pycache__/radnerf_torso.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/__pycache__/renderer.cpython-39.pyc b/Geneface_main/GeneFace/modules/radnerfs/__pycache__/renderer.cpython-39.pyc new file mode 100644 index 00000000..cb324434 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/__pycache__/renderer.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/__pycache__/utils.cpython-39.pyc b/Geneface_main/GeneFace/modules/radnerfs/__pycache__/utils.cpython-39.pyc new file mode 100644 index 00000000..52f701fa Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/__pycache__/utils.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/cond_encoder.py b/Geneface_main/GeneFace/modules/radnerfs/cond_encoder.py new file mode 100644 index 00000000..20853bd2 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/cond_encoder.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# Audio feature extractor +class AudioNet(nn.Module): + def __init__(self, dim_in=29, dim_aud=64, win_size=16): + super(AudioNet, self).__init__() + self.win_size = win_size + self.dim_aud = dim_aud + if win_size == 1: + strides = [1,1,1,1] + elif win_size == 2: + strides = [2,1,1,1] + elif win_size in [3, 4]: + strides = [2,2,1,1] + elif win_size == [5, 8]: + strides = [2,2,2,1] + elif win_size == 16: + strides = [2,2,2,2] + else: + raise ValueError("unsupported win_size") + self.encoder_conv = nn.Sequential( # n x 29 x 16 + nn.Conv1d(dim_in, 32, kernel_size=3, stride=strides[0], + padding=1, bias=True), # n x 32 x 8 + nn.LeakyReLU(0.02, True), + nn.Conv1d(32, 32, kernel_size=3, stride=strides[1], + padding=1, bias=True), # n x 32 x 4 + nn.LeakyReLU(0.02, True), + nn.Conv1d(32, 64, kernel_size=3, stride=strides[2], + padding=1, bias=True), # n x 64 x 2 + nn.LeakyReLU(0.02, True), + nn.Conv1d(64, 64, kernel_size=3, stride=strides[3], + padding=1, bias=True), # n x 64 x 1 + nn.LeakyReLU(0.02, True), + ) + self.encoder_fc1 = nn.Sequential( + nn.Linear(64, 64), + nn.LeakyReLU(0.02, True), + nn.Linear(64, dim_aud), + ) + + def forward(self, x): + """ + x: [b, t_window, c] + """ + half_w = int(self.win_size/2) + x = x.permute(0, 2, 1) # [b,t=16,c]=>[b,c,t=16] + x = self.encoder_conv(x).squeeze(-1) # [b, c=64, 1] => [b, c] + x = self.encoder_fc1(x).squeeze() # [b,out_dim=76] + return x + + +class AudioAttNet(nn.Module): + # Audio feature attention-based smoother in AD-NeRF + def __init__(self, in_out_dim=64, seq_len=8): + super(AudioAttNet, self).__init__() + self.seq_len = seq_len + self.in_out_dim = in_out_dim + self.attentionConvNet = nn.Sequential( # b x subspace_dim x seq_len + nn.Conv1d(self.in_out_dim, 16, kernel_size=3, + stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True), + nn.Conv1d(16, 8, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True), + nn.Conv1d(8, 4, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True), + nn.Conv1d(4, 2, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True), + nn.Conv1d(2, 1, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True) + ) + self.attentionNet = nn.Sequential( + nn.Linear(in_features=self.seq_len, out_features=self.seq_len, bias=True), + nn.Softmax(dim=1) + ) + + def forward(self, x): + """ + x: [b=8, c] + return: + [c] + """ + y = x[:, :self.in_out_dim].permute(1, 0).unsqueeze(0) # [b, c] => [1, c, b] + y = self.attentionConvNet(y) # [1,1,b] + y = self.attentionNet(y.view(1, self.seq_len)).view(self.seq_len, 1) # [8, 1] + smoothed_y = torch.sum(y*x, dim=0) # [8,1]*[8,c]=>[8,c]=>[c,] + return smoothed_y + + +class MLP(nn.Module): + def __init__(self, dim_in, dim_out, dim_hidden, num_layers): + super().__init__() + self.dim_in = dim_in + self.dim_out = dim_out + self.dim_hidden = dim_hidden + self.num_layers = num_layers + + net = [] + for l in range(num_layers): + net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=False)) + + self.net = nn.ModuleList(net) + + def forward(self, x): + for l in range(self.num_layers): + x = self.net[l](x) + if l != self.num_layers - 1: + x = F.relu(x, inplace=True) + return x \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/__pycache__/encoding.cpython-39.pyc b/Geneface_main/GeneFace/modules/radnerfs/encoders/__pycache__/encoding.cpython-39.pyc new file mode 100644 index 00000000..628ad6a1 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/__pycache__/encoding.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/encoding.py b/Geneface_main/GeneFace/modules/radnerfs/encoders/encoding.py new file mode 100644 index 00000000..d05b74ad --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/encoding.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def get_encoder(encoding, input_dim=3, + multires=6, + degree=4, + num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, + interpolation='linear', + **kwargs): + + if encoding == 'None': + return lambda x, **kwargs: x, input_dim + + elif encoding == 'frequency': + from modules.radnerfs.encoders.freqencoder import FreqEncoder + encoder = FreqEncoder(input_dim=input_dim, degree=multires) + + elif encoding == 'spherical_harmonics': + from modules.radnerfs.encoders.shencoder import SHEncoder + encoder = SHEncoder(input_dim=input_dim, degree=degree) + + elif encoding == 'hashgrid': + from modules.radnerfs.encoders.gridencoder import GridEncoder + encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners, interpolation=interpolation, **kwargs) + + elif encoding == 'tiledgrid': + from modules.radnerfs.encoders.gridencoder import GridEncoder + encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners, interpolation=interpolation, **kwargs) + + else: + raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, spherical_harmonics, hashgrid, tiledgrid]') + + return encoder, encoder.output_dim \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/__init__.py b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/__init__.py new file mode 100644 index 00000000..69ec49cf --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/__init__.py @@ -0,0 +1 @@ +from .freq import FreqEncoder \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/__pycache__/__init__.cpython-39.pyc b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 00000000..d3f345d2 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/__pycache__/__init__.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/__pycache__/freq.cpython-39.pyc b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/__pycache__/freq.cpython-39.pyc new file mode 100644 index 00000000..7fe1f698 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/__pycache__/freq.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/backend.py b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/backend.py new file mode 100644 index 00000000..3bd9131a --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/backend.py @@ -0,0 +1,41 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', + '-use_fast_math' +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_freqencoder', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'freqencoder.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/build/lib.linux-x86_64-cpython-39/_freqencoder.cpython-39-x86_64-linux-gnu.so b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/build/lib.linux-x86_64-cpython-39/_freqencoder.cpython-39-x86_64-linux-gnu.so new file mode 100644 index 00000000..c54fe895 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/build/lib.linux-x86_64-cpython-39/_freqencoder.cpython-39-x86_64-linux-gnu.so differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/.ninja_deps b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/.ninja_deps new file mode 100644 index 00000000..978ff3b0 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/.ninja_deps differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/.ninja_log b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/.ninja_log new file mode 100644 index 00000000..0536dd14 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/.ninja_log @@ -0,0 +1,7 @@ +# ninja log v5 +0 9696 1734544812438584295 /home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/src/bindings.o 22e7fe38871cfad7 +0 29662 1734544832408579648 /home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/src/freqencoder.o 3f67d6a89a5c2de0 +3 11264 1734727531907615244 /output/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/src/bindings.o 25289ba77012d03f +4 39558 1734727560207894560 /output/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/src/freqencoder.o dec9848bde2c8ff2 +4 10884 1734800909579676078 /output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/src/bindings.o 76fd4063a35b9937 +4 39405 1734800938107955836 /output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/src/freqencoder.o 491eedfb22ae30d7 diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/build.ninja b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/build.ninja new file mode 100644 index 00000000..ceab3cee --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/build.ninja @@ -0,0 +1,29 @@ +ninja_required_version = 1.3 +cxx = c++ +nvcc = /usr/local/cuda/bin/nvcc + +cflags = -pthread -B /output/GeneFace_Reproduction/conda/compiler_compat -Wno-unused-result -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /output/GeneFace_Reproduction/conda/include -I/output/GeneFace_Reproduction/conda/include -fPIC -O2 -isystem /output/GeneFace_Reproduction/conda/include -fPIC -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include/TH -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda/include -I/output/GeneFace_Reproduction/conda/include/python3.9 -c +post_cflags = -O3 -std=c++14 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_freqencoder -D_GLIBCXX_USE_CXX11_ABI=0 +cuda_cflags = -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include/TH -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda/include -I/output/GeneFace_Reproduction/conda/include/python3.9 -c +cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -use_fast_math -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_freqencoder -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 +ldflags = + +rule compile + command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags + depfile = $out.d + deps = gcc + +rule cuda_compile + depfile = $out.d + deps = gcc + command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags + + + +build /output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/src/bindings.o: compile /output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/src/bindings.cpp +build /output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/src/freqencoder.o: cuda_compile /output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/src/freqencoder.cu + + + + + diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/src/bindings.o b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/src/bindings.o new file mode 100644 index 00000000..53bd6aa3 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/src/bindings.o differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/src/freqencoder.o b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/src/freqencoder.o new file mode 100644 index 00000000..db550204 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/src/freqencoder.o differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/src/bindings.o b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/src/bindings.o new file mode 100644 index 00000000..2ce62078 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/src/bindings.o differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/src/freqencoder.o b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/src/freqencoder.o new file mode 100644 index 00000000..f2c47a74 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/src/freqencoder.o differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/freq.py b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/freq.py new file mode 100644 index 00000000..5cba1e66 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/freq.py @@ -0,0 +1,77 @@ +import numpy as np + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import _freqencoder as _backend +except ImportError: + from .backend import _backend + + +class _freq_encoder(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision + def forward(ctx, inputs, degree, output_dim): + # inputs: [B, input_dim], float + # RETURN: [B, F], float + + if not inputs.is_cuda: inputs = inputs.cuda() + inputs = inputs.contiguous() + + B, input_dim = inputs.shape # batch size, coord dim + + outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) + + _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) + + ctx.save_for_backward(inputs, outputs) + ctx.dims = [B, input_dim, degree, output_dim] + + return outputs + + @staticmethod + #@once_differentiable + @custom_bwd + def backward(ctx, grad): + # grad: [B, C * C] + + grad = grad.contiguous() + inputs, outputs = ctx.saved_tensors + B, input_dim, degree, output_dim = ctx.dims + + grad_inputs = torch.zeros_like(inputs) + _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) + + return grad_inputs, None, None + + +freq_encode = _freq_encoder.apply + + +class FreqEncoder(nn.Module): + def __init__(self, input_dim=3, degree=4): + super().__init__() + + self.input_dim = input_dim + self.degree = degree + self.output_dim = input_dim + input_dim * 2 * degree + + def __repr__(self): + return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}" + + def forward(self, inputs, **kwargs): + # inputs: [..., input_dim] + # return: [..., ] + + prefix_shape = list(inputs.shape[:-1]) + inputs = inputs.reshape(-1, self.input_dim) + + outputs = freq_encode(inputs, self.degree, self.output_dim) + + outputs = outputs.reshape(prefix_shape + [self.output_dim]) + + return outputs \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/freqencoder.egg-info/PKG-INFO b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/freqencoder.egg-info/PKG-INFO new file mode 100644 index 00000000..c280ab19 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/freqencoder.egg-info/PKG-INFO @@ -0,0 +1,3 @@ +Metadata-Version: 2.1 +Name: freqencoder +Version: 0.0.0 diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/freqencoder.egg-info/SOURCES.txt b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/freqencoder.egg-info/SOURCES.txt new file mode 100644 index 00000000..0f053162 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/freqencoder.egg-info/SOURCES.txt @@ -0,0 +1,7 @@ +setup.py +/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/src/bindings.cpp +/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/freqencoder/src/freqencoder.cu +freqencoder.egg-info/PKG-INFO +freqencoder.egg-info/SOURCES.txt +freqencoder.egg-info/dependency_links.txt +freqencoder.egg-info/top_level.txt \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/freqencoder.egg-info/dependency_links.txt b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/freqencoder.egg-info/dependency_links.txt new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/freqencoder.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/freqencoder.egg-info/top_level.txt b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/freqencoder.egg-info/top_level.txt new file mode 100644 index 00000000..85f88eba --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/freqencoder.egg-info/top_level.txt @@ -0,0 +1 @@ +_freqencoder diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/setup.py b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/setup.py new file mode 100644 index 00000000..3eb4af77 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/setup.py @@ -0,0 +1,51 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', + '-use_fast_math' +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +setup( + name='freqencoder', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_freqencoder', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'freqencoder.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/src/bindings.cpp b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/src/bindings.cpp new file mode 100644 index 00000000..bb5f285a --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/src/bindings.cpp @@ -0,0 +1,8 @@ +#include + +#include "freqencoder.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)"); + m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)"); +} \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/src/freqencoder.cu b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/src/freqencoder.cu new file mode 100644 index 00000000..de378840 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/src/freqencoder.cu @@ -0,0 +1,129 @@ +#include + +#include +#include +#include + +#include +#include + +#include +#include + +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + +inline constexpr __device__ float PI() { return 3.141592653589793f; } + +template +__host__ __device__ T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +// inputs: [B, D] +// outputs: [B, C], C = D + D * deg * 2 +__global__ void kernel_freq( + const float * __restrict__ inputs, + uint32_t B, uint32_t D, uint32_t deg, uint32_t C, + float * outputs +) { + // parallel on per-element + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + if (t >= B * C) return; + + // get index + const uint32_t b = t / C; + const uint32_t c = t - b * C; // t % C; + + // locate + inputs += b * D; + outputs += t; + + // write self + if (c < D) { + outputs[0] = inputs[c]; + // write freq + } else { + const uint32_t col = c / D - 1; + const uint32_t d = c % D; + const uint32_t freq = col / 2; + const float phase_shift = (col % 2) * (PI() / 2); + outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift); + } +} + +// grad: [B, C], C = D + D * deg * 2 +// outputs: [B, C] +// grad_inputs: [B, D] +__global__ void kernel_freq_backward( + const float * __restrict__ grad, + const float * __restrict__ outputs, + uint32_t B, uint32_t D, uint32_t deg, uint32_t C, + float * grad_inputs +) { + // parallel on per-element + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + if (t >= B * D) return; + + const uint32_t b = t / D; + const uint32_t d = t - b * D; // t % D; + + // locate + grad += b * C; + outputs += b * C; + grad_inputs += t; + + // register + float result = grad[d]; + grad += D; + outputs += D; + + for (uint32_t f = 0; f < deg; f++) { + result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]); + grad += 2 * D; + outputs += 2 * D; + } + + // write + grad_inputs[0] = result; +} + + +void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) { + CHECK_CUDA(inputs); + CHECK_CUDA(outputs); + + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(outputs); + + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(outputs); + + static constexpr uint32_t N_THREADS = 128; + + kernel_freq<<>>(inputs.data_ptr(), B, D, deg, C, outputs.data_ptr()); +} + + +void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) { + CHECK_CUDA(grad); + CHECK_CUDA(outputs); + CHECK_CUDA(grad_inputs); + + CHECK_CONTIGUOUS(grad); + CHECK_CONTIGUOUS(outputs); + CHECK_CONTIGUOUS(grad_inputs); + + CHECK_IS_FLOATING(grad); + CHECK_IS_FLOATING(outputs); + CHECK_IS_FLOATING(grad_inputs); + + static constexpr uint32_t N_THREADS = 128; + + kernel_freq_backward<<>>(grad.data_ptr(), outputs.data_ptr(), B, D, deg, C, grad_inputs.data_ptr()); +} \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/src/freqencoder.h b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/src/freqencoder.h new file mode 100644 index 00000000..34f28c79 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/freqencoder/src/freqencoder.h @@ -0,0 +1,10 @@ +# pragma once + +#include +#include + +// _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) +void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs); + +// _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) +void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs); \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/__init__.py b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/__init__.py new file mode 100644 index 00000000..f1476cef --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/__init__.py @@ -0,0 +1 @@ +from .grid import GridEncoder \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/__pycache__/__init__.cpython-39.pyc b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 00000000..c1aff715 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/__pycache__/__init__.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/__pycache__/grid.cpython-39.pyc b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/__pycache__/grid.cpython-39.pyc new file mode 100644 index 00000000..712b4cf0 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/__pycache__/grid.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/backend.py b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/backend.py new file mode 100644 index 00000000..d99acb1f --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/backend.py @@ -0,0 +1,40 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_grid_encoder', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'gridencoder.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/build/lib.linux-x86_64-cpython-39/_gridencoder.cpython-39-x86_64-linux-gnu.so b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/build/lib.linux-x86_64-cpython-39/_gridencoder.cpython-39-x86_64-linux-gnu.so new file mode 100644 index 00000000..072e9415 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/build/lib.linux-x86_64-cpython-39/_gridencoder.cpython-39-x86_64-linux-gnu.so differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/.ninja_deps b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/.ninja_deps new file mode 100644 index 00000000..2212342e Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/.ninja_deps differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/.ninja_log b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/.ninja_log new file mode 100644 index 00000000..e64fd952 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/.ninja_log @@ -0,0 +1,7 @@ +# ninja log v5 +0 8545 1734544875488560913 /home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/src/bindings.o 3bc810f1e4b85d3f +0 85171 1734544952108549649 /home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/src/gridencoder.o 88955bbd8104c5fc +3 10749 1734727617948464440 /output/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/src/bindings.o 5367f8c6d8b974d9 +3 105913 1734727713117403725 /output/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/src/gridencoder.o 3e30f7bf4daa23a +5 10979 1734800995968523232 /output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/src/bindings.o 8526c2103c5cd388 +6 106172 1734801091161456723 /output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/src/gridencoder.o d2783a10e8413287 diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/build.ninja b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/build.ninja new file mode 100644 index 00000000..70e01b31 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/build.ninja @@ -0,0 +1,29 @@ +ninja_required_version = 1.3 +cxx = c++ +nvcc = /usr/local/cuda/bin/nvcc + +cflags = -pthread -B /output/GeneFace_Reproduction/conda/compiler_compat -Wno-unused-result -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /output/GeneFace_Reproduction/conda/include -I/output/GeneFace_Reproduction/conda/include -fPIC -O2 -isystem /output/GeneFace_Reproduction/conda/include -fPIC -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include/TH -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda/include -I/output/GeneFace_Reproduction/conda/include/python3.9 -c +post_cflags = -O3 -std=c++14 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_gridencoder -D_GLIBCXX_USE_CXX11_ABI=0 +cuda_cflags = -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include/TH -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda/include -I/output/GeneFace_Reproduction/conda/include/python3.9 -c +cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_gridencoder -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 +ldflags = + +rule compile + command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags + depfile = $out.d + deps = gcc + +rule cuda_compile + depfile = $out.d + deps = gcc + command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags + + + +build /output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/src/bindings.o: compile /output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/src/bindings.cpp +build /output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/src/gridencoder.o: cuda_compile /output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/src/gridencoder.cu + + + + + diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/src/bindings.o b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/src/bindings.o new file mode 100644 index 00000000..ff8cda99 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/src/bindings.o differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/src/gridencoder.o b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/src/gridencoder.o new file mode 100644 index 00000000..31af8798 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/src/gridencoder.o differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/src/bindings.o b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/src/bindings.o new file mode 100644 index 00000000..e125d2b3 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/src/bindings.o differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/src/gridencoder.o b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/src/gridencoder.o new file mode 100644 index 00000000..4b3a45be Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/src/gridencoder.o differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/grid.py b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/grid.py new file mode 100644 index 00000000..32b8bead --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/grid.py @@ -0,0 +1,185 @@ +import numpy as np + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import _gridencoder as _backend +except ImportError: + from .backend import _backend + +_gridtype_to_id = { + 'hash': 0, + 'tiled': 1, +} + +_interp_to_id = { + 'linear': 0, + 'smoothstep': 1, +} + +class _grid_encode(Function): + @staticmethod + @custom_fwd + def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False, interpolation=0): + # inputs: [B, D], float in [0, 1] + # embeddings: [sO, C], float + # offsets: [L + 1], int + # RETURN: [B, F], float + + inputs = inputs.contiguous() + + B, D = inputs.shape # batch size, coord dim + L = offsets.shape[0] - 1 # level + C = embeddings.shape[1] # embedding dim for each level + S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f + H = base_resolution # base resolution + + # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision) + # if C % 2 != 0, force float, since half for atomicAdd is very slow. + if torch.is_autocast_enabled() and C % 2 == 0: + embeddings = embeddings.to(torch.half) + + # L first, optimize cache for cuda kernel, but needs an extra permute later + outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype) + + if calc_grad_inputs: + dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype) + else: + dy_dx = None + + _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners, interpolation) + + # permute back to [B, L * C] + outputs = outputs.permute(1, 0, 2).reshape(B, L * C) + + ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) + ctx.dims = [B, D, C, L, S, H, gridtype, interpolation] + ctx.align_corners = align_corners + + return outputs + + @staticmethod + #@once_differentiable + @custom_bwd + def backward(ctx, grad): + + inputs, embeddings, offsets, dy_dx = ctx.saved_tensors + B, D, C, L, S, H, gridtype, interpolation = ctx.dims + align_corners = ctx.align_corners + + # grad: [B, L * C] --> [L, B, C] + grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() + + grad_embeddings = torch.zeros_like(embeddings) + + if dy_dx is not None: + grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype) + else: + grad_inputs = None + + _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interpolation) + + if dy_dx is not None: + grad_inputs = grad_inputs.to(inputs.dtype) + + return grad_inputs, grad_embeddings, None, None, None, None, None, None, None + + + +grid_encode = _grid_encode.apply + + +class GridEncoder(nn.Module): + def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False, interpolation='linear'): + super().__init__() + + # the finest resolution desired at the last level, if provided, overridee per_level_scale + if desired_resolution is not None: + per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1)) + + self.input_dim = input_dim # coord dims, 2 or 3 + self.num_levels = num_levels # num levels, each level multiply resolution by 2 + self.level_dim = level_dim # encode channels per level + self.per_level_scale = per_level_scale # multiply resolution by this scale at each level. + self.log2_hashmap_size = log2_hashmap_size + self.base_resolution = base_resolution + self.output_dim = num_levels * level_dim + self.gridtype = gridtype + self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash" + self.interpolation = interpolation + self.interp_id = _interp_to_id[interpolation] # "linear" or "smoothstep" + self.align_corners = align_corners + + # allocate parameters + offsets = [] + offset = 0 + self.max_params = 2 ** log2_hashmap_size + for i in range(num_levels): + resolution = int(np.ceil(base_resolution * per_level_scale ** i)) + params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number + params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible + offsets.append(offset) + offset += params_in_level + offsets.append(offset) + offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) + self.register_buffer('offsets', offsets) + + self.n_params = offsets[-1] * level_dim + + # parameters + self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) + + self.reset_parameters() + + def reset_parameters(self): + std = 1e-4 + self.embeddings.data.uniform_(-std, std) + + def __repr__(self): + return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}" + + def forward(self, inputs, bound=1): + # inputs: [..., input_dim], normalized real world positions in [-bound, bound] + # return: [..., num_levels * level_dim] + + inputs = (inputs + bound) / (2 * bound) # map to [0, 1] + + #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) + + prefix_shape = list(inputs.shape[:-1]) + inputs = inputs.view(-1, self.input_dim) + + outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners, self.interp_id) + outputs = outputs.view(prefix_shape + [self.output_dim]) + + #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) + + return outputs + + # always run in float precision! + @torch.cuda.amp.autocast(enabled=False) + def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000): + # inputs: [..., input_dim], float in [-b, b], location to calculate TV loss. + + D = self.input_dim + C = self.embeddings.shape[1] # embedding dim for each level + L = self.offsets.shape[0] - 1 # level + S = np.log2(self.per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f + H = self.base_resolution # base resolution + + if inputs is None: + # randomized in [0, 1] + inputs = torch.rand(B, self.input_dim, device=self.embeddings.device) + else: + inputs = (inputs + bound) / (2 * bound) # map to [0, 1] + inputs = inputs.view(-1, self.input_dim) + B = inputs.shape[0] + + if self.embeddings.grad is None: + raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!') + + _backend.grad_total_variation(inputs, self.embeddings, self.embeddings.grad, self.offsets, weight, B, D, C, L, S, H, self.gridtype_id, self.align_corners) \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/gridencoder.egg-info/PKG-INFO b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/gridencoder.egg-info/PKG-INFO new file mode 100644 index 00000000..c89f13a5 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/gridencoder.egg-info/PKG-INFO @@ -0,0 +1,3 @@ +Metadata-Version: 2.1 +Name: gridencoder +Version: 0.0.0 diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/gridencoder.egg-info/SOURCES.txt b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/gridencoder.egg-info/SOURCES.txt new file mode 100644 index 00000000..4e5df689 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/gridencoder.egg-info/SOURCES.txt @@ -0,0 +1,7 @@ +setup.py +/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/src/bindings.cpp +/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/gridencoder/src/gridencoder.cu +gridencoder.egg-info/PKG-INFO +gridencoder.egg-info/SOURCES.txt +gridencoder.egg-info/dependency_links.txt +gridencoder.egg-info/top_level.txt \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/gridencoder.egg-info/dependency_links.txt b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/gridencoder.egg-info/dependency_links.txt new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/gridencoder.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/gridencoder.egg-info/top_level.txt b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/gridencoder.egg-info/top_level.txt new file mode 100644 index 00000000..0ab3d303 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/gridencoder.egg-info/top_level.txt @@ -0,0 +1 @@ +_gridencoder diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/setup.py b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/setup.py new file mode 100644 index 00000000..714bf1ca --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/setup.py @@ -0,0 +1,50 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +setup( + name='gridencoder', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_gridencoder', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'gridencoder.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/src/bindings.cpp b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/src/bindings.cpp new file mode 100644 index 00000000..93dea943 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/src/bindings.cpp @@ -0,0 +1,9 @@ +#include + +#include "gridencoder.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)"); + m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)"); + m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)"); +} \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/src/gridencoder.cu b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/src/gridencoder.cu new file mode 100644 index 00000000..22d95328 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/src/gridencoder.cu @@ -0,0 +1,644 @@ +#include +#include +#include + +#include +#include + +#include +#include + +#include +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + + +// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF... program will never reach here! + __device__ inline at::Half atomicAdd(at::Half *address, at::Half val) { + // requires CUDA >= 10 and ARCH >= 70 + // this is very slow compared to float or __half2, never use it. + //return atomicAdd(reinterpret_cast<__half*>(address), val); +} + + +template +__host__ __device__ inline T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +template +__host__ __device__ inline T clamp(const T v, const T2 lo, const T2 hi) { + return min(max(v, lo), hi); +} + +template +__device__ inline T smoothstep(T val) { + return val*val*(3.0f - 2.0f * val); +} + +template +__device__ inline T smoothstep_derivative(T val) { + return 6*val*(1.0f - val); +} + + +template +__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) { + + // coherent type of hashing + constexpr uint32_t primes[7] = { 1u, 2654435761u, 805459861u, 3674653429u, 2097192037u, 1434869437u, 2165219737u }; + + uint32_t result = 0; + #pragma unroll + for (uint32_t i = 0; i < D; ++i) { + result ^= pos_grid[i] * primes[i]; + } + + return result; +} + + +template +__device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) { + uint32_t stride = 1; + uint32_t index = 0; + + #pragma unroll + for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) { + index += pos_grid[d] * stride; + stride *= align_corners ? resolution: (resolution + 1); + } + + // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97. + // gridtype: 0 == hash, 1 == tiled + if (gridtype == 0 && stride > hashmap_size) { + index = fast_hash(pos_grid); + } + + return (index % hashmap_size) * C + ch; +} + + +template +__global__ void kernel_grid( + const float * __restrict__ inputs, + const scalar_t * __restrict__ grid, + const int * __restrict__ offsets, + scalar_t * __restrict__ outputs, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + scalar_t * __restrict__ dy_dx, + const uint32_t gridtype, + const bool align_corners, + const uint32_t interp +) { + const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; + + if (b >= B) return; + + const uint32_t level = blockIdx.y; + + // locate + grid += (uint32_t)offsets[level] * C; + inputs += b * D; + outputs += level * B * C + b * C; + + // check input range (should be in [0, 1]) + bool flag_oob = false; + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + flag_oob = true; + } + } + // if input out of bound, just set output to 0 + if (flag_oob) { + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + outputs[ch] = 0; + } + if (dy_dx) { + dy_dx += b * D * L * C + level * D * C; // B L D C + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + dy_dx[d * C + ch] = 0; + } + } + } + return; + } + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const float scale = exp2f(level * S) * H - 1.0f; + const uint32_t resolution = (uint32_t)ceil(scale) + 1; + + // calculate coordinate (always use float for precision!) + float pos[D]; + float pos_deriv[D]; // linear deriv is default to 1 + uint32_t pos_grid[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); + pos_grid[d] = floorf(pos[d]); + pos[d] -= (float)pos_grid[d]; + // smoothstep instead of linear + if (interp == 1) { + pos_deriv[d] = smoothstep_derivative(pos[d]); + pos[d] = smoothstep(pos[d]); + } else { + pos_deriv[d] = 1.0f; + } + } + + //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]); + + // interpolate + scalar_t results[C] = {0}; // temp results in register + + #pragma unroll + for (uint32_t idx = 0; idx < (1 << D); idx++) { + float w = 1; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if ((idx & (1 << d)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = pos_grid[d] + 1; + } + } + + uint32_t index = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); + + // writing to register (fast) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + results[ch] += w * grid[index + ch]; + } + + //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]); + } + + // writing to global memory (slow) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + outputs[ch] = results[ch]; + } + + // prepare dy_dx + // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9 + if (dy_dx) { + + dy_dx += b * D * L * C + level * D * C; // B L D C + + #pragma unroll + for (uint32_t gd = 0; gd < D; gd++) { + + scalar_t results_grad[C] = {0}; + + #pragma unroll + for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) { + float w = scale; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t nd = 0; nd < D - 1; nd++) { + const uint32_t d = (nd >= gd) ? (nd + 1) : nd; + + if ((idx & (1 << nd)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = pos_grid[d] + 1; + } + } + + pos_grid_local[gd] = pos_grid[gd]; + uint32_t index_left = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); + pos_grid_local[gd] = pos_grid[gd] + 1; + uint32_t index_right = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]) * pos_deriv[gd]; + } + } + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + dy_dx[gd * C + ch] = results_grad[ch]; + } + } + } +} + + +template +__global__ void kernel_grid_backward( + const scalar_t * __restrict__ grad, + const float * __restrict__ inputs, + const scalar_t * __restrict__ grid, + const int * __restrict__ offsets, + scalar_t * __restrict__ grad_grid, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + const uint32_t gridtype, + const bool align_corners, + const uint32_t interp +) { + const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C; + if (b >= B) return; + + const uint32_t level = blockIdx.y; + const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C; + + // locate + grad_grid += offsets[level] * C; + inputs += b * D; + grad += level * B * C + b * C + ch; // L, B, C + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const float scale = exp2f(level * S) * H - 1.0f; + const uint32_t resolution = (uint32_t)ceil(scale) + 1; + + // check input range (should be in [0, 1]) + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + return; // grad is init as 0, so we simply return. + } + } + + // calculate coordinate + float pos[D]; + uint32_t pos_grid[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); + pos_grid[d] = floorf(pos[d]); + pos[d] -= (float)pos_grid[d]; + // smoothstep instead of linear + if (interp == 1) { + pos[d] = smoothstep(pos[d]); + } + } + + scalar_t grad_cur[N_C] = {0}; // fetch to register + #pragma unroll + for (uint32_t c = 0; c < N_C; c++) { + grad_cur[c] = grad[c]; + } + + // interpolate + #pragma unroll + for (uint32_t idx = 0; idx < (1 << D); idx++) { + float w = 1; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if ((idx & (1 << d)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = pos_grid[d] + 1; + } + } + + uint32_t index = get_grid_index(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local); + + // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0 + // TODO: use float which is better than __half, if N_C % 2 != 0 + if (std::is_same::value && N_C % 2 == 0) { + #pragma unroll + for (uint32_t c = 0; c < N_C; c += 2) { + // process two __half at once (by interpreting as a __half2) + __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])}; + atomicAdd((__half2*)&grad_grid[index + c], v); + } + // float, or __half when N_C % 2 != 0 (which means C == 1) + } else { + #pragma unroll + for (uint32_t c = 0; c < N_C; c++) { + atomicAdd(&grad_grid[index + c], w * grad_cur[c]); + } + } + } +} + + +template +__global__ void kernel_input_backward( + const scalar_t * __restrict__ grad, + const scalar_t * __restrict__ dy_dx, + scalar_t * __restrict__ grad_inputs, + uint32_t B, uint32_t L +) { + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + if (t >= B * D) return; + + const uint32_t b = t / D; + const uint32_t d = t - b * D; + + dy_dx += b * L * D * C; + + scalar_t result = 0; + + # pragma unroll + for (int l = 0; l < L; l++) { + # pragma unroll + for (int ch = 0; ch < C; ch++) { + result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch]; + } + } + + grad_inputs[t] = result; +} + + +template +void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + static constexpr uint32_t N_THREAD = 512; + const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 }; + switch (C) { + case 1: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 2: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 4: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 8: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.) +// H: base resolution +// dy_dx: [B, L * D * C] +template +void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + switch (D) { + case 2: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 3: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 4: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 5: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + +template +void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + static constexpr uint32_t N_THREAD = 256; + const uint32_t N_C = std::min(2u, C); // n_features_per_thread + const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 }; + switch (C) { + case 1: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 2: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 4: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 8: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + +// grad: [L, B, C], float +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// grad_embeddings: [sO, C] +// H: base resolution +template +void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + switch (D) { + case 2: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + case 3: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + case 4: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + case 5: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + + +void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + CHECK_CUDA(inputs); + CHECK_CUDA(embeddings); + CHECK_CUDA(offsets); + CHECK_CUDA(outputs); + // CHECK_CUDA(dy_dx); + + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(embeddings); + CHECK_CONTIGUOUS(offsets); + CHECK_CONTIGUOUS(outputs); + // CHECK_CONTIGUOUS(dy_dx); + + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(embeddings); + CHECK_IS_INT(offsets); + CHECK_IS_FLOATING(outputs); + // CHECK_IS_FLOATING(dy_dx); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + embeddings.scalar_type(), "grid_encode_forward", ([&] { + grid_encode_forward_cuda(inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), outputs.data_ptr(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, gridtype, align_corners, interp); + })); +} + +void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + CHECK_CUDA(grad); + CHECK_CUDA(inputs); + CHECK_CUDA(embeddings); + CHECK_CUDA(offsets); + CHECK_CUDA(grad_embeddings); + // CHECK_CUDA(dy_dx); + // CHECK_CUDA(grad_inputs); + + CHECK_CONTIGUOUS(grad); + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(embeddings); + CHECK_CONTIGUOUS(offsets); + CHECK_CONTIGUOUS(grad_embeddings); + // CHECK_CONTIGUOUS(dy_dx); + // CHECK_CONTIGUOUS(grad_inputs); + + CHECK_IS_FLOATING(grad); + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(embeddings); + CHECK_IS_INT(offsets); + CHECK_IS_FLOATING(grad_embeddings); + // CHECK_IS_FLOATING(dy_dx); + // CHECK_IS_FLOATING(grad_inputs); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "grid_encode_backward", ([&] { + grid_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), grad_embeddings.data_ptr(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr() : nullptr, gridtype, align_corners, interp); + })); + +} + + +template +__global__ void kernel_grad_tv( + const scalar_t * __restrict__ inputs, + const scalar_t * __restrict__ grid, + scalar_t * __restrict__ grad, + const int * __restrict__ offsets, + const float weight, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + const uint32_t gridtype, + const bool align_corners +) { + const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; + + if (b >= B) return; + + const uint32_t level = blockIdx.y; + + // locate + inputs += b * D; + grid += (uint32_t)offsets[level] * C; + grad += (uint32_t)offsets[level] * C; + + // check input range (should be in [0, 1]) + bool flag_oob = false; + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + flag_oob = true; + } + } + + // if input out of bound, do nothing + if (flag_oob) return; + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const float scale = exp2f(level * S) * H - 1.0f; + const uint32_t resolution = (uint32_t)ceil(scale) + 1; + + // calculate coordinate + float pos[D]; + uint32_t pos_grid[D]; // [0, resolution] + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); + pos_grid[d] = floorf(pos[d]); + // pos[d] -= (float)pos_grid[d]; // not used + } + + //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]); + + // total variation on pos_grid + scalar_t results[C] = {0}; // temp results in register + scalar_t idelta[C] = {0}; + + uint32_t index = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); + + scalar_t w = weight / (2 * D); + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + + uint32_t cur_d = pos_grid[d]; + scalar_t grad_val; + + // right side + if (cur_d < resolution) { + pos_grid[d] = cur_d + 1; + uint32_t index_right = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + // results[ch] += w * clamp(grid[index + ch] - grid[index_right + ch], -1.0f, 1.0f); + grad_val = (grid[index + ch] - grid[index_right + ch]); + results[ch] += grad_val; + idelta[ch] += grad_val * grad_val; + } + } + + // left side + if (cur_d > 0) { + pos_grid[d] = cur_d - 1; + uint32_t index_left = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + // results[ch] += w * clamp(grid[index + ch] - grid[index_left + ch], -1.0f, 1.0f); + grad_val = (grid[index + ch] - grid[index_left + ch]); + results[ch] += grad_val; + idelta[ch] += grad_val * grad_val; + } + } + + // reset + pos_grid[d] = cur_d; + } + + // writing to global memory (slow) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + // index may collide, so use atomic! + atomicAdd(&grad[index + ch], w * results[ch] * rsqrtf(idelta[ch] + 1e-9f)); + } + +} + + +template +void kernel_grad_tv_wrapper(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { + static constexpr uint32_t N_THREAD = 512; + const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 }; + switch (C) { + case 1: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + case 2: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + case 4: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + case 8: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + +template +void grad_total_variation_cuda(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { + switch (D) { + case 2: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + case 3: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + case 4: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + case 5: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + +void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + embeddings.scalar_type(), "grad_total_variation", ([&] { + grad_total_variation_cuda(inputs.data_ptr(), embeddings.data_ptr(), grad.data_ptr(), offsets.data_ptr(), weight, B, D, C, L, S, H, gridtype, align_corners); + })); +} \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/src/gridencoder.h b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/src/gridencoder.h new file mode 100644 index 00000000..1b385755 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/gridencoder/src/gridencoder.h @@ -0,0 +1,17 @@ +#ifndef _HASH_ENCODE_H +#define _HASH_ENCODE_H + +#include +#include + +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// outputs: [B, L * C], float +// H: base resolution +void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp); +void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp); + +void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners); + +#endif \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/__init__.py b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/__init__.py new file mode 100644 index 00000000..2b55c96e --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/__init__.py @@ -0,0 +1 @@ +from .sphere_harmonics import SHEncoder \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/__pycache__/__init__.cpython-39.pyc b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 00000000..07641f70 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/__pycache__/__init__.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/__pycache__/sphere_harmonics.cpython-39.pyc b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/__pycache__/sphere_harmonics.cpython-39.pyc new file mode 100644 index 00000000..3722eb39 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/__pycache__/sphere_harmonics.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/backend.py b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/backend.py new file mode 100644 index 00000000..cc08a3e9 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/backend.py @@ -0,0 +1,40 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_sh_encoder', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'shencoder.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/build/lib.linux-x86_64-cpython-39/_shencoder.cpython-39-x86_64-linux-gnu.so b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/build/lib.linux-x86_64-cpython-39/_shencoder.cpython-39-x86_64-linux-gnu.so new file mode 100644 index 00000000..810ffbd9 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/build/lib.linux-x86_64-cpython-39/_shencoder.cpython-39-x86_64-linux-gnu.so differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/.ninja_deps b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/.ninja_deps new file mode 100644 index 00000000..9eec85a8 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/.ninja_deps differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/.ninja_log b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/.ninja_log new file mode 100644 index 00000000..a1351b74 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/.ninja_log @@ -0,0 +1,7 @@ +# ninja log v5 +0 8801 1734544843628577063 /home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/shencoder/src/bindings.o 70f4c08fcd2ce153 +1 29563 1734544864398561069 /home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/shencoder/src/shencoder.o 326c39b4f9020f3e +3 10906 1734727574520035816 /output/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/shencoder/src/bindings.o 46dcc04c9d469f7b +3 40312 1734727603932326104 /output/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/shencoder/src/shencoder.o 3e556628932a814d +4 10700 1734800952168093713 /output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/shencoder/src/bindings.o 30ac52a8f1cfc13e +5 40338 1734800981812384414 /output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/shencoder/src/shencoder.o 7b5d91f346a32763 diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/build.ninja b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/build.ninja new file mode 100644 index 00000000..5e9bc9d8 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/build.ninja @@ -0,0 +1,29 @@ +ninja_required_version = 1.3 +cxx = c++ +nvcc = /usr/local/cuda/bin/nvcc + +cflags = -pthread -B /output/GeneFace_Reproduction/conda/compiler_compat -Wno-unused-result -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /output/GeneFace_Reproduction/conda/include -I/output/GeneFace_Reproduction/conda/include -fPIC -O2 -isystem /output/GeneFace_Reproduction/conda/include -fPIC -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include/TH -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda/include -I/output/GeneFace_Reproduction/conda/include/python3.9 -c +post_cflags = -O3 -std=c++14 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_shencoder -D_GLIBCXX_USE_CXX11_ABI=0 +cuda_cflags = -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include/TH -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda/include -I/output/GeneFace_Reproduction/conda/include/python3.9 -c +cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_shencoder -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 +ldflags = + +rule compile + command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags + depfile = $out.d + deps = gcc + +rule cuda_compile + depfile = $out.d + deps = gcc + command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags + + + +build /output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/shencoder/src/bindings.o: compile /output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/shencoder/src/bindings.cpp +build /output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/shencoder/src/shencoder.o: cuda_compile /output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/shencoder/src/shencoder.cu + + + + + diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/shencoder/src/bindings.o b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/shencoder/src/bindings.o new file mode 100644 index 00000000..200559b6 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/shencoder/src/bindings.o differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/shencoder/src/shencoder.o b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/shencoder/src/shencoder.o new file mode 100644 index 00000000..bf678ce9 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/encoders/shencoder/src/shencoder.o differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/shencoder/src/bindings.o b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/shencoder/src/bindings.o new file mode 100644 index 00000000..2926f20a Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/shencoder/src/bindings.o differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/shencoder/src/shencoder.o b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/shencoder/src/shencoder.o new file mode 100644 index 00000000..442a5bc3 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/shencoder/src/shencoder.o differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/setup.py b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/setup.py new file mode 100644 index 00000000..342a6015 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/setup.py @@ -0,0 +1,50 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +setup( + name='shencoder', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_shencoder', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'shencoder.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/shencoder.egg-info/PKG-INFO b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/shencoder.egg-info/PKG-INFO new file mode 100644 index 00000000..8c586951 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/shencoder.egg-info/PKG-INFO @@ -0,0 +1,3 @@ +Metadata-Version: 2.1 +Name: shencoder +Version: 0.0.0 diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/shencoder.egg-info/SOURCES.txt b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/shencoder.egg-info/SOURCES.txt new file mode 100644 index 00000000..f754a846 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/shencoder.egg-info/SOURCES.txt @@ -0,0 +1,7 @@ +setup.py +/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/shencoder/src/bindings.cpp +/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/encoders/shencoder/src/shencoder.cu +shencoder.egg-info/PKG-INFO +shencoder.egg-info/SOURCES.txt +shencoder.egg-info/dependency_links.txt +shencoder.egg-info/top_level.txt \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/shencoder.egg-info/dependency_links.txt b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/shencoder.egg-info/dependency_links.txt new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/shencoder.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/shencoder.egg-info/top_level.txt b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/shencoder.egg-info/top_level.txt new file mode 100644 index 00000000..51b4d5fb --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/shencoder.egg-info/top_level.txt @@ -0,0 +1 @@ +_shencoder diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/sphere_harmonics.py b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/sphere_harmonics.py new file mode 100644 index 00000000..7bab24e6 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/sphere_harmonics.py @@ -0,0 +1,87 @@ +import numpy as np + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import _shencoder as _backend +except ImportError: + from .backend import _backend + +class _sh_encoder(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision + def forward(ctx, inputs, degree, calc_grad_inputs=False): + # inputs: [B, input_dim], float in [-1, 1] + # RETURN: [B, F], float + + inputs = inputs.contiguous() + B, input_dim = inputs.shape # batch size, coord dim + output_dim = degree ** 2 + + outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) + + if calc_grad_inputs: + dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device) + else: + dy_dx = None + + _backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx) + + ctx.save_for_backward(inputs, dy_dx) + ctx.dims = [B, input_dim, degree] + + return outputs + + @staticmethod + #@once_differentiable + @custom_bwd + def backward(ctx, grad): + # grad: [B, C * C] + + inputs, dy_dx = ctx.saved_tensors + + if dy_dx is not None: + grad = grad.contiguous() + B, input_dim, degree = ctx.dims + grad_inputs = torch.zeros_like(inputs) + _backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs) + return grad_inputs, None, None + else: + return None, None, None + + + +sh_encode = _sh_encoder.apply + + +class SHEncoder(nn.Module): + def __init__(self, input_dim=3, degree=4): + super().__init__() + + self.input_dim = input_dim # coord dims, must be 3 + self.degree = degree # 0 ~ 4 + self.output_dim = degree ** 2 + + assert self.input_dim == 3, "SH encoder only support input dim == 3" + assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]" + + def __repr__(self): + return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}" + + def forward(self, inputs, size=1): + # inputs: [..., input_dim], normalized real world positions in [-size, size] + # return: [..., degree^2] + + inputs = inputs / size # [-1, 1] + + prefix_shape = list(inputs.shape[:-1]) + inputs = inputs.reshape(-1, self.input_dim) + + outputs = sh_encode(inputs, self.degree, inputs.requires_grad) + outputs = outputs.reshape(prefix_shape + [self.output_dim]) + + return outputs \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/src/bindings.cpp b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/src/bindings.cpp new file mode 100644 index 00000000..595b5b3a --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/src/bindings.cpp @@ -0,0 +1,8 @@ +#include + +#include "shencoder.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("sh_encode_forward", &sh_encode_forward, "SH encode forward (CUDA)"); + m.def("sh_encode_backward", &sh_encode_backward, "SH encode backward (CUDA)"); +} \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/src/shencoder.cu b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/src/shencoder.cu new file mode 100644 index 00000000..a92e4ab7 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/src/shencoder.cu @@ -0,0 +1,439 @@ +#include + +#include +#include +#include + +#include +#include + +#include +#include + +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + + +template +__host__ __device__ T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +template +__global__ void kernel_sh( + const scalar_t * __restrict__ inputs, + scalar_t * outputs, + uint32_t B, uint32_t D, uint32_t C, + scalar_t * dy_dx +) { + const uint32_t b = threadIdx.x + blockIdx.x * blockDim.x; + if (b >= B) return; + + const uint32_t C2 = C * C; + + // locate + inputs += b * D; + outputs += b * C2; + + scalar_t x = inputs[0], y = inputs[1], z = inputs[2]; + + scalar_t xy=x*y, xz=x*z, yz=y*z, x2=x*x, y2=y*y, z2=z*z, xyz=xy*z; + scalar_t x4=x2*x2, y4=y2*y2, z4=z2*z2; + scalar_t x6=x4*x2, y6=y4*y2, z6=z4*z2; + + auto write_sh = [&]() { + outputs[0] = 0.28209479177387814f ; // 1/(2*sqrt(pi)) + if (C <= 1) { return; } + outputs[1] = -0.48860251190291987f*y ; // -sqrt(3)*y/(2*sqrt(pi)) + outputs[2] = 0.48860251190291987f*z ; // sqrt(3)*z/(2*sqrt(pi)) + outputs[3] = -0.48860251190291987f*x ; // -sqrt(3)*x/(2*sqrt(pi)) + if (C <= 2) { return; } + outputs[4] = 1.0925484305920792f*xy ; // sqrt(15)*xy/(2*sqrt(pi)) + outputs[5] = -1.0925484305920792f*yz ; // -sqrt(15)*yz/(2*sqrt(pi)) + outputs[6] = 0.94617469575755997f*z2 - 0.31539156525251999f ; // sqrt(5)*(3*z2 - 1)/(4*sqrt(pi)) + outputs[7] = -1.0925484305920792f*xz ; // -sqrt(15)*xz/(2*sqrt(pi)) + outputs[8] = 0.54627421529603959f*x2 - 0.54627421529603959f*y2 ; // sqrt(15)*(x2 - y2)/(4*sqrt(pi)) + if (C <= 3) { return; } + outputs[9] = 0.59004358992664352f*y*(-3.0f*x2 + y2) ; // sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi)) + outputs[10] = 2.8906114426405538f*xy*z ; // sqrt(105)*xy*z/(2*sqrt(pi)) + outputs[11] = 0.45704579946446572f*y*(1.0f - 5.0f*z2) ; // sqrt(42)*y*(1 - 5*z2)/(8*sqrt(pi)) + outputs[12] = 0.3731763325901154f*z*(5.0f*z2 - 3.0f) ; // sqrt(7)*z*(5*z2 - 3)/(4*sqrt(pi)) + outputs[13] = 0.45704579946446572f*x*(1.0f - 5.0f*z2) ; // sqrt(42)*x*(1 - 5*z2)/(8*sqrt(pi)) + outputs[14] = 1.4453057213202769f*z*(x2 - y2) ; // sqrt(105)*z*(x2 - y2)/(4*sqrt(pi)) + outputs[15] = 0.59004358992664352f*x*(-x2 + 3.0f*y2) ; // sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi)) + if (C <= 4) { return; } + outputs[16] = 2.5033429417967046f*xy*(x2 - y2) ; // 3*sqrt(35)*xy*(x2 - y2)/(4*sqrt(pi)) + outputs[17] = 1.7701307697799304f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(70)*yz*(-3*x2 + y2)/(8*sqrt(pi)) + outputs[18] = 0.94617469575756008f*xy*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*xy*(7*z2 - 1)/(4*sqrt(pi)) + outputs[19] = 0.66904654355728921f*yz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*yz*(3 - 7*z2)/(8*sqrt(pi)) + outputs[20] = -3.1735664074561294f*z2 + 3.7024941420321507f*z4 + 0.31735664074561293f ; // 3*(-30*z2 + 35*z4 + 3)/(16*sqrt(pi)) + outputs[21] = 0.66904654355728921f*xz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*xz*(3 - 7*z2)/(8*sqrt(pi)) + outputs[22] = 0.47308734787878004f*(x2 - y2)*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*(x2 - y2)*(7*z2 - 1)/(8*sqrt(pi)) + outputs[23] = 1.7701307697799304f*xz*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*xz*(-x2 + 3*y2)/(8*sqrt(pi)) + outputs[24] = -3.7550144126950569f*x2*y2 + 0.62583573544917614f*x4 + 0.62583573544917614f*y4 ; // 3*sqrt(35)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + if (C <= 5) { return; } + outputs[25] = 0.65638205684017015f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(154)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + outputs[26] = 8.3026492595241645f*xy*z*(x2 - y2) ; // 3*sqrt(385)*xy*z*(x2 - y2)/(4*sqrt(pi)) + outputs[27] = -0.48923829943525038f*y*(3.0f*x2 - y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*y*(3*x2 - y2)*(9*z2 - 1)/(32*sqrt(pi)) + outputs[28] = 4.7935367849733241f*xy*z*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xy*z*(3*z2 - 1)/(4*sqrt(pi)) + outputs[29] = 0.45294665119569694f*y*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*y*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + outputs[30] = 0.1169503224534236f*z*(-70.0f*z2 + 63.0f*z4 + 15.0f) ; // sqrt(11)*z*(-70*z2 + 63*z4 + 15)/(16*sqrt(pi)) + outputs[31] = 0.45294665119569694f*x*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*x*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + outputs[32] = 2.3967683924866621f*z*(x2 - y2)*(3.0f*z2 - 1.0f) ; // sqrt(1155)*z*(x2 - y2)*(3*z2 - 1)/(8*sqrt(pi)) + outputs[33] = -0.48923829943525038f*x*(x2 - 3.0f*y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*x*(x2 - 3*y2)*(9*z2 - 1)/(32*sqrt(pi)) + outputs[34] = 2.0756623148810411f*z*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + outputs[35] = 0.65638205684017015f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(154)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + if (C <= 6) { return; } + outputs[36] = 1.3663682103838286f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // sqrt(6006)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + outputs[37] = 2.3666191622317521f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + outputs[38] = 2.0182596029148963f*xy*(x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*xy*(x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) + outputs[39] = -0.92120525951492349f*yz*(3.0f*x2 - y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*yz*(3*x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) + outputs[40] = 0.92120525951492349f*xy*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*xy*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + outputs[41] = 0.58262136251873131f*yz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*yz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + outputs[42] = 6.6747662381009842f*z2 - 20.024298714302954f*z4 + 14.684485723822165f*z6 - 0.31784601133814211f ; // sqrt(13)*(105*z2 - 315*z4 + 231*z6 - 5)/(32*sqrt(pi)) + outputs[43] = 0.58262136251873131f*xz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*xz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + outputs[44] = 0.46060262975746175f*(x2 - y2)*(11.0f*z2*(3.0f*z2 - 1.0f) - 7.0f*z2 + 1.0f) ; // sqrt(2730)*(x2 - y2)*(11*z2*(3*z2 - 1) - 7*z2 + 1)/(64*sqrt(pi)) + outputs[45] = -0.92120525951492349f*xz*(x2 - 3.0f*y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*xz*(x2 - 3*y2)*(11*z2 - 3)/(32*sqrt(pi)) + outputs[46] = 0.50456490072872406f*(11.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(91)*(11*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) + outputs[47] = 2.3666191622317521f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + outputs[48] = 10.247761577878714f*x2*y4 - 10.247761577878714f*x4*y2 + 0.6831841051919143f*x6 - 0.6831841051919143f*y6 ; // sqrt(6006)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) + if (C <= 7) { return; } + outputs[49] = 0.70716273252459627f*y*(-21.0f*x2*y4 + 35.0f*x4*y2 - 7.0f*x6 + y6) ; // 3*sqrt(715)*y*(-21*x2*y4 + 35*x4*y2 - 7*x6 + y6)/(64*sqrt(pi)) + outputs[50] = 5.2919213236038001f*xy*z*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*z*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + outputs[51] = -0.51891557872026028f*y*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // -3*sqrt(385)*y*(13*z2 - 1)*(-10*x2*y2 + 5*x4 + y4)/(64*sqrt(pi)) + outputs[52] = 4.1513246297620823f*xy*z*(x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xy*z*(x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) + outputs[53] = -0.15645893386229404f*y*(3.0f*x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*y*(3*x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) + outputs[54] = 0.44253269244498261f*xy*z*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xy*z*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + outputs[55] = 0.090331607582517306f*y*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*y*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + outputs[56] = 0.068284276912004949f*z*(315.0f*z2 - 693.0f*z4 + 429.0f*z6 - 35.0f) ; // sqrt(15)*z*(315*z2 - 693*z4 + 429*z6 - 35)/(32*sqrt(pi)) + outputs[57] = 0.090331607582517306f*x*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*x*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + outputs[58] = 0.07375544874083044f*z*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*z*(x2 - y2)*(143*z2*(3*z2 - 1) - 187*z2 + 45)/(64*sqrt(pi)) + outputs[59] = -0.15645893386229404f*x*(x2 - 3.0f*y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*x*(x2 - 3*y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) + outputs[60] = 1.0378311574405206f*z*(13.0f*z2 - 3.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(13*z2 - 3)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) + outputs[61] = -0.51891557872026028f*x*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // -3*sqrt(385)*x*(13*z2 - 1)*(-10*x2*y2 + x4 + 5*y4)/(64*sqrt(pi)) + outputs[62] = 2.6459606618019f*z*(15.0f*x2*y4 - 15.0f*x4*y2 + x6 - y6) ; // 3*sqrt(10010)*z*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) + outputs[63] = 0.70716273252459627f*x*(-35.0f*x2*y4 + 21.0f*x4*y2 - x6 + 7.0f*y6) ; // 3*sqrt(715)*x*(-35*x2*y4 + 21*x4*y2 - x6 + 7*y6)/(64*sqrt(pi)) + }; + + write_sh(); + + if (dy_dx) { + scalar_t *dx = dy_dx + b * D * C2; + scalar_t *dy = dx + C2; + scalar_t *dz = dy + C2; + + auto write_sh_dx = [&]() { + dx[0] = 0.0f ; // 0 + if (C <= 1) { return; } + dx[1] = 0.0f ; // 0 + dx[2] = 0.0f ; // 0 + dx[3] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi)) + if (C <= 2) { return; } + dx[4] = 1.0925484305920792f*y ; // sqrt(15)*y/(2*sqrt(pi)) + dx[5] = 0.0f ; // 0 + dx[6] = 0.0f ; // 0 + dx[7] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi)) + dx[8] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi)) + if (C <= 3) { return; } + dx[9] = -3.5402615395598609f*xy ; // -3*sqrt(70)*xy/(4*sqrt(pi)) + dx[10] = 2.8906114426405538f*yz ; // sqrt(105)*yz/(2*sqrt(pi)) + dx[11] = 0.0f ; // 0 + dx[12] = 0.0f ; // 0 + dx[13] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi)) + dx[14] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi)) + dx[15] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi)) + if (C <= 4) { return; } + dx[16] = 2.5033429417967046f*y*(3.0f*x2 - y2) ; // 3*sqrt(35)*y*(3*x2 - y2)/(4*sqrt(pi)) + dx[17] = -10.620784618679583f*xy*z ; // -9*sqrt(70)*xy*z/(4*sqrt(pi)) + dx[18] = 0.94617469575756008f*y*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*y*(7*z2 - 1)/(4*sqrt(pi)) + dx[19] = 0.0f ; // 0 + dx[20] = 0.0f ; // 0 + dx[21] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi)) + dx[22] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi)) + dx[23] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi)) + dx[24] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi)) + if (C <= 5) { return; } + dx[25] = 13.127641136803401f*xy*(-x2 + y2) ; // 15*sqrt(154)*xy*(-x2 + y2)/(8*sqrt(pi)) + dx[26] = 8.3026492595241645f*yz*(3.0f*x2 - y2) ; // 3*sqrt(385)*yz*(3*x2 - y2)/(4*sqrt(pi)) + dx[27] = 2.9354297966115022f*xy*(1.0f - 9.0f*z2) ; // 3*sqrt(770)*xy*(1 - 9*z2)/(16*sqrt(pi)) + dx[28] = 4.7935367849733241f*yz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*yz*(3*z2 - 1)/(4*sqrt(pi)) + dx[29] = 0.0f ; // 0 + dx[30] = 0.0f ; // 0 + dx[31] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + dx[32] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi)) + dx[33] = -13.209434084751759f*x2*z2 + 1.4677148983057511f*x2 + 13.209434084751759f*y2*z2 - 1.4677148983057511f*y2 ; // 3*sqrt(770)*(-9*x2*z2 + x2 + 9*y2*z2 - y2)/(32*sqrt(pi)) + dx[34] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi)) + dx[35] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) + if (C <= 6) { return; } + dx[36] = 4.0991046311514854f*y*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 3*sqrt(6006)*y*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi)) + dx[37] = 47.332383244635047f*xy*z*(-x2 + y2) ; // 15*sqrt(2002)*xy*z*(-x2 + y2)/(8*sqrt(pi)) + dx[38] = 2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) + dx[39] = 5.5272315570895412f*xy*z*(3.0f - 11.0f*z2) ; // 3*sqrt(2730)*xy*z*(3 - 11*z2)/(16*sqrt(pi)) + dx[40] = 0.92120525951492349f*y*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*y*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + dx[41] = 0.0f ; // 0 + dx[42] = 0.0f ; // 0 + dx[43] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + dx[44] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + dx[45] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) + dx[46] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi)) + dx[47] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) + dx[48] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) + if (C <= 7) { return; } + dx[49] = 9.9002782553443485f*xy*(10.0f*x2*y2 - 3.0f*x4 - 3.0f*y4) ; // 21*sqrt(715)*xy*(10*x2*y2 - 3*x4 - 3*y4)/(32*sqrt(pi)) + dx[50] = 15.875763970811402f*yz*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 9*sqrt(10010)*yz*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi)) + dx[51] = -10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // -15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi)) + dx[52] = 4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) + dx[53] = 0.93875360317376422f*xy*(66.0f*z2 - 143.0f*z4 - 3.0f) ; // 9*sqrt(35)*xy*(66*z2 - 143*z4 - 3)/(32*sqrt(pi)) + dx[54] = 0.44253269244498261f*yz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*yz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + dx[55] = 0.0f ; // 0 + dx[56] = 0.0f ; // 0 + dx[57] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + dx[58] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + dx[59] = 30.97886890473422f*x2*z2 - 67.120882626924143f*x2*z4 - 1.4081304047606462f*x2 - 30.97886890473422f*y2*z2 + 67.120882626924143f*y2*z4 + 1.4081304047606462f*y2 ; // 9*sqrt(35)*(66*x2*z2 - 143*x2*z4 - 3*x2 - 66*y2*z2 + 143*y2*z4 + 3*y2)/(64*sqrt(pi)) + dx[60] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi)) + dx[61] = -0.51891557872026028f*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 4.0f*x2*(x2 - 5.0f*y2) + x4 + 5.0f*y4) ; // -3*sqrt(385)*(13*z2 - 1)*(-10*x2*y2 + 4*x2*(x2 - 5*y2) + x4 + 5*y4)/(64*sqrt(pi)) + dx[62] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) + dx[63] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi)) + }; + + auto write_sh_dy = [&]() { + dy[0] = 0.0f ; // 0 + if (C <= 1) { return; } + dy[1] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi)) + dy[2] = 0.0f ; // 0 + dy[3] = 0.0f ; // 0 + if (C <= 2) { return; } + dy[4] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi)) + dy[5] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi)) + dy[6] = 0.0f ; // 0 + dy[7] = 0.0f ; // 0 + dy[8] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi)) + if (C <= 3) { return; } + dy[9] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi)) + dy[10] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi)) + dy[11] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi)) + dy[12] = 0.0f ; // 0 + dy[13] = 0.0f ; // 0 + dy[14] = -2.8906114426405538f*yz ; // -sqrt(105)*yz/(2*sqrt(pi)) + dy[15] = 3.5402615395598609f*xy ; // 3*sqrt(70)*xy/(4*sqrt(pi)) + if (C <= 4) { return; } + dy[16] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi)) + dy[17] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi)) + dy[18] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi)) + dy[19] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi)) + dy[20] = 0.0f ; // 0 + dy[21] = 0.0f ; // 0 + dy[22] = 0.94617469575756008f*y*(1.0f - 7.0f*z2) ; // 3*sqrt(5)*y*(1 - 7*z2)/(4*sqrt(pi)) + dy[23] = 10.620784618679583f*xy*z ; // 9*sqrt(70)*xy*z/(4*sqrt(pi)) + dy[24] = 2.5033429417967046f*y*(-3.0f*x2 + y2) ; // 3*sqrt(35)*y*(-3*x2 + y2)/(4*sqrt(pi)) + if (C <= 5) { return; } + dy[25] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) + dy[26] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi)) + dy[27] = -1.4677148983057511f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // -3*sqrt(770)*(x2 - y2)*(9*z2 - 1)/(32*sqrt(pi)) + dy[28] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi)) + dy[29] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + dy[30] = 0.0f ; // 0 + dy[31] = 0.0f ; // 0 + dy[32] = 4.7935367849733241f*yz*(1.0f - 3.0f*z2) ; // sqrt(1155)*yz*(1 - 3*z2)/(4*sqrt(pi)) + dy[33] = 2.9354297966115022f*xy*(9.0f*z2 - 1.0f) ; // 3*sqrt(770)*xy*(9*z2 - 1)/(16*sqrt(pi)) + dy[34] = 8.3026492595241645f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(385)*yz*(-3*x2 + y2)/(4*sqrt(pi)) + dy[35] = 13.127641136803401f*xy*(x2 - y2) ; // 15*sqrt(154)*xy*(x2 - y2)/(8*sqrt(pi)) + if (C <= 6) { return; } + dy[36] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) + dy[37] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) + dy[38] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi)) + dy[39] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) + dy[40] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + dy[41] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + dy[42] = 0.0f ; // 0 + dy[43] = 0.0f ; // 0 + dy[44] = 0.92120525951492349f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // sqrt(2730)*y*(18*z2 - 33*z4 - 1)/(32*sqrt(pi)) + dy[45] = 5.5272315570895412f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(16*sqrt(pi)) + dy[46] = -2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) + dy[47] = 47.332383244635047f*xy*z*(x2 - y2) ; // 15*sqrt(2002)*xy*z*(x2 - y2)/(8*sqrt(pi)) + dy[48] = 4.0991046311514854f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(6006)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + if (C <= 7) { return; } + dy[49] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi)) + dy[50] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) + dy[51] = 0.51891557872026028f*(13.0f*z2 - 1.0f)*(10.0f*x2*y2 - 5.0f*x4 + 4.0f*y2*(5.0f*x2 - y2) - y4) ; // 3*sqrt(385)*(13*z2 - 1)*(10*x2*y2 - 5*x4 + 4*y2*(5*x2 - y2) - y4)/(64*sqrt(pi)) + dy[52] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi)) + dy[53] = -0.46937680158688211f*(x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -9*sqrt(35)*(x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) + dy[54] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + dy[55] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + dy[56] = 0.0f ; // 0 + dy[57] = 0.0f ; // 0 + dy[58] = 0.44253269244498261f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 3*sqrt(70)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi)) + dy[59] = 0.93875360317376422f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 9*sqrt(35)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi)) + dy[60] = -4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) + dy[61] = 10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi)) + dy[62] = 15.875763970811402f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 9*sqrt(10010)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + dy[63] = 9.9002782553443485f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 21*sqrt(715)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + }; + + auto write_sh_dz = [&]() { + dz[0] = 0.0f ; // 0 + if (C <= 1) { return; } + dz[1] = 0.0f ; // 0 + dz[2] = 0.48860251190291992f ; // sqrt(3)/(2*sqrt(pi)) + dz[3] = 0.0f ; // 0 + if (C <= 2) { return; } + dz[4] = 0.0f ; // 0 + dz[5] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi)) + dz[6] = 1.8923493915151202f*z ; // 3*sqrt(5)*z/(2*sqrt(pi)) + dz[7] = -1.0925484305920792f*x ; // -sqrt(15)*x/(2*sqrt(pi)) + dz[8] = 0.0f ; // 0 + if (C <= 3) { return; } + dz[9] = 0.0f ; // 0 + dz[10] = 2.8906114426405538f*xy ; // sqrt(105)*xy/(2*sqrt(pi)) + dz[11] = -4.5704579946446566f*yz ; // -5*sqrt(42)*yz/(4*sqrt(pi)) + dz[12] = 5.597644988851731f*z2 - 1.1195289977703462f ; // 3*sqrt(7)*(5*z2 - 1)/(4*sqrt(pi)) + dz[13] = -4.5704579946446566f*xz ; // -5*sqrt(42)*xz/(4*sqrt(pi)) + dz[14] = 1.4453057213202769f*x2 - 1.4453057213202769f*y2 ; // sqrt(105)*(x2 - y2)/(4*sqrt(pi)) + dz[15] = 0.0f ; // 0 + if (C <= 4) { return; } + dz[16] = 0.0f ; // 0 + dz[17] = 1.7701307697799304f*y*(-3.0f*x2 + y2) ; // 3*sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi)) + dz[18] = 13.246445740605839f*xy*z ; // 21*sqrt(5)*xy*z/(2*sqrt(pi)) + dz[19] = 2.0071396306718676f*y*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*y*(1 - 7*z2)/(8*sqrt(pi)) + dz[20] = 14.809976568128603f*pow(z, 3) - 6.3471328149122579f*z ; // (105*z**3 - 45*z)/(4*sqrt(pi)) + dz[21] = 2.0071396306718676f*x*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*x*(1 - 7*z2)/(8*sqrt(pi)) + dz[22] = 6.6232228703029197f*z*(x2 - y2) ; // 21*sqrt(5)*z*(x2 - y2)/(4*sqrt(pi)) + dz[23] = 1.7701307697799304f*x*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi)) + dz[24] = 0.0f ; // 0 + if (C <= 5) { return; } + dz[25] = 0.0f ; // 0 + dz[26] = 8.3026492595241645f*xy*(x2 - y2) ; // 3*sqrt(385)*xy*(x2 - y2)/(4*sqrt(pi)) + dz[27] = 8.8062893898345074f*yz*(-3.0f*x2 + y2) ; // 9*sqrt(770)*yz*(-3*x2 + y2)/(16*sqrt(pi)) + dz[28] = 4.7935367849733241f*xy*(9.0f*z2 - 1.0f) ; // sqrt(1155)*xy*(9*z2 - 1)/(4*sqrt(pi)) + dz[29] = 12.682506233479513f*yz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*yz*(1 - 3*z2)/(4*sqrt(pi)) + dz[30] = -24.559567715218954f*z2 + 36.839351572828434f*z4 + 1.754254836801354f ; // 15*sqrt(11)*(-14*z2 + 21*z4 + 1)/(16*sqrt(pi)) + dz[31] = 12.682506233479513f*xz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*xz*(1 - 3*z2)/(4*sqrt(pi)) + dz[32] = 2.3967683924866621f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // sqrt(1155)*(x2 - y2)*(9*z2 - 1)/(8*sqrt(pi)) + dz[33] = 8.8062893898345074f*xz*(-x2 + 3.0f*y2) ; // 9*sqrt(770)*xz*(-x2 + 3*y2)/(16*sqrt(pi)) + dz[34] = -12.453973889286246f*x2*y2 + 2.0756623148810411f*x4 + 2.0756623148810411f*y4 ; // 3*sqrt(385)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + dz[35] = 0.0f ; // 0 + if (C <= 6) { return; } + dz[36] = 0.0f ; // 0 + dz[37] = 2.3666191622317521f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + dz[38] = 44.401711264127719f*xy*z*(x2 - y2) ; // 33*sqrt(91)*xy*z*(x2 - y2)/(4*sqrt(pi)) + dz[39] = -2.7636157785447706f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*y*(3*x2 - y2)*(11*z2 - 1)/(32*sqrt(pi)) + dz[40] = 11.054463114179082f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(8*sqrt(pi)) + dz[41] = 2.9131068125936568f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*y*(18*z2 - 33*z4 - 1)/(16*sqrt(pi)) + dz[42] = 2.6699064952403937f*z*(-30.0f*z2 + 33.0f*z4 + 5.0f) ; // 21*sqrt(13)*z*(-30*z2 + 33*z4 + 5)/(16*sqrt(pi)) + dz[43] = 2.9131068125936568f*x*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*x*(18*z2 - 33*z4 - 1)/(16*sqrt(pi)) + dz[44] = 5.5272315570895412f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(16*sqrt(pi)) + dz[45] = -2.7636157785447706f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*x*(x2 - 3*y2)*(11*z2 - 1)/(32*sqrt(pi)) + dz[46] = 11.10042781603193f*z*(-6.0f*x2*y2 + x4 + y4) ; // 33*sqrt(91)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + dz[47] = 2.3666191622317521f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + dz[48] = 0.0f ; // 0 + if (C <= 7) { return; } + dz[49] = 0.0f ; // 0 + dz[50] = 5.2919213236038001f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + dz[51] = 13.491805046726766f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 39*sqrt(385)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + dz[52] = 12.453973889286248f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 9*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(8*sqrt(pi)) + dz[53] = -6.8841930899409371f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*yz*(3*x2 - y2)*(13*z2 - 3)/(16*sqrt(pi)) + dz[54] = 2.2126634622249131f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 15*sqrt(70)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi)) + dz[55] = 1.6259689364853116f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi)) + dz[56] = 64.528641681844675f*z2 - 236.60501950009714f*z4 + 205.05768356675085f*z6 - 2.3899496919201733f ; // 7*sqrt(15)*(135*z2 - 495*z4 + 429*z6 - 5)/(32*sqrt(pi)) + dz[57] = 1.6259689364853116f*xz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*xz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi)) + dz[58] = 0.07375544874083044f*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) + 132.0f*z2*(13.0f*z2 - 5.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*(x2 - y2)*(143*z2*(3*z2 - 1) + 132*z2*(13*z2 - 5) - 187*z2 + 45)/(64*sqrt(pi)) + dz[59] = -6.8841930899409371f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*xz*(x2 - 3*y2)*(13*z2 - 3)/(16*sqrt(pi)) + dz[60] = 3.1134934723215619f*(13.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 9*sqrt(385)*(13*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) + dz[61] = 13.491805046726766f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 39*sqrt(385)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + dz[62] = 39.6894099270285f*x2*y4 - 39.6894099270285f*x4*y2 + 2.6459606618019f*x6 - 2.6459606618019f*y6 ; // 3*sqrt(10010)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) + dz[63] = 0.0f ; // 0 + }; + write_sh_dx(); + write_sh_dy(); + write_sh_dz(); + } +} + + +template +__global__ void kernel_sh_backward( + const scalar_t * __restrict__ grad, + const scalar_t * __restrict__ inputs, + uint32_t B, uint32_t D, uint32_t C, + const scalar_t * __restrict__ dy_dx, + scalar_t * grad_inputs +) { + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + const uint32_t b = t / D; + if (b >= B) return; + + const uint32_t d = t - b * D; + const uint32_t C2 = C * C; + + // locate + grad += b * C2; + dy_dx += b * D * C2 + d * C2; + + for (int ch = 0; ch < C2; ch++) { + grad_inputs[t] += grad[ch] * dy_dx[ch]; + //printf("t=%d, b=%d, d=%d, ch=%d, grad=%f (+= %f * %f)\n", t, b, d, ch, grad_inputs[t], grad[ch], dy_dx[ch]); + } + +} + +// inputs: [B, D], float, in [0, 1] +// outputs: [B, L * C], float +template +void sh_encode_forward_cuda(const scalar_t *inputs, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx) { + static constexpr uint32_t N_THREADS = 256; + kernel_sh<<>>(inputs, outputs, B, D, C, dy_dx); +} + + +template +void sh_encode_backward_cuda(const scalar_t *grad, const scalar_t *inputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx, scalar_t *grad_inputs) { + static constexpr uint32_t N_THREADS = 256; + kernel_sh_backward<<>>(grad, inputs, B, D, C, dy_dx, grad_inputs); +} + + +void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional dy_dx) { + CHECK_CUDA(inputs); + CHECK_CUDA(outputs); + // CHECK_CUDA(dy_dx); + + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(outputs); + // CHECK_CONTIGUOUS(dy_dx); + + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(outputs); + // CHECK_IS_FLOATING(dy_dx); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + inputs.scalar_type(), "sh_encode_forward_cuda", ([&] { + sh_encode_forward_cuda(inputs.data_ptr(), outputs.data_ptr(), B, D, C, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr); + })); +} + +void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs) { + CHECK_CUDA(grad); + CHECK_CUDA(inputs); + CHECK_CUDA(dy_dx); + CHECK_CUDA(grad_inputs); + + CHECK_CONTIGUOUS(grad); + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(dy_dx); + CHECK_CONTIGUOUS(grad_inputs); + + CHECK_IS_FLOATING(grad); + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(dy_dx); + CHECK_IS_FLOATING(grad_inputs); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "sh_encode_backward_cuda", ([&] { + sh_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), B, D, C, dy_dx.data_ptr(), grad_inputs.data_ptr()); + })); +} \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/src/shencoder.h b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/src/shencoder.h new file mode 100644 index 00000000..f9e89fac --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/encoders/shencoder/src/shencoder.h @@ -0,0 +1,10 @@ +# pragma once + +#include +#include + +// inputs: [B, D], float, in [-1, 1] +// outputs: [B, F], float + +void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional dy_dx); +void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs); \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/radnerf.py b/Geneface_main/GeneFace/modules/radnerfs/radnerf.py new file mode 100644 index 00000000..7173e57a --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/radnerf.py @@ -0,0 +1,130 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modules.radnerfs.encoders.encoding import get_encoder +from modules.radnerfs.renderer import NeRFRenderer +from modules.radnerfs.cond_encoder import AudioNet, AudioAttNet, MLP +from modules.radnerfs.utils import trunc_exp + + +class RADNeRF(NeRFRenderer): + def __init__(self, hparams): + super().__init__(hparams) + self.hparams = hparams + if hparams['cond_type'] == 'esperanto': + self.cond_in_dim = 44 + elif hparams['cond_type'] == 'deepspeech': + self.cond_in_dim = 29 + elif hparams['cond_type'] == 'idexp_lm3d_normalized': + self.cond_in_dim = 68*3 + else: + raise NotImplementedError() + + # a prenet that processes the raw condition + self.cond_out_dim = hparams['cond_out_dim'] + self.cond_win_size = hparams['cond_win_size'] + self.smo_win_size = hparams['smo_win_size'] + self.cond_prenet = AudioNet(self.cond_in_dim, self.cond_out_dim, win_size=self.cond_win_size) + + # a attention net that smoothes the condition feat sequence + self.with_att = hparams['with_att'] + if self.with_att: + self.cond_att_net = AudioAttNet(self.cond_out_dim, seq_len=self.smo_win_size) + + # a ambient network that predict the 2D ambient coordinate + # the ambient grid models the dynamic of canonical face + # by predict ambient coords given cond_feat, we can be driven the face by either audio or landmark! + self.grid_type = hparams['grid_type'] # tiledgrid or hashgrid + self.grid_interpolation_type = hparams['grid_interpolation_type'] # smoothstep or linear + self.position_embedder, self.position_embedding_dim = get_encoder(self.grid_type, input_dim=3, num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=hparams['log2_hashmap_size'], desired_resolution=hparams['desired_resolution'] * self.bound, interpolation=self.grid_interpolation_type) + self.num_layers_ambient = hparams['num_layers_ambient'] + self.hidden_dim_ambient = hparams['hidden_dim_ambient'] + self.ambient_out_dim = hparams['ambient_out_dim'] + self.ambient_net = MLP(self.position_embedding_dim + self.cond_out_dim, self.ambient_out_dim, self.hidden_dim_ambient, self.num_layers_ambient) + # the learnable ambient grid + self.ambient_embedder, self.ambient_embedding_dim = get_encoder(self.grid_type, input_dim=hparams['ambient_out_dim'], num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=hparams['log2_hashmap_size'], desired_resolution=hparams['desired_resolution'], interpolation=self.grid_interpolation_type) + + # sigma network + self.num_layers_sigma = hparams['num_layers_sigma'] + self.hidden_dim_sigma = hparams['hidden_dim_sigma'] + self.geo_feat_dim = hparams['geo_feat_dim'] + + self.sigma_net = MLP(self.position_embedding_dim + self.ambient_embedding_dim, 1 + self.geo_feat_dim, self.hidden_dim_sigma, self.num_layers_sigma) + + # color network + self.num_layers_color = hparams['num_layers_color'] + self.hidden_dim_color = hparams['hidden_dim_color'] + self.direction_embedder, self.direction_embedding_dim = get_encoder('spherical_harmonics') + self.color_net = MLP(self.direction_embedding_dim + self.geo_feat_dim + self.individual_embedding_dim, 3, self.hidden_dim_color, self.num_layers_color) + + def cal_cond_feat(self, cond): + """ + cond: [B, T, Ç] + if deepspeech, [1/8, T=16, 29] + if eserpanto, [1/8, T=16, 44] + if idexp_lm3d_normalized, [1/5, T=1, 204] + """ + cond_feat = self.cond_prenet(cond) + if self.with_att: + cond_feat = self.cond_att_net(cond_feat) # [1, 64] + return cond_feat + + def forward(self, position, direction, cond_feat, individual_code): + """ + position: [N, 3], position, in [-bound, bound] + direction: [N, 3], direction, nomalized in [-1, 1] + cond_feat: [1, cond_dim], condition encoding, generated by self.cal_cond_feat + individual_code: [1, ind_dim], individual code for each timestep + """ + cond_feat = cond_feat.repeat(position.shape[0], 1) # [1,cond_dim] ==> [N, cond_dim] + pos_feat = self.position_embedder(position, bound=self.bound) # spatial feat f after E^3_{spatial} 3D grid in the paper + + # ambient + ambient_inp = torch.cat([pos_feat, cond_feat], dim=1) # audio feat and spatial feat + ambient_logit = self.ambient_net(ambient_inp).float() # the MLP after AFE in paper, use float(), prevent performance drop due to amp + ambient_pos = torch.tanh(ambient_logit) # normalized to [-1, 1], act as the coordinate in the 2D ambient tilegrid + ambient_feat = self.ambient_embedder(ambient_pos, bound=1) # E^2_{audio} in paper, 2D grid + + # sigma + h = torch.cat([pos_feat, ambient_feat], dim=-1) + h = self.sigma_net(h) + sigma = trunc_exp(h[..., 0]) + geo_feat = h[..., 1:] + + # color + direction_feat = self.direction_embedder(direction) + if individual_code is not None: + color_inp = torch.cat([direction_feat, geo_feat, individual_code.repeat(position.shape[0], 1)], dim=-1) + else: + color_inp = torch.cat([direction_feat, geo_feat], dim=-1) + color_logit = self.color_net(color_inp) + # sigmoid activation for rgb + color = torch.sigmoid(color_logit) + + return sigma, color, ambient_pos + + def density(self, position, cond_feat, e=None): + """ + Calculate Density, this is a sub-process of self.forward + """ + cond_feat = cond_feat.repeat(position.shape[0], 1) # [1,cond_dim] ==> [N, cond_dim] + pos_feat = self.position_embedder(position, bound=self.bound) # spatial feat f after E^3_{spatial} 3D grid in the paper + + # ambient + ambient_inp = torch.cat([pos_feat, cond_feat], dim=1) # audio feat and spatial feat + ambient_logit = self.ambient_net(ambient_inp).float() # the MLP after AFE in paper, use float(), prevent performance drop due to amp + ambient_pos = torch.tanh(ambient_logit) # normalized to [-1, 1], act as the coordinate in the 2D ambient tilegrid + ambient_feat = self.ambient_embedder(ambient_pos, bound=1) # E^2_{audio} in paper, 2D grid + + # sigma + h = torch.cat([pos_feat, ambient_feat], dim=-1) + h = self.sigma_net(h) + sigma = trunc_exp(h[..., 0]) + geo_feat = h[..., 1:] + + return { + 'sigma': sigma, + 'geo_feat': geo_feat, + } + diff --git a/Geneface_main/GeneFace/modules/radnerfs/radnerf_torso.py b/Geneface_main/GeneFace/modules/radnerfs/radnerf_torso.py new file mode 100644 index 00000000..15f302ab --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/radnerf_torso.py @@ -0,0 +1,241 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import random + +import modules.radnerfs.raymarching as raymarching +from modules.radnerfs.encoders.encoding import get_encoder +from modules.radnerfs.renderer import NeRFRenderer +from modules.radnerfs.radnerf import RADNeRF +from modules.radnerfs.cond_encoder import AudioNet, AudioAttNet, MLP +from modules.radnerfs.utils import trunc_exp +from modules.radnerfs.utils import custom_meshgrid, convert_poses + +from utils.commons.hparams import hparams + + +class RADNeRFTorso(RADNeRF): + def __init__(self, hparams): + super().__init__(hparams) + density_grid_torso = torch.zeros([self.grid_size ** 2]) # [H * H] + self.register_buffer('density_grid_torso', density_grid_torso) + self.mean_density_torso = 0 + self.density_thresh_torso = hparams['density_thresh_torso'] + + self.torso_individual_embedding_num = hparams['individual_embedding_num'] + self.torso_individual_embedding_dim = hparams['torso_individual_embedding_dim'] + if self.torso_individual_embedding_dim > 0: + self.torso_individual_codes = nn.Parameter(torch.randn(self.torso_individual_embedding_num, self.torso_individual_embedding_dim) * 0.1) + + self.torso_pose_embedder, self.pose_embedding_dim = get_encoder('frequency', input_dim=6, multires=4) + self.torso_deform_pos_embedder, self.torso_deform_pos_dim = get_encoder('frequency', input_dim=2, multires=10) # input 2D position + self.torso_embedder, self.torso_in_dim = get_encoder('tiledgrid', input_dim=2, num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=16, desired_resolution=2048) + + deform_net_in_dim = self.torso_deform_pos_dim + self.pose_embedding_dim + self.torso_individual_embedding_dim + canonicial_net_in_dim = self.torso_in_dim + self.torso_deform_pos_dim + self.pose_embedding_dim + self.torso_individual_embedding_dim + if hparams['torso_head_aware']: + head_aware_out_dim = 16 + self.head_color_weights_encoder = nn.Sequential(*[ + nn.Linear(3+1, 16, bias=True), + nn.LeakyReLU(0.02, True), + nn.Linear(16, 32, bias=True), + nn.LeakyReLU(0.02, True), + nn.Linear(32, head_aware_out_dim, bias=True), + ]) + deform_net_in_dim += head_aware_out_dim + canonicial_net_in_dim += head_aware_out_dim + + self.torso_deform_net = MLP(deform_net_in_dim, 2, 64, 3) + self.torso_canonicial_net = MLP(canonicial_net_in_dim, 4, 32, 3) + + def forward_torso(self, x, poses, c=None, image=None, weights_sum=None): + # x: [N, 2] in [-1, 1] + # head poses: [1, 6] + # c: [1, ind_dim], individual code + + # test: shrink x + x = x * hparams['torso_shrink'] + + # deformation-based + enc_pose = self.torso_pose_embedder(poses) + enc_x = self.torso_deform_pos_embedder(x) + + if c is not None: + h = torch.cat([enc_x, enc_pose.repeat(x.shape[0], 1), c.repeat(x.shape[0], 1)], dim=-1) + else: + h = torch.cat([enc_x, enc_pose.repeat(x.shape[0], 1)], dim=-1) + + if hparams['torso_head_aware']: + if image is None: + image = torch.zeros([x.shape[0],3], dtype=h.dtype, device=h.device) + weights_sum = torch.zeros([x.shape[0],1], dtype=h.dtype, device=h.device) + head_color_weights_inp = torch.cat([image, weights_sum],dim=-1) + head_color_weights_encoding = self.head_color_weights_encoder(head_color_weights_inp) + h = torch.cat([h, head_color_weights_encoding],dim=-1) + + dx = self.torso_deform_net(h) + x = (x + dx).clamp(-1, 1).float() + x = self.torso_embedder(x, bound=1) + h = torch.cat([x, h], dim=-1) + h = self.torso_canonicial_net(h) + alpha = torch.sigmoid(h[..., :1]) + color = torch.sigmoid(h[..., 1:]) + + return alpha, color, dx + + def render(self, rays_o, rays_d, cond, bg_coords, poses, index=0, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs): + # rays_o, rays_d: [B, N, 3], assumes B == 1 + # cond: [B, 29, 16] + # bg_coords: [1, N, 2] + # return: pred_rgb: [B, N, 3] + + ### run head nerf with no_grad to get the renderred head + with torch.no_grad(): + prefix = rays_o.shape[:-1] + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + bg_coords = bg_coords.contiguous().view(-1, 2) + N = rays_o.shape[0] # N = B * N, in fact + device = rays_o.device + results = {} + # pre-calculate near far + nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer, self.min_near) + nears = nears.detach() + fars = fars.detach() + # encode audio + cond_feat = self.cal_cond_feat(cond) # [1, 64] + if self.individual_embedding_dim > 0: + if self.training: + ind_code = self.individual_embeddings[index] + # use a fixed ind code for the unknown test data. + else: + ind_code = self.individual_embeddings[0] + else: + ind_code = None + if self.training: + # setup counter + counter = self.step_counter[self.local_step % 16] + counter.zero_() # set to 0 + self.local_step += 1 + xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps) + sigmas, rgbs, ambient = self(xyzs, dirs, cond_feat, ind_code) + sigmas = self.density_scale * sigmas + #print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})') + weights_sum, ambient_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, ambient.abs().sum(-1), deltas, rays) + # for training only + results['weights_sum'] = weights_sum + results['ambient'] = ambient_sum + else: + dtype = torch.float32 + weights_sum = torch.zeros(N, dtype=dtype, device=device) + depth = torch.zeros(N, dtype=dtype, device=device) + image = torch.zeros(N, 3, dtype=dtype, device=device) + n_alive = N + rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N] + rays_t = nears.clone() # [N] + step = 0 + while step < max_steps: + # count alive rays + n_alive = rays_alive.shape[0] + # exit loop + if n_alive <= 0: + break + # decide compact_steps + n_step = max(min(N // n_alive, 8), 1) + xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps) + sigmas, rgbs, ambient = self(xyzs, dirs, cond_feat, ind_code) + sigmas = self.density_scale * sigmas + raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh) + rays_alive = rays_alive[rays_alive >= 0] + step += n_step + # background + if bg_color is None: + bg_color = 1 + + ### Start Rendering Torso + if self.torso_individual_embedding_dim > 0: + if self.training: + torso_individual_code = self.torso_individual_codes[index] + # use a fixed ind code for the unknown test data. + else: + torso_individual_code = self.torso_individual_codes[0] + else: + torso_individual_code = None + + # 2D density grid for acceleration... + density_thresh_torso = min(self.density_thresh_torso, self.mean_density_torso) + occupancy = F.grid_sample(self.density_grid_torso.view(1, 1, self.grid_size, self.grid_size), bg_coords.view(1, -1, 1, 2), align_corners=True).view(-1) + mask = occupancy > density_thresh_torso + + # masked query of torso + torso_alpha = torch.zeros([N, 1], device=device) + torso_color = torch.zeros([N, 3], device=device) + + if mask.any(): + if hparams['torso_head_aware']: + if random.random() < 0.5: + torso_alpha_mask, torso_color_mask, deform = self.forward_torso(bg_coords[mask], poses, torso_individual_code, image[mask], weights_sum.unsqueeze(-1)[mask]) + else: + torso_alpha_mask, torso_color_mask, deform = self.forward_torso(bg_coords[mask], poses, torso_individual_code, None, None) + else: + torso_alpha_mask, torso_color_mask, deform = self.forward_torso(bg_coords[mask], poses, torso_individual_code) + torso_alpha[mask] = torso_alpha_mask.float() + torso_color[mask] = torso_color_mask.float() + results['deform'] = deform + # first mix torso with background + bg_color = torso_color * torso_alpha + bg_color * (1 - torso_alpha) + results['torso_alpha_map'] = torso_alpha + results['torso_rgb_map'] = bg_color + # then mix the head image with the torso_bg + image = image + (1 - weights_sum).unsqueeze(-1) * bg_color + image = image.view(*prefix, 3) + image = image.clamp(0, 1) + depth = torch.clamp(depth - nears, min=0) / (fars - nears) + depth = depth.view(*prefix) + results['depth_map'] = depth + results['rgb_map'] = image # head_image if train, else com_image + + return results + + @torch.no_grad() + def update_extra_state(self, decay=0.95, S=128): + # forbid updating head if is training torso... + # only update torso density grid + tmp_grid_torso = torch.zeros_like(self.density_grid_torso) + + # random pose, random ind_code + rand_idx = random.randint(0, self.poses.shape[0] - 1) + pose = convert_poses(self.poses[[rand_idx]]).to(self.density_bitfield.device) + + if self.torso_individual_embedding_dim > 0: + ind_code = self.torso_individual_codes[[rand_idx]] + else: + ind_code = None + + X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + + half_grid_size = 1 / self.grid_size + + for xs in X: + for ys in Y: + xx, yy = custom_meshgrid(xs, ys) + coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], dim=-1) # [N, 2], in [0, 128) + indices = (coords[:, 1] * self.grid_size + coords[:, 0]).long() # NOTE: xy transposed! + xys = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 2] in [-1, 1] + xys = xys * (1 - half_grid_size) + # add noise in [-hgs, hgs] + xys += (torch.rand_like(xys) * 2 - 1) * half_grid_size + # query density + alphas, _, _ = self.forward_torso(xys, pose, ind_code) # [N, 1] + + # assign + tmp_grid_torso[indices] = alphas.squeeze(1).float() + + # dilate + tmp_grid_torso = tmp_grid_torso.view(1, 1, self.grid_size, self.grid_size) + tmp_grid_torso = F.max_pool2d(tmp_grid_torso, kernel_size=5, stride=1, padding=2) + tmp_grid_torso = tmp_grid_torso.view(-1) + + self.density_grid_torso = torch.maximum(self.density_grid_torso * decay, tmp_grid_torso) + self.mean_density_torso = torch.mean(self.density_grid_torso).item() diff --git a/Geneface_main/GeneFace/modules/radnerfs/raymarching/__init__.py b/Geneface_main/GeneFace/modules/radnerfs/raymarching/__init__.py new file mode 100644 index 00000000..26d3cc6d --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/raymarching/__init__.py @@ -0,0 +1 @@ +from .raymarching import * \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/raymarching/__pycache__/__init__.cpython-39.pyc b/Geneface_main/GeneFace/modules/radnerfs/raymarching/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 00000000..8002961e Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/raymarching/__pycache__/__init__.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/raymarching/__pycache__/raymarching.cpython-39.pyc b/Geneface_main/GeneFace/modules/radnerfs/raymarching/__pycache__/raymarching.cpython-39.pyc new file mode 100644 index 00000000..5e0ff5c1 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/raymarching/__pycache__/raymarching.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/raymarching/backend.py b/Geneface_main/GeneFace/modules/radnerfs/raymarching/backend.py new file mode 100644 index 00000000..d8f65d6f --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/raymarching/backend.py @@ -0,0 +1,40 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_raymarching_face', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'raymarching.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/raymarching/build/lib.linux-x86_64-cpython-39/_raymarching_face.cpython-39-x86_64-linux-gnu.so b/Geneface_main/GeneFace/modules/radnerfs/raymarching/build/lib.linux-x86_64-cpython-39/_raymarching_face.cpython-39-x86_64-linux-gnu.so new file mode 100644 index 00000000..3e84b26c Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/raymarching/build/lib.linux-x86_64-cpython-39/_raymarching_face.cpython-39-x86_64-linux-gnu.so differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/.ninja_deps b/Geneface_main/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/.ninja_deps new file mode 100644 index 00000000..1178fae7 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/.ninja_deps differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/.ninja_log b/Geneface_main/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/.ninja_log new file mode 100644 index 00000000..23751e08 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/.ninja_log @@ -0,0 +1,7 @@ +# ninja log v5 +0 9676 1734544964548549022 /home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/raymarching/src/bindings.o d44c8657275f4fca +0 29459 1734544984338544964 /home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/raymarching/src/raymarching.o 3dfa6976fffa9354 +3 11384 1734727728357554141 /output/GeneFace_reproduction/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/output/GeneFace_reproduction/GeneFace/modules/radnerfs/raymarching/src/bindings.o 39df9da0bf6cd4c3 +3 40521 1734727757505841825 /output/GeneFace_reproduction/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/output/GeneFace_reproduction/GeneFace/modules/radnerfs/raymarching/src/raymarching.o 67eddab14736585c +4 11484 1734801106313605309 /output/GeneFace_Reproduction/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/raymarching/src/bindings.o 48494375ed4a525f +5 40347 1734801135185888440 /output/GeneFace_Reproduction/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/raymarching/src/raymarching.o 25f431ed67794e9d diff --git a/Geneface_main/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/build.ninja b/Geneface_main/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/build.ninja new file mode 100644 index 00000000..d870b4a6 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/build.ninja @@ -0,0 +1,29 @@ +ninja_required_version = 1.3 +cxx = c++ +nvcc = /usr/local/cuda/bin/nvcc + +cflags = -pthread -B /output/GeneFace_Reproduction/conda/compiler_compat -Wno-unused-result -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /output/GeneFace_Reproduction/conda/include -I/output/GeneFace_Reproduction/conda/include -fPIC -O2 -isystem /output/GeneFace_Reproduction/conda/include -fPIC -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include/TH -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda/include -I/output/GeneFace_Reproduction/conda/include/python3.9 -c +post_cflags = -O3 -std=c++14 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_raymarching_face -D_GLIBCXX_USE_CXX11_ABI=0 +cuda_cflags = -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include/TH -I/output/GeneFace_Reproduction/conda/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda/include -I/output/GeneFace_Reproduction/conda/include/python3.9 -c +cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_raymarching_face -D_GLIBCXX_USE_CXX11_ABI=0 +ldflags = + +rule compile + command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags + depfile = $out.d + deps = gcc + +rule cuda_compile + depfile = $out.d + deps = gcc + command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags + + + +build /output/GeneFace_Reproduction/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/raymarching/src/bindings.o: compile /output/GeneFace_Reproduction/GeneFace/modules/radnerfs/raymarching/src/bindings.cpp +build /output/GeneFace_Reproduction/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/raymarching/src/raymarching.o: cuda_compile /output/GeneFace_Reproduction/GeneFace/modules/radnerfs/raymarching/src/raymarching.cu + + + + + diff --git a/Geneface_main/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/raymarching/src/bindings.o b/Geneface_main/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/raymarching/src/bindings.o new file mode 100644 index 00000000..bad379c1 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/raymarching/src/bindings.o differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/raymarching/src/raymarching.o b/Geneface_main/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/raymarching/src/raymarching.o new file mode 100644 index 00000000..d35f439b Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/home/dedfaf/GeneFace_reproduction/GeneFace/modules/radnerfs/raymarching/src/raymarching.o differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/raymarching/src/bindings.o b/Geneface_main/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/raymarching/src/bindings.o new file mode 100644 index 00000000..266d3b1b Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/raymarching/src/bindings.o differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/raymarching/src/raymarching.o b/Geneface_main/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/raymarching/src/raymarching.o new file mode 100644 index 00000000..b21cfd98 Binary files /dev/null and b/Geneface_main/GeneFace/modules/radnerfs/raymarching/build/temp.linux-x86_64-cpython-39/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/raymarching/src/raymarching.o differ diff --git a/Geneface_main/GeneFace/modules/radnerfs/raymarching/raymarching.py b/Geneface_main/GeneFace/modules/radnerfs/raymarching/raymarching.py new file mode 100644 index 00000000..22dd0441 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/raymarching/raymarching.py @@ -0,0 +1,423 @@ +import numpy as np +import time + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import _raymarching_face as _backend +except ImportError: + from .backend import _backend + +# ---------------------------------------- +# utils +# ---------------------------------------- + +class _near_far_from_aabb(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, rays_o, rays_d, aabb, min_near=0.2): + ''' near_far_from_aabb, CUDA implementation + Calculate rays' intersection time (near and far) with aabb + Args: + rays_o: float, [N, 3] + rays_d: float, [N, 3] + aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax) + min_near: float, scalar + Returns: + nears: float, [N] + fars: float, [N] + ''' + if not rays_o.is_cuda: rays_o = rays_o.cuda() + if not rays_d.is_cuda: rays_d = rays_d.cuda() + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # num rays + + nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) + fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) + + _backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars) + + return nears, fars + +near_far_from_aabb = _near_far_from_aabb.apply + + +class _sph_from_ray(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, rays_o, rays_d, radius): + ''' sph_from_ray, CUDA implementation + get spherical coordinate on the background sphere from rays. + Assume rays_o are inside the Sphere(radius). + Args: + rays_o: [N, 3] + rays_d: [N, 3] + radius: scalar, float + Return: + coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface) + ''' + if not rays_o.is_cuda: rays_o = rays_o.cuda() + if not rays_d.is_cuda: rays_d = rays_d.cuda() + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # num rays + + coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device) + + _backend.sph_from_ray(rays_o, rays_d, radius, N, coords) + + return coords + +sph_from_ray = _sph_from_ray.apply + + +class _morton3D(Function): + @staticmethod + def forward(ctx, coords): + ''' morton3D, CUDA implementation + Args: + coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...) + TODO: check if the coord range is valid! (current 128 is safe) + Returns: + indices: [N], int32, in [0, 128^3) + + ''' + if not coords.is_cuda: coords = coords.cuda() + + N = coords.shape[0] + + indices = torch.empty(N, dtype=torch.int32, device=coords.device) + + _backend.morton3D(coords.int(), N, indices) + + return indices + +morton3D = _morton3D.apply + +class _morton3D_invert(Function): + @staticmethod + def forward(ctx, indices): + ''' morton3D_invert, CUDA implementation + Args: + indices: [N], int32, in [0, 128^3) + Returns: + coords: [N, 3], int32, in [0, 128) + + ''' + if not indices.is_cuda: indices = indices.cuda() + + N = indices.shape[0] + + coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device) + + _backend.morton3D_invert(indices.int(), N, coords) + + return coords + +morton3D_invert = _morton3D_invert.apply + + +class _packbits(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, grid, thresh, bitfield=None): + ''' packbits, CUDA implementation + Pack up the density grid into a bit field to accelerate ray marching. + Args: + grid: float, [C, H * H * H], assume H % 2 == 0 + thresh: float, threshold + Returns: + bitfield: uint8, [C, H * H * H / 8] + ''' + if not grid.is_cuda: grid = grid.cuda() + grid = grid.contiguous() + + C = grid.shape[0] + H3 = grid.shape[1] + N = C * H3 // 8 + + if bitfield is None: + bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device) + + _backend.packbits(grid, N, thresh, bitfield) + + return bitfield + +packbits = _packbits.apply + + +class _morton3D_dilation(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, grid): + ''' max pooling with morton coord, CUDA implementation + or maybe call it dilation... we don't support adjust kernel size. + Args: + grid: float, [C, H * H * H], assume H % 2 == 0 + Returns: + grid_dilate: float, [C, H * H * H], assume H % 2 == 0bitfield: uint8, [C, H * H * H / 8] + ''' + if not grid.is_cuda: grid = grid.cuda() + grid = grid.contiguous() + + C = grid.shape[0] + H3 = grid.shape[1] + H = int(np.cbrt(H3)) + grid_dilation = torch.empty_like(grid) + + _backend.morton3D_dilation(grid, C, H, grid_dilation) + + return grid_dilation + +morton3D_dilation = _morton3D_dilation.apply + +# ---------------------------------------- +# train functions +# ---------------------------------------- + +class _march_rays_train(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024): + ''' march rays to generate points (forward only) + Args: + rays_o/d: float, [N, 3] + bound: float, scalar + density_bitfield: uint8: [CHHH // 8] + C: int + H: int + nears/fars: float, [N] + step_counter: int32, (2), used to count the actual number of generated points. + mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.) + perturb: bool + align: int, pad output so its size is dividable by align, set to -1 to disable. + force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays. + dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) + max_steps: int, max number of sampled points along each ray, also affect min_stepsize. + Returns: + xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray) + dirs: float, [M, 3], all generated points' view dirs. + deltas: float, [M, 2], first is delta_t, second is rays_t + rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 1] + rays[i, 2]] --> points belonging to rays[i, 0] + ''' + + if not rays_o.is_cuda: rays_o = rays_o.cuda() + if not rays_d.is_cuda: rays_d = rays_d.cuda() + if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda() + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + density_bitfield = density_bitfield.contiguous() + + N = rays_o.shape[0] # num rays + M = N * max_steps # init max points number in total + + # running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp) + # It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated. + if not force_all_rays and mean_count > 0: + if align > 0: + mean_count += align - mean_count % align + M = mean_count + + xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) + rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps + + if step_counter is None: + step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter + + if perturb: + noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device) + else: + noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device) + + _backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, noises) # m is the actually used points number + + #print(step_counter, M) + + # only used at the first (few) epochs. + if force_all_rays or mean_count <= 0: + m = step_counter[0].item() # D2H copy + if align > 0: + m += align - m % align + xyzs = xyzs[:m] + dirs = dirs[:m] + deltas = deltas[:m] + + torch.cuda.empty_cache() + + ctx.save_for_backward(rays, deltas) + + return xyzs, dirs, deltas, rays + + # to support optimizing camera poses. + @staticmethod + @custom_bwd + def backward(ctx, grad_xyzs, grad_dirs, grad_deltas, grad_rays): + # grad_xyzs/dirs: [M, 3] + + rays, deltas = ctx.saved_tensors + + N = rays.shape[0] + M = grad_xyzs.shape[0] + + grad_rays_o = torch.zeros(N, 3, device=rays.device) + grad_rays_d = torch.zeros(N, 3, device=rays.device) + + _backend.march_rays_train_backward(grad_xyzs, grad_dirs, rays, deltas, N, M, grad_rays_o, grad_rays_d) + + return grad_rays_o, grad_rays_d, None, None, None, None, None, None, None, None, None, None, None, None, None + +march_rays_train = _march_rays_train.apply + + +class _composite_rays_train(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, sigmas, rgbs, ambient, deltas, rays, T_thresh=1e-4): + ''' composite rays' rgbs, according to the ray marching formula. + Args: + rgbs: float, [M, 3] + sigmas: float, [M,] + ambient: float, [M,] (after summing up the last dimension) + deltas: float, [M, 2] + rays: int32, [N, 3] + Returns: + weights_sum: float, [N,], the alpha channel + depth: float, [N, ], the Depth + image: float, [N, 3], the RGB channel (after multiplying alpha!) + ''' + + sigmas = sigmas.contiguous() + rgbs = rgbs.contiguous() + ambient = ambient.contiguous() + + M = sigmas.shape[0] + N = rays.shape[0] + + weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + ambient_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device) + + _backend.composite_rays_train_forward(sigmas, rgbs, ambient, deltas, rays, M, N, T_thresh, weights_sum, ambient_sum, depth, image) + + ctx.save_for_backward(sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, depth, image) + ctx.dims = [M, N, T_thresh] + + return weights_sum, ambient_sum, depth, image + + @staticmethod + @custom_bwd + def backward(ctx, grad_weights_sum, grad_ambient_sum, grad_depth, grad_image): + + # NOTE: grad_depth is not used now! It won't be propagated to sigmas. + + grad_weights_sum = grad_weights_sum.contiguous() + grad_ambient_sum = grad_ambient_sum.contiguous() + grad_image = grad_image.contiguous() + + sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, depth, image = ctx.saved_tensors + M, N, T_thresh = ctx.dims + + grad_sigmas = torch.zeros_like(sigmas) + grad_rgbs = torch.zeros_like(rgbs) + grad_ambient = torch.zeros_like(ambient) + + _backend.composite_rays_train_backward(grad_weights_sum, grad_ambient_sum, grad_image, sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs, grad_ambient) + + return grad_sigmas, grad_rgbs, grad_ambient, None, None, None + + +composite_rays_train = _composite_rays_train.apply + +# ---------------------------------------- +# infer functions +# ---------------------------------------- + +class _march_rays(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, align=-1, perturb=False, dt_gamma=0, max_steps=1024): + ''' march rays to generate points (forward only, for inference) + Args: + n_alive: int, number of alive rays + n_step: int, how many steps we march + rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive) + rays_t: float, [N], the alive rays' time, we only use the first n_alive. + rays_o/d: float, [N, 3] + bound: float, scalar + density_bitfield: uint8: [CHHH // 8] + C: int + H: int + nears/fars: float, [N] + align: int, pad output so its size is dividable by align, set to -1 to disable. + perturb: bool/int, int > 0 is used as the random seed. + dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) + max_steps: int, max number of sampled points along each ray, also affect min_stepsize. + Returns: + xyzs: float, [n_alive * n_step, 3], all generated points' coords + dirs: float, [n_alive * n_step, 3], all generated points' view dirs. + deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth). + ''' + + if not rays_o.is_cuda: rays_o = rays_o.cuda() + if not rays_d.is_cuda: rays_d = rays_d.cuda() + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + M = n_alive * n_step + + if align > 0: + M += align - (M % align) + + xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth + + if perturb: + # torch.manual_seed(perturb) # test_gui uses spp index as seed + noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device) + else: + noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device) + + _backend.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, deltas, noises) + + return xyzs, dirs, deltas + +march_rays = _march_rays.apply + + +class _composite_rays(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float + def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh=1e-2): + ''' composite rays' rgbs, according to the ray marching formula. (for inference) + Args: + n_alive: int, number of alive rays + n_step: int, how many steps we march + rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive) + rays_t: float, [N], the alive rays' time + sigmas: float, [n_alive * n_step,] + rgbs: float, [n_alive * n_step, 3] + deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth). + In-place Outputs: + weights_sum: float, [N,], the alpha channel + depth: float, [N,], the depth value + image: float, [N, 3], the RGB channel (after multiplying alpha!) + ''' + _backend.composite_rays(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image) + return tuple() + + +composite_rays = _composite_rays.apply \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/raymarching/raymarching_face.egg-info/PKG-INFO b/Geneface_main/GeneFace/modules/radnerfs/raymarching/raymarching_face.egg-info/PKG-INFO new file mode 100644 index 00000000..9c4a56e3 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/raymarching/raymarching_face.egg-info/PKG-INFO @@ -0,0 +1,3 @@ +Metadata-Version: 2.1 +Name: raymarching_face +Version: 0.0.0 diff --git a/Geneface_main/GeneFace/modules/radnerfs/raymarching/raymarching_face.egg-info/SOURCES.txt b/Geneface_main/GeneFace/modules/radnerfs/raymarching/raymarching_face.egg-info/SOURCES.txt new file mode 100644 index 00000000..d11a8f9a --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/raymarching/raymarching_face.egg-info/SOURCES.txt @@ -0,0 +1,7 @@ +setup.py +/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/raymarching/src/bindings.cpp +/output/GeneFace_Reproduction/GeneFace/modules/radnerfs/raymarching/src/raymarching.cu +raymarching_face.egg-info/PKG-INFO +raymarching_face.egg-info/SOURCES.txt +raymarching_face.egg-info/dependency_links.txt +raymarching_face.egg-info/top_level.txt \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/raymarching/raymarching_face.egg-info/dependency_links.txt b/Geneface_main/GeneFace/modules/radnerfs/raymarching/raymarching_face.egg-info/dependency_links.txt new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/raymarching/raymarching_face.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/Geneface_main/GeneFace/modules/radnerfs/raymarching/raymarching_face.egg-info/top_level.txt b/Geneface_main/GeneFace/modules/radnerfs/raymarching/raymarching_face.egg-info/top_level.txt new file mode 100644 index 00000000..8b34112c --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/raymarching/raymarching_face.egg-info/top_level.txt @@ -0,0 +1 @@ +_raymarching_face diff --git a/Geneface_main/GeneFace/modules/radnerfs/raymarching/setup.py b/Geneface_main/GeneFace/modules/radnerfs/raymarching/setup.py new file mode 100644 index 00000000..6a7e62f7 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/raymarching/setup.py @@ -0,0 +1,63 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + # '-lineinfo', # to debug illegal memory access + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +''' +Usage: + +python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory) + +python setup.py install # build extensions and install (copy) to PATH. +pip install . # ditto but better (e.g., dependency & metadata handling) + +python setup.py develop # build extensions and install (symbolic) to PATH. +pip install -e . # ditto but better (e.g., dependency & metadata handling) + +''' +setup( + name='raymarching_face', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_raymarching_face', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'raymarching.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/raymarching/src/bindings.cpp b/Geneface_main/GeneFace/modules/radnerfs/raymarching/src/bindings.cpp new file mode 100644 index 00000000..589de244 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/raymarching/src/bindings.cpp @@ -0,0 +1,21 @@ +#include + +#include "raymarching.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // utils + m.def("packbits", &packbits, "packbits (CUDA)"); + m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)"); + m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)"); + m.def("morton3D", &morton3D, "morton3D (CUDA)"); + m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)"); + m.def("morton3D_dilation", &morton3D_dilation, "morton3D_dilation (CUDA)"); + // train + m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)"); + m.def("march_rays_train_backward", &march_rays_train_backward, "march_rays_train_backward (CUDA)"); + m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)"); + m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)"); + // infer + m.def("march_rays", &march_rays, "march rays (CUDA)"); + m.def("composite_rays", &composite_rays, "composite rays (CUDA)"); +} \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/raymarching/src/raymarching.cu b/Geneface_main/GeneFace/modules/radnerfs/raymarching/src/raymarching.cu new file mode 100644 index 00000000..ae5839bc --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/raymarching/src/raymarching.cu @@ -0,0 +1,1038 @@ +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + + +inline constexpr __device__ float SQRT3() { return 1.7320508075688772f; } +inline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; } +inline constexpr __device__ float PI() { return 3.141592653589793f; } +inline constexpr __device__ float RPI() { return 0.3183098861837907f; } + + +template +inline __host__ __device__ T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +inline __host__ __device__ float signf(const float x) { + return copysignf(1.0, x); +} + +inline __host__ __device__ float clamp(const float x, const float min, const float max) { + return fminf(max, fmaxf(min, x)); +} + +inline __host__ __device__ void swapf(float& a, float& b) { + float c = a; a = b; b = c; +} + +inline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) { + const float mx = fmaxf(fabsf(x), fmaxf(fabs(y), fabs(z))); + int exponent; + frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, [2, 4) --> 2, ... + return fminf(max_cascade - 1, fmaxf(0, exponent)); +} + +inline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) { + const float mx = dt * H * 0.5; + int exponent; + frexpf(mx, &exponent); + return fminf(max_cascade - 1, fmaxf(0, exponent)); +} + +inline __host__ __device__ uint32_t __expand_bits(uint32_t v) +{ + v = (v * 0x00010001u) & 0xFF0000FFu; + v = (v * 0x00000101u) & 0x0F00F00Fu; + v = (v * 0x00000011u) & 0xC30C30C3u; + v = (v * 0x00000005u) & 0x49249249u; + return v; +} + +inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z) +{ + uint32_t xx = __expand_bits(x); + uint32_t yy = __expand_bits(y); + uint32_t zz = __expand_bits(z); + return xx | (yy << 1) | (zz << 2); +} + +inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x) +{ + x = x & 0x49249249; + x = (x | (x >> 2)) & 0xc30c30c3; + x = (x | (x >> 4)) & 0x0f00f00f; + x = (x | (x >> 8)) & 0xff0000ff; + x = (x | (x >> 16)) & 0x0000ffff; + return x; +} + + +//////////////////////////////////////////////////// +///////////// utils ///////////// +//////////////////////////////////////////////////// + +// rays_o/d: [N, 3] +// nears/fars: [N] +// scalar_t should always be float in use. +template +__global__ void kernel_near_far_from_aabb( + const scalar_t * __restrict__ rays_o, + const scalar_t * __restrict__ rays_d, + const scalar_t * __restrict__ aabb, + const uint32_t N, + const float min_near, + scalar_t * nears, scalar_t * fars +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + rays_o += n * 3; + rays_d += n * 3; + + const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; + const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; + const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; + + // get near far (assume cube scene) + float near = (aabb[0] - ox) * rdx; + float far = (aabb[3] - ox) * rdx; + if (near > far) swapf(near, far); + + float near_y = (aabb[1] - oy) * rdy; + float far_y = (aabb[4] - oy) * rdy; + if (near_y > far_y) swapf(near_y, far_y); + + if (near > far_y || near_y > far) { + nears[n] = fars[n] = std::numeric_limits::max(); + return; + } + + if (near_y > near) near = near_y; + if (far_y < far) far = far_y; + + float near_z = (aabb[2] - oz) * rdz; + float far_z = (aabb[5] - oz) * rdz; + if (near_z > far_z) swapf(near_z, far_z); + + if (near > far_z || near_z > far) { + nears[n] = fars[n] = std::numeric_limits::max(); + return; + } + + if (near_z > near) near = near_z; + if (far_z < far) far = far_z; + + if (near < min_near) near = min_near; + + nears[n] = near; + fars[n] = far; +} + + +void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rays_o.scalar_type(), "near_far_from_aabb", ([&] { + kernel_near_far_from_aabb<<>>(rays_o.data_ptr(), rays_d.data_ptr(), aabb.data_ptr(), N, min_near, nears.data_ptr(), fars.data_ptr()); + })); +} + + +// rays_o/d: [N, 3] +// radius: float +// coords: [N, 2] +template +__global__ void kernel_sph_from_ray( + const scalar_t * __restrict__ rays_o, + const scalar_t * __restrict__ rays_d, + const float radius, + const uint32_t N, + scalar_t * coords +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + rays_o += n * 3; + rays_d += n * 3; + coords += n * 2; + + const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; + const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; + const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; + + // solve t from || o + td || = radius + const float A = dx * dx + dy * dy + dz * dz; + const float B = ox * dx + oy * dy + oz * dz; // in fact B / 2 + const float C = ox * ox + oy * oy + oz * oz - radius * radius; + + const float t = (- B + sqrtf(B * B - A * C)) / A; // always use the larger solution (positive) + + // solve theta, phi (assume y is the up axis) + const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz; + const float theta = atan2(sqrtf(x * x + z * z), y); // [0, PI) + const float phi = atan2(z, x); // [-PI, PI) + + // normalize to [-1, 1] + coords[0] = 2 * theta * RPI() - 1; + coords[1] = phi * RPI(); +} + + +void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rays_o.scalar_type(), "sph_from_ray", ([&] { + kernel_sph_from_ray<<>>(rays_o.data_ptr(), rays_d.data_ptr(), radius, N, coords.data_ptr()); + })); +} + + +// coords: int32, [N, 3] +// indices: int32, [N] +__global__ void kernel_morton3D( + const int * __restrict__ coords, + const uint32_t N, + int * indices +) { + // parallel + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + coords += n * 3; + indices[n] = __morton3D(coords[0], coords[1], coords[2]); +} + + +void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices) { + static constexpr uint32_t N_THREAD = 128; + kernel_morton3D<<>>(coords.data_ptr(), N, indices.data_ptr()); +} + + +// indices: int32, [N] +// coords: int32, [N, 3] +__global__ void kernel_morton3D_invert( + const int * __restrict__ indices, + const uint32_t N, + int * coords +) { + // parallel + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + coords += n * 3; + + const int ind = indices[n]; + + coords[0] = __morton3D_invert(ind >> 0); + coords[1] = __morton3D_invert(ind >> 1); + coords[2] = __morton3D_invert(ind >> 2); +} + + +void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords) { + static constexpr uint32_t N_THREAD = 128; + kernel_morton3D_invert<<>>(indices.data_ptr(), N, coords.data_ptr()); +} + + +// grid: float, [C, H, H, H] +// N: int, C * H * H * H / 8 +// density_thresh: float +// bitfield: uint8, [N] +template +__global__ void kernel_packbits( + const scalar_t * __restrict__ grid, + const uint32_t N, + const float density_thresh, + uint8_t * bitfield +) { + // parallel per byte + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + grid += n * 8; + + uint8_t bits = 0; + + #pragma unroll + for (uint8_t i = 0; i < 8; i++) { + bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0; + } + + bitfield[n] = bits; +} + + +void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grid.scalar_type(), "packbits", ([&] { + kernel_packbits<<>>(grid.data_ptr(), N, density_thresh, bitfield.data_ptr()); + })); +} + + +// grid: float, [C, H, H, H] +__global__ void kernel_morton3D_dilation( + const float * __restrict__ grid, + const uint32_t C, + const uint32_t H, + float * __restrict__ grid_dilation +) { + // parallel per byte + const uint32_t H3 = H * H * H; + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= C * H3) return; + + // locate + const uint32_t c = n / H3; + const uint32_t ind = n - c * H3; + + const uint32_t x = __morton3D_invert(ind >> 0); + const uint32_t y = __morton3D_invert(ind >> 1); + const uint32_t z = __morton3D_invert(ind >> 2); + + // manual max pool + float res = grid[n]; + + if (x + 1 < H) res = fmaxf(res, grid[c * H3 + __morton3D(x + 1, y, z)]); + if (x > 0) res = fmaxf(res, grid[c * H3 + __morton3D(x - 1, y, z)]); + if (y + 1 < H) res = fmaxf(res, grid[c * H3 + __morton3D(x, y + 1, z)]); + if (y > 0) res = fmaxf(res, grid[c * H3 + __morton3D(x, y - 1, z)]); + if (z + 1 < H) res = fmaxf(res, grid[c * H3 + __morton3D(x, y, z + 1)]); + if (z > 0) res = fmaxf(res, grid[c * H3 + __morton3D(x, y, z - 1)]); + + // write + grid_dilation[n] = res; +} + +void morton3D_dilation(const at::Tensor grid, const uint32_t C, const uint32_t H, at::Tensor grid_dilation) { + static constexpr uint32_t N_THREAD = 128; + + kernel_morton3D_dilation<<>>(grid.data_ptr(), C, H, grid_dilation.data_ptr()); +} + +//////////////////////////////////////////////////// +///////////// training ///////////// +//////////////////////////////////////////////////// + +// rays_o/d: [N, 3] +// grid: [CHHH / 8] +// xyzs, dirs, deltas: [M, 3], [M, 3], [M, 2] +// dirs: [M, 3] +// rays: [N, 3], idx, offset, num_steps +template +__global__ void kernel_march_rays_train( + const scalar_t * __restrict__ rays_o, + const scalar_t * __restrict__ rays_d, + const uint8_t * __restrict__ grid, + const float bound, + const float dt_gamma, const uint32_t max_steps, + const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, + const scalar_t* __restrict__ nears, + const scalar_t* __restrict__ fars, + scalar_t * xyzs, scalar_t * dirs, scalar_t * deltas, + int * rays, + int * counter, + const scalar_t* __restrict__ noises +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + rays_o += n * 3; + rays_d += n * 3; + + // ray marching + const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; + const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; + const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; + const float rH = 1 / (float)H; + const float H3 = H * H * H; + + const float near = nears[n]; + const float far = fars[n]; + const float noise = noises[n]; + + const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H; + const float dt_min = fminf(dt_max, 2 * SQRT3() / max_steps); + + float t0 = near; + + // perturb + t0 += clamp(t0 * dt_gamma, dt_min, dt_max) * noise; + + // first pass: estimation of num_steps + float t = t0; + uint32_t num_steps = 0; + + //if (t < far) printf("valid ray %d t=%f near=%f far=%f \n", n, t, near, far); + + while (t < far && num_steps < max_steps) { + // current point + const float x = clamp(ox + t * dx, -bound, bound); + const float y = clamp(oy + t * dy, -bound, bound); + const float z = clamp(oz + t * dz, -bound, bound); + + const float dt = clamp(t * dt_gamma, dt_min, dt_max); + + // get mip level + const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] + + const float mip_bound = fminf(scalbnf(1.0f, level), bound); + const float mip_rbound = 1 / mip_bound; + + // convert to nearest grid position + const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + + const uint32_t index = level * H3 + __morton3D(nx, ny, nz); + const bool occ = grid[index / 8] & (1 << (index % 8)); + + // if occpuied, advance a small step, and write to output + //if (n == 0) printf("t=%f density=%f vs thresh=%f step=%d\n", t, density, density_thresh, num_steps); + + if (occ) { + num_steps++; + t += dt; + // else, skip a large step (basically skip a voxel grid) + } else { + // calc distance to next voxel + const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; + const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; + const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; + + const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); + // step until next voxel + do { + t += clamp(t * dt_gamma, dt_min, dt_max); + } while (t < tt); + } + } + + //printf("[n=%d] num_steps=%d, near=%f, far=%f, dt=%f, max_steps=%f\n", n, num_steps, near, far, dt_min, (far - near) / dt_min); + + // second pass: really locate and write points & dirs + uint32_t point_index = atomicAdd(counter, num_steps); + uint32_t ray_index = atomicAdd(counter + 1, 1); + + //printf("[n=%d] num_steps=%d, point_index=%d, ray_index=%d\n", n, num_steps, point_index, ray_index); + + // write rays + rays[ray_index * 3] = n; + rays[ray_index * 3 + 1] = point_index; + rays[ray_index * 3 + 2] = num_steps; + + if (num_steps == 0) return; + if (point_index + num_steps > M) return; + + xyzs += point_index * 3; + dirs += point_index * 3; + deltas += point_index * 2; + + t = t0; + uint32_t step = 0; + + while (t < far && step < num_steps) { + // current point + const float x = clamp(ox + t * dx, -bound, bound); + const float y = clamp(oy + t * dy, -bound, bound); + const float z = clamp(oz + t * dz, -bound, bound); + + const float dt = clamp(t * dt_gamma, dt_min, dt_max); + + // get mip level + const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] + + const float mip_bound = fminf(scalbnf(1.0f, level), bound); + const float mip_rbound = 1 / mip_bound; + + // convert to nearest grid position + const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + + // query grid + const uint32_t index = level * H3 + __morton3D(nx, ny, nz); + const bool occ = grid[index / 8] & (1 << (index % 8)); + + // if occpuied, advance a small step, and write to output + if (occ) { + // write step + xyzs[0] = x; + xyzs[1] = y; + xyzs[2] = z; + dirs[0] = dx; + dirs[1] = dy; + dirs[2] = dz; + t += dt; + deltas[0] = dt; + deltas[1] = t; // used to calc depth + xyzs += 3; + dirs += 3; + deltas += 2; + step++; + // else, skip a large step (basically skip a voxel grid) + } else { + // calc distance to next voxel + const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; + const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; + const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; + const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); + // step until next voxel + do { + t += clamp(t * dt_gamma, dt_min, dt_max); + } while (t < tt); + } + } +} + +void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rays_o.scalar_type(), "march_rays_train", ([&] { + kernel_march_rays_train<<>>(rays_o.data_ptr(), rays_d.data_ptr(), grid.data_ptr(), bound, dt_gamma, max_steps, N, C, H, M, nears.data_ptr(), fars.data_ptr(), xyzs.data_ptr(), dirs.data_ptr(), deltas.data_ptr(), rays.data_ptr(), counter.data_ptr(), noises.data_ptr()); + })); +} + + +// grad_xyzs/dirs: [M, 3] +// rays: [N, 3] +// deltas: [M, 2] +// grad_rays_o/d: [N, 3] +template +__global__ void kernel_march_rays_train_backward( + const scalar_t * __restrict__ grad_xyzs, + const scalar_t * __restrict__ grad_dirs, + const int * __restrict__ rays, + const scalar_t * __restrict__ deltas, + const uint32_t N, const uint32_t M, + scalar_t * grad_rays_o, + scalar_t * grad_rays_d +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + grad_rays_o += n * 3; + grad_rays_d += n * 3; + + uint32_t index = rays[n * 3]; + uint32_t offset = rays[n * 3 + 1]; + uint32_t num_steps = rays[n * 3 + 2]; + + // empty ray, or ray that exceed max step count. + if (num_steps == 0 || offset + num_steps > M) return; + + grad_xyzs += offset * 3; + grad_dirs += offset * 3; + deltas += offset * 2; + + // accumulate + uint32_t step = 0; + while (step < num_steps) { + + grad_rays_o[0] += grad_xyzs[0]; + grad_rays_o[1] += grad_xyzs[1]; + grad_rays_o[2] += grad_xyzs[2]; + + grad_rays_d[0] += grad_xyzs[0] * deltas[1] + grad_dirs[0]; + grad_rays_d[1] += grad_xyzs[1] * deltas[1] + grad_dirs[1]; + grad_rays_d[2] += grad_xyzs[2] * deltas[1] + grad_dirs[2]; + + // locate + grad_xyzs += 3; + grad_dirs += 3; + deltas += 2; + + step++; + } +} + +void march_rays_train_backward(const at::Tensor grad_xyzs, const at::Tensor grad_dirs, const at::Tensor rays, const at::Tensor deltas, const uint32_t N, const uint32_t M, at::Tensor grad_rays_o, at::Tensor grad_rays_d) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_xyzs.scalar_type(), "march_rays_train_backward", ([&] { + kernel_march_rays_train_backward<<>>(grad_xyzs.data_ptr(), grad_dirs.data_ptr(), rays.data_ptr(), deltas.data_ptr(), N, M, grad_rays_o.data_ptr(), grad_rays_d.data_ptr()); + })); +} + + +// sigmas: [M] +// rgbs: [M, 3] +// deltas: [M, 2] +// rays: [N, 3], idx, offset, num_steps +// weights_sum: [N], final pixel alpha +// depth: [N,] +// image: [N, 3] +template +__global__ void kernel_composite_rays_train_forward( + const scalar_t * __restrict__ sigmas, + const scalar_t * __restrict__ rgbs, + const scalar_t * __restrict__ ambient, + const scalar_t * __restrict__ deltas, + const int * __restrict__ rays, + const uint32_t M, const uint32_t N, const float T_thresh, + scalar_t * weights_sum, + scalar_t * ambient_sum, + scalar_t * depth, + scalar_t * image +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + uint32_t index = rays[n * 3]; + uint32_t offset = rays[n * 3 + 1]; + uint32_t num_steps = rays[n * 3 + 2]; + + // empty ray, or ray that exceed max step count. + if (num_steps == 0 || offset + num_steps > M) { + weights_sum[index] = 0; + ambient_sum[index] = 0; + depth[index] = 0; + image[index * 3] = 0; + image[index * 3 + 1] = 0; + image[index * 3 + 2] = 0; + return; + } + + sigmas += offset; + rgbs += offset * 3; + ambient += offset; + deltas += offset * 2; + + // accumulate + uint32_t step = 0; + + scalar_t T = 1.0f; + scalar_t r = 0, g = 0, b = 0, ws = 0, d = 0, amb = 0; + + while (step < num_steps) { + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + const scalar_t weight = alpha * T; + + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + + d += weight * deltas[1]; + + ws += weight; + + amb += ambient[0]; + + T *= 1.0f - alpha; + + // minimal remained transmittence + if (T < T_thresh) break; + + //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); + + // locate + sigmas++; + rgbs += 3; + ambient++; + deltas += 2; + + step++; + } + + //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); + + // write + weights_sum[index] = ws; // weights_sum + ambient_sum[index] = amb; + depth[index] = d; + image[index * 3] = r; + image[index * 3 + 1] = g; + image[index * 3 + 2] = b; +} + + +void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor ambient_sum, at::Tensor depth, at::Tensor image) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + sigmas.scalar_type(), "composite_rays_train_forward", ([&] { + kernel_composite_rays_train_forward<<>>(sigmas.data_ptr(), rgbs.data_ptr(), ambient.data_ptr(), deltas.data_ptr(), rays.data_ptr(), M, N, T_thresh, weights_sum.data_ptr(), ambient_sum.data_ptr(), depth.data_ptr(), image.data_ptr()); + })); +} + + +// grad_weights_sum: [N,] +// grad: [N, 3] +// sigmas: [M] +// rgbs: [M, 3] +// deltas: [M, 2] +// rays: [N, 3], idx, offset, num_steps +// weights_sum: [N,], weights_sum here +// image: [N, 3] +// grad_sigmas: [M] +// grad_rgbs: [M, 3] +template +__global__ void kernel_composite_rays_train_backward( + const scalar_t * __restrict__ grad_weights_sum, + const scalar_t * __restrict__ grad_ambient_sum, + const scalar_t * __restrict__ grad_image, + const scalar_t * __restrict__ sigmas, + const scalar_t * __restrict__ rgbs, + const scalar_t * __restrict__ ambient, + const scalar_t * __restrict__ deltas, + const int * __restrict__ rays, + const scalar_t * __restrict__ weights_sum, + const scalar_t * __restrict__ ambient_sum, + const scalar_t * __restrict__ image, + const uint32_t M, const uint32_t N, const float T_thresh, + scalar_t * grad_sigmas, + scalar_t * grad_rgbs, + scalar_t * grad_ambient +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + uint32_t index = rays[n * 3]; + uint32_t offset = rays[n * 3 + 1]; + uint32_t num_steps = rays[n * 3 + 2]; + + if (num_steps == 0 || offset + num_steps > M) return; + + grad_weights_sum += index; + grad_ambient_sum += index; + grad_image += index * 3; + weights_sum += index; + ambient_sum += index; + image += index * 3; + + sigmas += offset; + rgbs += offset * 3; + ambient += offset; + deltas += offset * 2; + + grad_sigmas += offset; + grad_rgbs += offset * 3; + grad_ambient += offset; + + // accumulate + uint32_t step = 0; + + scalar_t T = 1.0f; + const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0]; + scalar_t r = 0, g = 0, b = 0, ws = 0; + + while (step < num_steps) { + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + const scalar_t weight = alpha * T; + + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + // amb += weight * ambient[0]; + ws += weight; + + T *= 1.0f - alpha; + + // check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation. + // write grad_rgbs + grad_rgbs[0] = grad_image[0] * weight; + grad_rgbs[1] = grad_image[1] * weight; + grad_rgbs[2] = grad_image[2] * weight; + + // write grad_ambient + grad_ambient[0] = grad_ambient_sum[0]; + + // write grad_sigmas + grad_sigmas[0] = deltas[0] * ( + grad_image[0] * (T * rgbs[0] - (r_final - r)) + + grad_image[1] * (T * rgbs[1] - (g_final - g)) + + grad_image[2] * (T * rgbs[2] - (b_final - b)) + + // grad_ambient_sum[0] * (T * ambient[0] - (amb_final - amb)) + + grad_weights_sum[0] * (1 - ws_final) + ); + + //printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r); + // minimal remained transmittence + if (T < T_thresh) break; + + // locate + sigmas++; + rgbs += 3; + // ambient++; + deltas += 2; + grad_sigmas++; + grad_rgbs += 3; + grad_ambient++; + + step++; + } +} + + +void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_ambient_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor ambient_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_ambient) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_image.scalar_type(), "composite_rays_train_backward", ([&] { + kernel_composite_rays_train_backward<<>>(grad_weights_sum.data_ptr(), grad_ambient_sum.data_ptr(), grad_image.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), ambient.data_ptr(), deltas.data_ptr(), rays.data_ptr(), weights_sum.data_ptr(), ambient_sum.data_ptr(), image.data_ptr(), M, N, T_thresh, grad_sigmas.data_ptr(), grad_rgbs.data_ptr(), grad_ambient.data_ptr()); + })); +} + + +//////////////////////////////////////////////////// +///////////// infernce ///////////// +//////////////////////////////////////////////////// + +template +__global__ void kernel_march_rays( + const uint32_t n_alive, + const uint32_t n_step, + const int* __restrict__ rays_alive, + const scalar_t* __restrict__ rays_t, + const scalar_t* __restrict__ rays_o, + const scalar_t* __restrict__ rays_d, + const float bound, + const float dt_gamma, const uint32_t max_steps, + const uint32_t C, const uint32_t H, + const uint8_t * __restrict__ grid, + const scalar_t* __restrict__ nears, + const scalar_t* __restrict__ fars, + scalar_t* xyzs, scalar_t* dirs, scalar_t* deltas, + const scalar_t* __restrict__ noises +) { + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= n_alive) return; + + const int index = rays_alive[n]; // ray id + const float noise = noises[n]; + + // locate + rays_o += index * 3; + rays_d += index * 3; + xyzs += n * n_step * 3; + dirs += n * n_step * 3; + deltas += n * n_step * 2; + + const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; + const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; + const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; + const float rH = 1 / (float)H; + const float H3 = H * H * H; + + float t = rays_t[index]; // current ray's t + const float near = nears[index], far = fars[index]; + + const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H; + const float dt_min = fminf(dt_max, 2 * SQRT3() / max_steps); + + // march for n_step steps, record points + uint32_t step = 0; + + // introduce some randomness + t += clamp(t * dt_gamma, dt_min, dt_max) * noise; + + while (t < far && step < n_step) { + // current point + const float x = clamp(ox + t * dx, -bound, bound); + const float y = clamp(oy + t * dy, -bound, bound); + const float z = clamp(oz + t * dz, -bound, bound); + + const float dt = clamp(t * dt_gamma, dt_min, dt_max); + + // get mip level + const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] + + const float mip_bound = fminf(scalbnf(1, level), bound); + const float mip_rbound = 1 / mip_bound; + + // convert to nearest grid position + const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + + const uint32_t index = level * H3 + __morton3D(nx, ny, nz); + const bool occ = grid[index / 8] & (1 << (index % 8)); + + // if occpuied, advance a small step, and write to output + if (occ) { + // write step + xyzs[0] = x; + xyzs[1] = y; + xyzs[2] = z; + dirs[0] = dx; + dirs[1] = dy; + dirs[2] = dz; + // calc dt + t += dt; + deltas[0] = dt; + deltas[1] = t; // used to calc depth + // step + xyzs += 3; + dirs += 3; + deltas += 2; + step++; + + // else, skip a large step (basically skip a voxel grid) + } else { + // calc distance to next voxel + const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; + const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; + const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; + const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); + // step until next voxel + do { + t += clamp(t * dt_gamma, dt_min, dt_max); + } while (t < tt); + } + } +} + + +void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor near, const at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises) { + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rays_o.scalar_type(), "march_rays", ([&] { + kernel_march_rays<<>>(n_alive, n_step, rays_alive.data_ptr(), rays_t.data_ptr(), rays_o.data_ptr(), rays_d.data_ptr(), bound, dt_gamma, max_steps, C, H, grid.data_ptr(), near.data_ptr(), far.data_ptr(), xyzs.data_ptr(), dirs.data_ptr(), deltas.data_ptr(), noises.data_ptr()); + })); +} + + +template +__global__ void kernel_composite_rays( + const uint32_t n_alive, + const uint32_t n_step, + const float T_thresh, + int* rays_alive, + scalar_t* rays_t, + const scalar_t* __restrict__ sigmas, + const scalar_t* __restrict__ rgbs, + const scalar_t* __restrict__ deltas, + scalar_t* weights_sum, scalar_t* depth, scalar_t* image +) { + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= n_alive) return; + + const int index = rays_alive[n]; // ray id + + // locate + sigmas += n * n_step; + rgbs += n * n_step * 3; + deltas += n * n_step * 2; + + rays_t += index; + weights_sum += index; + depth += index; + image += index * 3; + + scalar_t t = rays_t[0]; // current ray's t + + scalar_t weight_sum = weights_sum[0]; + scalar_t d = depth[0]; + scalar_t r = image[0]; + scalar_t g = image[1]; + scalar_t b = image[2]; + + // accumulate + uint32_t step = 0; + while (step < n_step) { + + // ray is terminated if delta == 0 + if (deltas[0] == 0) break; + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + + /* + T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j) + w_i = alpha_i * T_i + --> + T_i = 1 - \sum_{j=0}^{i-1} w_j + */ + const scalar_t T = 1 - weight_sum; + const scalar_t weight = alpha * T; + weight_sum += weight; + + t = deltas[1]; + d += weight * t; + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + + //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); + + // ray is terminated if T is too small + // use a larger bound to further accelerate inference + if (T < T_thresh) break; + + // locate + sigmas++; + rgbs += 3; + deltas += 2; + step++; + } + + //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); + + // rays_alive = -1 means ray is terminated early. + if (step < n_step) { + rays_alive[n] = -1; + } else { + rays_t[0] = t; + } + + weights_sum[0] = weight_sum; // this is the thing I needed! + depth[0] = d; + image[0] = r; + image[1] = g; + image[2] = b; +} + + +void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights, at::Tensor depth, at::Tensor image) { + static constexpr uint32_t N_THREAD = 128; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + image.scalar_type(), "composite_rays", ([&] { + kernel_composite_rays<<>>(n_alive, n_step, T_thresh, rays_alive.data_ptr(), rays_t.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), weights.data_ptr(), depth.data_ptr(), image.data_ptr()); + })); +} \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/raymarching/src/raymarching.h b/Geneface_main/GeneFace/modules/radnerfs/raymarching/src/raymarching.h new file mode 100644 index 00000000..e7d9b219 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/raymarching/src/raymarching.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + + +void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars); +void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords); +void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices); +void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords); +void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield); +void morton3D_dilation(const at::Tensor grid, const uint32_t C, const uint32_t H, at::Tensor grid_dilation); + +void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises); +void march_rays_train_backward(const at::Tensor grad_xyzs, const at::Tensor grad_dirs, const at::Tensor rays, const at::Tensor deltas, const uint32_t N, const uint32_t M, at::Tensor grad_rays_o, at::Tensor grad_rays_d); +void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor ambient_sum, at::Tensor depth, at::Tensor image); +void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_ambient_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor ambient_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_ambient); + +void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises); +void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); \ No newline at end of file diff --git a/Geneface_main/GeneFace/modules/radnerfs/renderer.py b/Geneface_main/GeneFace/modules/radnerfs/renderer.py new file mode 100644 index 00000000..4f9673d2 --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/renderer.py @@ -0,0 +1,368 @@ +import math +import trimesh +import numpy as np +import random + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import modules.radnerfs.raymarching as raymarching +from modules.radnerfs.utils import custom_meshgrid, get_audio_features, euler_angles_to_matrix, convert_poses + + +def sample_pdf(bins, weights, n_samples, det=False): + # This implementation is from NeRF + # bins: [B, T], old_z_vals + # weights: [B, T - 1], bin weights. + # return: [B, n_samples], new_z_vals + + # Get pdf + weights = weights + 1e-5 # prevent nans + pdf = weights / torch.sum(weights, -1, keepdim=True) + cdf = torch.cumsum(pdf, -1) + cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) + # Take uniform samples + if det: + u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device) + u = u.expand(list(cdf.shape[:-1]) + [n_samples]) + else: + u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device) + + # Invert CDF + u = u.contiguous() + inds = torch.searchsorted(cdf, u, right=True) + below = torch.max(torch.zeros_like(inds - 1), inds - 1) + above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) + inds_g = torch.stack([below, above], -1) # (B, n_samples, 2) + + matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] + cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) + bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) + + denom = (cdf_g[..., 1] - cdf_g[..., 0]) + denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) + t = (u - cdf_g[..., 0]) / denom + samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) + + return samples + + +def plot_pointcloud(pc, color=None): + # pc: [N, 3] + # color: [N, 3/4] + print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0)) + pc = trimesh.PointCloud(pc, color) + # axis + axes = trimesh.creation.axis(axis_length=4) + # sphere + sphere = trimesh.creation.icosphere(radius=1) + trimesh.Scene([pc, axes, sphere]).show() + + +class NeRFRenderer(nn.Module): + def __init__(self, hparams): + super().__init__() + self.bound = hparams['bound'] + self.cascade = 1 + math.ceil(math.log2(hparams['bound'])) + self.grid_size = hparams['grid_size'] + self.density_scale = 1 + + self.min_near = hparams['min_near'] + self.density_thresh = hparams['density_thresh'] + + self.cuda_ray = hparams['cuda_ray'] + + # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax) + # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing. + aabb_train = torch.FloatTensor([-self.bound, -self.bound/2, -self.bound, self.bound, self.bound/2, self.bound]) + aabb_infer = aabb_train.clone() + self.register_buffer('aabb_train', aabb_train) + self.register_buffer('aabb_infer', aabb_infer) + + # individual codes + self.individual_embedding_num = hparams['individual_embedding_num'] + self.individual_embedding_dim = hparams['individual_embedding_dim'] + if self.individual_embedding_dim > 0: + self.individual_embeddings = nn.Parameter(torch.randn(self.individual_embedding_num, self.individual_embedding_dim) * 0.1) + + # 3D head density grid + density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H] + density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8] + self.register_buffer('density_grid', density_grid) # points of the grid + self.register_buffer('density_bitfield', density_bitfield) # use 8 bit [0~255] to represent 8 points of a cube, if grid[i]>density threshold, set this bit to 1, so each cube can be represent as 0-255 + self.mean_density = 0 + self.iter_density = 0 + + # step counter + step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging... + self.register_buffer('step_counter', step_counter) + self.mean_count = 0 + self.local_step = 0 + + def cal_cond_feat(self, cond): + raise NotImplementedError() + + def forward(self, x, d): + raise NotImplementedError() + + # separated density and color query (can accelerate non-cuda-ray mode.) + def density(self, x): + raise NotImplementedError() + + def color(self, x, d, mask=None, **kwargs): + raise NotImplementedError() + + def reset_extra_state(self): + if not self.cuda_ray: + return + # density grid + self.density_grid.zero_() + self.mean_density = 0 + self.iter_density = 0 + # step counter + self.step_counter.zero_() + self.mean_count = 0 + self.local_step = 0 + + @torch.no_grad() + def mark_untrained_grid(self, poses, intrinsic, S=64): + # poses: [B, 4, 4] + # intrinsic: [3, 3] + + if not self.cuda_ray: + return + + if isinstance(poses, np.ndarray): + poses = torch.from_numpy(poses) + + B = poses.shape[0] + + fx, fy, cx, cy = intrinsic + + ori_device = self.density_bitfield.device + self.density_bitfield = self.density_bitfield.cuda() + self.density_grid = self.density_grid.cuda() + + X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + + count = torch.zeros_like(self.density_grid) + poses = poses.to(count.device) + + # 5-level loop, forgive me... + + for xs in X: + for ys in Y: + for zs in Z: + + # construct points + xx, yy, zz = custom_meshgrid(xs, ys, zs) + coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128) + indices = raymarching.morton3D(coords).long() # [N] + world_xyzs = (2 * coords.float() / (self.grid_size - 1) - 1).unsqueeze(0) # [1, N, 3] in [-1, 1] + + # cascading + for cas in range(self.cascade): + bound = min(2 ** cas, self.bound) + half_grid_size = bound / self.grid_size + # scale to current cascade's resolution + cas_world_xyzs = world_xyzs * (bound - half_grid_size) + + # split batch to avoid OOM + head = 0 + while head < B: + tail = min(head + S, B) + + # world2cam transform (poses is c2w, so we need to transpose it. Another transpose is needed for batched matmul, so the final form is without transpose.) + cam_xyzs = cas_world_xyzs - poses[head:tail, :3, 3].unsqueeze(1) + cam_xyzs = cam_xyzs @ poses[head:tail, :3, :3] # [S, N, 3] + + # query if point is covered by any camera + mask_z = cam_xyzs[:, :, 2] > 0 # [S, N] + mask_x = torch.abs(cam_xyzs[:, :, 0]) < cx / fx * cam_xyzs[:, :, 2] + half_grid_size * 2 + mask_y = torch.abs(cam_xyzs[:, :, 1]) < cy / fy * cam_xyzs[:, :, 2] + half_grid_size * 2 + mask = (mask_z & mask_x & mask_y).sum(0).reshape(-1) # [N] + + # update count + count[cas, indices] += mask + head += S + + # mark untrained grid as -1 + self.density_grid[count == 0] = -1 + self.density_bitfield = self.density_bitfield.to(ori_device) + self.density_grid = self.density_grid.to(ori_device) + #print(f'[mark untrained grid] {(count == 0).sum()} from {resolution ** 3 * self.cascade}') + + @torch.no_grad() + def update_extra_state(self, decay=0.95, S=128): + # call before each epoch to update extra states. + if not self.cuda_ray: + return + # use random cond (different expressions should have similar density grid...) + rand_idx = random.randint(0, self.conds.shape[0] - 1) + cond = get_audio_features(self.conds, 2, rand_idx).to(self.density_bitfield.device) + + # encode audio + enc_a = self.cal_cond_feat(cond) + + ### update density grid + tmp_grid = torch.zeros_like(self.density_grid) + + # full update + X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + + for xs in X: + for ys in Y: + for zs in Z: + + # construct points + xx, yy, zz = custom_meshgrid(xs, ys, zs) + coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128) + indices = raymarching.morton3D(coords).long() # [N] + xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1] + + # cascading + for cas in range(self.cascade): + bound = min(2 ** cas, self.bound) + half_grid_size = bound / self.grid_size + # scale to current cascade's resolution + cas_xyzs = xyzs * (bound - half_grid_size) + # add noise in [-hgs, hgs] + cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size + # query density + sigmas = self.density(cas_xyzs, enc_a)['sigma'].reshape(-1).detach().to(tmp_grid.dtype) + sigmas *= self.density_scale + # assign + tmp_grid[cas, indices] = sigmas + + # dilate the density_grid (less aggressive culling) + tmp_grid = raymarching.morton3D_dilation(tmp_grid) + + # ema update + valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0) + self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask]) + self.mean_density = torch.mean(self.density_grid.clamp(min=0)).item() # -1 non-training regions are viewed as 0 density. + self.iter_density += 1 + + # convert to bitfield + density_thresh = min(self.mean_density, self.density_thresh) + # each point in bitfield (a 8 bit uint) represents 8 points in density grid, 1 means the density is larger than density_threshold + self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield) + + ### update step counter + total_step = min(16, self.local_step) + if total_step > 0: + self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step) + self.local_step = 0 + + + def render(self, rays_o, rays_d, cond, bg_coords, poses, index=0, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs): + # rays_o, rays_d: [B, N, 3], assumes B == 1 + # cond: [B, 29, 16] + # bg_coords: [1, N, 2] + # return: pred_rgb: [B, N, 3] + + prefix = rays_o.shape[:-1] + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + bg_coords = bg_coords.contiguous().view(-1, 2) + + N = rays_o.shape[0] # N = B * N, in fact + device = rays_o.device + + results = {} + + # pre-calculate near far + nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer, self.min_near) + nears = nears.detach() + fars = fars.detach() + + # encode audio + cond_feat = self.cal_cond_feat(cond) # [1, 64] + + if self.individual_embedding_dim > 0: + if self.training: + ind_code = self.individual_embeddings[index] + # use a fixed ind code for the unknown test data. + else: + ind_code = self.individual_embeddings[0] + else: + ind_code = None + + if self.training: + # setup counter + counter = self.step_counter[self.local_step % 16] + counter.zero_() # set to 0 + self.local_step += 1 + + xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps) + + sigmas, rgbs, ambient = self(xyzs, dirs, cond_feat, ind_code) + sigmas = self.density_scale * sigmas + + #print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})') + + weights_sum, ambient_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, ambient.abs().sum(-1), deltas, rays) + + # for training only + results['weights_sum'] = weights_sum + results['ambient'] = ambient_sum + else: + + dtype = torch.float32 + + weights_sum = torch.zeros(N, dtype=dtype, device=device) + depth = torch.zeros(N, dtype=dtype, device=device) + image = torch.zeros(N, 3, dtype=dtype, device=device) + + n_alive = N + rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N] + rays_t = nears.clone() # [N] + + step = 0 + + while step < max_steps: + + # count alive rays + n_alive = rays_alive.shape[0] + + # exit loop + if n_alive <= 0: + break + + # decide compact_steps + n_step = max(min(N // n_alive, 8), 1) + + xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps) + + sigmas, rgbs, ambient = self(xyzs, dirs, cond_feat, ind_code) + sigmas = self.density_scale * sigmas + + raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh) + + rays_alive = rays_alive[rays_alive >= 0] + + # print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}') + + step += n_step + + # background + if bg_color is None: + bg_color = 1 + + image = image + (1 - weights_sum).unsqueeze(-1) * bg_color + image = image.view(*prefix, 3) + image = image.clamp(0, 1) + + depth = torch.clamp(depth - nears, min=0) / (fars - nears) + depth = depth.view(*prefix) + + results['depth_map'] = depth + results['rgb_map'] = image # head_image if train, else com_image + + return results + diff --git a/Geneface_main/GeneFace/modules/radnerfs/utils.py b/Geneface_main/GeneFace/modules/radnerfs/utils.py new file mode 100644 index 00000000..75ec969f --- /dev/null +++ b/Geneface_main/GeneFace/modules/radnerfs/utils.py @@ -0,0 +1,429 @@ +import os +import glob +import tqdm +import math +import random +import warnings +import tensorboardX + +import numpy as np +import pandas as pd + +import time +from datetime import datetime + +import cv2 +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd +import torch.distributed as dist +from torch.utils.data import Dataset, DataLoader + +import trimesh +import mcubes + +from utils.commons.hparams import hparams +from packaging import version as pver +import imageio +import lpips + + +class _trunc_exp(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # cast to float32 + def forward(ctx, x): + ctx.save_for_backward(x) + return torch.exp(x) + + @staticmethod + @custom_bwd + def backward(ctx, g): + x = ctx.saved_tensors[0] + return g * torch.exp(x.clamp(-15, 15)) + +trunc_exp = _trunc_exp.apply + + +# ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50 +def nerf_matrix_to_ngp(pose, scale=4, offset=[0, 0, 0]): + new_pose = np.array([ + [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale + offset[0]], + [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale + offset[1]], + [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale + offset[2]], + [0, 0, 0, 1], + ], dtype=np.float32) + return new_pose + + +def custom_meshgrid(*args): + # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid + if pver.parse(torch.__version__) < pver.parse('1.10'): + return torch.meshgrid(*args) + else: + return torch.meshgrid(*args, indexing='ij') + + +def get_audio_features(features, att_mode, index): + if att_mode == 0: + return features[[index]] + elif att_mode == 1: + left = index - hparams['smo_win_size'] + pad_left = 0 + if left < 0: + pad_left = -left + left = 0 + auds = features[left:index] + if pad_left > 0: + # pad may be longer than auds, so do not use zeros_like + auds = torch.cat([torch.zeros(pad_left, *auds.shape[1:], device=auds.device, dtype=auds.dtype), auds], dim=0) + return auds + elif att_mode == 2: + left = index - hparams['smo_win_size']//2 + right = index + (hparams['smo_win_size']-hparams['smo_win_size']//2) + pad_left = 0 + pad_right = 0 + if left < 0: + pad_left = -left + left = 0 + if right > features.shape[0]: + pad_right = right - features.shape[0] + right = features.shape[0] + auds = features[left:right] + if pad_left > 0: + auds = torch.cat([torch.zeros_like(auds[:pad_left]), auds], dim=0) + if pad_right > 0: + auds = torch.cat([auds, torch.zeros_like(auds[:pad_right])], dim=0) # [8, 16] + return auds + else: + raise NotImplementedError(f'wrong att_mode: {att_mode}') + + +@torch.jit.script +def linear_to_srgb(x): + return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055) + + +@torch.jit.script +def srgb_to_linear(x): + return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4) + + +# copied from pytorch3d +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +) -> torch.Tensor: + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _index_from_letter(letter: str) -> int: + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + raise ValueError("letter must be either X, Y or Z.") + + +def matrix_to_euler_angles(matrix: torch.Tensor, convention: str = 'XYZ') -> torch.Tensor: + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + # if len(convention) != 3: + # raise ValueError("Convention must have 3 letters.") + # if convention[1] in (convention[0], convention[2]): + # raise ValueError(f"Invalid convention {convention}.") + # for letter in convention: + # if letter not in ("X", "Y", "Z"): + # raise ValueError(f"Invalid letter {letter} in convention string.") + # if matrix.size(-1) != 3 or matrix.size(-2) != 3: + # raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + + +@torch.cuda.amp.autocast(enabled=False) +def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +@torch.cuda.amp.autocast(enabled=False) +def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str='XYZ') -> torch.Tensor: + """ + Convert rotations given as Euler angles in radians to rotation matrices. + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + # print(euler_angles, euler_angles.dtype) + + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = [ + _axis_angle_rotation(c, e) + for c, e in zip(convention, torch.unbind(euler_angles, -1)) + ] + + return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) + + +@torch.cuda.amp.autocast(enabled=False) +def convert_poses(poses): + # poses: [B, 4, 4] + # return [B, 3], 4 rot, 3 trans + out = torch.empty(poses.shape[0], 6, dtype=torch.float32, device=poses.device) + out[:, :3] = matrix_to_euler_angles(poses[:, :3, :3]) + out[:, 3:] = poses[:, :3, 3] + return out + + +@torch.cuda.amp.autocast(enabled=False) +def get_bg_coords(H, W, device): + X = torch.arange(H, device=device) / (H - 1) * 2 - 1 # in [-1, 1] + Y = torch.arange(W, device=device) / (W - 1) * 2 - 1 # in [-1, 1] + xs, ys = custom_meshgrid(X, Y) + bg_coords = torch.cat([xs.reshape(-1, 1), ys.reshape(-1, 1)], dim=-1).unsqueeze(0) # [1, H*W, 2], in [-1, 1] + return bg_coords + + +@torch.cuda.amp.autocast(enabled=False) +def get_rays(poses, intrinsics, H, W, N=-1, patch_size=1, rect=None): + ''' get rays + Args: + poses: [B, 4, 4], cam2world + intrinsics: [4] + H, W, N: int + Returns: + rays_o, rays_d: [B, N, 3] + inds: [B, N] + ''' + + device = poses.device + B = poses.shape[0] + fx, fy, cx, cy = intrinsics + + if rect is not None: + xmin, xmax, ymin, ymax = rect + N = (xmax - xmin) * (ymax - ymin) + + i, j = custom_meshgrid(torch.linspace(0, W-1, W, device=device), torch.linspace(0, H-1, H, device=device)) # float + i = i.t().reshape([1, H*W]).expand([B, H*W]) + 0.5 + j = j.t().reshape([1, H*W]).expand([B, H*W]) + 0.5 + + results = {} + + if N > 0: + N = min(N, H*W) + + if patch_size > 1: + + # random sample left-top cores. + # NOTE: this impl will lead to less sampling on the image corner pixels... but I don't have other ideas. + num_patch = N // (patch_size ** 2) + inds_x = torch.randint(0, H - patch_size, size=[num_patch], device=device) + inds_y = torch.randint(0, W - patch_size, size=[num_patch], device=device) + inds = torch.stack([inds_x, inds_y], dim=-1) # [np, 2] + + # create meshgrid for each patch + pi, pj = custom_meshgrid(torch.arange(patch_size, device=device), torch.arange(patch_size, device=device)) + offsets = torch.stack([pi.reshape(-1), pj.reshape(-1)], dim=-1) # [p^2, 2] + + inds = inds.unsqueeze(1) + offsets.unsqueeze(0) # [np, p^2, 2] + inds = inds.view(-1, 2) # [N, 2] + inds = inds[:, 0] * W + inds[:, 1] # [N], flatten + + inds = inds.expand([B, N]) + + # only get rays in the specified rect + elif rect is not None: + # assert B == 1 + mask = torch.zeros(H, W, dtype=torch.bool, device=device) + xmin, xmax, ymin, ymax = rect + mask[xmin:xmax, ymin:ymax] = 1 + inds = torch.where(mask.view(-1))[0] # [nzn] + inds = inds.unsqueeze(0) # [1, N] + + else: + inds = torch.randint(0, H*W, size=[N], device=device) # may duplicate + inds = inds.expand([B, N]) + + i = torch.gather(i, -1, inds) + j = torch.gather(j, -1, inds) + else: + inds = torch.arange(H*W, device=device).expand([B, H*W]) + + results['i'] = i + results['j'] = j + results['inds'] = inds + + zs = torch.ones_like(i) + xs = (i - cx) / fx * zs + ys = (j - cy) / fy * zs + directions = torch.stack((xs, ys, zs), dim=-1) + directions = directions / torch.norm(directions, dim=-1, keepdim=True) + rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3) + + rays_o = poses[..., :3, 3] # [B, 3] + rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3] + + results['rays_o'] = rays_o #.clone() + results['rays_d'] = rays_d + return results + + +def seed_everything(seed): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + #torch.backends.cudnn.deterministic = True + #torch.backends.cudnn.benchmark = True + + +def torch_vis_2d(x, renormalize=False): + # x: [3, H, W] or [1, H, W] or [H, W] + import matplotlib.pyplot as plt + import numpy as np + import torch + + if isinstance(x, torch.Tensor): + if len(x.shape) == 3: + x = x.permute(1,2,0).squeeze() + x = x.detach().cpu().numpy() + + print(f'[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}') + + x = x.astype(np.float32) + + # renormalize + if renormalize: + x = (x - x.min(axis=0, keepdims=True)) / (x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8) + + plt.imshow(x) + plt.show() + + +def extract_fields(bound_min, bound_max, resolution, query_func, S=128): + + X = torch.linspace(bound_min[0], bound_max[0], resolution).split(S) + Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(S) + Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(S) + + u = np.zeros([resolution, resolution, resolution], dtype=np.float32) + with torch.no_grad(): + for xi, xs in enumerate(X): + for yi, ys in enumerate(Y): + for zi, zs in enumerate(Z): + xx, yy, zz = custom_meshgrid(xs, ys, zs) + pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3] + val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z] + u[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val + return u + + +def extract_geometry(bound_min, bound_max, resolution, threshold, query_func): + #print('threshold: {}'.format(threshold)) + u = extract_fields(bound_min, bound_max, resolution, query_func) + + #print(u.shape, u.max(), u.min(), np.percentile(u, 50)) + + vertices, triangles = mcubes.marching_cubes(u, threshold) + + b_max_np = bound_max.detach().cpu().numpy() + b_min_np = bound_min.detach().cpu().numpy() + + vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] + return vertices, triangles diff --git a/Geneface_main/GeneFace/modules/syncnet/__pycache__/models.cpython-39.pyc b/Geneface_main/GeneFace/modules/syncnet/__pycache__/models.cpython-39.pyc new file mode 100644 index 00000000..0d01aea0 Binary files /dev/null and b/Geneface_main/GeneFace/modules/syncnet/__pycache__/models.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/modules/syncnet/models.py b/Geneface_main/GeneFace/modules/syncnet/models.py new file mode 100644 index 00000000..c2b146fd --- /dev/null +++ b/Geneface_main/GeneFace/modules/syncnet/models.py @@ -0,0 +1,102 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +class Conv1d(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.Conv1d(cin, cout, kernel_size, stride, padding), + nn.BatchNorm1d(cout) + ) + self.act = nn.ReLU() + self.residual = residual + + def forward(self, x): + out = self.conv_block(x) + if self.residual: + out += x + return self.act(out) + +class LandmarkHubertSyncNet(nn.Module): + def __init__(self, lm_dim=60): + super(LandmarkHubertSyncNet, self).__init__() + + # hubert = torch.rand(B, 1, , t=10) + self.hubert_encoder = nn.Sequential( + Conv1d(1024, 128, kernel_size=3, stride=1, padding=1), + + Conv1d(128, 128, kernel_size=3, stride=1, padding=1), + Conv1d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv1d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + + Conv1d(128, 256, kernel_size=3, stride=2, padding=1), + Conv1d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv1d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + + Conv1d(256, 512, kernel_size=3, stride=2, padding=1), + Conv1d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + Conv1d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + + Conv1d(512, 512, kernel_size=3, stride=1, padding=1), + Conv1d(512, 512, kernel_size=3, stride=1, padding=0), + Conv1d(512, 512, kernel_size=1, stride=1, padding=0),) + + + # mouth = torch.rand(B, 20*3, t=5) + self.mouth_encoder = nn.Sequential( + Conv1d(lm_dim, 96, kernel_size=3, stride=1, padding=1), + + Conv1d(96, 128, kernel_size=3, stride=1, padding=1), + Conv1d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv1d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + + Conv1d(128, 256, kernel_size=3, stride=2, padding=1), + Conv1d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv1d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + + Conv1d(256, 512, kernel_size=3, stride=1, padding=1), + Conv1d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + Conv1d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + + Conv1d(512, 512, kernel_size=3, stride=1, padding=1), + Conv1d(512, 512, kernel_size=3, stride=1, padding=0), + Conv1d(512, 512, kernel_size=1, stride=1, padding=0),) + self.lm_dim = lm_dim + self.logloss = nn.BCELoss() + def forward(self, hubert, mouth_lm): + # hubert := (B, T=10, C=1024) + # mouth_lm3d := (B, T=5, C=60) + hubert = hubert.transpose(1,2) + mouth_lm = mouth_lm.transpose(1,2) + mouth_embedding = self.mouth_encoder(mouth_lm) + audio_embedding = self.hubert_encoder(hubert) + audio_embedding = audio_embedding.view(audio_embedding.size(0), -1) + mouth_embedding = mouth_embedding.view(mouth_embedding.size(0), -1) + audio_embedding = F.normalize(audio_embedding, p=2, dim=1) + mouth_embedding = F.normalize(mouth_embedding, p=2, dim=1) + return audio_embedding, mouth_embedding + + def cal_sync_loss(self, audio_embedding, mouth_embedding, label): + if isinstance(label, torch.Tensor): + gt_d = label.float().view(-1,1).to(audio_embedding.device) + else: + gt_d = (torch.ones([audio_embedding.shape[0],1]) * label).float().to(audio_embedding.device) # int + d = nn.functional.cosine_similarity(audio_embedding, mouth_embedding) + loss = self.logloss(d.unsqueeze(1), gt_d) + return loss, d + + def cal_cosine_similarity(self, audio_embedding, mouth_embedding): + d = nn.functional.cosine_similarity(audio_embedding, mouth_embedding) + return d + + +if __name__ == '__main__': + syncnet = LandmarkHubertSyncNet(lm_dim=204) + hubert = torch.rand(2, 10, 1024) + lm = torch.rand(2, 5, 204) + mel_embedding, exp_embedding = syncnet(hubert, lm) + label = torch.tensor([1., 0.]) + loss = syncnet.cal_sync_loss(mel_embedding, exp_embedding, label) + print(" ") \ No newline at end of file diff --git a/Geneface_main/GeneFace/scripts/infer_adnerf.sh b/Geneface_main/GeneFace/scripts/infer_adnerf.sh new file mode 100644 index 00000000..f6de628e --- /dev/null +++ b/Geneface_main/GeneFace/scripts/infer_adnerf.sh @@ -0,0 +1,11 @@ +export CUDA_VISIBLE_DEVICES=0 +# export CUDA_VISIBLE_DEVICES=0,1,2,3 # now we support multi-gpu inference! +export Video_ID=May +export Wav_ID=zozo # the .wav file should locate at `data/raw/val_wavs/.wav` + +python inference/nerfs/adnerf_infer.py \ + --config=checkpoints/${Video_ID}/adnerf_torso/config.yaml \ + --exp_name=${Video_ID}/adnerf_torso \ + --hparams=infer_audio_source_name=data/raw/val_wavs/${Wav_ID}.wav,\ +infer_out_video_name=infer_out/ADNeRF/${Video_ID}/pred_video/${Wav_ID}.mp4 \ + --reset diff --git a/Geneface_main/GeneFace/scripts/infer_audio2pose.sh b/Geneface_main/GeneFace/scripts/infer_audio2pose.sh new file mode 100644 index 00000000..85bf1cbd --- /dev/null +++ b/Geneface_main/GeneFace/scripts/infer_audio2pose.sh @@ -0,0 +1,9 @@ +export CUDA_VISIBLE_DEVICES=0 +export Video_ID=May +export Wav_ID=zozo + +python inference/audio2pose/audio2pose_infer.py \ + --config=checkpoints/${Video_ID}/audio2pose/config.yaml \ + --hparams=infer_audio_source_name=data/raw/val_wavs/${Wav_ID}.wav,\ +infer_out_npy_name=infer_out/${Video_ID}/pred_c2w/${Wav_ID}.npy \ + --reset \ No newline at end of file diff --git a/Geneface_main/GeneFace/scripts/infer_lm3d_nerf.sh b/Geneface_main/GeneFace/scripts/infer_lm3d_nerf.sh new file mode 100644 index 00000000..43efa0b9 --- /dev/null +++ b/Geneface_main/GeneFace/scripts/infer_lm3d_nerf.sh @@ -0,0 +1,39 @@ +export PYTHONPATH=. +export CUDA_VISIBLE_DEVICES=0 +# export CUDA_VISIBLE_DEVICES=0,1 # now we support multi-gpu inference! +export Video_ID=May +export Wav_ID=zozo # the .wav file should locate at `data/raw/val_wavs/.wav` +export n_samples_per_ray=32 # during training 64 +# export n_samples_per_ray=16 # you can use a smaller value (e.g., 16) to accelerate the inference +export n_samples_per_ray_fine=128 # during training 128 +# export n_samples_per_ray_fine=32 # you can use a smaller value (e.g., 32) to accelerate the inference +export infer_scale_factor=1.0 # scale of output resolution, defautlt 1.0 -> 512x512 image + +# postprocessing params +export infer_lm3d_clamp_std=3.0 # typically 1.~5., reduce it when blurry or bad cases occurs +export infer_lm3d_lle_percent=0. # 0.~1., enlarge it when blurry or bad cases occurs +export infer_lm3d_smooth_sigma=0. # typically 0.~3., enlarge it when blurry or bad cases occurs + +# use head pose from the dataset +python inference/nerfs/lm3d_nerf_infer.py \ + --config=checkpoints/${Video_ID}/lm3d_nerf_torso/config.yaml \ + --hparams=infer_audio_source_name=data/raw/val_wavs/${Wav_ID}.wav,\ +infer_cond_name=infer_out/${Video_ID}/pred_lm3d/${Wav_ID}.npy,\ +infer_out_video_name=infer_out/${Video_ID}/pred_video/${Wav_ID}_infer_torso_nerf.mp4,\ +n_samples_per_ray=${n_samples_per_ray},n_samples_per_ray_fine=${n_samples_per_ray_fine},\ +infer_scale_factor=${infer_scale_factor},\ +infer_lm3d_clamp_std=${infer_lm3d_clamp_std},\ +infer_lm3d_lle_percent=${infer_lm3d_lle_percent},\ +infer_lm3d_smooth_sigma=${infer_lm3d_smooth_sigma} \ + --infer + +# use the head pose predicted by audio2pose model +# python inference/nerfs/lm3d_nerf_infer.py \ +# --config=checkpoints/${Video_ID}/lm3d_nerf_torso/config.yaml \ +# --hparams=infer_audio_source_name=data/raw/val_wavs/${Wav_ID}.wav,\ +# infer_cond_name=infer_out/${Video_ID}/pred_lm3d/${Wav_ID}.npy,\ +# infer_out_video_name=infer_out/${Video_ID}/pred_video/${Wav_ID}_pred_pose.mp4,\ +# n_samples_per_ray=${n_samples_per_ray},n_samples_per_ray_fine=${n_samples_per_ray_fine},\ +# infer_scale_factor=${infer_scale_factor},\ +# infer_c2w_name=infer_out/${Video_ID}/pred_c2w/${Wav_ID}.npy \ +# --infer diff --git a/Geneface_main/GeneFace/scripts/infer_lm3d_radnerf.sh b/Geneface_main/GeneFace/scripts/infer_lm3d_radnerf.sh new file mode 100644 index 00000000..9da4358a --- /dev/null +++ b/Geneface_main/GeneFace/scripts/infer_lm3d_radnerf.sh @@ -0,0 +1,13 @@ +export PYTHONPATH=. +export CUDA_VISIBLE_DEVICES=0 +# export CUDA_VISIBLE_DEVICES=0,1 # now we support multi-gpu inference! +export Video_ID=May +export Wav_ID=zozo # the .wav file should locate at `data/raw/val_wavs/.wav` + +python inference/nerfs/lm3d_radnerf_infer.py \ + --config=checkpoints/${Video_ID}/lm3d_radnerf_torso/config.yaml \ + --hparams=infer_audio_source_name=data/raw/val_wavs/${Wav_ID}.wav,\ +infer_cond_name=infer_out/${Video_ID}/pred_lm3d/${Wav_ID}.npy,\ +infer_out_video_name=infer_out/${Video_ID}/pred_video/${Wav_ID}_radnerf_torso_smo.mp4\ + --infer + diff --git a/Geneface_main/GeneFace/scripts/infer_lm3d_radnerf_May.sh b/Geneface_main/GeneFace/scripts/infer_lm3d_radnerf_May.sh new file mode 100644 index 00000000..b55bf092 --- /dev/null +++ b/Geneface_main/GeneFace/scripts/infer_lm3d_radnerf_May.sh @@ -0,0 +1,13 @@ +export PYTHONPATH=. +export CUDA_VISIBLE_DEVICES=0 +# export CUDA_VISIBLE_DEVICES=0,1 # now we support multi-gpu inference! +export Video_ID=May +export Wav_ID=May # the .wav file should locate at `data/raw/val_wavs/.wav` + +python inference/nerfs/lm3d_radnerf_infer.py \ + --config=checkpoints/${Video_ID}/lm3d_radnerf_torso/config.yaml \ + --hparams=infer_audio_source_name=data/raw/val_wavs/${Wav_ID}.wav,\ +infer_cond_name=infer_out/${Video_ID}/pred_lm3d/${Wav_ID}.npy,\ +infer_out_video_name=infer_out/${Video_ID}/pred_video/${Wav_ID}_radnerf_torso_smo.mp4\ + --infer + diff --git a/Geneface_main/GeneFace/scripts/infer_lm3d_radnerf_SY.sh b/Geneface_main/GeneFace/scripts/infer_lm3d_radnerf_SY.sh new file mode 100644 index 00000000..5345a2ae --- /dev/null +++ b/Geneface_main/GeneFace/scripts/infer_lm3d_radnerf_SY.sh @@ -0,0 +1,13 @@ +export PYTHONPATH=. +export CUDA_VISIBLE_DEVICES=0 +# export CUDA_VISIBLE_DEVICES=0,1 # now we support multi-gpu inference! +export Video_ID=May +export Wav_ID=SY # the .wav file should locate at `data/raw/val_wavs/.wav` + +python inference/nerfs/lm3d_radnerf_infer.py \ + --config=checkpoints/${Video_ID}/lm3d_radnerf_torso/config.yaml \ + --hparams=infer_audio_source_name=data/raw/val_wavs/${Wav_ID}.wav,\ +infer_cond_name=infer_out/${Video_ID}/pred_lm3d/${Wav_ID}.npy,\ +infer_out_video_name=infer_out/${Video_ID}/pred_video/${Wav_ID}_radnerf_torso_smo.mp4\ + --infer + diff --git a/Geneface_main/GeneFace/scripts/infer_postnet.sh b/Geneface_main/GeneFace/scripts/infer_postnet.sh new file mode 100644 index 00000000..0690b17c --- /dev/null +++ b/Geneface_main/GeneFace/scripts/infer_postnet.sh @@ -0,0 +1,11 @@ +export CUDA_VISIBLE_DEVICES=0 +export Video_ID=May +export Wav_ID=zozo +export Postnet_Ckpt_Steps=4000 # please reach to `docs/train_models.md` to get some tips about how to select an approprate ckpt_steps! + +python inference/postnet/postnet_infer.py \ + --config=checkpoints/${Video_ID}/lm3d_postnet_sync/config.yaml \ + --hparams=infer_audio_source_name=data/raw/val_wavs/${Wav_ID}.wav,\ +infer_out_npy_name=infer_out/${Video_ID}/pred_lm3d/${Wav_ID}.npy,\ +infer_ckpt_steps=${Postnet_Ckpt_Steps} \ + --reset \ No newline at end of file diff --git a/Geneface_main/GeneFace/scripts/infer_postnet_May.sh b/Geneface_main/GeneFace/scripts/infer_postnet_May.sh new file mode 100644 index 00000000..47de2dc7 --- /dev/null +++ b/Geneface_main/GeneFace/scripts/infer_postnet_May.sh @@ -0,0 +1,11 @@ +export CUDA_VISIBLE_DEVICES=0 +export Video_ID=May +export Wav_ID=May +export Postnet_Ckpt_Steps=4000 # please reach to `docs/train_models.md` to get some tips about how to select an approprate ckpt_steps! + +python inference/postnet/postnet_infer.py \ + --config=checkpoints/${Video_ID}/lm3d_postnet_sync/config.yaml \ + --hparams=infer_audio_source_name=data/raw/val_wavs/${Wav_ID}.wav,\ +infer_out_npy_name=infer_out/${Video_ID}/pred_lm3d/${Wav_ID}.npy,\ +infer_ckpt_steps=${Postnet_Ckpt_Steps} \ + --reset diff --git a/Geneface_main/GeneFace/scripts/infer_postnet_SY.sh b/Geneface_main/GeneFace/scripts/infer_postnet_SY.sh new file mode 100644 index 00000000..2ca9b6ea --- /dev/null +++ b/Geneface_main/GeneFace/scripts/infer_postnet_SY.sh @@ -0,0 +1,11 @@ +export CUDA_VISIBLE_DEVICES=0 +export Video_ID=May +export Wav_ID=SY +export Postnet_Ckpt_Steps=4000 # please reach to `docs/train_models.md` to get some tips about how to select an approprate ckpt_steps! + +python inference/postnet/postnet_infer.py \ + --config=checkpoints/${Video_ID}/lm3d_postnet_sync/config.yaml \ + --hparams=infer_audio_source_name=data/raw/val_wavs/${Wav_ID}.wav,\ +infer_out_npy_name=infer_out/${Video_ID}/pred_lm3d/${Wav_ID}.npy,\ +infer_ckpt_steps=${Postnet_Ckpt_Steps} \ + --reset diff --git a/Geneface_main/GeneFace/scripts/infer_radnerf_gui.sh b/Geneface_main/GeneFace/scripts/infer_radnerf_gui.sh new file mode 100644 index 00000000..36c3a0ce --- /dev/null +++ b/Geneface_main/GeneFace/scripts/infer_radnerf_gui.sh @@ -0,0 +1,13 @@ +export PYTHONPATH=. +export CUDA_VISIBLE_DEVICES=0 +# export CUDA_VISIBLE_DEVICES=0,1 # now we support multi-gpu inference! +export Video_ID=May +export Wav_ID=zozo # the .wav file should locate at `data/raw/val_wavs/.wav` + +# use head pose from the dataset +python inference/nerfs/radnerf_gui.py \ + --config=checkpoints/${Video_ID}/lm3d_radnerf_torso/config.yaml \ + --hparams=infer_audio_source_name=data/raw/val_wavs/${Wav_ID}.wav,\ +infer_cond_name=infer_out/${Video_ID}/pred_lm3d/${Wav_ID}.npy + + diff --git a/Geneface_main/GeneFace/scripts/test_audio2motion.sh b/Geneface_main/GeneFace/scripts/test_audio2motion.sh new file mode 100644 index 00000000..0f712aee --- /dev/null +++ b/Geneface_main/GeneFace/scripts/test_audio2motion.sh @@ -0,0 +1,16 @@ +export CUDA_VISIBLE_DEVICES=3 +export Video_ID=May +export Wav_ID=zozo +# export Audio2motion_Steps=2000 # please reach to `docs/train_models.md` to get some tips about how to select an approprate ckpt_steps! +for Audio2motion_Steps in 2000 4000 8000 16000 20000 24000 28000 32000 36000 40000 +do + python inference/audio2motion/audio2motion_infer.py \ + --config=checkpoints/lrs3/lm3d_vae/config.yaml \ + --hparams=infer_audio_source_name=data/raw/val_wavs/${Wav_ID}.wav,\ +infer_out_npy_name=infer_out/audio2motion/pred_lm3d/step${Audio2motion_Steps}_${Wav_ID}.npy,\ +infer_ckpt_steps=${Audio2motion_Steps} \ + --reset + + python utils/visualization/lm_visualizer.py --npy_name=infer_out/audio2motion/pred_lm3d/step${Audio2motion_Steps}_${Wav_ID}.npy \ +--audio_name=data/raw/val_wavs/${Wav_ID}.wav --out_path=infer_out/audio2motion/visualizd_lm3d/step${Audio2motion_Steps}_${Wav_ID}.mp4 +done \ No newline at end of file diff --git a/Geneface_main/GeneFace/scripts/train_audio2motion.sh b/Geneface_main/GeneFace/scripts/train_audio2motion.sh new file mode 100644 index 00000000..c32a61a8 --- /dev/null +++ b/Geneface_main/GeneFace/scripts/train_audio2motion.sh @@ -0,0 +1 @@ +CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/lrs3/lm3d_vae_sync.yaml --exp_name=lrs3/lm3d_vae_sync --reset diff --git a/Geneface_main/GeneFace/scripts/train_audio2pose.sh b/Geneface_main/GeneFace/scripts/train_audio2pose.sh new file mode 100644 index 00000000..0ccdd286 --- /dev/null +++ b/Geneface_main/GeneFace/scripts/train_audio2pose.sh @@ -0,0 +1,6 @@ +export CUDA_VISIBLE_DEVICES=0 +export Video_ID=Zhang2 + +python tasks/run.py --config=egs/datasets/videos/${Video_ID}/audio2pose.yaml \ + --exp_name=${Video_ID}/audio2pose \ + --reset diff --git a/Geneface_main/GeneFace/scripts/train_nerf.sh b/Geneface_main/GeneFace/scripts/train_nerf.sh new file mode 100644 index 00000000..736e9dae --- /dev/null +++ b/Geneface_main/GeneFace/scripts/train_nerf.sh @@ -0,0 +1,7 @@ +export Video_ID=May +# binarize the dataset +python data_gen/nerf/binarizer.py --config=egs/datasets/videos/${Video_ID}/lm3d_nerf.yaml +# train Head NeRF +CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/videos/${Video_ID}/lm3d_nerf.yaml --exp_name=${Video_ID}/lm3d_nerf --reset +# train Torso NeRF +CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/videos/${Video_ID}/lm3d_nerf_torso.yaml --exp_name=${Video_ID}/lm3d_nerf_torso --reset \ No newline at end of file diff --git a/Geneface_main/GeneFace/scripts/train_postnet.sh b/Geneface_main/GeneFace/scripts/train_postnet.sh new file mode 100644 index 00000000..073474ee --- /dev/null +++ b/Geneface_main/GeneFace/scripts/train_postnet.sh @@ -0,0 +1,2 @@ +export Video_ID=May +CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/videos/${Video_ID}/lm3d_postnet_sync.yaml --exp_name=${Video_ID}/lm3d_postnet_sync --reset diff --git a/Geneface_main/GeneFace/scripts/train_radnerf.sh b/Geneface_main/GeneFace/scripts/train_radnerf.sh new file mode 100644 index 00000000..ec70d569 --- /dev/null +++ b/Geneface_main/GeneFace/scripts/train_radnerf.sh @@ -0,0 +1,7 @@ +export Video_ID=May +# binarize the dataset +python data_gen/nerf/binarizer.py --config=egs/datasets/videos/${Video_ID}/lm3d_radnerf.yaml +# train Head NeRF +CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/videos/${Video_ID}/lm3d_radnerf.yaml --exp_name=${Video_ID}/lm3d_radnerf --reset +# train Torso NeRF +CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/videos/${Video_ID}/lm3d_radnerf_torso.yaml --exp_name=${Video_ID}/lm3d_radnerf_torso --reset \ No newline at end of file diff --git a/Geneface_main/GeneFace/scripts/train_syncnet.sh b/Geneface_main/GeneFace/scripts/train_syncnet.sh new file mode 100644 index 00000000..8c3e4611 --- /dev/null +++ b/Geneface_main/GeneFace/scripts/train_syncnet.sh @@ -0,0 +1 @@ +CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config=egs/datasets/lrs3/lm3d_syncnet.yaml --exp_name=lrs3/syncnet --reset diff --git a/Geneface_main/GeneFace/tasks/audio2motion/__pycache__/lm3d_vae_sync.cpython-39.pyc b/Geneface_main/GeneFace/tasks/audio2motion/__pycache__/lm3d_vae_sync.cpython-39.pyc new file mode 100644 index 00000000..d342ed50 Binary files /dev/null and b/Geneface_main/GeneFace/tasks/audio2motion/__pycache__/lm3d_vae_sync.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/tasks/audio2motion/dataset_utils/__pycache__/euler2quaterion.cpython-39.pyc b/Geneface_main/GeneFace/tasks/audio2motion/dataset_utils/__pycache__/euler2quaterion.cpython-39.pyc new file mode 100644 index 00000000..442cc810 Binary files /dev/null and b/Geneface_main/GeneFace/tasks/audio2motion/dataset_utils/__pycache__/euler2quaterion.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/tasks/audio2motion/dataset_utils/__pycache__/lrs3_dataset.cpython-39.pyc b/Geneface_main/GeneFace/tasks/audio2motion/dataset_utils/__pycache__/lrs3_dataset.cpython-39.pyc new file mode 100644 index 00000000..74ea0460 Binary files /dev/null and b/Geneface_main/GeneFace/tasks/audio2motion/dataset_utils/__pycache__/lrs3_dataset.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/tasks/audio2motion/dataset_utils/euler2quaterion.py b/Geneface_main/GeneFace/tasks/audio2motion/dataset_utils/euler2quaterion.py new file mode 100644 index 00000000..e3fd35af --- /dev/null +++ b/Geneface_main/GeneFace/tasks/audio2motion/dataset_utils/euler2quaterion.py @@ -0,0 +1,35 @@ +import numpy as np +import torch +import math +import numba +from scipy.spatial.transform import Rotation as R + +def euler2quaterion(euler, use_radian=True): + """ + euler: np.array, [batch, 3] + return: the quaterion, np.array, [batch, 4] + """ + r = R.from_euler('xyz',euler, degrees=not use_radian) + return r.as_quat() + +def quaterion2euler(quat, use_radian=True): + """ + quat: np.array, [batch, 4] + return: the euler, np.array, [batch, 3] + """ + r = R.from_quat(quat) + return r.as_euler('xyz', degrees=not use_radian) + +def rot2quaterion(rot): + r = R.from_matrix(rot) + return r.as_quat() + +def quaterion2rot(quat): + r = R.from_quat(quat) + return r.as_matrix() + +if __name__ == '__main__': + euler = np.array([89.999,89.999,89.999] * 100).reshape([100,3]) + q = euler2quaterion(euler, use_radian=False) + e = quaterion2euler(q, use_radian=False) + print(" ") diff --git a/Geneface_main/GeneFace/tasks/audio2motion/dataset_utils/lrs3_dataset.py b/Geneface_main/GeneFace/tasks/audio2motion/dataset_utils/lrs3_dataset.py new file mode 100644 index 00000000..8d9f42aa --- /dev/null +++ b/Geneface_main/GeneFace/tasks/audio2motion/dataset_utils/lrs3_dataset.py @@ -0,0 +1,317 @@ +import torch +import numpy as np +import pickle as pkl +import os, sys +import math, random +from torch.utils.data import Dataset, DataLoader +import tqdm + +from utils.commons.hparams import hparams +from utils.commons.tensor_utils import convert_to_tensor +from data_util.face3d_helper import Face3DHelper + +from utils.commons.indexed_datasets import IndexedDataset +from tasks.audio2motion.dataset_utils.euler2quaterion import euler2quaterion, quaterion2euler + +class LRS3SeqDataset(Dataset): + def __init__(self, prefix='train'): + self.db_key = prefix + self.ds_path = hparams['binary_data_dir'] + self.ds = None + self.sizes = None + self.ordered_indices() + self.memory_cache = {} # we use hash table to accelerate indexing + self.face3d_helper = Face3DHelper('deep_3drecon/BFM') + self.x_multiply = 8 + if hparams['load_db_to_memory']: + self.load_db_to_memory() + + @property + def _sizes(self): + return self.sizes + + def __len__(self): + return len(self._sizes) + + def _cal_avatar_style_encoding(self, exp, pose): + diff_exp = exp[:-1, :] - exp[1:, :] + exp_std = (torch.std(exp, dim = 0) - self.exp_std_mean) / self.exp_std_std + diff_exp_std = (torch.std(diff_exp, dim = 0) - self.exp_diff_std_mean) / self.exp_diff_std_std + + diff_pose = pose[:-1, :] - pose[1:, :] + diff_pose_std = (torch.std(diff_pose, dim = 0) - self.pose_diff_std_mean) / self.pose_diff_std_std + + return torch.cat((exp_std, diff_exp_std, diff_pose_std)) # [135,] + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + sizes_fname = os.path.join(self.ds_path, f"sizes_{self.db_key}.npy") + if os.path.exists(sizes_fname): + sizes = np.load(sizes_fname, allow_pickle=True) + self.sizes = sizes + if self.sizes is None: + self.sizes = [] + print("Counting the size of each item in dataset...") + ds = IndexedDataset(f"{self.ds_path}/{self.db_key}") + for i_sample in tqdm.trange(len(ds)): + sample = ds[i_sample] + if sample is None: + size = 0 + else: + x = sample['mel'] + size = x.shape[-1] # time step in audio + self.sizes.append(size) + np.save(sizes_fname, self.sizes) + indices = np.arange(len(self)) + indices = indices[np.argsort(np.array(self.sizes)[indices], kind='mergesort')] + return indices + + def batch_by_size(self, indices, max_tokens=None, max_sentences=None, + required_batch_size_multiple=1): + """ + Yield mini-batches of indices bucketed by size. Batches may contain + sequences of different lengths. + + Args: + indices (List[int]): ordered list of dataset indices + num_tokens_fn (callable): function that returns the number of tokens at + a given index + max_tokens (int, optional): max number of tokens in each batch + (default: None). + max_sentences (int, optional): max number of sentences in each + batch (default: None). + required_batch_size_multiple (int, optional): require batch size to + be a multiple of N (default: 1). + """ + def _is_batch_full(batch, num_tokens, max_tokens, max_sentences): + if len(batch) == 0: + return 0 + if len(batch) == max_sentences: + return 1 + if num_tokens > max_tokens: + return 1 + return 0 + + num_tokens_fn = lambda x: self.sizes[x] + max_tokens = max_tokens if max_tokens is not None else 60000 + max_sentences = max_sentences if max_sentences is not None else 512 + bsz_mult = required_batch_size_multiple + + sample_len = 0 + sample_lens = [] + batch = [] + batches = [] + for i in range(len(indices)): + idx = indices[i] + num_tokens = num_tokens_fn(idx) + sample_lens.append(num_tokens) + sample_len = max(sample_len, num_tokens) + + assert sample_len <= max_tokens, ( + "sentence at index {} of size {} exceeds max_tokens " + "limit of {}!".format(idx, sample_len, max_tokens) + ) + num_tokens = (len(batch) + 1) * sample_len + + if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): + mod_len = max( + bsz_mult * (len(batch) // bsz_mult), + len(batch) % bsz_mult, + ) + batches.append(batch[:mod_len]) + batch = batch[mod_len:] + sample_lens = sample_lens[mod_len:] + sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 + batch.append(idx) + if len(batch) > 0: + batches.append(batch) + return batches + + def decode_pose(self, pose): + """ + pose [B, T, C=7=4+3] + """ + b,t,_ = pose.shape + if self.normalize_target: + pose = pose * self.pose_std + self.pose_mean + translations = pose[:, :, :3].cpu().numpy() # [B, T, 3] + angles = pose[:, :, 3:].cpu().numpy() # [B, T, 4] + angles = quaterion2euler(angles.reshape([b*t,4])) # [B*T, 3] + angles = angles.reshape([b,t,3]) + return angles, translations + + def load_db_to_memory(self): + for idx in tqdm.trange(len(self), desc='Loading database to memory...'): + raw_item = self._get_item(idx) + if raw_item is None: + continue + item = {} + item_id = raw_item['item_id'] # str: "_" + item['item_id'] = item_id + # audio-related features + mel = raw_item['mel'] + hubert = raw_item['hubert'] + item['mel'] = torch.from_numpy(mel).float() # [T_x, c=80] + item['hubert'] = torch.from_numpy(hubert).float() # [T_x, c=80] + if 'f0' in raw_item.keys(): + f0 = raw_item['f0'] + item['f0'] = torch.from_numpy(f0).float() # [T_x,] + # video-related features + coeff = raw_item['coeff'] # [T_y ~= T_x//2, c=257] + exp = coeff[:, 80:144] + item['exp'] = torch.from_numpy(exp).float() # [T_y, c=64] + translation = coeff[:, 254:257] # [T_y, c=3] + angles = euler2quaterion(coeff[:, 224:227]) # # [T_y, c=4] + pose = np.concatenate([translation, angles], axis=1) + item['pose'] = torch.from_numpy(pose).float() # [T_y, c=4+3] + + # Load identity for landmark construction + item['identity'] = torch.from_numpy(raw_item['coeff'][..., :80]).float() + + # Load lm3d + t_lm, dim_lm, _ = raw_item['idexp_lm3d'].shape # [T, 68, 3] + item['idexp_lm3d'] = torch.from_numpy(raw_item['idexp_lm3d']).reshape(t_lm, -1).float() + eye_idexp_lm3d, mouth_idexp_lm3d = self.face3d_helper.get_eye_mouth_lm_from_lm3d(raw_item['idexp_lm3d']) + item['eye_idexp_lm3d'] = convert_to_tensor(eye_idexp_lm3d).reshape(t_lm, -1).float() + item['mouth_idexp_lm3d'] = convert_to_tensor(mouth_idexp_lm3d).reshape(t_lm, -1).float() + item['ref_mean_lm3d'] = item['idexp_lm3d'].mean(dim=0).reshape([204,]) + + self.memory_cache[idx] = item + + def _get_item(self, index): + """ + This func is necessary to open files in multi-threads! + """ + if self.ds is None: + self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') + return self.ds[index] + + def __getitem__(self, idx): + if hparams['load_db_to_memory']: + return self.memory_cache[idx] + + raw_item = self._get_item(idx) + if raw_item is None: + print("loading from binary data failed!") + return None + item = {} + item_id = raw_item['item_id'] # str: "_" + item['item_id'] = item_id + # audio-related features + mel = raw_item['mel'] + hubert = raw_item['hubert'] + item['mel'] = torch.from_numpy(mel).float() # [T_x, c=80] + item['hubert'] = torch.from_numpy(hubert).float() # [T_x, c=80] + if 'f0' in raw_item.keys(): + f0 = raw_item['f0'] + item['f0'] = torch.from_numpy(f0).float() # [T_x,] + # video-related features + coeff = raw_item['coeff'] # [T_y ~= T_x//2, c=257] + exp = coeff[:, 80:144] + item['exp'] = torch.from_numpy(exp).float() # [T_y, c=64] + translation = coeff[:, 254:257] # [T_y, c=3] + angles = euler2quaterion(coeff[:, 224:227]) # # [T_y, c=4] + pose = np.concatenate([translation, angles], axis=1) + item['pose'] = torch.from_numpy(pose).float() # [T_y, c=4+3] + + # Load identity for landmark construction + item['identity'] = torch.from_numpy(raw_item['coeff'][..., :80]).float() + + # Load lm3d + t_lm, dim_lm, _ = raw_item['idexp_lm3d'].shape # [T, 68, 3] + item['idexp_lm3d'] = torch.from_numpy(raw_item['idexp_lm3d']).reshape(t_lm, -1).float() + eye_idexp_lm3d, mouth_idexp_lm3d = self.face3d_helper.get_eye_mouth_lm_from_lm3d(raw_item['idexp_lm3d']) + item['eye_idexp_lm3d'] = convert_to_tensor(eye_idexp_lm3d).reshape(t_lm, -1).float() + item['mouth_idexp_lm3d'] = convert_to_tensor(mouth_idexp_lm3d).reshape(t_lm, -1).float() + + # item = self.memory_cache[idx] + item['ref_mean_lm3d'] = item['idexp_lm3d'].mean(dim=0).reshape([204,]) + return item + + @staticmethod + def _collate_2d(values, max_len=None, pad_value=0): + """ + Convert a list of 2d tensors into a padded 3d tensor. + values: list of Batch tensors with shape [T, C] + return: [B, T, C] + """ + max_len = max(v.size(0) for v in values) if max_len is None else max_len + hidden_dim = values[0].size(1) + batch_size = len(values) + ret = torch.ones([batch_size, max_len, hidden_dim],dtype=values[0].dtype) * pad_value + for i, v in enumerate(values): + ret[i, :v.shape[0], :].copy_(v) + return ret + + def collater(self, samples): + none_idx = [] + for i in range(len(samples)): + if samples[i] is None: + none_idx.append(i) + for i in sorted(none_idx, reverse=True): + del samples[i] + if len(samples) == 0: + return None + batch = {} + item_names = [s['item_id'] for s in samples] + # style_batch = torch.stack([s["style"] for s in samples], dim=0) # [b, 135] + x_len = max(s['mel'].size(0) for s in samples) + x_len = x_len + (self.x_multiply - (x_len % self.x_multiply)) % self.x_multiply + y_len = x_len // 2 + mel_batch = self._collate_2d([s["mel"] for s in samples], max_len=x_len, pad_value=0) # [b, t_max_y, 64] + hubert_batch = self._collate_2d([s["hubert"] for s in samples], max_len=x_len, pad_value=0) # [b, t_max_y, 64] + exp_batch = self._collate_2d([s["exp"] for s in samples], max_len=y_len, pad_value=0) # [b, t_max_y, 64] + pose_batch = self._collate_2d([s["pose"] for s in samples], max_len=y_len, pad_value=0) # [b, t_max_y, 64] + + idexp_lm3d = self._collate_2d([s["idexp_lm3d"] for s in samples], max_len=y_len, pad_value=0) # [b, t_max, 1] + ref_mean_lm3d = torch.stack([s['ref_mean_lm3d'] for s in samples], dim=0) # [b, h=204*5] + mouth_idexp_lm3d = self._collate_2d([s["mouth_idexp_lm3d"] for s in samples], max_len=y_len, pad_value=0) # [b, t_max, 1] + + x_mask = (mel_batch.abs().sum(dim=-1) > 0).float() # [b, t_max_x] + y_mask = (pose_batch.abs().sum(dim=-1) > 0).float() # [b, t_max_y] + + batch.update({ + 'item_id': item_names, + # 'style': style_batch, + 'mel': mel_batch, + 'hubert': hubert_batch, + 'x_mask': x_mask, + 'exp': exp_batch, + 'pose': pose_batch, + 'y_mask': y_mask, + 'idexp_lm3d': idexp_lm3d, + 'ref_mean_lm3d': ref_mean_lm3d, + 'mouth_idexp_lm3d': mouth_idexp_lm3d, + }) + + if 'f0' in samples[0].keys(): + f0_batch = self._collate_2d([s["f0"].reshape([-1,1]) for s in samples], max_len=x_len, pad_value=0).squeeze(-1) # [b, t_max_y] + batch['f0'] = f0_batch + return batch + + def get_dataloader(self): + shuffle = True if self.db_key == 'train' else False + max_tokens = 60000 + batches_idx = self.batch_by_size(self.ordered_indices(), max_tokens=max_tokens) + batches_idx = batches_idx * 50 + random.shuffle(batches_idx) + loader = DataLoader(self, pin_memory=True,collate_fn=self.collater, batch_sampler=batches_idx, num_workers=4) + return loader + + +if __name__ == '__main__': + from utils.commons.hparams import set_hparams + set_hparams() + ds = LRS3SeqDataset('train') + ret = ds[0] + loader = ds.get_dataloader() + pbar = tqdm.tqdm(total=len(ds.batch_by_size(ds.ordered_indices()))) + # for i in tqdm.trange(len(ds)): + # ds[i] + for batch in loader: + pbar.update(1) + + pbar = tqdm.tqdm(total=len(ds.batch_by_size(ds.ordered_indices()))) + for batch in loader: + pbar.update(1) diff --git a/Geneface_main/GeneFace/tasks/audio2motion/lm3d_vae_sync.py b/Geneface_main/GeneFace/tasks/audio2motion/lm3d_vae_sync.py new file mode 100644 index 00000000..538f1356 --- /dev/null +++ b/Geneface_main/GeneFace/tasks/audio2motion/lm3d_vae_sync.py @@ -0,0 +1,198 @@ +import torch +import numpy as np +import os + +from utils.commons.base_task import BaseTask +from utils.commons.dataset_utils import data_loader +from utils.commons.hparams import hparams +from utils.commons.ckpt_utils import load_ckpt +from utils.commons.tensor_utils import tensors_to_scalars, convert_to_np +from utils.nn.model_utils import print_arch, get_device_of_model, not_requires_grad +from utils.nn.schedulers import ExponentialSchedule +from utils.nn.grad import get_grad_norm + +from modules.audio2motion.vae import VAEModel +from tasks.audio2motion.dataset_utils.lrs3_dataset import LRS3SeqDataset +from tasks.syncnet.lm3d_syncnet import SyncNetTask + +from data_util.face3d_helper import Face3DHelper + +class VAESyncAudio2MotionTask(BaseTask): + def __init__(self): + super().__init__() + self.dataset_cls = LRS3SeqDataset + self.enable_sync = False # enables when sync loss is lower than 0.5! + self.face3d_helper = Face3DHelper(use_gpu=True) + + def build_model(self): + self.syncnet_task = SyncNetTask() + self.syncnet_task.build_model() + load_ckpt(self.syncnet_task.model, hparams["syncnet_work_dir"], steps=hparams["syncnet_ckpt_steps"]) + not_requires_grad(self.syncnet_task) + self.syncnet_task.eval() + + self.model = VAEModel(in_out_dim=68*3) + return self.model + + def build_optimizer(self, model): + self.optimizer = optimizer = torch.optim.Adam( + model.parameters(), + lr=hparams['lr'], + betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2'])) + return optimizer + + def build_scheduler(self, optimizer): + return ExponentialSchedule(optimizer, hparams['lr'], hparams['warmup_updates']) + + @data_loader + def train_dataloader(self): + train_dataset = self.dataset_cls(prefix='train') + self.train_dl = train_dataset.get_dataloader() + return self.train_dl + + @data_loader + def val_dataloader(self): + val_dataset = self.dataset_cls(prefix='val') + self.val_dl = val_dataset.get_dataloader() + return self.val_dl + + @data_loader + def test_dataloader(self): + val_dataset = self.dataset_cls(prefix='val') + self.val_dl = val_dataset.get_dataloader() + return self.val_dl + + ########################## + # training and validation + ########################## + def run_model(self, sample, infer=False, temperature=1.0, sync_batch_size=1024): + """ + render or train on a single-frame + :param sample: a batch of data + :param infer: bool, run in infer mode + :return: + if not infer: + return losses, model_out + if infer: + return model_out + """ + model_out = {} + if infer: + self.model(sample, model_out, train=False, temperature=temperature) + else: + sample['y'] = sample['idexp_lm3d'] + self.model(sample, model_out, train=True) + + if not infer: + # forward the syncnet to get sync_loss + losses_out = {} + pred_lm3d = model_out['pred'] + pred_lm3d = pred_lm3d.reshape(pred_lm3d.size(0), pred_lm3d.size(1), 68, 3) + + _, pred_mouth_lm3d = self.face3d_helper.get_eye_mouth_lm_from_lm3d_batch(pred_lm3d) + syncnet_sample = { + 'mouth_idexp_lm3d': pred_mouth_lm3d.reshape(pred_mouth_lm3d.size(0), pred_mouth_lm3d.size(1), -1), + 'hubert': sample['hubert'], + 'y_mask': model_out['mask'] + } + syncnet_out = self.syncnet_task.run_model(syncnet_sample, infer=True, batch_size=sync_batch_size) + losses_out['sync'] = syncnet_out['sync_loss'] + + x_gt = sample['idexp_lm3d'] + x_pred = model_out['pred'] + x_mask = model_out['mask'] + losses_out['mse'] = self.mse_loss(x_gt, x_pred, x_mask) + losses_out['continuity'] = self.continuity_loss(x_gt, x_pred, x_mask) + losses_out['kl'] = model_out['loss_kl'] + return losses_out, model_out + else: + return model_out + + def _training_step(self, sample, batch_idx, optimizer_idx): + loss_output, model_out = self.run_model(sample) + loss_weights = { + 'kl': hparams['lambda_kl'], + 'mse': 1.0, + 'continuity': 3.0, + 'sync': hparams.get('lambda_sync', 0.01) if self.enable_sync else 0. + } + total_loss = sum([loss_weights.get(k, 1) * v for k, v in loss_output.items() if isinstance(v, torch.Tensor) and v.requires_grad]) + + return total_loss, loss_output + + def validation_start(self): + pass + + @torch.no_grad() + def validation_step(self, sample, batch_idx): + outputs = {} + outputs['losses'] = {} + outputs['losses'], model_out = self.run_model(sample, infer=False, sync_batch_size=10000) + outputs = tensors_to_scalars(outputs) + if outputs['losses']['sync'] <= 0.75 and not self.enable_sync: + self.enable_sync = True + return outputs + + def validation_end(self, outputs): + return super().validation_end(outputs) + + ##################### + # Testing + ##################### + def test_start(self): + self.gen_dir = os.path.join(hparams['work_dir'], f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}') + os.makedirs(self.gen_dir, exist_ok=True) + + @torch.no_grad() + def test_step(self, sample, batch_idx): + """ + :param sample: + :param batch_idx: + :return: + """ + outputs = {} + outputs['losses'], model_out = self.run_model(sample, infer=True) + pred_exp = model_out['pred'] + self.save_result(pred_exp, "pred_exp_val" , self.gen_dir) + if hparams['save_gt']: + base_fn = f"gt_exp_val" + self.save_result(sample['exp'], base_fn , self.gen_dir) + return outputs + + def test_end(self, outputs): + pass + + @staticmethod + def save_result(exp_arr, base_fname, gen_dir): + exp_arr = convert_to_np(exp_arr) + np.save(f"{gen_dir}/{base_fname}.npy", exp_arr) + + def get_grad(self, opt_idx): + grad_dict = { + 'grad/model': get_grad_norm(self.model), + } + return grad_dict + + def mse_loss(self, x_gt, x_pred, x_mask): + # mean squared error, l2 loss + error = (x_pred - x_gt) * x_mask[:,:, None] + num_frame = x_mask.sum() + n_dim = 68*3 + return (error ** 2).sum() / (num_frame * n_dim) + + def mae_loss(self, x_gt, x_pred, x_mask): + # mean absolute error, l1 loss + error = (x_pred - x_gt) * x_mask[:,:, None] + num_frame = x_mask.sum() + n_dim = 68*3 + return error.abs().sum() / (num_frame * n_dim) + + def continuity_loss(self, x_gt, x_pred, x_mask): + # continuity loss, borrowed from + diff_x_pred = x_pred[:,1:] - x_pred[:,:-1] + diff_x_gt = x_gt[:,1:] - x_gt[:,:-1] + error = (diff_x_pred[:,:,:] - diff_x_gt[:,:,:]) * x_mask[:,1:,None] + init_error = x_pred[:,0,:] - x_gt[:,0,:] + num_frame = x_mask.sum() + n_dim = 68*3 + return (error.pow(2).sum() + init_error.pow(2).sum()) / (num_frame * n_dim) \ No newline at end of file diff --git a/Geneface_main/GeneFace/tasks/audio2motion/lm3d_vae_sync_pitch.py b/Geneface_main/GeneFace/tasks/audio2motion/lm3d_vae_sync_pitch.py new file mode 100644 index 00000000..40884328 --- /dev/null +++ b/Geneface_main/GeneFace/tasks/audio2motion/lm3d_vae_sync_pitch.py @@ -0,0 +1,198 @@ +import torch +import numpy as np +import os + +from utils.commons.base_task import BaseTask +from utils.commons.dataset_utils import data_loader +from utils.commons.hparams import hparams +from utils.commons.ckpt_utils import load_ckpt +from utils.commons.tensor_utils import tensors_to_scalars, convert_to_np +from utils.nn.model_utils import print_arch, get_device_of_model, not_requires_grad +from utils.nn.schedulers import ExponentialSchedule +from utils.nn.grad import get_grad_norm + +from modules.audio2motion.vae import PitchContourVAEModel +from tasks.audio2motion.dataset_utils.lrs3_dataset import LRS3SeqDataset +from tasks.syncnet.lm3d_syncnet import SyncNetTask + +from data_util.face3d_helper import Face3DHelper + +class VAESyncAudio2MotionTask(BaseTask): + def __init__(self): + super().__init__() + self.dataset_cls = LRS3SeqDataset + self.enable_sync = False # enables when sync loss is lower than 0.5! + self.face3d_helper = Face3DHelper(use_gpu=True) + + def build_model(self): + self.syncnet_task = SyncNetTask() + self.syncnet_task.build_model() + load_ckpt(self.syncnet_task.model, hparams["syncnet_work_dir"], steps=hparams["syncnet_ckpt_steps"]) + not_requires_grad(self.syncnet_task) + self.syncnet_task.eval() + + self.model = PitchContourVAEModel(in_out_dim=68*3) + return self.model + + def build_optimizer(self, model): + self.optimizer = optimizer = torch.optim.Adam( + model.parameters(), + lr=hparams['lr'], + betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2'])) + return optimizer + + def build_scheduler(self, optimizer): + return ExponentialSchedule(optimizer, hparams['lr'], hparams['warmup_updates']) + + @data_loader + def train_dataloader(self): + train_dataset = self.dataset_cls(prefix='train') + self.train_dl = train_dataset.get_dataloader() + return self.train_dl + + @data_loader + def val_dataloader(self): + val_dataset = self.dataset_cls(prefix='val') + self.val_dl = val_dataset.get_dataloader() + return self.val_dl + + @data_loader + def test_dataloader(self): + val_dataset = self.dataset_cls(prefix='val') + self.val_dl = val_dataset.get_dataloader() + return self.val_dl + + ########################## + # training and validation + ########################## + def run_model(self, sample, infer=False, temperature=1.0, sync_batch_size=1024): + """ + render or train on a single-frame + :param sample: a batch of data + :param infer: bool, run in infer mode + :return: + if not infer: + return losses, model_out + if infer: + return model_out + """ + model_out = {} + if infer: + self.model(sample, model_out, train=False, temperature=temperature) + else: + sample['y'] = sample['idexp_lm3d'] + self.model(sample, model_out, train=True) + + if not infer: + # forward the syncnet to get sync_loss + losses_out = {} + pred_lm3d = model_out['pred'] + pred_lm3d = pred_lm3d.reshape(pred_lm3d.size(0), pred_lm3d.size(1), 68, 3) + + _, pred_mouth_lm3d = self.face3d_helper.get_eye_mouth_lm_from_lm3d_batch(pred_lm3d) + syncnet_sample = { + 'mouth_idexp_lm3d': pred_mouth_lm3d.reshape(pred_mouth_lm3d.size(0), pred_mouth_lm3d.size(1), -1), + 'hubert': sample['hubert'], + 'y_mask': model_out['mask'], + } + syncnet_out = self.syncnet_task.run_model(syncnet_sample, infer=True, batch_size=sync_batch_size) + losses_out['sync'] = syncnet_out['sync_loss'] + + x_gt = sample['idexp_lm3d'] + x_pred = model_out['pred'] + x_mask = model_out['mask'] + losses_out['mse'] = self.mse_loss(x_gt, x_pred, x_mask) + losses_out['continuity'] = self.continuity_loss(x_gt, x_pred, x_mask) + losses_out['kl'] = model_out['loss_kl'] + return losses_out, model_out + else: + return model_out + + def _training_step(self, sample, batch_idx, optimizer_idx): + loss_output, model_out = self.run_model(sample) + loss_weights = { + 'kl': hparams['lambda_kl'], + 'mse': 1.0, + 'continuity': 3.0, + 'sync': hparams.get('lambda_sync', 0.01) if self.enable_sync else 0. + } + total_loss = sum([loss_weights.get(k, 1) * v for k, v in loss_output.items() if isinstance(v, torch.Tensor) and v.requires_grad]) + + return total_loss, loss_output + + def validation_start(self): + pass + + @torch.no_grad() + def validation_step(self, sample, batch_idx): + outputs = {} + outputs['losses'] = {} + outputs['losses'], model_out = self.run_model(sample, infer=False, sync_batch_size=10000) + outputs = tensors_to_scalars(outputs) + if outputs['losses']['sync'] <= 0.75 and not self.enable_sync: + self.enable_sync = True + return outputs + + def validation_end(self, outputs): + return super().validation_end(outputs) + + ##################### + # Testing + ##################### + def test_start(self): + self.gen_dir = os.path.join(hparams['work_dir'], f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}') + os.makedirs(self.gen_dir, exist_ok=True) + + @torch.no_grad() + def test_step(self, sample, batch_idx): + """ + :param sample: + :param batch_idx: + :return: + """ + outputs = {} + outputs['losses'], model_out = self.run_model(sample, infer=True) + pred_exp = model_out['pred'] + self.save_result(pred_exp, "pred_exp_val" , self.gen_dir) + if hparams['save_gt']: + base_fn = f"gt_exp_val" + self.save_result(sample['exp'], base_fn , self.gen_dir) + return outputs + + def test_end(self, outputs): + pass + + @staticmethod + def save_result(exp_arr, base_fname, gen_dir): + exp_arr = convert_to_np(exp_arr) + np.save(f"{gen_dir}/{base_fname}.npy", exp_arr) + + def get_grad(self, opt_idx): + grad_dict = { + 'grad/model': get_grad_norm(self.model), + } + return grad_dict + + def mse_loss(self, x_gt, x_pred, x_mask): + # mean squared error, l2 loss + error = (x_pred - x_gt) * x_mask[:,:, None] + num_frame = x_mask.sum() + n_dim = 68*3 + return (error ** 2).sum() / (num_frame * n_dim) + + def mae_loss(self, x_gt, x_pred, x_mask): + # mean absolute error, l1 loss + error = (x_pred - x_gt) * x_mask[:,:, None] + num_frame = x_mask.sum() + n_dim = 68*3 + return error.abs().sum() / (num_frame * n_dim) + + def continuity_loss(self, x_gt, x_pred, x_mask): + # continuity loss, borrowed from + diff_x_pred = x_pred[:,1:] - x_pred[:,:-1] + diff_x_gt = x_gt[:,1:] - x_gt[:,:-1] + error = (diff_x_pred[:,:,:] - diff_x_gt[:,:,:]) * x_mask[:,1:,None] + init_error = x_pred[:,0,:] - x_gt[:,0,:] + num_frame = x_mask.sum() + n_dim = 68*3 + return (error.pow(2).sum() + init_error.pow(2).sum()) / (num_frame * n_dim) \ No newline at end of file diff --git a/Geneface_main/GeneFace/tasks/audio2pose/audio2pose.py b/Geneface_main/GeneFace/tasks/audio2pose/audio2pose.py new file mode 100644 index 00000000..10f8f20e --- /dev/null +++ b/Geneface_main/GeneFace/tasks/audio2pose/audio2pose.py @@ -0,0 +1,115 @@ +from utils.commons.base_task import BaseTask + +import torch +import numpy as np +import os + +from utils.commons.base_task import BaseTask +from utils.commons.dataset_utils import data_loader +from utils.commons.hparams import hparams +from utils.commons.ckpt_utils import load_ckpt +from utils.commons.tensor_utils import tensors_to_scalars, convert_to_np +from utils.nn.model_utils import print_arch, get_device_of_model, not_requires_grad +from utils.nn.schedulers import ExponentialSchedule +from utils.nn.grad import get_grad_norm + +from modules.audio2pose.models import Audio2PoseModel +from modules.audio2pose.gmm_utils import GMMLogLoss +from tasks.audio2pose.dataset_utils import Audio2PoseDataset + +class Audio2PoseTask(BaseTask): + def __init__(self): + super().__init__() + self.dataset_cls = Audio2PoseDataset + self.gmm_loss_fn = GMMLogLoss(ncenter=1, ndim=12, sigma_min=0.03) + + def build_model(self): + self.model = Audio2PoseModel(hparams['reception_field']) + return self.model + + def build_optimizer(self, model): + self.optimizer = optimizer = torch.optim.Adam( + model.parameters(), + lr=hparams['lr'], + betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2'])) + return optimizer + + def build_scheduler(self, optimizer): + return ExponentialSchedule(optimizer, hparams['lr'], hparams['warmup_updates']) + + @data_loader + def train_dataloader(self): + train_dataset = self.dataset_cls() + self.train_dl = train_dataset.get_dataloader() + return self.train_dl + + @data_loader + def val_dataloader(self): + val_dataset = self.dataset_cls() + self.val_dl = val_dataset.get_dataloader() + return self.val_dl + + @data_loader + def test_dataloader(self): + val_dataset = self.dataset_cls(prefix='val') + self.val_dl = val_dataset.get_dataloader() + return self.val_dl + + ########################## + # training and validation + ########################## + def run_model(self, sample): + """ + render or train on a single-frame + :param sample: a batch of data + :param infer: bool, run in infer mode + :return: + if not infer: + return losses, model_out + if infer: + return model_out + """ + model_out = {} + losses_out = {} + audio_window = sample['audio_window'] + history_pose_and_velocity = sample['history_pose_and_velocity'] + target_pose_and_velocity = sample['target_pose_and_velocity'] + + ret = self.model.forward(audio_window, history_pose_and_velocity) + pred_pose_velocity_gmm_params = ret[:,-1, :] + + model_out['pred_pose_velocity_gmm_params'] = pred_pose_velocity_gmm_params + losses_out['gmm_loss'] = self.gmm_loss_fn(pred_pose_velocity_gmm_params.unsqueeze(1), target_pose_and_velocity.unsqueeze(1)) + losses_out['history_gmm_loss'] = self.gmm_loss_fn(ret[:-1], history_pose_and_velocity[1:]) + + return losses_out, model_out + + + def _training_step(self, sample, batch_idx, optimizer_idx): + loss_output, model_out = self.run_model(sample) + loss_weights = { + 'gmm_loss': 1.0, + } + total_loss = sum([loss_weights.get(k, 1) * v for k, v in loss_output.items() if isinstance(v, torch.Tensor) and v.requires_grad]) + + return total_loss, loss_output + + def validation_start(self): + pass + + @torch.no_grad() + def validation_step(self, sample, batch_idx): + outputs = {} + outputs['losses'] = {} + outputs['losses'], model_out = self.run_model(sample) + outputs = tensors_to_scalars(outputs) + return outputs + + def validation_end(self, outputs): + return super().validation_end(outputs) + + def get_grad(self, opt_idx): + grad_dict = { + 'grad/model': get_grad_norm(self.model), + } + return grad_dict diff --git a/Geneface_main/GeneFace/tasks/audio2pose/dataset_utils.py b/Geneface_main/GeneFace/tasks/audio2pose/dataset_utils.py new file mode 100644 index 00000000..a295955c --- /dev/null +++ b/Geneface_main/GeneFace/tasks/audio2pose/dataset_utils.py @@ -0,0 +1,82 @@ +import os +import tqdm +import random +import torch +import numpy as np +from torch.utils.data import DataLoader +from utils.commons.hparams import hparams, set_hparams +from utils.commons.tensor_utils import convert_to_tensor +from utils.commons.euler2rot import euler_trans_2_c2w, c2w_to_euler_trans + + +class Audio2PoseDataset(torch.utils.data.Dataset): + def __init__(self, data_dir=None): + super().__init__() + self.data_dir = os.path.join(hparams['binary_data_dir'], hparams['video_id']) if data_dir is None else data_dir + binary_file_name = os.path.join(self.data_dir, "trainval_dataset.npy") + ds_dict = np.load(binary_file_name, allow_pickle=True).tolist() + self.samples = [convert_to_tensor(sample) for sample in ds_dict['train_samples']] + [convert_to_tensor(sample) for sample in ds_dict['val_samples']] + self.num_samples = len(ds_dict['train_samples']) + len(ds_dict['val_samples']) + self.audio_lst = [None] * self.num_samples + self.pose_lst = [None] * self.num_samples + self.euler_lst = [None] * self.num_samples + self.trans_lst = [None] * self.num_samples + for i in range(self.num_samples): + sample = self.samples[i] + # audio_win_size = sample['hubert_win'].shape[0] + # audio = sample['hubert_win'][audio_win_size//2-1:audio_win_size//2+1].reshape([2*1024]) + audio = sample['deepspeech_win'][7:9,:].reshape([2*29]) + self.audio_lst[i] = audio + self.euler_lst[i] = sample['euler'] + self.trans_lst[i] = sample['trans'] + # todo: 计算mean trans + self.mean_trans = torch.stack(self.trans_lst).mean(dim=0) + self.trans_lst = [self.trans_lst[i]-self.mean_trans for i in range(self.num_samples)] + self.pose_lst = [torch.cat([self.euler_lst[i], self.trans_lst[i]], dim=-1) for i in range(self.num_samples)] + self.pose_velocity_lst = [torch.zeros_like(self.pose_lst[0])]+[self.pose_lst[i+1] - self.pose_lst[i] for i in range(0,self.num_samples-1)] + + self.audio_lst = torch.stack(self.audio_lst) + self.pose_lst = torch.stack(self.pose_lst) + self.pose_velocity_lst = torch.stack(self.pose_velocity_lst) + # self.reception_field = 30 + self.reception_field = hparams['reception_field'] + self.target_length = 5 + + def __getitem__(self, idx): + if idx < self.reception_field or idx > self.num_samples - self.target_length: + idx = random.randint(self.reception_field, self.num_samples - self.target_length) + sample = { + 'idx': idx, + 'audio': self.audio_lst[idx-self.reception_field: idx], # [t=30, c=512] + 'history_pose': self.pose_lst[idx-self.reception_field: idx], # [t=30, c=6] + 'history_velocity': self.pose_velocity_lst[idx-self.reception_field: idx], # [t=30, c=6] + 'target_pose': self.pose_lst[idx], # [c=6] + 'target_velocity': self.pose_velocity_lst[idx] # [c=6] + } + sample['history_pose_and_velocity'] = torch.cat([sample['history_pose'], sample['history_velocity']], dim=-1) # [t=30, c=12] + sample['target_pose_and_velocity'] = torch.cat([sample['target_pose'], sample['target_velocity']], dim=-1) # [c=12] + return sample + + def __len__(self): + return len(self.samples) + + def collater(self, samples): + batch = { + 'idx' : [s['idx'] for s in samples], + 'audio_window': torch.stack([s['audio'] for s in samples]), # [b, t=30, c=512] + 'history_pose_and_velocity': torch.stack([s['history_pose_and_velocity'] for s in samples]), # [b, t=30, t=12] + 'target_pose_and_velocity': torch.stack([s['target_pose_and_velocity'] for s in samples]), # [b, t=12] + } + return batch + + def get_dataloader(self, batch_size=64): + loader = DataLoader(self,batch_size=batch_size, pin_memory=False,collate_fn=self.collater, shuffle=True, num_workers=4) + return loader + +if __name__ == '__main__': + set_hparams() + ds = Audio2PoseDataset(data_dir='data/binary/videos/May') + dl = ds.get_dataloader() + for batch in dl: + print("") + print("done!") \ No newline at end of file diff --git a/Geneface_main/GeneFace/tasks/nerfs/__pycache__/dataset_utils.cpython-39.pyc b/Geneface_main/GeneFace/tasks/nerfs/__pycache__/dataset_utils.cpython-39.pyc new file mode 100644 index 00000000..e1d6c21c Binary files /dev/null and b/Geneface_main/GeneFace/tasks/nerfs/__pycache__/dataset_utils.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/tasks/nerfs/adnerf.py b/Geneface_main/GeneFace/tasks/nerfs/adnerf.py new file mode 100644 index 00000000..c916291b --- /dev/null +++ b/Geneface_main/GeneFace/tasks/nerfs/adnerf.py @@ -0,0 +1,252 @@ +import torch +import torch.nn as nn +import numpy as np +import os +import cv2 + +from modules.nerfs.commons.volume_rendering import render_dynamic_face +from modules.nerfs.adnerf.adnerf import ADNeRF +from modules.nerfs.commons.ray_samplers import UniformRaySampler, FullRaySampler, PatchRaySampler + +from utils.commons.image_utils import to8b +from utils.commons.base_task import BaseTask +from utils.commons.dataset_utils import data_loader +from utils.commons.hparams import hparams +from utils.commons.ckpt_utils import load_ckpt +from utils.commons.tensor_utils import tensors_to_scalars, convert_to_np, move_to_cuda +from utils.nn.model_utils import print_arch, num_params +from utils.nn.schedulers import ExponentialScheduleWithAudattNet +from utils.nn.grad import get_grad_norm + +from tasks.nerfs.dataset_utils import NeRFDataset + + +class ADNeRFTask(BaseTask): + def __init__(self): + super().__init__() + self.chunk = 1024 + self.no_smo_iterations = hparams['no_smo_iterations'] + self.n_samples_per_ray = hparams['n_samples_per_ray'] + self.n_samples_per_ray_fine = hparams['n_samples_per_ray_fine'] + self.n_rays = hparams['n_rays'] + self.rays_sampler = UniformRaySampler(N_rays=self.n_rays) + self.full_rays_sampler = FullRaySampler() + self.dataset_cls = NeRFDataset + self.train_dataset = self.dataset_cls(prefix='train') + self.val_dataset = self.dataset_cls(prefix='val') + + def build_model(self): + self.model = ADNeRF(hparams) + self.audatt_net_params = [p for p in self.model.audatt_net.parameters() if p.requires_grad] + self.gen_params_except_audatt_net = [p for k, p in self.model.named_parameters() if (('audatt_net' not in k) and p.requires_grad)] + return self.model + + def on_train_start(self): + super().on_train_start() + for n, m in self.model.named_children(): + num_params(m, model_name=n) + + def build_optimizer(self, model): + self.optimizer = torch.optim.Adam( + self.gen_params_except_audatt_net, + lr=hparams['lr'], + betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2'])) + self.optimizer.add_param_group({ + 'params': self.audatt_net_params, + 'lr': hparams['lr'] * 5, + 'betas': (hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']) + }) + return self.optimizer + + def build_scheduler(self, optimizer): + return ExponentialScheduleWithAudattNet(optimizer, hparams['lr'], hparams['warmup_updates']) + + @data_loader + def train_dataloader(self): + self.train_dl = torch.utils.data.DataLoader(self.train_dataset,collate_fn=self.train_dataset.collater, + batch_size=1, shuffle=True, + num_workers=0, pin_memory=True) + return self.train_dl + + @data_loader + def val_dataloader(self): + self.val_dl = torch.utils.data.DataLoader(self.val_dataset,collate_fn=self.val_dataset.collater, + batch_size=1, shuffle=True, + num_workers=0, pin_memory=True) + return self.val_dl + + @data_loader + def test_dataloader(self): + self.val_dl = torch.utils.data.DataLoader(self.val_dataset,collate_fn=self.val_dataset.collater, + batch_size=1, shuffle=False, + num_workers=0, pin_memory=True) + return self.val_dl + + ########################## + # forward the model + ########################## + def run_model(self, sample, infer=False): + """ + render or train on a single-frame + :param sample: a batch of data + :param infer: bool, run in infer mode + :return: + if not infer: + return losses, model_out + if infer: + return model_out + """ + cond = sample['cond_win'] if hparams['use_window_cond'] else sample['cond'] + cond_wins = sample['cond_wins'] + H = sample['H'] + W = sample['W'] + focal = sample['focal'] + cx = sample['cx'] + cy = sample['cy'] + near = sample['near'] + far = sample['far'] + bg_img = sample['bg_img'] + c2w = sample['c2w'] + c2w_t0 = sample['c2w_t0'] + t = sample['t'] + + with_att = self.global_step >= self.no_smo_iterations + if with_att: + cond_feat = self.model.cal_cond_feat(cond_wins, with_att=True) + else: + cond_feat = self.model.cal_cond_feat(cond, with_att=False) + + if infer: + rays_o, rays_d, _ = self.full_rays_sampler(H, W, focal, c2w) + rgb_pred, disp, acc, _, _, extras = render_dynamic_face(H, W, focal, cx, cy, rays_o=rays_o, rays_d=rays_d, + bc_rgb=bg_img, + chunk=2048, + c2w=None, cond=cond_feat, near=near, far=far, + network_fn=self.model, N_samples=self.n_samples_per_ray, N_importance=self.n_samples_per_ray_fine, + c2w_t=c2w, c2w_t0=c2w_t0,t=t, + ) + model_out = { + "rgb_map" : rgb_pred + } + return model_out + else: + rays_o, rays_d, select_coords = self.rays_sampler(H, W, focal, c2w, n_rays=None, rect=sample['rect'], in_rect_percent=hparams['in_rect_percent'], iterations=self.global_step) + target = sample['head_img'] + rgb_gt = self.rays_sampler.sample_pixels_from_img_with_select_coords(target, select_coords) + rgb_bc = self.rays_sampler.sample_pixels_from_img_with_select_coords(sample['bg_img'], select_coords) + + rgb_pred, disp, acc, _, _, extras = render_dynamic_face(H, W, focal, cx, cy, rays_o=rays_o, rays_d=rays_d, + bc_rgb=rgb_bc,chunk=self.chunk, c2w=None, cond=cond_feat, near=near, far=far, + network_fn=self.model, N_samples=self.n_samples_per_ray, N_importance=self.n_samples_per_ray_fine, + c2w_t=c2w, c2w_t0=c2w_t0,t=t,) + losses_out = {} + losses_out['mse_loss'] = torch.mean((rgb_pred - rgb_gt) ** 2) + if 'rgb_map_coarse' in extras: + losses_out['mse_loss_coarse'] = torch.mean((extras['rgb_map_coarse'] - rgb_gt) ** 2) + model_out = { + "rgb_map": rgb_pred + } + return losses_out, model_out + + ########################## + # training + ########################## + def _training_step(self, sample, batch_idx, optimizer_idx): + loss_output, model_out = self.run_model(sample) + total_loss = sum([v for v in loss_output.values() if isinstance(v, torch.Tensor) and v.requires_grad]) + def mse2psnr(x): return -10. * torch.log(x) / torch.log(torch.Tensor([10.])).to(x.device) + loss_output['head_psnr'] = mse2psnr(loss_output['mse_loss'].detach()) + return total_loss, loss_output + + def on_before_optimization(self, opt_idx): + prefix = f"grad_norm_opt_idx_{opt_idx}" + grad_norm_dict = { + f'{prefix}/model_coarse': get_grad_norm(self.model.model_coarse), + f'{prefix}/model_fine': get_grad_norm(self.model.model_fine), + f'{prefix}/aud_net': get_grad_norm(self.model.aud_net), + f'{prefix}/audatt_net': get_grad_norm(self.model.audatt_net), + } + return grad_norm_dict + + def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx): + self.scheduler.step(self.global_step // hparams['accumulate_grad_batches']) + + ##################### + # Validation + ##################### + def validation_start(self): + if self.global_step % hparams['valid_infer_interval'] == 0: + self.gen_dir = os.path.join(hparams['work_dir'], f'validation_results/validation_{self.trainer.global_step}_{hparams["gen_dir_name"]}') + os.makedirs(self.gen_dir, exist_ok=True) + os.makedirs(f'{self.gen_dir}/imgs', exist_ok=True) + os.makedirs(f'{self.gen_dir}/plot', exist_ok=True) + + @torch.no_grad() + def validation_step(self, sample, batch_idx): + outputs = {} + outputs['losses'] = {} + outputs['losses'], model_out = self.run_model(sample) + outputs['total_loss'] = sum(outputs['losses'].values()) + outputs['nsamples'] = 1 + outputs = tensors_to_scalars(outputs) + if self.global_step % hparams['valid_infer_interval'] == 0 \ + and batch_idx < hparams['num_valid_plots']: + # idx_lst = [291,156,540,113,28] + num_val_samples = len(self.val_dataset) + interval = (num_val_samples-1) // 4 + idx_lst = [i * interval for i in range(5)] + sample = move_to_cuda(self.val_dataset[idx_lst[batch_idx]]) + infer_outputs = self.run_model(sample, infer=True) + rgb_pred = infer_outputs['rgb_map'] + H, W = sample['H'], sample['W'] + img_pred = rgb_pred.reshape([H, W, 3]) + gen_dir = self.gen_dir + base_fn = f"frame_{sample['idx']}" + self.save_result(img_pred, base_fn , gen_dir) + target = sample['head_img'] + img_gt = target.reshape([H, W, 3]) + if hparams['save_gt']: + base_fn = f"frame_{sample['idx']}_gt" + self.save_result(img_gt, base_fn , gen_dir) + return outputs + + def validation_end(self, outputs): + return super().validation_end(outputs) + + ##################### + # Testing + ##################### + def test_start(self): + self.gen_dir = os.path.join(hparams['work_dir'], f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}') + os.makedirs(self.gen_dir, exist_ok=True) + os.makedirs(f'{self.gen_dir}/imgs', exist_ok=True) + os.makedirs(f'{self.gen_dir}/plot', exist_ok=True) + + @torch.no_grad() + def test_step(self, sample, batch_idx): + outputs = self.run_model(sample, infer=True) + rgb_pred = outputs['rgb_map'] + H, W = sample['H'], sample['W'] + img_pred = rgb_pred.reshape([H, W, 3]) + gen_dir = self.gen_dir + base_fn = f"frame_{sample['idx']}" + self.save_result(img_pred, base_fn , gen_dir) + target = sample['gt_img'] if hparams['use_pos_deform'] else sample['head_img'] + img_gt = target.reshape([H, W, 3]) + if hparams['save_gt']: + base_fn = f"frame_{sample['idx']}_gt" + self.save_result(img_gt, base_fn , gen_dir) + outputs['losses'] = (img_gt - img_pred).mean() + return outputs + + def test_end(self, outputs): + pass + + ##################### + # Visualization utils + ##################### + @staticmethod + def save_result(rgb, base_fname, gen_dir): + rgb = convert_to_np(rgb * 255.).astype(np.uint8) + bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) + cv2.imwrite(f"{gen_dir}/imgs/{base_fname}.jpg", bgr) diff --git a/Geneface_main/GeneFace/tasks/nerfs/adnerf_torso.py b/Geneface_main/GeneFace/tasks/nerfs/adnerf_torso.py new file mode 100644 index 00000000..042f5ca9 --- /dev/null +++ b/Geneface_main/GeneFace/tasks/nerfs/adnerf_torso.py @@ -0,0 +1,230 @@ +import torch +import torch.nn as nn +import numpy as np +import tqdm +import time +import imageio +import os +import cv2 + +from utils.commons.hparams import hparams +from utils.commons.ckpt_utils import load_ckpt +from utils.nn.model_utils import print_arch +from utils.nn.grad import get_grad_norm, GradBuffer +from utils.nn.model_utils import print_arch, get_device_of_model, not_requires_grad +from utils.commons.tensor_utils import tensors_to_scalars, convert_to_np, move_to_cuda, convert_to_tensor +from utils.commons.ckpt_utils import load_ckpt +from utils.nn.schedulers import ExponentialSchedule, ExponentialScheduleWithAudattNet + +from modules.nerfs.adnerf.adnerf import ADNeRF as ADNeRF_head +from modules.nerfs.adnerf.adnerf_torso import ADNeRFTorso as ADNeRF_torso +from modules.nerfs.commons.volume_rendering import render_dynamic_face +from modules.nerfs.commons.ray_samplers import TorsoUniformRaySampler + +from tasks.nerfs.adnerf import ADNeRFTask + + +class ADNeRFTorsoTask(ADNeRFTask): + def __init__(self): + super().__init__() + self.torso_rays_sampler = TorsoUniformRaySampler(self.n_rays) + + def build_model(self): + self.head_model = ADNeRF_head(hparams) + head_model_dir = hparams['head_model_dir'] + load_ckpt(self.head_model, head_model_dir) + + self.model = ADNeRF_torso(hparams) + self.audatt_net_params = [p for p in self.model.audatt_net.parameters() if p.requires_grad] + self.gen_params_except_audatt_net = [p for k, p in self.model.named_parameters() if (('audatt_net' not in k) and p.requires_grad)] + return self.model + + def build_optimizer(self, model): + self.optimizer = torch.optim.Adam( + self.gen_params_except_audatt_net, + lr=hparams['lr'], + betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2'])) + self.optimizer.add_param_group({ + 'params': self.audatt_net_params, + 'lr': hparams['lr'] * 5, + 'betas': (hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']) + }) + return self.optimizer + + def build_scheduler(self, optimizer): + return ExponentialScheduleWithAudattNet(optimizer, hparams['lr'], hparams['warmup_updates']) + + def run_model(self, sample, infer=False, run_head_mode=False): + """ + render or train on a single-frame + :param sample: a batch of data + :param infer: bool, run in infer mode + :return: + if not infer: + return losses, model_out + if infer: + return model_out + """ + cond = sample['cond_win'] if hparams['use_window_cond'] else sample['cond'] + cond_wins = sample['cond_wins'] + H = sample['H'] + W = sample['W'] + focal = sample['focal'] + cx = sample['cx'] + cy = sample['cy'] + near = sample['near'] + far = sample['far'] + bg_img = sample['bg_img'] + c2w_t = sample['c2w'] + c2w_t0 = sample['c2w_t0'] + euler = sample['euler'] + euler_t0 = sample['euler_t0'] + trans = sample['trans'] + trans_t0 = sample['trans_t0'] + losses_out = {} + + if infer: + # Inference Phase + with torch.no_grad(): + # render head + cond_feat = self.head_model.cal_cond_feat(cond_wins, with_att=True) + rays_o, rays_d, _ = self.full_rays_sampler(H, W, focal, c2w_t) + rgb_pred, disp, acc, last_weight, rgb_map_fg, extras = render_dynamic_face(H, W, focal, cx, cy, rays_o=rays_o, rays_d=rays_d, + bc_rgb=bg_img, chunk=2048, c2w=None, cond=cond_feat, near=near, far=far, + network_fn=self.head_model, N_samples=self.n_samples_per_ray, N_importance=self.n_samples_per_ray_fine, + run_head_mode=True, + c2w_t=c2w_t, c2w_t0=c2w_t0,t=torch.tensor([0.,]).cuda(), + ) + + # render torso + cond_feat = self.model.cal_cond_feat(cond_wins, color=rgb_pred, euler=sample['euler'], trans=sample['trans'],with_att=True) + rays_o, rays_d, _ = self.full_rays_sampler(H, W, focal, c2w_t0) + rgb_pred_torso, disp_torso, acc_torso, last_weight_torso, rgb_map_fg_torso, extras_torso = render_dynamic_face(H, W, focal, cx, cy, rays_o=rays_o, rays_d=rays_d, + bc_rgb=bg_img,chunk=2048, c2w=None, cond=cond_feat, near=near, far=far, + network_fn=self.model, N_samples=self.n_samples_per_ray, N_importance=self.n_samples_per_ray_fine, + run_head_mode=False, + c2w_t=c2w_t, c2w_t0=c2w_t0,t=torch.tensor([0.,]).cuda(), + euler=euler, euler_t0=euler_t0, trans=trans, trans_t0=trans_t0 + ) + rgb_com = rgb_pred * last_weight_torso.unsqueeze(-1) + rgb_map_fg_torso + + model_out = { + "rgb_map" : rgb_com + } + return model_out + else: + # Training Phase + if run_head_mode: + # Run Head NeRF + with torch.no_grad(): + cond_feat = self.head_model.cal_cond_feat(cond_wins, with_att=True) + + target = sample['head_img'] + rays_o, rays_d, select_coords = self.rays_sampler(H, W, focal, c2w_t, n_rays=None, rect=sample['rect'], in_rect_percent=hparams['in_rect_percent'], iterations=self.global_step) + rgb_gt = self.rays_sampler.sample_pixels_from_img_with_select_coords(target, select_coords) + rgb_bc = self.rays_sampler.sample_pixels_from_img_with_select_coords(sample['bg_img'], select_coords) + rgb_pred, disp, acc, last_weight, rgb_map_fg, extras = render_dynamic_face(H, W, focal, cx, cy, rays_o=rays_o, rays_d=rays_d, + bc_rgb=rgb_bc,chunk=self.chunk, c2w=None, cond=cond_feat, near=near, far=far, + network_fn=self.head_model, N_samples=self.n_samples_per_ray, N_importance=self.n_samples_per_ray_fine, + run_head_mode=True, + c2w_t=c2w_t, c2w_t0=c2w_t0,t=torch.tensor([0.,]).cuda(), + ) + losses_out['head_mse_loss'] = torch.mean((rgb_pred - rgb_gt) ** 2) + if 'rgb_map_coarse' in extras: + losses_out['head_mse_loss_coarse'] = torch.mean((extras['rgb_map_coarse'] - rgb_gt) ** 2) + model_out = { + "rgb_map": rgb_pred + } + else: + # Run Torso NeRF + target = sample['gt_img'] + rect = [0, H/2, W, H/2] # only sample the lower part for torso + + rays_o, rays_d, select_coords = self.torso_rays_sampler(H, W, focal, c2w_t0, n_rays=None, rect=rect, + in_rect_percent=hparams['in_rect_percent']) + rays_o_head, rays_d_head, _ = self.rays_sampler(H, W, focal, c2w_t, select_coords=select_coords) + rgb_gt = self.rays_sampler.sample_pixels_from_img_with_select_coords(target, select_coords) + rgb_bc = self.rays_sampler.sample_pixels_from_img_with_select_coords(sample['bg_img'], select_coords) + + with torch.no_grad(): + cond_feat = self.head_model.cal_cond_feat(cond_wins, with_att=True) + rgb_pred_head, disp_head, acc_head, last_weight_head, rgb_map_fg_head, extras_head = render_dynamic_face(H, W, focal, cx, cy, rays_o=rays_o_head, rays_d=rays_d_head, + bc_rgb=rgb_bc,chunk=self.chunk, c2w=None, cond=cond_feat, near=near, far=far, + network_fn=self.head_model, N_samples=self.n_samples_per_ray, N_importance=self.n_samples_per_ray_fine, + run_head_mode=True, + c2w_t=c2w_t, c2w_t0=c2w_t0,t=torch.tensor([0.,]).cuda(), + ) + + cond_feat = self.model.cal_cond_feat(cond_wins, color=rgb_pred_head, euler=sample['euler'], trans=sample['trans'],with_att=True) + + rgb_pred_torso, disp_torso, acc_torso, last_weight_torso, rgb_map_fg_torso, extras_torso = render_dynamic_face(H, W, focal, cx, cy, rays_o=rays_o, rays_d=rays_d, + bc_rgb=rgb_bc,chunk=self.chunk, c2w=None, cond=cond_feat, near=near, far=far, + network_fn=self.model, N_samples=self.n_samples_per_ray, N_importance=self.n_samples_per_ray_fine, + run_head_mode=run_head_mode, + c2w_t=c2w_t, c2w_t0=c2w_t0,t=torch.tensor([0.,]).cuda(), + euler=euler, euler_t0=euler_t0, trans=trans, trans_t0=trans_t0 + ) + + rgb_com = rgb_pred_head * last_weight_torso.unsqueeze(-1) + rgb_map_fg_torso + losses_out['com_mse_loss'] = torch.mean((rgb_com - rgb_gt) ** 2) + def mse2psnr(x): return -10. * torch.log(x) / torch.log(torch.Tensor([10.])).to(x.device) + losses_out['com_psnr'] = mse2psnr(losses_out['com_mse_loss'].detach()) + if 'rgb_map_coarse' in extras_torso: + rgb_com0 = extras_head['rgb_map_coarse'] * extras_torso['last_weight0'].unsqueeze(-1) + extras_torso['rgb_map_fg0'] + losses_out['com_mse_loss_coarse'] = torch.mean((rgb_com0 - rgb_gt) ** 2) + model_out = { + "rgb_map": rgb_com + } + return losses_out, model_out + + + def _training_step(self, sample, batch_idx, optimizer_idx): + loss_output = {} + loss_weights = {} + ####################### + # TorsoNeRF # + ####################### + loss_output, model_out = self.run_model(sample, infer=False, run_head_mode=False) + def mse2psnr(x): return -10. * torch.log(x) / torch.log(torch.Tensor([10.])).to(x.device) + loss_output['torso_psnr'] = mse2psnr(loss_output['com_mse_loss'].detach()) + total_loss = sum([loss_weights.get(k, 1) * v for k, v in loss_output.items() if isinstance(v, torch.Tensor) and v.requires_grad]) + return total_loss, loss_output + + def on_before_optimization(self, opt_idx): + prefix = f"grad_norm_opt_idx_{opt_idx}" + grad_norm_dict = { + f'{prefix}/model_coarse_torso': get_grad_norm(self.model.model_coarse), + f'{prefix}/model_fine_torso': get_grad_norm(self.model.model_fine), + } + if hparams.get("use_color", False): + grad_norm_dict[f'{prefix}/color_encoder_torso'] = get_grad_norm(self.model.color_encoder) + return grad_norm_dict + + @torch.no_grad() + def validation_step(self, sample, batch_idx): + sample['c2w_t0'] = convert_to_tensor(self.train_dataset.samples[0]['c2w'][:3]).float().to(sample['c2w'].device) + outputs = {} + outputs['losses'] = {} + outputs['losses'], model_out = self.run_model(sample, infer=False, run_head_mode=False) + outputs['total_loss'] = sum(outputs['losses'].values()) + outputs['nsamples'] = 1 + outputs = tensors_to_scalars(outputs) + if self.global_step % hparams['valid_infer_interval'] == 0 \ + and batch_idx < hparams['num_valid_plots']: + idx_interval = (len(self.val_dataset)-1)//(hparams['num_valid_plots']-1) + idx_lst = [i_plot*idx_interval for i_plot in range(hparams['num_valid_plots'])] + sample = move_to_cuda(self.val_dataset[idx_lst[batch_idx]]) + sample['c2w_t0'] = convert_to_tensor(self.train_dataset.samples[0]['c2w'][:3]).float().to(sample['c2w'].device) + infer_outputs = self.run_model(sample, infer=True) + rgb_pred = infer_outputs['rgb_map'] + H, W = sample['H'], sample['W'] + img_pred = rgb_pred.reshape([H, W, 3]) + gen_dir = self.gen_dir + base_fn = f"frame_{sample['idx']}" + self.save_result(img_pred, base_fn , gen_dir) + target = sample['gt_img'] + img_gt = target.reshape([H, W, 3]) + if hparams['save_gt']: + base_fn = f"frame_{sample['idx']}_gt" + self.save_result(img_gt, base_fn , gen_dir) + return outputs \ No newline at end of file diff --git a/Geneface_main/GeneFace/tasks/nerfs/dataset_utils.py b/Geneface_main/GeneFace/tasks/nerfs/dataset_utils.py new file mode 100644 index 00000000..68da6024 --- /dev/null +++ b/Geneface_main/GeneFace/tasks/nerfs/dataset_utils.py @@ -0,0 +1,114 @@ +import os +import tqdm +import torch +import numpy as np +from utils.commons.hparams import hparams, set_hparams +from utils.commons.tensor_utils import convert_to_tensor +from utils.commons.image_utils import load_image_as_uint8_tensor + + +class NeRFDataset(torch.utils.data.Dataset): + def __init__(self, prefix, data_dir=None, cond_type=None): + super().__init__() + self.data_dir = os.path.join(hparams['binary_data_dir'], hparams['video_id']) if data_dir is None else data_dir + self.cond_type = hparams['cond_type'] if cond_type is None else cond_type + binary_file_name = os.path.join(self.data_dir, "trainval_dataset.npy") + ds_dict = np.load(binary_file_name, allow_pickle=True).tolist() + if prefix == 'train': + self.samples = [convert_to_tensor(sample) for sample in ds_dict['train_samples']] + elif prefix == 'val': + self.samples = [convert_to_tensor(sample) for sample in ds_dict['val_samples']] + elif prefix == 'trainval': + self.samples = [convert_to_tensor(sample) for sample in ds_dict['train_samples']] + [convert_to_tensor(sample) for sample in ds_dict['val_samples']] + else: + raise ValueError("prefix should in train/val !") + self.prefix = prefix + self.H = ds_dict['H'] + self.W = ds_dict['W'] + self.focal = ds_dict['focal'] + self.cx = ds_dict['cx'] + self.cy = ds_dict['cy'] + self.near = hparams['near'] # follow AD-NeRF, we dont use near-far in ds_dict + self.far = hparams['far'] # follow AD-NeRF, we dont use near-far in ds_dict + self.bg_img = torch.from_numpy(ds_dict['bg_img']).float() / 255. + self.idexp_lm3d_mean = torch.from_numpy(ds_dict['idexp_lm3d_mean']).float() + self.idexp_lm3d_std = torch.from_numpy(ds_dict['idexp_lm3d_std']).float() + self.max_t = len(ds_dict['train_samples']) + len(ds_dict['val_samples']) + + def __getitem__(self, idx): + raw_sample = self.samples[idx] + + if hparams.get("load_imgs_to_memory", True): + # disable it to save memory usage. + # for 5500 images, it takes 1 minutes to imread, by contrast, only 1s is needed to index them in memory. + # But it reuqires 15GB memory for caching 5500 images at 512x512 resolution. + if 'head_img' not in self.samples[idx].keys(): + self.samples[idx]['head_img'] = load_image_as_uint8_tensor(self.samples[idx]['head_img_fname']) + self.samples[idx]['gt_img'] = load_image_as_uint8_tensor(self.samples[idx]['gt_img_fname']) + head_img = self.samples[idx]['head_img'] + gt_img = self.samples[idx]['gt_img'] + else: + head_img = load_image_as_uint8_tensor(self.samples[idx]['head_img_fname']) + gt_img = load_image_as_uint8_tensor(self.samples[idx]['gt_img_fname']) + + sample = { + 'H': self.H, + 'W': self.W, + 'focal': self.focal, + 'cx': self.cx, + 'cy': self.cy, + 'near': self.near, + 'far': self.far, + 'idx': raw_sample['idx'], + 'rect': raw_sample['face_rect'], + 'bg_img': self.bg_img, + 'c2w': raw_sample['c2w'][:3], + 'euler': raw_sample['euler'], + 'trans': raw_sample['trans'], + 'euler_t0': self.samples[0]['euler'], + 'trans_t0': self.samples[0]['trans'], + 'c2w_t': raw_sample['c2w'][:3], + 'c2w_t0': self.samples[0]['c2w'][:3], + 't': torch.tensor([idx]).float()/ self.max_t, + } + + sample.update({ + 'head_img': head_img.float() / 255., + 'gt_img': gt_img.float() / 255., + }) + + if self.cond_type == 'deepspeech': + sample.update({ + 'cond_win': raw_sample['deepspeech_win'].unsqueeze(0), # [B=1, T=16, C=29] + 'cond_wins': raw_sample['deepspeech_wins'], # [Win=8, T=16, C=29] + }) + elif self.cond_type == 'idexp_lm3d_normalized': + sample['cond'] = raw_sample['idexp_lm3d_normalized'].reshape([1,-1]) # [1, 204] + sample['cond_win'] = raw_sample['idexp_lm3d_normalized_win'].reshape([1, hparams['cond_win_size'],-1]) # [1, T_win, 204] + sample['cond_wins'] = raw_sample['idexp_lm3d_normalized_wins'].reshape([hparams['smo_win_size'], hparams['cond_win_size'],-1]) # [smo_win, T_win, 204] + + if hparams.get("use_hubert", False): + sample['hubert_win'] = raw_sample['hubert_win'].unsqueeze(0) # [Win=8, C=64] + sample['hubert_wins'] = raw_sample['hubert_wins'].unsqueeze(0) # [Win=8, C=64] + sample.update({ + 'deepspeech_win': raw_sample['deepspeech_win'].unsqueeze(0), # [B=1, T=16, C=29] + 'deepspeech_wins': raw_sample['deepspeech_wins'], # [Win=8, T=16, C=29] + }) + else: + raise NotImplementedError + + return sample + + def __len__(self): + return len(self.samples) + + def collater(self, samples): + assert len(samples) == 1 # NeRF only take 1 image for each iteration + return samples[0] + + +if __name__ == '__main__': + set_hparams() + ds = NeRFDataset('train', data_dir='data/binary/videos/May') + ds[0] + print("done!") \ No newline at end of file diff --git a/Geneface_main/GeneFace/tasks/nerfs/lm3d_nerf.py b/Geneface_main/GeneFace/tasks/nerfs/lm3d_nerf.py new file mode 100644 index 00000000..a7ba2cd6 --- /dev/null +++ b/Geneface_main/GeneFace/tasks/nerfs/lm3d_nerf.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn +import numpy as np +import tqdm +import time +import imageio +import os +import cv2 + +from utils.commons.hparams import hparams +from utils.commons.ckpt_utils import load_ckpt +from utils.nn.model_utils import print_arch +from utils.nn.grad import get_grad_norm +from utils.nn.model_utils import num_params +from utils.nn.schedulers import ExponentialSchedule, ExponentialScheduleWithAudattNet + +from tasks.nerfs.adnerf import ADNeRFTask +from modules.nerfs.lm3d_nerf.lm3d_nerf import Lm3dNeRF +from modules.nerfs.commons.volume_rendering import render_dynamic_face + + +class Lm3dNeRFTask(ADNeRFTask): + def __init__(self): + super().__init__() + + def build_model(self): + self.model = Lm3dNeRF(hparams) + if hparams['with_att']: + self.lmatt_encoder_params = [p for p in self.model.lmatt_encoder.parameters() if p.requires_grad] + self.gen_params_except_lmatt_encoder = [p for k, p in self.model.named_parameters() if (('lmatt_encoder' not in k) and p.requires_grad)] + else: + self.gen_params = [p for k, p in self.model.named_parameters() if p.requires_grad] + return self.model + + def build_optimizer(self, model): + if hparams['with_att']: + self.optimizer = torch.optim.Adam( + self.gen_params_except_lmatt_encoder, + lr=hparams['lr'], + betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2'])) + self.optimizer.add_param_group({ + 'params': self.lmatt_encoder_params, + 'lr': hparams['lr'] * 5, + 'betas': (hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']) + }) + else: + self.optimizer = torch.optim.Adam( + self.gen_params, + lr=hparams['lr'], + betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2'])) + return self.optimizer + + def build_scheduler(self, optimizer): + if hparams['with_att']: + return ExponentialScheduleWithAudattNet(optimizer, hparams['lr'], hparams['warmup_updates']) + else: + return ExponentialSchedule(optimizer, hparams['lr'], hparams['warmup_updates']) + + def on_before_optimization(self, opt_idx): + prefix = f"grad_norm_opt_idx_{opt_idx}" + grad_norm_dict = { + f'{prefix}/model_coarse': get_grad_norm(self.model.model_coarse), + f'{prefix}/model_fine': get_grad_norm(self.model.model_fine), + f'{prefix}/lm_encoder': get_grad_norm(self.model.lm_encoder), + } + if hparams['with_att']: + grad_norm_dict[f'{prefix}/lmatt_encoder'] = get_grad_norm(self.model.lmatt_encoder) + return grad_norm_dict diff --git a/Geneface_main/GeneFace/tasks/nerfs/lm3d_nerf_torso.py b/Geneface_main/GeneFace/tasks/nerfs/lm3d_nerf_torso.py new file mode 100644 index 00000000..560f9e33 --- /dev/null +++ b/Geneface_main/GeneFace/tasks/nerfs/lm3d_nerf_torso.py @@ -0,0 +1,215 @@ +import torch +import torch.nn as nn +import numpy as np +import tqdm +import time +import imageio +import os +import cv2 + +from utils.commons.hparams import hparams +from utils.commons.ckpt_utils import load_ckpt +from utils.nn.model_utils import print_arch +from utils.nn.grad import get_grad_norm, GradBuffer +from utils.nn.model_utils import print_arch, get_device_of_model, not_requires_grad +from utils.commons.tensor_utils import tensors_to_scalars, convert_to_np, move_to_cuda, convert_to_tensor + +from tasks.nerfs.adnerf_torso import ADNeRFTorsoTask +from modules.nerfs.lm3d_nerf.lm3d_nerf import Lm3dNeRF as Lm3dNeRF_head +from modules.nerfs.adnerf.adnerf import ADNeRF as ADNeRF_head +from modules.nerfs.adnerf.adnerf_torso import ADNeRFTorso as ADNeRF_torso +from modules.nerfs.commons.volume_rendering import render_dynamic_face +from modules.nerfs.commons.ray_samplers import TorsoUniformRaySampler +from scipy.ndimage import gaussian_filter1d, gaussian_filter + + +class Lm3dNeRFTorsoTask(ADNeRFTorsoTask): + + def build_model(self): + self.head_model = Lm3dNeRF_head(hparams) + head_model_dir = hparams['head_model_dir'] + load_ckpt(self.head_model, head_model_dir) + + self.model = ADNeRF_torso(hparams) + self.audatt_net_params = [p for p in self.model.audatt_net.parameters() if p.requires_grad] + self.gen_params_except_audatt_net = [p for k, p in self.model.named_parameters() if (('audatt_net' not in k) and p.requires_grad)] + return self.model + + + def run_model(self, sample, infer=False, run_head_mode=False): + """ + render or train on a single-frame + :param sample: a batch of data + :param infer: bool, run in infer mode + :return: + if not infer: + return losses, model_out + if infer: + return model_out + """ + cond = sample['cond_win'] if hparams['use_window_cond'] else sample['cond'] + cond_wins = sample['cond_wins'] + H = sample['H'] + W = sample['W'] + focal = sample['focal'] + cx = sample['cx'] + cy = sample['cy'] + near = sample['near'] + far = sample['far'] + bg_img = sample['bg_img'] + c2w_t = sample['c2w'] + c2w_t0 = sample['c2w_t0'] + euler = sample['euler'] + euler_t0 = sample['euler_t0'] + trans = sample['trans'] + trans_t0 = sample['trans_t0'] + losses_out = {} + + with_att = hparams['with_att'] + + if infer: + # Inference Phase + with torch.no_grad(): + # render head + # sample the rays of the whole image + rays_o, rays_d, _ = self.full_rays_sampler(H, W, focal, c2w_t) + + if with_att: + cond_feat = self.head_model.cal_cond_feat(cond_wins, with_att=True) + else: + cond_feat = self.head_model.cal_cond_feat(cond, with_att=False) + rgb_pred, disp, acc, last_weight, rgb_map_fg, extras = render_dynamic_face(H, W, focal, cx, cy, rays_o=rays_o, rays_d=rays_d, + bc_rgb=bg_img, chunk=2048, c2w=None, cond=cond_feat, near=near, far=far, + network_fn=self.head_model, N_samples=self.n_samples_per_ray, N_importance=self.n_samples_per_ray_fine, + run_head_mode=True, + c2w_t=c2w_t, c2w_t0=c2w_t0,t=torch.tensor([0.,]).cuda(), + ) + + # render torso + # sample the rays of the whole image, in the canoical space, i.e., with c2w_t0 + rays_o, rays_d, _ = self.full_rays_sampler(H, W, focal, c2w_t0) + + cond_feat = self.model.cal_cond_feat(sample['deepspeech_wins'],color=rgb_pred, euler=sample['euler'], trans=sample['trans'],with_att=True) + rgb_pred_torso, disp_torso, acc_torso, last_weight_torso, rgb_map_fg_torso, extras_torso = render_dynamic_face(H, W, focal, cx, cy, rays_o=rays_o, rays_d=rays_d, + bc_rgb=bg_img,chunk=2048, c2w=None, cond=cond_feat, near=near, far=far, + network_fn=self.model, N_samples=self.n_samples_per_ray, N_importance=self.n_samples_per_ray_fine, + run_head_mode=False, + c2w_t=c2w_t, c2w_t0=c2w_t0,t=torch.tensor([0.,]).cuda(), + euler=euler, euler_t0=euler_t0, trans=trans, trans_t0=trans_t0 + ) + + if hparams.get("infer_with_more_dynamic_c2w_sequence", False) is True: + """ + Note: enable it only when you find there is overlap problem between head and torso! + Since the torso nerf is modeled in canoical space (i.e., static pose), + it cannot model large-range movements of the torso. + When the head is moving extremely down, + the torso nerf sometimes will render results on the face part, + which leads to shallow artifacts on the face part. + To handle this, we set the rgb_map_fg_torso on the face part to zero. + """ + w_h = int(last_weight.reshape([-1]).shape[0]**0.5) + smo_last_weight = gaussian_filter((last_weight.reshape([w_h,w_h]).cpu() * 255).int().numpy(), sigma=1.).reshape([-1]) / 255. + has_head_mask = convert_to_tensor(smo_last_weight <= 0.3).to(rgb_map_fg.device).bool() # where head has much confidence + def shrink_has_head_mask(has_head_mask): + w_h = int(has_head_mask.reshape([-1]).shape[0]**0.5) + has_head_mask = has_head_mask.reshape([w_h, w_h]) + centered_mask = has_head_mask[1:-1,1:-1] + left_offset_mask = has_head_mask[0:-2,1:-1] + right_offset_mask = has_head_mask[2:,1:-1] + up_offset_mask = has_head_mask[1:-1,0:-2] + down_offset_mask = has_head_mask[1:-1,2:] + mask = torch.bitwise_and(centered_mask, left_offset_mask) + mask = torch.bitwise_and(mask, right_offset_mask) + mask = torch.bitwise_and(mask, up_offset_mask) + mask = torch.bitwise_and(mask, down_offset_mask) + has_head_mask[1:-1,1:-1] = mask + return has_head_mask.reshape([-1,]) + for _ in range(6): + has_head_mask = shrink_has_head_mask(has_head_mask) + disable_torso_mask = has_head_mask + last_weight_torso[disable_torso_mask] = 1 + rgb_map_fg_torso[last_weight_torso==1] = 0 + rgb_com = rgb_pred * last_weight_torso.unsqueeze(-1) + rgb_map_fg_torso + + model_out = { + "rgb_map" : rgb_com + } + return model_out + else: + # Training Phase + if run_head_mode: + # Run Head NeRF + # uniformly sample the rays + rays_o, rays_d, select_coords = self.rays_sampler(H, W, focal, c2w_t, n_rays=None, rect=sample['rect'], in_rect_percent=hparams['in_rect_percent'], iterations=self.global_step) + rgb_gt = self.rays_sampler.sample_pixels_from_img_with_select_coords(sample['head_img'], select_coords) + rgb_bc = self.rays_sampler.sample_pixels_from_img_with_select_coords(sample['bg_img'], select_coords) + with torch.no_grad(): + # calculate the condition + if with_att: + cond_feat = self.head_model.cal_cond_feat(cond_wins, with_att=True) + else: + cond_feat = self.head_model.cal_cond_feat(cond, with_att=False) + # volume rendering to get rgb_pred + rgb_pred, disp, acc, last_weight, rgb_map_fg, extras = render_dynamic_face(H, W, focal, cx, cy, rays_o=rays_o, rays_d=rays_d, + bc_rgb=rgb_bc,chunk=self.chunk, c2w=None, cond=cond_feat, near=near, far=far, + network_fn=self.head_model, N_samples=self.n_samples_per_ray, N_importance=self.n_samples_per_ray_fine, + run_head_mode=True, + c2w_t=c2w_t, c2w_t0=c2w_t0,t=torch.tensor([0.,]).cuda(), + ) + # calculate loss + losses_out['head_mse_loss'] = torch.mean((rgb_pred - rgb_gt) ** 2) + if 'rgb_map_coarse' in extras: + losses_out['head_mse_loss_coarse'] = torch.mean((extras['rgb_map_coarse'] - rgb_gt) ** 2) + model_out = { + "rgb_map": rgb_pred + } + else: + # Run Torso NeRF + # uniformly sample the rays, in the canoical space, i.e., with c2w_t0 + target = sample['gt_img'] + rect = [0, H/2, W, H/2] # only sample the lower part for torso + rays_o, rays_d, select_coords = self.torso_rays_sampler(H, W, focal, c2w_t0, n_rays=None, rect=rect, + in_rect_percent=hparams['in_rect_percent']) + rays_o_head, rays_d_head, _ = self.rays_sampler(H, W, focal, c2w_t, select_coords=select_coords) + rgb_gt = self.rays_sampler.sample_pixels_from_img_with_select_coords(target, select_coords) + rgb_bc = self.rays_sampler.sample_pixels_from_img_with_select_coords(sample['bg_img'], select_coords) + + # render head + with torch.no_grad(): + if with_att: + cond_feat = self.head_model.cal_cond_feat(cond_wins, with_att=True) + else: + cond_feat = self.head_model.cal_cond_feat(cond, with_att=False) + rgb_pred_head, disp_head, acc_head, last_weight_head, rgb_map_fg_head, extras_head = render_dynamic_face(H, W, focal, cx, cy, rays_o=rays_o_head, rays_d=rays_d_head, + bc_rgb=rgb_bc,chunk=self.chunk, c2w=None, cond=cond_feat, near=near, far=far, + network_fn=self.head_model, N_samples=self.n_samples_per_ray, N_importance=self.n_samples_per_ray_fine, + run_head_mode=True, + c2w_t=c2w_t, c2w_t0=c2w_t0,t=torch.tensor([0.,]).cuda(), + ) + + # render torso + # calculate the condition based on deepspeech + cond_feat = self.model.cal_cond_feat(sample['deepspeech_wins'],color=rgb_pred_head, euler=sample['euler'], trans=sample['trans'],with_att=True) + # volume rendering to get rgb_pred_com + rgb_pred_torso, disp_torso, acc_torso, last_weight_torso, rgb_map_fg_torso, extras_torso = render_dynamic_face(H, W, focal, cx, cy, rays_o=rays_o, rays_d=rays_d, + bc_rgb=rgb_bc,chunk=self.chunk, c2w=None, cond=cond_feat, near=near, far=far, + network_fn=self.model, N_samples=self.n_samples_per_ray, N_importance=self.n_samples_per_ray_fine, + run_head_mode=run_head_mode, + c2w_t=c2w_t, c2w_t0=c2w_t0,t=torch.tensor([0.,]).cuda(), + euler=euler, euler_t0=euler_t0, trans=trans, trans_t0=trans_t0 + ) + rgb_com = rgb_pred_head * last_weight_torso.unsqueeze(-1) + rgb_map_fg_torso + + # calculate loss + losses_out['com_mse_loss'] = torch.mean((rgb_com - rgb_gt) ** 2) + def mse2psnr(x): return -10. * torch.log(x) / torch.log(torch.Tensor([10.])).to(x.device) + losses_out['com_psnr'] = mse2psnr(losses_out['com_mse_loss'].detach()) + if 'rgb_map_coarse' in extras_torso: + rgb_com0 = extras_head['rgb_map_coarse'] * extras_torso['last_weight0'].unsqueeze(-1) + extras_torso['rgb_map_fg0'] + losses_out['com_mse_loss_coarse'] = torch.mean((rgb_com0 - rgb_gt) ** 2) + model_out = { + "rgb_map": rgb_com + } + return losses_out, model_out + \ No newline at end of file diff --git a/Geneface_main/GeneFace/tasks/postnet/__pycache__/dataset_utils.cpython-39.pyc b/Geneface_main/GeneFace/tasks/postnet/__pycache__/dataset_utils.cpython-39.pyc new file mode 100644 index 00000000..aa625fee Binary files /dev/null and b/Geneface_main/GeneFace/tasks/postnet/__pycache__/dataset_utils.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/tasks/postnet/__pycache__/lm3d_postnet_adv_sync.cpython-39.pyc b/Geneface_main/GeneFace/tasks/postnet/__pycache__/lm3d_postnet_adv_sync.cpython-39.pyc new file mode 100644 index 00000000..ee7988a3 Binary files /dev/null and b/Geneface_main/GeneFace/tasks/postnet/__pycache__/lm3d_postnet_adv_sync.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/tasks/postnet/dataset_utils.py b/Geneface_main/GeneFace/tasks/postnet/dataset_utils.py new file mode 100644 index 00000000..4e2f8135 --- /dev/null +++ b/Geneface_main/GeneFace/tasks/postnet/dataset_utils.py @@ -0,0 +1,102 @@ +import os +import torch +import numpy as np +from torch.utils.data import DataLoader +from utils.commons.hparams import hparams, set_hparams + +from tasks.audio2motion.dataset_utils.lrs3_dataset import LRS3SeqDataset + + +class PostnetDataset(torch.utils.data.Dataset): + def __init__(self, prefix, data_dir=None): + super().__init__() + self.person_binary_data_dir = os.path.join(hparams['person_binary_data_dir'], hparams['video_id']) if data_dir is None else data_dir + binary_file_name = os.path.join(self.person_binary_data_dir, "trainval_dataset.npy") + person_ds_dict = np.load(binary_file_name, allow_pickle=True).tolist() + mel = person_ds_dict['mel'] + f0 = person_ds_dict['f0'].reshape([-1,1]) + hubert = person_ds_dict['hubert'] + # if len(mel.shape) == 0: # is object + # mel = mel.tolist()['mel'] + train_lm3d_normalized = np.stack([sample['idexp_lm3d_normalized'] for sample in person_ds_dict['train_samples']], axis=0) + train_lm3d = np.stack([sample['idexp_lm3d'] for sample in person_ds_dict['train_samples']], axis=0) + val_lm3d_normalized = np.stack([sample['idexp_lm3d_normalized'] for sample in person_ds_dict['val_samples']], axis=0) + val_lm3d = np.stack([sample['idexp_lm3d'] for sample in person_ds_dict['val_samples']], axis=0) + lm3d_len = train_lm3d_normalized.shape[0] + val_lm3d_normalized.shape[0] + mel_len = mel.shape[0] + if mel_len > 2 * lm3d_len: + mel = mel[:2*lm3d_len] + f0 = f0[:2*lm3d_len] + hubert = hubert[:2*lm3d_len] + elif mel_len < 2 * lm3d_len: + num_to_pad = 2 * lm3d_len - mel_len + mel = np.pad(mel, ((0,num_to_pad),(0,0)), mode="constant") + f0 = np.pad(f0, ((0,num_to_pad),(0,0)), mode="constant") + hubert = np.pad(hubert, ((0,num_to_pad),(0,0)), mode="constant") + + if prefix == 'train': + lm3d_normalized = train_lm3d_normalized + lm3d = train_lm3d + mel = mel[:lm3d_normalized.shape[0]*2] + f0 = f0[:lm3d_normalized.shape[0]*2] + hubert = hubert[:lm3d_normalized.shape[0]*2] + elif prefix == 'val': + lm3d_normalized = val_lm3d_normalized + lm3d = val_lm3d + mel = mel[train_lm3d_normalized.shape[0]*2 : train_lm3d_normalized.shape[0]*2 + lm3d_normalized.shape[0]*2] + f0 = f0[train_lm3d_normalized.shape[0]*2 : train_lm3d_normalized.shape[0]*2 + lm3d_normalized.shape[0]*2] + hubert = hubert[train_lm3d_normalized.shape[0]*2 : train_lm3d_normalized.shape[0]*2 + lm3d_normalized.shape[0]*2] + else: + raise ValueError("prefix should in train/val !") + + + target_x_len = mel.shape[0] // 8 * 8 + target_y_len = target_x_len // 2 + mel = mel[:target_x_len] + f0 = f0[:target_x_len].reshape([-1,]) + hubert = hubert[:target_x_len] + lm3d_normalized = lm3d_normalized[:target_y_len] + lm3d_normalized = lm3d_normalized.reshape(lm3d_normalized.shape[0], -1) + lm3d = lm3d[:target_y_len] + lm3d = lm3d.reshape(lm3d_normalized.shape[0], -1) + + idexp_lm3d_mean = person_ds_dict['idexp_lm3d_mean'][:target_y_len].reshape(1, -1) + idexp_lm3d_std = person_ds_dict['idexp_lm3d_std'][:target_y_len].reshape(1, -1) + self.person_ds = { + 'mel': torch.from_numpy(mel).float().unsqueeze(0), + 'f0': torch.from_numpy(f0).float().unsqueeze(0), + 'hubert': torch.from_numpy(hubert).float().unsqueeze(0), + 'idexp_lm3d_normalized': torch.from_numpy(lm3d_normalized).float().unsqueeze(0), + 'idexp_lm3d': torch.from_numpy(lm3d).float().unsqueeze(0), + 'x_mask': torch.ones([target_x_len,]).float().unsqueeze(0), + 'y_mask': torch.ones([target_y_len,]).float().unsqueeze(0), + 'idexp_lm3d_mean': torch.from_numpy(idexp_lm3d_mean).float().unsqueeze(0), + 'idexp_lm3d_std': torch.from_numpy(idexp_lm3d_std).float().unsqueeze(0), + } + + self.audio2motion_ds = LRS3SeqDataset(prefix) + + def __getitem__(self, idx): + sample = self.audio2motion_ds[idx] + return sample + + def __len__(self): + return len(self.samples) + + def collater(self, samples): + batch = self.audio2motion_ds.collater(samples) + batch['person_ds'] = self.person_ds + return batch + + def get_dataloader(self): + max_tokens = 60000 + batches_idx = self.audio2motion_ds.batch_by_size(self.audio2motion_ds.ordered_indices(), max_tokens=max_tokens) + # loader = DataLoader(self, pin_memory=False,collate_fn=self.collater, batch_sampler=batches_idx, num_workers=0) + loader = DataLoader(self, pin_memory=True,collate_fn=self.collater, batch_sampler=batches_idx, num_workers=4) + return loader + +if __name__ == '__main__': + set_hparams() + ds = PostnetDataset("train") + ds[0] + print("done!") \ No newline at end of file diff --git a/Geneface_main/GeneFace/tasks/postnet/lm3d_postnet_adv_sync.py b/Geneface_main/GeneFace/tasks/postnet/lm3d_postnet_adv_sync.py new file mode 100644 index 00000000..0ba008bf --- /dev/null +++ b/Geneface_main/GeneFace/tasks/postnet/lm3d_postnet_adv_sync.py @@ -0,0 +1,205 @@ +import torch +import importlib + +from utils.commons.base_task import BaseTask +from utils.commons.dataset_utils import data_loader +from utils.commons.hparams import hparams +from utils.commons.ckpt_utils import load_ckpt +from utils.commons.tensor_utils import tensors_to_scalars, convert_to_np +from utils.nn.model_utils import print_arch + +from modules.postnet.models import CNNPostNet, MLPDiscriminator +from tasks.audio2motion.lm3d_vae_sync import VAESyncAudio2MotionTask +from tasks.postnet.dataset_utils import PostnetDataset +from tasks.syncnet.lm3d_syncnet import SyncNetTask + +from data_util.face3d_helper import Face3DHelper + + +class PostnetAdvSyncTask(BaseTask): + def __init__(self): + super().__init__() + self.audio2motion_task = self.build_audio2motion_task() + self.syncnet_task = self.build_syncnet_task() + self.build_disc_model() + self.dataset_cls = PostnetDataset + self.face3d_helper = Face3DHelper(use_gpu=True) + + def build_audio2motion_task(self): + assert hparams['audio2motion_task_cls'] != '' + pkg = ".".join(hparams["audio2motion_task_cls"].split(".")[:-1]) + cls_name = hparams["audio2motion_task_cls"].split(".")[-1] + task_cls = getattr(importlib.import_module(pkg), cls_name) + self.audio2motion_task = task_cls() + self.audio2motion_task.build_model() + audio2motion_work_dir = hparams['audio2motion_work_dir'] + audio2motion_ckpt_steps = hparams["audio2motion_ckpt_steps"] + load_ckpt(self.audio2motion_task.model, audio2motion_work_dir, 'model', steps=audio2motion_ckpt_steps) + self.audio2motion_task.eval() + return self.audio2motion_task + + def build_syncnet_task(self): + self.syncnet_task = SyncNetTask() + self.syncnet_task.build_model() + syncnet_work_dir = hparams["syncnet_work_dir"] + syncnet_ckpt_steps = hparams["syncnet_ckpt_steps"] + load_ckpt(self.syncnet_task.model, syncnet_work_dir, 'model', steps=syncnet_ckpt_steps) + for p in self.syncnet_task.parameters(): + p.requires_grad = False + self.syncnet_task.eval() + return self.syncnet_task + + def build_model(self): + self.model = CNNPostNet(in_out_dim=68*3) + print_arch(self.model) + return self.model + + def build_disc_model(self): + self.disc_model = MLPDiscriminator(in_dim=68*3) + + def build_optimizer(self, model): + self.optimizer_gen = torch.optim.RMSprop(self.model.parameters(), + lr=hparams['postnet_lr'],) + self.optimizer_disc = torch.optim.RMSprop(self.disc_model.parameters(), + lr=hparams['postnet_disc_lr'],) + + return [self.optimizer_gen, self.optimizer_disc] + + + def build_scheduler(self, optimizer): + return [ + VAESyncAudio2MotionTask.build_scheduler(self, optimizer[0]), + torch.optim.lr_scheduler.StepLR(optimizer=optimizer[1], + **hparams["discriminator_scheduler_params"]), + ] + + def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx): + if self.scheduler is not None: + self.scheduler[0].step(self.global_step // hparams['accumulate_grad_batches']) + self.scheduler[1].step(self.global_step // hparams['accumulate_grad_batches']) + + @data_loader + def train_dataloader(self): + train_dataset = self.dataset_cls(prefix='train') + self.train_dl = train_dataset.get_dataloader() + return self.train_dl + + @data_loader + def val_dataloader(self): + val_dataset = self.dataset_cls(prefix='val') + self.val_dl = val_dataset.get_dataloader() + return self.val_dl + + @data_loader + def test_dataloader(self): + val_dataset = self.dataset_cls(prefix='val') + self.val_dl = val_dataset.get_dataloader() + return self.val_dl + + ########################## + # training and validation + ########################## + def run_model(self, sample, infer=False, temperature=1.0): + """ + render or train on a single-frame + :param sample: a batch of data + :param infer: bool, run in infer mode + :return: + if not infer: + return losses, model_out + if infer: + return model_out + """ + losses_out = {} + lrs3_batch = { + 'x_mask': sample['x_mask'], + 'y_mask': sample['y_mask'], + 'hubert': sample['hubert'], + } + ### Perform the audio2motion first + with torch.no_grad(): + model_out = self.audio2motion_task.run_model(lrs3_batch, infer=True, temperature=temperature) + raw_pred_lm3d = model_out['pred'] + + ### Then forward the PostNet + if infer: + refine_pred_lm3d = self.model(raw_pred_lm3d) + model_out = { + 'refine_lm3d': refine_pred_lm3d, + 'raw_lm3d': raw_pred_lm3d + } + return model_out + else: + person_batch = sample['person_ds'] + gt_pred_lm3d_for_person_ds = person_batch['idexp_lm3d'] + with torch.no_grad(): + model_out = self.audio2motion_task.run_model(person_batch, infer=True, temperature=temperature) + raw_pred_lm3d_for_person_ds = model_out['pred'] + refine_pred_lm3d_for_person_ds = self.model(raw_pred_lm3d_for_person_ds) * person_batch['y_mask'].unsqueeze(-1) + if hparams.get("loss_type", 'mse') == 'mse': + losses_out['mse'] = (gt_pred_lm3d_for_person_ds - refine_pred_lm3d_for_person_ds).pow(2).sum() / (person_batch['y_mask'].sum() * 68*3) + else: + losses_out['mae'] = (gt_pred_lm3d_for_person_ds - refine_pred_lm3d_for_person_ds).abs().sum() / (person_batch['y_mask'].sum() * 68*3) + + refine_pred_lm3d = self.model(raw_pred_lm3d) + + ### Calculate Syncnet Loss + refine_pred_lm3d_ = refine_pred_lm3d.reshape(refine_pred_lm3d.size(0), refine_pred_lm3d.size(1), 68, 3) + _, refine_mouth_lm3d = self.face3d_helper.get_eye_mouth_lm_from_lm3d_batch(refine_pred_lm3d_) + syncnet_sample = { + 'idexp_lm3d': refine_pred_lm3d.reshape(refine_pred_lm3d.size(0), refine_pred_lm3d.size(1), -1), + 'mouth_idexp_lm3d': refine_mouth_lm3d.reshape(refine_mouth_lm3d.size(0), refine_mouth_lm3d.size(1), -1), + 'hubert': sample['hubert'], + 'y_mask': sample['y_mask'] + } + syncnet_out = self.syncnet_task.run_model(syncnet_sample, infer=True, batch_size=1024) + sync_loss = syncnet_out['sync_loss'] + losses_out['sync'] = sync_loss + + model_out = { + 'refine_lm3d': refine_pred_lm3d + } + return losses_out, model_out + + def _training_step(self, sample, batch_idx, optimizer_idx): + loss_output = {} + loss_weights = {} + disc_start = self.global_step >= hparams["postnet_disc_start_steps"] and hparams['postnet_lambda_adv'] > 0 + if optimizer_idx == 0: + loss_output, model_out = self.run_model(sample) + loss_weights = { + 'mse': hparams['postnet_lambda_mse'], + } + + pred = model_out['refine_lm3d'] + self.pred = pred.detach() + if disc_start: + disc_conf_neg = self.disc_model(x=pred)[0] + loss_output['adv'] = (1 - disc_conf_neg).pow(2).mean() + loss_weights['adv'] = hparams['postnet_lambda_adv'] + loss_weights['sync'] = hparams['postnet_lambda_sync'] + else: + # train the discriminator + if self.global_step % hparams['postnet_disc_interval'] == 0: + pred = self.pred + p_ = self.disc_model(x=pred)[0] + person_idexp_normalized = sample['person_ds']['idexp_lm3d'] + p = self.disc_model(x=person_idexp_normalized)[0] + loss_output['disc_neg_conf'] = p_.detach().mean().item() + loss_output['disc_pos_conf'] = p.detach().mean().item() + + loss_output["disc_fake_loss"] = (p_ - p_.new_zeros(p_.size())).pow(2).mean() + loss_output["disc_true_loss"] = (p - p.new_ones(p.size())).pow(2).mean() + else: + return None + total_loss = sum([loss_weights.get(k, 1) * v for k, v in loss_output.items() if isinstance(v, torch.Tensor) and v.requires_grad]) + return total_loss, loss_output + + @torch.no_grad() + def validation_step(self, sample, batch_idx): + outputs = {} + outputs['losses'] = {} + outputs['losses'], model_out = self.run_model(sample, infer=False) + outputs = tensors_to_scalars(outputs) + return outputs + diff --git a/Geneface_main/GeneFace/tasks/postnet/lm3d_postnet_adv_sync_pitch.py b/Geneface_main/GeneFace/tasks/postnet/lm3d_postnet_adv_sync_pitch.py new file mode 100644 index 00000000..50c53db2 --- /dev/null +++ b/Geneface_main/GeneFace/tasks/postnet/lm3d_postnet_adv_sync_pitch.py @@ -0,0 +1,227 @@ +import torch +import importlib + +from utils.commons.base_task import BaseTask +from utils.commons.dataset_utils import data_loader +from utils.commons.hparams import hparams +from utils.commons.ckpt_utils import load_ckpt +from utils.commons.tensor_utils import tensors_to_scalars, convert_to_np +from utils.nn.model_utils import print_arch +from utils.commons.pitch_utils import f0_to_coarse + +from modules.postnet.models import PitchContourCNNPostNet, MLPDiscriminator +from tasks.audio2motion.lm3d_vae_sync_pitch import VAESyncAudio2MotionTask +from tasks.postnet.dataset_utils import PostnetDataset +from tasks.syncnet.lm3d_syncnet import SyncNetTask + +from data_util.face3d_helper import Face3DHelper + + +class PostnetAdvSyncTask(BaseTask): + def __init__(self): + super().__init__() + self.audio2motion_task = self.build_audio2motion_task() + self.syncnet_task = self.build_syncnet_task() + self.build_disc_model() + self.dataset_cls = PostnetDataset + self.face3d_helper = Face3DHelper(use_gpu=True) + + def build_audio2motion_task(self): + assert hparams['audio2motion_task_cls'] != '' + pkg = ".".join(hparams["audio2motion_task_cls"].split(".")[:-1]) + cls_name = hparams["audio2motion_task_cls"].split(".")[-1] + task_cls = getattr(importlib.import_module(pkg), cls_name) + self.audio2motion_task = task_cls() + self.audio2motion_task.build_model() + audio2motion_work_dir = hparams['audio2motion_work_dir'] + audio2motion_ckpt_steps = hparams["audio2motion_ckpt_steps"] + load_ckpt(self.audio2motion_task.model, audio2motion_work_dir, 'model', steps=audio2motion_ckpt_steps) + self.audio2motion_task.eval() + self.downsampler = self.audio2motion_task.model.downsampler + self.pitch_embed = self.audio2motion_task.model.pitch_embed + return self.audio2motion_task + + def build_syncnet_task(self): + self.syncnet_task = SyncNetTask() + self.syncnet_task.build_model() + syncnet_work_dir = hparams["syncnet_work_dir"] + syncnet_ckpt_steps = hparams["syncnet_ckpt_steps"] + load_ckpt(self.syncnet_task.model, syncnet_work_dir, 'model', steps=syncnet_ckpt_steps) + for p in self.syncnet_task.parameters(): + p.requires_grad = False + self.syncnet_task.eval() + return self.syncnet_task + + def build_model(self): + self.model = PitchContourCNNPostNet(in_out_dim=68*3, pitch_dim=64) + print_arch(self.model) + return self.model + + def build_disc_model(self): + self.disc_model = MLPDiscriminator(in_dim=68*3) + + def build_optimizer(self, model): + self.optimizer_gen = torch.optim.RMSprop(self.model.parameters(), + lr=hparams['postnet_lr'],) + self.optimizer_disc = torch.optim.RMSprop(self.disc_model.parameters(), + lr=hparams['postnet_disc_lr'],) + + return [self.optimizer_gen, self.optimizer_disc] + + + def build_scheduler(self, optimizer): + return [ + VAESyncAudio2MotionTask.build_scheduler(self, optimizer[0]), + torch.optim.lr_scheduler.StepLR(optimizer=optimizer[1], + **hparams["discriminator_scheduler_params"]), + ] + + def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx): + if self.scheduler is not None: + self.scheduler[0].step(self.global_step // hparams['accumulate_grad_batches']) + self.scheduler[1].step(self.global_step // hparams['accumulate_grad_batches']) + + @data_loader + def train_dataloader(self): + train_dataset = self.dataset_cls(prefix='train') + self.train_dl = train_dataset.get_dataloader() + return self.train_dl + + @data_loader + def val_dataloader(self): + val_dataset = self.dataset_cls(prefix='val') + self.val_dl = val_dataset.get_dataloader() + return self.val_dl + + @data_loader + def test_dataloader(self): + val_dataset = self.dataset_cls(prefix='val') + self.val_dl = val_dataset.get_dataloader() + return self.val_dl + + ########################## + # training and validation + ########################## + def run_model(self, sample, infer=False, temperature=1.0): + """ + render or train on a single-frame + :param sample: a batch of data + :param infer: bool, run in infer mode + :return: + if not infer: + return losses, model_out + if infer: + return model_out + """ + losses_out = {} + lrs3_batch = { + 'x_mask': sample['x_mask'], + 'y_mask': sample['y_mask'], + 'hubert': sample['hubert'], + 'f0': sample['f0'] + } + ### Perform the audio2motion first + with torch.no_grad(): + model_out = self.audio2motion_task.run_model(lrs3_batch, infer=True, temperature=temperature) + raw_pred_lm3d = model_out['pred'] + lrs3_f0 = self.downsampler(lrs3_batch['f0'].unsqueeze(-1)).squeeze(-1) + pitch_lrs3 = self.pitch_embed(f0_to_coarse(lrs3_f0)) + + ### Then forward the PostNet + if infer: + refine_pred_lm3d = self.model(raw_pred_lm3d, pitch_lrs3) + model_out = { + 'refine_lm3d': refine_pred_lm3d, + 'raw_lm3d': raw_pred_lm3d + } + return model_out + else: + person_batch = sample['person_ds'] + gt_pred_lm3d_for_person_ds = person_batch['idexp_lm3d'] + with torch.no_grad(): + model_out = self.audio2motion_task.run_model(person_batch, infer=True, temperature=temperature) + raw_pred_lm3d_for_person_ds = model_out['pred'] + person_f0 = self.downsampler(person_batch['f0'].unsqueeze(-1)).squeeze(-1) + pitch_for_person_ds = self.pitch_embed(f0_to_coarse(person_f0)) + refine_pred_lm3d_for_person_ds = self.model(raw_pred_lm3d_for_person_ds, pitch_for_person_ds) * person_batch['y_mask'].unsqueeze(-1) + if hparams.get("loss_type", 'mse') == 'mse': + losses_out['mse'] = (gt_pred_lm3d_for_person_ds - refine_pred_lm3d_for_person_ds).pow(2).sum() / (person_batch['y_mask'].sum() * 68*3) + else: + losses_out['mae'] = (gt_pred_lm3d_for_person_ds - refine_pred_lm3d_for_person_ds).abs().sum() / (person_batch['y_mask'].sum() * 68*3) + losses_out['continuity'] = self.continuity_loss(gt_pred_lm3d_for_person_ds, refine_pred_lm3d_for_person_ds, person_batch['y_mask']) + refine_pred_lm3d = self.model(raw_pred_lm3d, pitch_lrs3) + + ### Calculate Syncnet Loss + refine_pred_lm3d_ = refine_pred_lm3d.reshape(refine_pred_lm3d.size(0), refine_pred_lm3d.size(1), 68, 3) + _, refine_mouth_lm3d = self.face3d_helper.get_eye_mouth_lm_from_lm3d_batch(refine_pred_lm3d_) + syncnet_sample = { + 'idexp_lm3d': refine_pred_lm3d.reshape(refine_pred_lm3d.size(0), refine_pred_lm3d.size(1), -1), + 'mouth_idexp_lm3d': refine_mouth_lm3d.reshape(refine_mouth_lm3d.size(0), refine_mouth_lm3d.size(1), -1), + 'hubert': sample['hubert'], + 'y_mask': sample['y_mask'] + } + syncnet_out = self.syncnet_task.run_model(syncnet_sample, infer=True, batch_size=1024) + sync_loss = syncnet_out['sync_loss'] + losses_out['sync'] = sync_loss + + # regularization loss + losses_out['reg'] = (((refine_pred_lm3d - raw_pred_lm3d)*lrs3_batch['y_mask'].unsqueeze(-1)) ** 2).sum() / lrs3_batch['y_mask'].sum() + + model_out = { + 'refine_lm3d': refine_pred_lm3d + } + return losses_out, model_out + + def _training_step(self, sample, batch_idx, optimizer_idx): + loss_output = {} + loss_weights = {} + disc_start = self.global_step >= hparams["postnet_disc_start_steps"] and hparams['postnet_lambda_adv'] > 0 + if optimizer_idx == 0: + loss_output, model_out = self.run_model(sample) + loss_weights = { + 'mse': hparams['postnet_lambda_mse'], + 'reg': hparams.get('postnet_lambda_reg',0), + 'continuity': hparams.get('postnet_lambda_continuity',0), + } + + pred = model_out['refine_lm3d'] + self.pred = pred.detach() + if disc_start: + disc_conf_neg = self.disc_model(x=pred)[0] + loss_output['adv'] = (1 - disc_conf_neg).pow(2).mean() + loss_weights['adv'] = hparams['postnet_lambda_adv'] + loss_weights['sync'] = hparams['postnet_lambda_sync'] + else: + # train the discriminator + if self.global_step % hparams['postnet_disc_interval'] == 0: + pred = self.pred + p_ = self.disc_model(x=pred)[0] + person_idexp_normalized = sample['person_ds']['idexp_lm3d'] + p = self.disc_model(x=person_idexp_normalized)[0] + loss_output['disc_neg_conf'] = p_.detach().mean().item() + loss_output['disc_pos_conf'] = p.detach().mean().item() + + loss_output["disc_fake_loss"] = (p_ - p_.new_zeros(p_.size())).pow(2).mean() + loss_output["disc_true_loss"] = (p - p.new_ones(p.size())).pow(2).mean() + else: + return None + total_loss = sum([loss_weights.get(k, 1) * v for k, v in loss_output.items() if isinstance(v, torch.Tensor) and v.requires_grad]) + return total_loss, loss_output + + @torch.no_grad() + def validation_step(self, sample, batch_idx): + outputs = {} + outputs['losses'] = {} + outputs['losses'], model_out = self.run_model(sample, infer=False) + outputs = tensors_to_scalars(outputs) + return outputs + + def continuity_loss(self, x_gt, x_pred, x_mask): + # continuity loss, borrowed from + diff_x_pred = x_pred[:,1:] - x_pred[:,:-1] + diff_x_gt = x_gt[:,1:] - x_gt[:,:-1] + error = (diff_x_pred[:,:,:] - diff_x_gt[:,:,:]) * x_mask[:,1:,None] + init_error = x_pred[:,0,:] - x_gt[:,0,:] + num_frame = x_mask.sum() + n_dim = 68*3 + return (error.pow(2).sum() + init_error.pow(2).sum()) / (num_frame * n_dim) \ No newline at end of file diff --git a/Geneface_main/GeneFace/tasks/radnerfs/__pycache__/dataset_utils.cpython-39.pyc b/Geneface_main/GeneFace/tasks/radnerfs/__pycache__/dataset_utils.cpython-39.pyc new file mode 100644 index 00000000..cbc2d0fc Binary files /dev/null and b/Geneface_main/GeneFace/tasks/radnerfs/__pycache__/dataset_utils.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/tasks/radnerfs/__pycache__/radnerf.cpython-39.pyc b/Geneface_main/GeneFace/tasks/radnerfs/__pycache__/radnerf.cpython-39.pyc new file mode 100644 index 00000000..57879ea9 Binary files /dev/null and b/Geneface_main/GeneFace/tasks/radnerfs/__pycache__/radnerf.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/tasks/radnerfs/__pycache__/radnerf_torso.cpython-39.pyc b/Geneface_main/GeneFace/tasks/radnerfs/__pycache__/radnerf_torso.cpython-39.pyc new file mode 100644 index 00000000..178fbcbd Binary files /dev/null and b/Geneface_main/GeneFace/tasks/radnerfs/__pycache__/radnerf_torso.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/tasks/radnerfs/dataset_utils.py b/Geneface_main/GeneFace/tasks/radnerfs/dataset_utils.py new file mode 100644 index 00000000..17c3d979 --- /dev/null +++ b/Geneface_main/GeneFace/tasks/radnerfs/dataset_utils.py @@ -0,0 +1,221 @@ +import os +import tqdm +import torch +import cv2 +import numpy as np + +from scipy.spatial.transform import Rotation + +from utils.commons.hparams import hparams, set_hparams +from utils.commons.tensor_utils import convert_to_tensor +from utils.commons.image_utils import load_image_as_uint8_tensor + +from modules.radnerfs.utils import get_audio_features, get_rays, get_bg_coords, convert_poses, nerf_matrix_to_ngp + + +def smooth_camera_path(poses, kernel_size=7): + # smooth the camera trajectory (i.e., translation)... + # poses: [N, 4, 4], numpy array + N = poses.shape[0] + K = kernel_size // 2 + + trans = poses[:, :3, 3].copy() # [N, 3] + rots = poses[:, :3, :3].copy() # [N, 3, 3] + + for i in range(N): + start = max(0, i - K) + end = min(N, i + K + 1) + poses[i, :3, 3] = trans[start:end].mean(0) + try: + poses[i, :3, :3] = Rotation.from_matrix(rots[start:end]).mean().as_matrix() + except: + if i == 0: + poses[i, :3, :3] = rots[i] + else: + poses[i, :3, :3] = poses[i-1, :3, :3] + return poses + + +class RADNeRFDataset(torch.utils.data.Dataset): + def __init__(self, prefix, data_dir=None, training=True): + super().__init__() + self.data_dir = os.path.join(hparams['binary_data_dir'], hparams['video_id']) if data_dir is None else data_dir + binary_file_name = os.path.join(self.data_dir, "trainval_dataset.npy") + ds_dict = np.load(binary_file_name, allow_pickle=True).tolist() + if prefix == 'train': + self.samples = [convert_to_tensor(sample) for sample in ds_dict['train_samples']] + elif prefix == 'val': + self.samples = [convert_to_tensor(sample) for sample in ds_dict['val_samples']] + elif prefix == 'trainval': + self.samples = [convert_to_tensor(sample) for sample in ds_dict['train_samples']] + [convert_to_tensor(sample) for sample in ds_dict['val_samples']] + else: + raise ValueError("prefix should in train/val !") + self.prefix = prefix + self.cond_type = hparams['cond_type'] + self.H = ds_dict['H'] + self.W = ds_dict['W'] + self.focal = ds_dict['focal'] + self.cx = ds_dict['cx'] + self.cy = ds_dict['cy'] + self.near = hparams['near'] # follow AD-NeRF, we dont use near-far in ds_dict + self.far = hparams['far'] # follow AD-NeRF, we dont use near-far in ds_dict + if hparams['infer_bg_img_fname'] == '': + # use the default bg_img from dataset + bg_img = torch.from_numpy(ds_dict['bg_img']).float() / 255. + elif hparams['infer_bg_img_fname'] == 'white': # special + bg_img = np.ones((self.H, self.W, 3), dtype=np.float32) + elif hparams['infer_bg_img_fname'] == 'black': # special + bg_img = np.zeros((self.H, self.W, 3), dtype=np.float32) + else: # load from a specificfile + bg_img = cv2.imread(hparams['infer_bg_img_fname'], cv2.IMREAD_UNCHANGED) # [H, W, 3] + if bg_img.shape[0] != self.H or bg_img.shape[1] != self.W: + bg_img = cv2.resize(bg_img, (self.W, self.H), interpolation=cv2.INTER_AREA) + bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB) + bg_img = bg_img.astype(np.float32) / 255 # [H, W, 3/4] + self.bg_img = convert_to_tensor(bg_img) + + self.idexp_lm3d_mean = torch.from_numpy(ds_dict['idexp_lm3d_mean']).float() + self.idexp_lm3d_std = torch.from_numpy(ds_dict['idexp_lm3d_std']).float() + + fl_x = fl_y = self.focal + self.intrinsics = np.array([fl_x, fl_y, self.cx, self.cy]) + self.poses = torch.from_numpy(np.stack([nerf_matrix_to_ngp(s['c2w'], scale=hparams['camera_scale'], offset=hparams['camera_offset']) for s in self.samples])) + if torch.any(torch.isnan(self.poses)): + raise ValueError("Found NaN in transform_matrix, please check the face_tracker process!") + if not training and hparams['infer_smooth_camera_path']: + smo_poses = smooth_camera_path(self.poses.numpy(), kernel_size=hparams['infer_smooth_camera_path_kernel_size']) + self.poses = torch.from_numpy(smo_poses) + print(f"{prefix}: Smooth head trajectory (rotation and translation) with a window size of {hparams['infer_smooth_camera_path_kernel_size']}") + self.bg_coords = get_bg_coords(self.H, self.W, 'cpu') # [1, H*W, 2] in [-1, 1] + + if self.cond_type == 'deepspeech': + self.conds = torch.stack([s['deepspeech_win'] for s in self.samples]) # [B=1, T=16, C=29] + elif self.cond_type == 'esperanto': + self.conds = torch.stack([s['esperanto_win'] for s in self.samples]) # [B=1, T=16, C=44] + elif self.cond_type == 'idexp_lm3d_normalized': + self.conds = torch.stack([s['idexp_lm3d_normalized_win'].reshape([hparams['cond_win_size'], 204]) for s in self.samples]) # [B=1, T=1, C=204] + else: + raise NotImplementedError + + self.finetune_lip_flag = False + self.lips_rect = [] + for sample in self.samples: + img_id = sample['idx'] + lms = np.loadtxt(os.path.join(hparams['processed_data_dir'],hparams['video_id'], 'ori_imgs', str(img_id) + '.lms')) # [68, 2] + lips = slice(48, 60) + xmin, xmax = int(lms[lips, 1].min()), int(lms[lips, 1].max()) + ymin, ymax = int(lms[lips, 0].min()), int(lms[lips, 0].max()) + + # padding to H == W + cx = (xmin + xmax) // 2 + cy = (ymin + ymax) // 2 + + l = max(xmax - xmin, ymax - ymin) // 2 + xmin = max(0, cx - l) + xmax = min(self.H, cx + l) + ymin = max(0, cy - l) + ymax = min(self.W, cy + l) + self.lips_rect.append([xmin, xmax, ymin, ymax]) + + self.training = training + self.global_step = 0 + + @property + def num_rays(self): + return hparams['n_rays'] if self.training else -1 + + def __getitem__(self, idx): + raw_sample = self.samples[idx] + + if hparams.get("load_imgs_to_memory", True): + # disable it to save memory usage. + # for 5500 images, it takes 1 minutes to imread, by contrast, only 1s is needed to index them in memory. + # But it reuqires 15GB memory for caching 5500 images at 512x512 resolution. + if 'torso_img' not in self.samples[idx].keys(): + self.samples[idx]['torso_img'] = load_image_as_uint8_tensor(self.samples[idx]['torso_img_fname']) + self.samples[idx]['gt_img'] = load_image_as_uint8_tensor(self.samples[idx]['gt_img_fname']) + torso_img = self.samples[idx]['torso_img'] + gt_img = self.samples[idx]['gt_img'] + else: + torso_img = load_image_as_uint8_tensor(self.samples[idx]['torso_img_fname']) + gt_img = load_image_as_uint8_tensor(self.samples[idx]['gt_img_fname']) + + + sample = { + 'H': self.H, + 'W': self.W, + 'focal': self.focal, + 'cx': self.cx, + 'cy': self.cy, + 'near': self.near, + 'far': self.far, + 'idx': raw_sample['idx'], + 'face_rect': raw_sample['face_rect'], + 'lip_rect': self.lips_rect[idx], + 'bg_img': self.bg_img, + } + + sample['cond_wins'] = get_audio_features(self.conds, att_mode=2, index=idx) + + ngp_pose = self.poses[idx].unsqueeze(0) + sample['pose'] = convert_poses(ngp_pose) # [B, 6] + sample['pose_matrix'] = ngp_pose # [B, 4, 4] + + sample.update({ + 'torso_img': torso_img.float() / 255., + 'gt_img': gt_img.float() / 255., + }) + + if self.training: + if self.finetune_lip_flag: + # the finetune_lip_flag is controlled by the task that use this dataset + rays = get_rays(ngp_pose.cuda(), self.intrinsics, self.H, self.W, N=-1, rect=sample['lip_rect']) + else: + # training phase + rays = get_rays(ngp_pose.cuda(), self.intrinsics, self.H, self.W, N=self.num_rays, rect=None) + else: + # inference phase + rays = get_rays(ngp_pose.cuda(), self.intrinsics, self.H, self.W, N=-1) + sample['rays_o'] = rays['rays_o'] + sample['rays_d'] = rays['rays_d'] + + xmin, xmax, ymin, ymax = raw_sample['face_rect'] + face_mask = (rays['j'] >= xmin) & (rays['j'] < xmax) & (rays['i'] >= ymin) & (rays['i'] < ymax) # [B, N] + sample['face_mask'] = face_mask + + bg_torso_img = sample['torso_img'] + bg_torso_img = bg_torso_img[..., :3] * bg_torso_img[..., 3:] + self.bg_img * (1 - bg_torso_img[..., 3:]) + bg_torso_img = bg_torso_img.view(1, -1, 3) # treat torso as a part of background + bg_img = self.bg_img.view(1, -1, 3) + + C = sample['gt_img'].shape[-1] + if self.training: + bg_img = torch.gather(bg_img.cuda(), 1, torch.stack(3 * [rays['inds']], -1)) # [B, N, 3] + bg_torso_img = torch.gather(bg_torso_img.cuda(), 1, torch.stack(3 * [rays['inds']], -1)) # [B, N, 3] + gt_img = torch.gather(sample['gt_img'].reshape(1, -1, C).cuda(), 1, torch.stack(C * [rays['inds']], -1)) # [B, N, 3/4] + sample['gt_img'] = gt_img + else: + sample['gt_img'] = sample['gt_img'].reshape([1,-1,C]) + sample['bg_img'] = bg_img + sample['bg_torso_img'] = bg_torso_img + + if self.training: + bg_coords = torch.gather(self.bg_coords.cuda(), 1, torch.stack(2 * [rays['inds']], -1)) # [1, N, 2] + else: + bg_coords = self.bg_coords # [1, N, 2] + sample['bg_coords'] = bg_coords + + return sample + + def __len__(self): + return len(self.samples) + + def collater(self, samples): + assert len(samples) == 1 # NeRF only take 1 image for each iteration + return samples[0] + +if __name__ == '__main__': + set_hparams() + ds = RADNeRFDataset('trainval', data_dir='data/binary/videos/May') + ds[0] + print("done!") \ No newline at end of file diff --git a/Geneface_main/GeneFace/tasks/radnerfs/radnerf.py b/Geneface_main/GeneFace/tasks/radnerfs/radnerf.py new file mode 100644 index 00000000..53a1f9de --- /dev/null +++ b/Geneface_main/GeneFace/tasks/radnerfs/radnerf.py @@ -0,0 +1,411 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np +import os +import cv2 +import lpips +import matplotlib.pyplot as plt + +from modules.radnerfs.radnerf import RADNeRF +from modules.radnerfs.utils import convert_poses, get_bg_coords, get_rays + +from utils.commons.image_utils import to8b +from utils.commons.base_task import BaseTask +from utils.commons.dataset_utils import data_loader +from utils.commons.hparams import hparams +from utils.commons.ckpt_utils import load_ckpt +from utils.commons.tensor_utils import tensors_to_scalars, convert_to_np, move_to_cuda +from utils.nn.model_utils import print_arch, num_params +from utils.nn.schedulers import ExponentialScheduleForRADNeRF +from utils.nn.grad import get_grad_norm + +from tasks.radnerfs.dataset_utils import RADNeRFDataset + + +class RADNeRFTask(BaseTask): + def __init__(self): + super().__init__() + self.dataset_cls = RADNeRFDataset + self.train_dataset = self.dataset_cls(prefix='train', training=True) + self.val_dataset = self.dataset_cls(prefix='val', training=False) + + self.criterion_lpips = lpips.LPIPS(net='alex') + self.finetune_lip_flag = False + + @property + def device(self): + return iter(self.model.parameters()).__next__().device + def build_model(self): + self.model = RADNeRF(hparams) + self.embedders_params = [] + self.embedders_params += [p for k, p in self.model.named_parameters() if p.requires_grad and 'position_embedder' in k] + self.embedders_params += [p for k, p in self.model.named_parameters() if p.requires_grad and 'ambient_embedder' in k] + self.network_params = [p for k, p in self.model.named_parameters() if (p.requires_grad and 'position_embedder' not in k and 'ambient_embedder' not in k and 'cond_att_net' not in k)] + self.att_net_params = [p for k, p in self.model.named_parameters() if p.requires_grad and 'cond_att_net' in k] + + self.model.conds = self.train_dataset.conds + self.model.mark_untrained_grid(self.train_dataset.poses, self.train_dataset.intrinsics) + + return self.model + + def on_train_start(self): + super().on_train_start() + for n, m in self.model.named_children(): + num_params(m, model_name=n) + + def build_optimizer(self, model): + self.optimizer = torch.optim.Adam( + self.network_params, + lr=hparams['lr'], + betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), + eps=1e-15) + self.optimizer.add_param_group({ + 'params': self.embedders_params, + 'lr': hparams['lr'] * 10, + 'betas': (hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), + 'eps': 1e-15 + }) + self.optimizer.add_param_group({ + 'params': self.att_net_params, + 'lr': hparams['lr'] * 5, + 'betas': (hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), + 'eps': 1e-15 + }) + return self.optimizer + + def build_scheduler(self, optimizer): + return ExponentialScheduleForRADNeRF(optimizer, hparams['lr'], hparams['warmup_updates']) + + @data_loader + def train_dataloader(self): + self.train_dl = torch.utils.data.DataLoader(self.train_dataset,collate_fn=self.train_dataset.collater, + batch_size=1, shuffle=True, + # num_workers=0, pin_memory=True) + num_workers=0, pin_memory=False) + return self.train_dl + + @data_loader + def val_dataloader(self): + self.val_dl = torch.utils.data.DataLoader(self.val_dataset,collate_fn=self.val_dataset.collater, + batch_size=1, shuffle=True, + # num_workers=0, pin_memory=True) + num_workers=0, pin_memory=False) + return self.val_dl + + @data_loader + def test_dataloader(self): + self.val_dl = torch.utils.data.DataLoader(self.val_dataset,collate_fn=self.val_dataset.collater, + batch_size=1, shuffle=False, + # num_workers=0, pin_memory=True) + num_workers=0, pin_memory=False) + return self.val_dl + + ########################## + # forward the model + ########################## + def run_model(self, sample, infer=False): + """ + render or train on a single-frame + :param sample: a batch of data + :param infer: bool, run in infer mode + :return: + if not infer: + return losses, model_out + if infer: + return model_out + """ + cond_wins = sample['cond_wins'] + rays_o = sample['rays_o'] # [B, N, 3] + rays_d = sample['rays_d'] # [B, N, 3] + bg_coords = sample['bg_coords'] # [1, N, 2] + poses = sample['pose'] # [B, 6] + idx = sample['idx'] # [B] + bg_color = sample['bg_torso_img'] if 'bg_torso_img' in sample else sample['bg_img'] # treat torso as a part of background + H, W = sample['H'], sample['W'] + + cond_inp = cond_wins + start_finetune_lip = hparams['finetune_lips'] and self.global_step > hparams['finetune_lips_start_iter'] + + if not infer: + # training phase, sample rays from the image + model_out = self.model.render(rays_o, rays_d, cond_inp, bg_coords, poses, index=idx, staged=False, bg_color=bg_color, perturb=True, force_all_rays=False, **hparams) + pred_rgb = model_out['rgb_map'] + + losses_out = {} + gt_rgb = sample['gt_img'] + losses_out['mse_loss'] = torch.mean((pred_rgb - gt_rgb) ** 2) # [B, N, 3] --> scalar + + if self.model.training: + alphas = model_out['weights_sum'].clamp(1e-5, 1 - 1e-5) + losses_out['weights_entropy_loss'] = torch.mean(- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)) + ambient = model_out['ambient'] # [N], abs sum + face_mask = sample['face_mask'] # [B, N] + losses_out['ambient_loss'] = (ambient * (~face_mask.view(-1))).mean() + + if start_finetune_lip and self.finetune_lip_flag: + # during the training phase of finetuning lip, all rays are from lip part + xmin, xmax, ymin, ymax = sample['lip_rect'] + gt_rgb = gt_rgb.view(-1, xmax - xmin, ymax - ymin, 3).permute(0, 3, 1, 2).contiguous() + pred_rgb = pred_rgb.view(-1, xmax - xmin, ymax - ymin, 3).permute(0, 3, 1, 2).contiguous() + losses_out['lpips_loss'] = self.criterion_lpips(pred_rgb, gt_rgb).mean() + else: + # validation step, calulate lpips loss + if 'lip_rect' in sample: + xmin, xmax, ymin, ymax = sample['lip_rect'] + lip_gt_rgb = gt_rgb.view(-1,H,W,3)[:,xmin:xmax,ymin:ymax,:].permute(0, 3, 1, 2).contiguous() + lip_pred_rgb = pred_rgb.view(-1,H,W,3)[:,xmin:xmax,ymin:ymax,:].permute(0, 3, 1, 2).contiguous() + losses_out['lpips_loss'] = self.criterion_lpips(lip_pred_rgb, lip_gt_rgb).mean() + + if self.model.training and start_finetune_lip: + # during training, flip in each iteration, to prevent forgetting other facial parts. + self.finetune_lip_flag = not self.finetune_lip_flag + self.train_dataset.finetune_lip_flag = self.finetune_lip_flag + return losses_out, model_out + + else: + # infer phase, generate the whole image + model_out = self.model.render(rays_o, rays_d, cond_inp, bg_coords, poses, index=idx, staged=False, bg_color=bg_color, perturb=False, force_all_rays=True, **hparams) + # calculate val loss + if 'gt_img' in sample: + gt_rgb = sample['gt_img'] + pred_rgb = model_out['rgb_map'] + model_out['mse_loss'] = torch.mean((pred_rgb - gt_rgb) ** 2) # [B, N, 3] --> scalar + if 'lip_rect' in sample: + xmin, xmax, ymin, ymax = sample['lip_rect'] + gt_rgb = gt_rgb.view(-1, H, W, 3)[:,xmin:xmax,ymin:ymax,:].permute(0, 3, 1, 2).contiguous() + pred_rgb = pred_rgb.view(-1, H, W, 3)[:,xmin:xmax,ymin:ymax,:].permute(0, 3, 1, 2).contiguous() + model_out['lpips_loss'] = self.criterion_lpips(pred_rgb, gt_rgb).mean() + return model_out + + ########################## + # training + ########################## + def _training_step(self, sample, batch_idx, optimizer_idx): + outputs = {} + self.train_dataset.global_step = self.global_step + if self.global_step % hparams['update_extra_interval'] == 0: + start_finetune_lips = hparams['finetune_lips'] and self.global_step > hparams['finetune_lips_start_iter'] + if not start_finetune_lips: + # when finetuning lips, we don't update the density grid and bitfield. + self.model.update_extra_state() + + loss_output, model_out = self.run_model(sample) + loss_weights = { + 'mse_loss': 1.0, + 'weights_entropy_loss': hparams['lambda_weights_entropy'], + 'lpips_loss': hparams['lambda_lpips_loss'], + 'ambient_loss': min(self.global_step / 250000, 1.0) * hparams['lambda_ambient'], # gradually increase it + } + total_loss = sum([loss_weights.get(k, 1) * v for k, v in loss_output.items() if isinstance(v, torch.Tensor) and v.requires_grad]) + def mse2psnr(x): return -10. * torch.log(x) / torch.log(torch.Tensor([10.])).to(x.device) + loss_output['head_psnr'] = mse2psnr(loss_output['mse_loss'].detach()) + outputs.update(loss_output) + + if (self.global_step+1) % hparams['tb_log_interval'] == 0: + density_grid_info = { + "density_grid_info/min_density": self.model.density_grid.min().item(), + "density_grid_info/max_density": self.model.density_grid.max().item(), + "density_grid_info/mean_density": self.model.mean_density, + # "density_grid_info/occupancy_rate": (self.model.density_grid > 0.01).sum() / (128**3 * self.model.cascade), + "density_grid_info/occupancy_rate": (self.model.density_grid > min(self.model.mean_density, self.model.density_thresh)).sum() / (128**3 * self.model.cascade), + "density_grid_info/step_mean_count": self.model.mean_count + } + outputs.update(density_grid_info) + return total_loss, outputs + + def on_before_optimization(self, opt_idx): + prefix = f"grad_norm_opt_idx_{opt_idx}" + grad_norm_dict = { + f'{prefix}/cond_att': get_grad_norm(self.att_net_params), + f'{prefix}/embedders_params': get_grad_norm(self.embedders_params), + f'{prefix}/network_params': get_grad_norm(self.network_params ), + } + if self.gradient_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(self.parameters(), self.gradient_clip_norm) + if self.gradient_clip_val > 0: + torch.nn.utils.clip_grad_value_(self.parameters(), self.gradient_clip_val) + return grad_norm_dict + + def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx): + self.scheduler.step(self.global_step // hparams['accumulate_grad_batches']) + + ##################### + # Validation + ##################### + def validation_start(self): + if self.global_step % hparams['valid_infer_interval'] == 0: + self.gen_dir = os.path.join(hparams['work_dir'], f'validation_results/validation_{self.trainer.global_step}') + os.makedirs(self.gen_dir, exist_ok=True) + os.makedirs(f'{self.gen_dir}/images', exist_ok=True) + os.makedirs(f'{self.gen_dir}/depth', exist_ok=True) + + @torch.no_grad() + def validation_step(self, sample, batch_idx): + outputs = {} + outputs['losses'] = {} + outputs['losses'], model_out = self.run_model(sample, infer=False) + outputs['total_loss'] = sum(outputs['losses'].values()) + outputs['nsamples'] = 1 + outputs = tensors_to_scalars(outputs) + if self.global_step % hparams['valid_infer_interval'] == 0 \ + and batch_idx < hparams['num_valid_plots']: + num_val_samples = len(self.val_dataset) + interval = (num_val_samples-1) // 4 + idx_lst = [i * interval for i in range(5)] + sample = move_to_cuda(self.val_dataset[idx_lst[batch_idx]]) + infer_outputs = self.run_model(sample, infer=True) + H, W = sample['H'], sample['W'] + img_pred = infer_outputs['rgb_map'].reshape([H, W, 3]) + depth_pred = infer_outputs['depth_map'].reshape([H, W]) + + base_fn = f"frame_{sample['idx']}" + self.logger.add_figure(f"frame_{sample['idx']}/img_pred", self.rgb_to_figure(img_pred), self.global_step) + self.logger.add_figure(f"frame_{sample['idx']}/depth_pred", self.rgb_to_figure(depth_pred), self.global_step) + + self.save_rgb_to_fname(img_pred, f"{self.gen_dir}/images/{base_fn}.png") + self.save_rgb_to_fname(depth_pred, f"{self.gen_dir}/depth/{base_fn}.png") + + if hparams['save_gt']: + img_gt = sample['gt_img'].reshape([H, W, 3]) + if self.global_step == hparams['valid_infer_interval']: + self.logger.add_figure(f"frame_{sample['idx']}/img_gt", self.rgb_to_figure(img_gt), self.global_step) + base_fn = f"frame_{sample['idx']}_gt" + self.save_rgb_to_fname(img_gt, f"{self.gen_dir}/images/{base_fn}.png") + return outputs + + def validation_end(self, outputs): + return super().validation_end(outputs) + + ##################### + # Testing + ##################### + def test_start(self): + self.gen_dir = os.path.join(hparams['work_dir'], f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}') + os.makedirs(self.gen_dir, exist_ok=True) + os.makedirs(f'{self.gen_dir}/images', exist_ok=True) + os.makedirs(f'{self.gen_dir}/depth', exist_ok=True) + + @torch.no_grad() + def test_step(self, sample, batch_idx): + outputs = self.run_model(sample, infer=True) + rgb_pred = outputs['rgb_map'] + H, W = sample['H'], sample['W'] + img_pred = rgb_pred.reshape([H, W, 3]) + gen_dir = self.gen_dir + base_fn = f"frame_{sample['idx']}" + self.save_rgb_to_fname(img_pred, f"{gen_dir}/images/{base_fn}.png") + self.save_rgb_to_fname(img_pred, f"{gen_dir}/depth/{base_fn}.png") + target = sample['gt_img'] + img_gt = target.reshape([H, W, 3]) + if hparams['save_gt']: + base_fn = f"frame_{sample['idx']}_gt" + self.save_rgb_to_fname(img_gt, f"{gen_dir}/images/{base_fn}.png") + + outputs['losses'] = (img_gt - img_pred).mean() + return outputs + + def test_end(self, outputs): + pass + + ##################### + # Visualization utils + ##################### + @staticmethod + def rgb_to_figure(rgb): + fig = plt.figure(figsize=(12, 6)) + rgb = convert_to_np(rgb * 255.).astype(np.uint8) + plt.imshow(rgb) + return fig + + @staticmethod + def save_rgb_to_fname(rgb, fname): + rgb = convert_to_np(rgb * 255.).astype(np.uint8) + if rgb.ndim == 3: + bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) + cv2.imwrite(f"{fname}", bgr) + else: + # gray image + cv2.imwrite(f"{fname}", rgb) + + ### GUI utils + def test_gui_with_editable_data(self, pose, intrinsics, W, H, cond_wins, index=0, bg_color=None, spp=1, downscale=1): + # def test_gui_with_edited_data(self, pose, intrinsics, W, H, cond_wins, index=0, bg_color=None, downscale=1): + # render resolution (may need downscale to for better frame rate) + rH = int(H * downscale) + rW = int(W * downscale) + intrinsics = intrinsics * downscale + + cond_wins = cond_wins.cuda() + pose = torch.from_numpy(pose).unsqueeze(0).cuda() + rays = get_rays(pose, intrinsics, rH, rW, -1) + bg_coords = get_bg_coords(rH, rW, 'cuda:0') + + sample = { + 'rays_o': rays['rays_o'].cuda(), + 'rays_d': rays['rays_d'].cuda(), + 'H': rH, + 'W': rW, + 'cond_wins': cond_wins, + 'idx': [index], # support choosing index for individual codes + 'pose': convert_poses(pose), + 'bg_coords': bg_coords, + 'bg_img': bg_color.cuda() + } + + self.model.eval() + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=hparams['amp']): + # here spp is used as perturb random seed! + # face: do not perturb for the first spp, else lead to scatters. + infer_outputs = self.run_model(sample, infer=True) + preds = infer_outputs['rgb_map'].reshape([1,rH, rW, 3]) + preds_depth = infer_outputs['depth_map'].reshape([1, rH, rW]) + + # interpolation to the original resolution + if downscale != 1: + # TODO: have to permute twice with torch... + preds = F.interpolate(preds.permute(0, 3, 1, 2), size=(H, W), mode='bilinear').permute(0, 2, 3, 1).contiguous() + preds_depth = F.interpolate(preds_depth.unsqueeze(1), size=(H, W), mode='nearest').squeeze(1) + + pred = preds[0].detach().cpu().numpy() + pred_depth = preds_depth[0].detach().cpu().numpy() + + outputs = { + 'image': pred, + 'depth': pred_depth, + } + + return outputs + + # [GUI] test with provided data + def test_gui_with_data(self, sample, target_W, target_H): + # prevent calculate loss, which increase costs. + del sample['gt_img'] + del sample['lip_rect'] + + self.model.eval() + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=hparams['amp']): + # here spp is used as perturb random seed! + # face: do not perturb for the first spp, else lead to scatters. + infer_outputs = self.run_model(sample, infer=True) + H, W = sample['H'], sample['W'] + preds = infer_outputs['rgb_map'].reshape([1,H, W, 3]) + preds_depth = infer_outputs['depth_map'].reshape([1,H, W]) + + # the H/W in data may be differnt to GUI, so we still need to resize... + preds = F.interpolate(preds.permute(0, 3, 1, 2), size=(target_H, target_W), mode='bilinear').permute(0, 2, 3, 1).contiguous() + preds_depth = F.interpolate(preds_depth.unsqueeze(1), size=(target_H, target_W), mode='nearest').squeeze(1) + + pred = preds[0].detach().cpu().numpy() + pred_depth = preds_depth[0].detach().cpu().numpy() + + outputs = { + 'image': pred, + 'depth': pred_depth, + } + + return outputs + diff --git a/Geneface_main/GeneFace/tasks/radnerfs/radnerf_torso.py b/Geneface_main/GeneFace/tasks/radnerfs/radnerf_torso.py new file mode 100644 index 00000000..b7317d7b --- /dev/null +++ b/Geneface_main/GeneFace/tasks/radnerfs/radnerf_torso.py @@ -0,0 +1,165 @@ +import torch +import torch.nn as nn +import numpy as np +import os +import cv2 +import lpips +import matplotlib.pyplot as plt + +from modules.radnerfs.radnerf import RADNeRF +from modules.radnerfs.radnerf_torso import RADNeRFTorso +from tasks.radnerfs.radnerf import RADNeRFTask + +from utils.commons.image_utils import to8b +from utils.commons.base_task import BaseTask +from utils.commons.dataset_utils import data_loader +from utils.commons.hparams import hparams +from utils.commons.ckpt_utils import load_ckpt +from utils.commons.tensor_utils import tensors_to_scalars, convert_to_np, move_to_cuda +from utils.nn.model_utils import print_arch, num_params, not_requires_grad +from utils.nn.schedulers import ExponentialScheduleForRADNeRFTorso +from utils.nn.grad import get_grad_norm + +from tasks.radnerfs.dataset_utils import RADNeRFDataset + + +class RADNeRFTorsoTask(RADNeRFTask): + def __init__(self): + super().__init__() + + def build_model(self): + self.model = RADNeRFTorso(hparams) + # todo: load state_dict in RADNeRF + head_model = RADNeRF(hparams) + load_ckpt(head_model, hparams['head_model_dir']) + print(f"Loaded Head Model from {hparams['head_model_dir']}") + self.model.load_state_dict(head_model.state_dict(), strict=False) + print(f"Loaded state_dict of Head Model to the RADNeRFTorso Model") + del head_model + + self.torso_embedders_params = [p for k, p in self.model.named_parameters() if p.requires_grad and 'torso_embedder' in k] + self.torso_network_params = [p for k, p in self.model.named_parameters() if (p.requires_grad and 'torso_embedder' not in k and 'torso' in k)] + for k, p in self.model.named_parameters(): + if 'torso' not in k: + not_requires_grad(p) + + self.model.poses = self.train_dataset.poses + return self.model + + def on_train_start(self): + super().on_train_start() + for n, m in self.model.named_children(): + num_params(m, model_name=n) + + def build_optimizer(self, model): + self.optimizer = torch.optim.Adam( + self.torso_network_params, + lr=hparams['lr'], + betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), + eps=1e-15) + self.optimizer.add_param_group({ + 'params': self.torso_embedders_params, + 'lr': hparams['lr'] * 10, + 'betas': (hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), + 'eps': 1e-15 + }) + return self.optimizer + + def build_scheduler(self, optimizer): + return ExponentialScheduleForRADNeRFTorso(optimizer, hparams['lr'], hparams['warmup_updates']) + + ########################## + # forward the model + ########################## + def run_model(self, sample, infer=False): + """ + render or train on a single-frame + :param sample: a batch of data + :param infer: bool, run in infer mode + :return: + if not infer: + return losses, model_out + if infer: + return model_out + """ + cond_wins = sample['cond_wins'] + rays_o = sample['rays_o'] # [B, N, 3] + rays_d = sample['rays_d'] # [B, N, 3] + bg_coords = sample['bg_coords'] # [1, N, 2] + poses = sample['pose'] # [B, 6] + idx = sample['idx'] # [B] + bg_color = sample['bg_img'] + H, W = sample['H'], sample['W'] + + cond_inp = cond_wins + + if not infer: + model_out = self.model.render(rays_o, rays_d, cond_inp, bg_coords, poses, index=idx, staged=False, bg_color=bg_color, perturb=True, force_all_rays=False, **hparams) + if hparams['torso_train_mode'] == 1: + pred_rgb = model_out['torso_rgb_map'] + gt_rgb = sample['bg_torso_img'] # the target is bg_torso_img + else: + pred_rgb = model_out['rgb_map'] # todo: try whole image + gt_rgb = sample['gt_img'] # todo: try gt_image + + losses_out = {} + + losses_out['torso_mse_loss'] = torch.mean((pred_rgb - gt_rgb) ** 2) # [B, N, 3] --> scalar + + alphas = model_out['torso_alpha_map'].clamp(1e-5, 1 - 1e-5) + losses_out['torso_weights_entropy_loss'] = torch.mean(- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)) + + return losses_out, model_out + + else: + # infer phase, generate the whole image + model_out = self.model.render(rays_o, rays_d, cond_inp, bg_coords, poses, index=idx, staged=False, bg_color=bg_color, perturb=False, force_all_rays=True, **hparams) + # calculate val loss + if 'gt_img' in sample: + gt_rgb = sample['gt_img'] + pred_rgb = model_out['rgb_map'] + model_out['mse_loss'] = torch.mean((pred_rgb - gt_rgb) ** 2) # [B, N, 3] --> scalar + return model_out + + ########################## + # training + ########################## + def _training_step(self, sample, batch_idx, optimizer_idx): + outputs = {} + self.train_dataset.global_step = self.global_step + if self.global_step % hparams['update_extra_interval'] == 0: + self.model.update_extra_state() + + loss_output, model_out = self.run_model(sample) + loss_weights = { + 'torso_mse_loss': 1.0, + 'torso_weights_entropy_loss': hparams['lambda_weights_entropy'], + } + total_loss = sum([loss_weights.get(k, 1) * v for k, v in loss_output.items() if isinstance(v, torch.Tensor) and v.requires_grad]) + def mse2psnr(x): return -10. * torch.log(x) / torch.log(torch.Tensor([10.])).to(x.device) + loss_output['image_psnr'] = mse2psnr(loss_output['torso_mse_loss'].detach()) + outputs.update(loss_output) + + if (self.global_step+1) % hparams['tb_log_interval'] == 0: + density_grid_info = { + "density_grid_info/min_density_torso": self.model.density_grid_torso.min().item(), + "density_grid_info/max_density_torso": self.model.density_grid_torso.max().item(), + "density_grid_info/mean_density_torso": self.model.mean_density_torso, + "density_grid_info/occupancy_rate_torso": (self.model.density_grid_torso > min(self.model.mean_density_torso, self.model.density_thresh_torso)).sum() / (128**3 * self.model.cascade), + "density_grid_info/step_mean_count_torso": self.model.mean_count + } + outputs.update(density_grid_info) + return total_loss, outputs + + def on_before_optimization(self, opt_idx): + prefix = f"grad_norm_opt_idx_{opt_idx}" + grad_norm_dict = { + f'{prefix}/torso_embedders_params': get_grad_norm(self.torso_embedders_params), + f'{prefix}/torso_network_params': get_grad_norm(self.torso_network_params ), + } + if self.gradient_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(self.parameters(), self.gradient_clip_norm) + if self.gradient_clip_val > 0: + torch.nn.utils.clip_grad_value_(self.parameters(), self.gradient_clip_val) + return grad_norm_dict + diff --git a/Geneface_main/GeneFace/tasks/run.py b/Geneface_main/GeneFace/tasks/run.py new file mode 100644 index 00000000..ef2b0a31 --- /dev/null +++ b/Geneface_main/GeneFace/tasks/run.py @@ -0,0 +1,19 @@ +import os + +os.environ["OMP_NUM_THREADS"] = "1" + +from utils.commons.hparams import hparams, set_hparams +import importlib + + +def run_task(): + assert hparams['task_cls'] != '' + pkg = ".".join(hparams["task_cls"].split(".")[:-1]) + cls_name = hparams["task_cls"].split(".")[-1] + task_cls = getattr(importlib.import_module(pkg), cls_name) + task_cls.start() + + +if __name__ == '__main__': + set_hparams() + run_task() diff --git a/Geneface_main/GeneFace/tasks/syncnet/__pycache__/lm3d_syncnet.cpython-39.pyc b/Geneface_main/GeneFace/tasks/syncnet/__pycache__/lm3d_syncnet.cpython-39.pyc new file mode 100644 index 00000000..75fe415d Binary files /dev/null and b/Geneface_main/GeneFace/tasks/syncnet/__pycache__/lm3d_syncnet.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/tasks/syncnet/lm3d_syncnet.py b/Geneface_main/GeneFace/tasks/syncnet/lm3d_syncnet.py new file mode 100644 index 00000000..241143e1 --- /dev/null +++ b/Geneface_main/GeneFace/tasks/syncnet/lm3d_syncnet.py @@ -0,0 +1,165 @@ +import torch +import random + +from utils.commons.base_task import BaseTask +from utils.commons.dataset_utils import data_loader +from utils.commons.hparams import hparams +from utils.commons.tensor_utils import tensors_to_scalars +from utils.nn.model_utils import print_arch +from utils.nn.schedulers import CosineSchedule + +from modules.syncnet.models import LandmarkHubertSyncNet +from tasks.audio2motion.dataset_utils.lrs3_dataset import LRS3SeqDataset + +class SyncNetTask(BaseTask): + def __init__(self): + super().__init__() + self.dataset_cls = LRS3SeqDataset + + def build_model(self): + lm_dim = 20*3 + self.model = LandmarkHubertSyncNet(lm_dim) + print_arch(self.model) + return self.model + + def build_optimizer(self, model): + self.optimizer = optimizer = torch.optim.Adam( + model.parameters(), + lr=hparams['lr'], + betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2'])) + return optimizer + + def build_scheduler(self, optimizer): + return CosineSchedule(optimizer, hparams['lr'], warmup_updates=0, total_updates=hparams['max_updates']) + + @data_loader + def train_dataloader(self): + train_dataset = self.dataset_cls(prefix='train') + self.train_dl = train_dataset.get_dataloader() + return self.train_dl + + @data_loader + def val_dataloader(self): + val_dataset = self.dataset_cls(prefix='val') + self.val_dl = val_dataset.get_dataloader() + return self.val_dl + + @data_loader + def test_dataloader(self): + val_dataset = self.dataset_cls(prefix='val') + self.val_dl = val_dataset.get_dataloader() + return self.val_dl + + ########################## + # training and validation + ########################## + def run_model(self, sample, infer=False, batch_size=1024): + """ + render or train on a single-frame + :param sample: a batch of data + :param infer: bool, run in infer mode + :return: + if not infer: + return losses, model_out + if infer: + return model_out + """ + model_out = {} + mouth_lm3d = sample['mouth_idexp_lm3d'] + mel = sample['hubert'] + + y_mask = sample['y_mask'] + y_len = y_mask.sum(dim=1) + mouth_lst, mel_lst, label_lst = [], [], [] + while len(mouth_lst) < batch_size: + for i in range(mouth_lm3d.shape[0]): + if not infer: + is_pos_sample = random.choice([True, False]) + else: + is_pos_sample = True + exp_idx = random.randint(a=0, b=y_len[i]-1-5) + mouth_clip = mouth_lm3d[i, exp_idx: exp_idx+5] + assert mouth_clip.shape[0]==5, f"exp_idx={exp_idx},y_len={y_len[i]}" + if is_pos_sample: + mel_clip = mel[i, exp_idx*2: exp_idx*2 + 10] + label_lst.append(1.) + else: + if random.random() < 0.25: + wrong_spk_idx = random.randint(a=0, b=len(y_len)-1) + wrong_exp_idx = random.randint(a=0, b=y_len[wrong_spk_idx]-1-5) + while wrong_exp_idx == exp_idx: + wrong_exp_idx = random.randint(a=0, b=y_len[wrong_spk_idx]-1-5) + mel_clip = mel[wrong_spk_idx, wrong_exp_idx*2: wrong_exp_idx*2 + 10] + assert mel_clip.shape[0]==10 + elif random.random() < 0.5: + wrong_exp_idx = random.randint(a=0, b=y_len[i]-1-5) + while wrong_exp_idx == exp_idx: + wrong_exp_idx = random.randint(a=0, b=y_len[i]-1-5) + mel_clip = mel[i, wrong_exp_idx*2: wrong_exp_idx*2 + 10] + assert mel_clip.shape[0]==10 + else: + left_offset = max(-5, -exp_idx) + right_offset = min(5, (y_len[i]-5-exp_idx)) + exp_offset = random.randint(a=left_offset, b=right_offset) + while abs(exp_offset) <= 1: + exp_offset = random.randint(a=left_offset, b=right_offset) + wrong_exp_idx = exp_offset + exp_idx + mel_clip = mel[i, wrong_exp_idx*2: wrong_exp_idx*2 + 10] + assert mel_clip.shape[0]==10, y_len[i]-wrong_exp_idx + mel_clip = mel[i, wrong_exp_idx*2: wrong_exp_idx*2 + 10] + label_lst.append(0.) + mouth_lst.append(mouth_clip) + mel_lst.append(mel_clip) + mel_clips = torch.stack(mel_lst) + mouth_clips = torch.stack(mouth_lst) + labels = torch.tensor(label_lst).float().to(mel_clips.device) + + audio_embedding, mouth_embedding = self.model(mel_clips, mouth_clips) + sync_loss, cosine_sim = self.model.cal_sync_loss(audio_embedding, mouth_embedding, labels) + if not infer: + losses_out = {} + model_out = {} + losses_out['sync_loss'] = sync_loss + model_out['cosine_sim'] = cosine_sim + return losses_out, model_out + else: + model_out['sync_loss'] = sync_loss + return model_out + + def _training_step(self, sample, batch_idx, optimizer_idx): + loss_output, model_out = self.run_model(sample, infer=False) + loss_weights = {} + total_loss = sum([loss_weights.get(k, 1) * v for k, v in loss_output.items() if isinstance(v, torch.Tensor) and v.requires_grad]) + return total_loss, loss_output + + def validation_start(self): + pass + + @torch.no_grad() + def validation_step(self, sample, batch_idx): + outputs = {} + outputs['losses'] = {} + outputs['losses'], model_out = self.run_model(sample, infer=False, batch_size=20000) + outputs = tensors_to_scalars(outputs) + return outputs + + def validation_end(self, outputs): + return super().validation_end(outputs) + + ##################### + # Testing + ##################### + def test_start(self): + pass + + @torch.no_grad() + def test_step(self, sample, batch_idx): + """ + :param sample: + :param batch_idx: + :return: + """ + pass + + def test_end(self, outputs): + pass diff --git a/Geneface_main/GeneFace/utils/commons/__pycache__/base_task.cpython-39.pyc b/Geneface_main/GeneFace/utils/commons/__pycache__/base_task.cpython-39.pyc new file mode 100644 index 00000000..31bc0be2 Binary files /dev/null and b/Geneface_main/GeneFace/utils/commons/__pycache__/base_task.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/utils/commons/__pycache__/ckpt_utils.cpython-39.pyc b/Geneface_main/GeneFace/utils/commons/__pycache__/ckpt_utils.cpython-39.pyc new file mode 100644 index 00000000..bc96e181 Binary files /dev/null and b/Geneface_main/GeneFace/utils/commons/__pycache__/ckpt_utils.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/utils/commons/__pycache__/dataset_utils.cpython-39.pyc b/Geneface_main/GeneFace/utils/commons/__pycache__/dataset_utils.cpython-39.pyc new file mode 100644 index 00000000..935c0ff1 Binary files /dev/null and b/Geneface_main/GeneFace/utils/commons/__pycache__/dataset_utils.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/utils/commons/__pycache__/ddp_utils.cpython-39.pyc b/Geneface_main/GeneFace/utils/commons/__pycache__/ddp_utils.cpython-39.pyc new file mode 100644 index 00000000..7225ac9d Binary files /dev/null and b/Geneface_main/GeneFace/utils/commons/__pycache__/ddp_utils.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/utils/commons/__pycache__/euler2rot.cpython-39.pyc b/Geneface_main/GeneFace/utils/commons/__pycache__/euler2rot.cpython-39.pyc new file mode 100644 index 00000000..56195cc2 Binary files /dev/null and b/Geneface_main/GeneFace/utils/commons/__pycache__/euler2rot.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/utils/commons/__pycache__/hparams.cpython-39.pyc b/Geneface_main/GeneFace/utils/commons/__pycache__/hparams.cpython-39.pyc new file mode 100644 index 00000000..f87549ae Binary files /dev/null and b/Geneface_main/GeneFace/utils/commons/__pycache__/hparams.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/utils/commons/__pycache__/image_utils.cpython-39.pyc b/Geneface_main/GeneFace/utils/commons/__pycache__/image_utils.cpython-39.pyc new file mode 100644 index 00000000..a4e73a19 Binary files /dev/null and b/Geneface_main/GeneFace/utils/commons/__pycache__/image_utils.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/utils/commons/__pycache__/indexed_datasets.cpython-39.pyc b/Geneface_main/GeneFace/utils/commons/__pycache__/indexed_datasets.cpython-39.pyc new file mode 100644 index 00000000..05e83dc4 Binary files /dev/null and b/Geneface_main/GeneFace/utils/commons/__pycache__/indexed_datasets.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/utils/commons/__pycache__/meters.cpython-39.pyc b/Geneface_main/GeneFace/utils/commons/__pycache__/meters.cpython-39.pyc new file mode 100644 index 00000000..7e7093c6 Binary files /dev/null and b/Geneface_main/GeneFace/utils/commons/__pycache__/meters.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/utils/commons/__pycache__/multiprocess_utils.cpython-39.pyc b/Geneface_main/GeneFace/utils/commons/__pycache__/multiprocess_utils.cpython-39.pyc new file mode 100644 index 00000000..b73b5e90 Binary files /dev/null and b/Geneface_main/GeneFace/utils/commons/__pycache__/multiprocess_utils.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/utils/commons/__pycache__/os_utils.cpython-39.pyc b/Geneface_main/GeneFace/utils/commons/__pycache__/os_utils.cpython-39.pyc new file mode 100644 index 00000000..f7e9a949 Binary files /dev/null and b/Geneface_main/GeneFace/utils/commons/__pycache__/os_utils.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/utils/commons/__pycache__/pitch_utils.cpython-39.pyc b/Geneface_main/GeneFace/utils/commons/__pycache__/pitch_utils.cpython-39.pyc new file mode 100644 index 00000000..abd8037e Binary files /dev/null and b/Geneface_main/GeneFace/utils/commons/__pycache__/pitch_utils.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/utils/commons/__pycache__/tensor_utils.cpython-310.pyc b/Geneface_main/GeneFace/utils/commons/__pycache__/tensor_utils.cpython-310.pyc new file mode 100644 index 00000000..b2b1f050 Binary files /dev/null and b/Geneface_main/GeneFace/utils/commons/__pycache__/tensor_utils.cpython-310.pyc differ diff --git a/Geneface_main/GeneFace/utils/commons/__pycache__/tensor_utils.cpython-39.pyc b/Geneface_main/GeneFace/utils/commons/__pycache__/tensor_utils.cpython-39.pyc new file mode 100644 index 00000000..3ae37939 Binary files /dev/null and b/Geneface_main/GeneFace/utils/commons/__pycache__/tensor_utils.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/utils/commons/__pycache__/trainer.cpython-39.pyc b/Geneface_main/GeneFace/utils/commons/__pycache__/trainer.cpython-39.pyc new file mode 100644 index 00000000..eda4129c Binary files /dev/null and b/Geneface_main/GeneFace/utils/commons/__pycache__/trainer.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/utils/commons/base_task.py b/Geneface_main/GeneFace/utils/commons/base_task.py new file mode 100644 index 00000000..3e6a30a1 --- /dev/null +++ b/Geneface_main/GeneFace/utils/commons/base_task.py @@ -0,0 +1,256 @@ +import logging +import os +import random +import subprocess +import sys +from datetime import datetime +import numpy as np +import torch.utils.data +from torch import nn +from torch.utils.tensorboard import SummaryWriter +from utils.commons.dataset_utils import data_loader +from utils.commons.hparams import hparams +from utils.commons.meters import AvgrageMeter +from utils.commons.tensor_utils import tensors_to_scalars +from utils.commons.trainer import Trainer + +torch.multiprocessing.set_sharing_strategy(os.getenv('TORCH_SHARE_STRATEGY', 'file_system')) + +log_format = '%(asctime)s %(message)s' +logging.basicConfig(stream=sys.stdout, level=logging.INFO, + format=log_format, datefmt='%m/%d %I:%M:%S %p') + + +class BaseTask(nn.Module): + def __init__(self, *args, **kwargs): + super(BaseTask, self).__init__() + self.current_epoch = 0 + self.global_step = 0 + self.trainer = None + self.use_ddp = False + self.gradient_clip_norm = hparams['clip_grad_norm'] + self.gradient_clip_val = hparams.get('clip_grad_value', 0) + self.model = None + self.training_losses_meter = None + self.logger: SummaryWriter = None + + ###################### + # build model, dataloaders, optimizer, scheduler and tensorboard + ###################### + def build_model(self): + raise NotImplementedError + + @data_loader + def train_dataloader(self): + raise NotImplementedError + + @data_loader + def test_dataloader(self): + raise NotImplementedError + + @data_loader + def val_dataloader(self): + raise NotImplementedError + + def build_scheduler(self, optimizer): + return None + + def build_optimizer(self, model): + raise NotImplementedError + + def configure_optimizers(self): + optm = self.build_optimizer(self.model) + self.scheduler = self.build_scheduler(optm) + if isinstance(optm, (list, tuple)): + return optm + return [optm] + + def build_tensorboard(self, save_dir, name, **kwargs): + log_dir = os.path.join(save_dir, name) + os.makedirs(log_dir, exist_ok=True) + self.logger = SummaryWriter(log_dir=log_dir, **kwargs) + + ###################### + # training + ###################### + def on_train_start(self): + pass + + def on_train_end(self): + pass + + def on_epoch_start(self): + self.training_losses_meter = {'total_loss': AvgrageMeter()} + + def on_epoch_end(self): + loss_outputs = {k: v.avg for k, v in self.training_losses_meter.items()} + print(f"Epoch {self.current_epoch} ended. Steps: {self.global_step}. {loss_outputs}") + loss_outputs = {"epoch_mean/"+k:v for k,v in loss_outputs.items()} + return loss_outputs + + def _training_step(self, sample, batch_idx, optimizer_idx): + """ + + :param sample: + :param batch_idx: + :return: total loss: torch.Tensor, loss_log: dict + """ + raise NotImplementedError + + def training_step(self, sample, batch_idx, optimizer_idx=-1): + """ + + :param sample: + :param batch_idx: + :param optimizer_idx: + :return: {'loss': torch.Tensor, 'progress_bar': dict, 'tb_log': dict} + """ + # perform the main training step in a specific task + loss_ret = self._training_step(sample, batch_idx, optimizer_idx) + if loss_ret is None: + return {'loss': None} + total_loss, log_outputs = loss_ret + log_outputs = tensors_to_scalars(log_outputs) + + # add to epoch meter + for k, v in log_outputs.items(): + if '/' in k: + k_split = k.split("/") + assert len(k_split) == 2, "we only support one `/` in tag_name, i.e., `/`" + k = k_split[-1] + if k not in self.training_losses_meter: + self.training_losses_meter[k] = AvgrageMeter() + if not np.isnan(v): + self.training_losses_meter[k].update(v) + self.training_losses_meter['total_loss'].update(total_loss.item()) + + if optimizer_idx >= 0: + log_outputs[f'lr_{optimizer_idx}'] = self.trainer.optimizers[optimizer_idx].param_groups[0]['lr'] + + # add to progress bar + progress_bar_log = {} + for k, v in log_outputs.items(): + if '/' in k: + k_split = k.split("/") + assert len(k_split) == 2, "we only support one `/` in tag_name, i.e., `/`" + k = k_split[-1] + assert k not in progress_bar_log, f"we got duplicate tags in log_outputs, check this `{k}`" + progress_bar_log[k] = v + + # add to progress bar + tb_log = {} + for k, v in log_outputs.items(): + if '/' in k: + tb_log[k] = v + else: + tb_log[f'tr/{k}'] = v + return { + 'loss': total_loss, + 'progress_bar': progress_bar_log, + 'tb_log': tb_log + } + + def on_before_optimization(self, opt_idx): + if self.gradient_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(self.parameters(), self.gradient_clip_norm) + if self.gradient_clip_val > 0: + torch.nn.utils.clip_grad_value_(self.parameters(), self.gradient_clip_val) + + def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx): + if self.scheduler is not None: + self.scheduler.step(self.global_step // hparams['accumulate_grad_batches']) + + ###################### + # validation + ###################### + def validation_start(self): + pass + + def validation_step(self, sample, batch_idx): + """ + + :param sample: + :param batch_idx: + :return: output: {"losses": {...}, "total_loss": float, ...} or (total loss: torch.Tensor, loss_log: dict) + """ + raise NotImplementedError + + def validation_end(self, outputs): + """ + + :param outputs: + :return: loss_output: dict + """ + all_losses_meter = {'total_loss': AvgrageMeter()} + for output in outputs: + if len(output) == 0 or output is None: + continue + if isinstance(output, dict): + assert 'losses' in output, 'Key "losses" should exist in validation output.' + n = output.pop('nsamples', 1) + losses = tensors_to_scalars(output['losses']) + total_loss = output.get('total_loss', sum(losses.values())) + else: + assert len(output) == 2, 'Validation output should only consist of two elements: (total_loss, losses)' + n = 1 + total_loss, losses = output + losses = tensors_to_scalars(losses) + if isinstance(total_loss, torch.Tensor): + total_loss = total_loss.item() + for k, v in losses.items(): + if k not in all_losses_meter: + all_losses_meter[k] = AvgrageMeter() + all_losses_meter[k].update(v, n) + all_losses_meter['total_loss'].update(total_loss, n) + loss_output = {k: round(v.avg, 4) for k, v in all_losses_meter.items()} + print(f"| Validation results@{self.global_step}: {loss_output}") + return { + 'tb_log': {f'val/{k}': v for k, v in loss_output.items()}, + 'val_loss': loss_output['total_loss'] + } + + ###################### + # testing + ###################### + def test_start(self): + pass + + def test_step(self, sample, batch_idx): + return self.validation_step(sample, batch_idx) + + def test_end(self, outputs): + return self.validation_end(outputs) + + ###################### + # start training/testing + ###################### + @classmethod + def start(cls): + os.environ['MASTER_PORT'] = str(random.randint(15000, 30000)) + random.seed(hparams['seed']) + np.random.seed(hparams['seed']) + work_dir = hparams['work_dir'] + trainer = Trainer( + work_dir=work_dir, + val_check_interval=hparams['val_check_interval'], + tb_log_interval=hparams['tb_log_interval'], + max_updates=hparams['max_updates'], + num_sanity_val_steps=hparams['num_sanity_val_steps'] if not hparams['validate'] else 10000, + accumulate_grad_batches=hparams['accumulate_grad_batches'], + print_nan_grads=hparams['print_nan_grads'], + resume_from_checkpoint=hparams.get('resume_from_checkpoint', 0), + amp=hparams['amp'], + monitor_key=hparams['valid_monitor_key'], + monitor_mode=hparams['valid_monitor_mode'], + num_ckpt_keep=hparams['num_ckpt_keep'], + save_best=hparams['save_best'], + seed=hparams['seed'], + debug=hparams['debug'] + ) + if not hparams['infer']: # train + trainer.fit(cls) + else: + trainer.test(cls) + + def on_keyboard_interrupt(self): + pass diff --git a/Geneface_main/GeneFace/utils/commons/ckpt_utils.py b/Geneface_main/GeneFace/utils/commons/ckpt_utils.py new file mode 100644 index 00000000..3460c59a --- /dev/null +++ b/Geneface_main/GeneFace/utils/commons/ckpt_utils.py @@ -0,0 +1,66 @@ +import glob +import os +import re +import torch + + +def get_last_checkpoint(work_dir, steps=None): + checkpoint = None + last_ckpt_path = None + ckpt_paths = get_all_ckpts(work_dir, steps) + if len(ckpt_paths) > 0: + last_ckpt_path = ckpt_paths[0] + checkpoint = torch.load(last_ckpt_path, map_location='cpu') + return checkpoint, last_ckpt_path + + +def get_all_ckpts(work_dir, steps=None): + if steps is None: + ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt' + else: + ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt' + return sorted(glob.glob(ckpt_path_pattern), + key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0])) + + +def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True, steps=None): + if os.path.isfile(ckpt_base_dir): + base_dir = os.path.dirname(ckpt_base_dir) + ckpt_path = ckpt_base_dir + checkpoint = torch.load(ckpt_base_dir, map_location='cpu') + else: + base_dir = ckpt_base_dir + checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps) + if checkpoint is not None: + state_dict = checkpoint["state_dict"] + if len([k for k in state_dict.keys() if '.' in k]) > 0: + state_dict = {k[len(model_name) + 1:]: v for k, v in state_dict.items() + if k.startswith(f'{model_name}.')} + else: + if '.' not in model_name: + state_dict = state_dict[model_name] + else: + base_model_name = model_name.split('.')[0] + rest_model_name = model_name[len(base_model_name) + 1:] + state_dict = { + k[len(rest_model_name) + 1:]: v for k, v in state_dict[base_model_name].items() + if k.startswith(f'{rest_model_name}.')} + if not strict: + cur_model_state_dict = cur_model.state_dict() + unmatched_keys = [] + for key, param in state_dict.items(): + if key in cur_model_state_dict: + new_param = cur_model_state_dict[key] + if new_param.shape != param.shape: + unmatched_keys.append(key) + print("| Unmatched keys: ", key, new_param.shape, param.shape) + for key in unmatched_keys: + del state_dict[key] + cur_model.load_state_dict(state_dict, strict=strict) + print(f"| load '{model_name}' from '{ckpt_path}'.") + else: + e_msg = f"| ckpt not found in {base_dir}." + if force: + assert False, e_msg + else: + print(e_msg) diff --git a/Geneface_main/GeneFace/utils/commons/crop_head.py b/Geneface_main/GeneFace/utils/commons/crop_head.py new file mode 100644 index 00000000..e61cfc51 --- /dev/null +++ b/Geneface_main/GeneFace/utils/commons/crop_head.py @@ -0,0 +1,106 @@ +import face_alignment +import os +import cv2 +import skimage.transform as trans +import argparse +import torch +import numpy as np +import tqdm + +device = 'cuda' if torch.cuda.is_available() else 'cpu' + + +def get_affine(src): + dst = np.array([[87, 59], + [137, 59], + [112, 120]], dtype=np.float32) + tform = trans.SimilarityTransform() + tform.estimate(src, dst) + M = tform.params[0:2, :] + return M + + +def affine_align_img(img, M, crop_size=224): + warped = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0) + return warped + + +def affine_align_3landmarks(landmarks, M): + new_landmarks = np.concatenate([landmarks, np.ones((3, 1))], 1) + affined_landmarks = np.matmul(new_landmarks, M.transpose()) + return affined_landmarks + + +def get_eyes_mouths(landmark): + three_points = np.zeros((3, 2)) + three_points[0] = landmark[36:42].mean(0) + three_points[1] = landmark[42:48].mean(0) + three_points[2] = landmark[60:68].mean(0) + return three_points + + +def get_mouth_bias(three_points): + bias = np.array([112, 120]) - three_points[2] + return bias + + +def align_folder(folder_path, folder_save_path): + + fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device=device) + preds = fa.get_landmarks_from_directory(folder_path) + + sumpoints = 0 + three_points_list = [] + + for img in tqdm.tqdm(preds.keys(), desc='preprocessing..'): + pred_points = np.array(preds[img]) + if pred_points is None or len(pred_points.shape) != 3: + print('preprocessing failed') + return False + else: + num_faces, size, _ = pred_points.shape + if num_faces == 1 and size == 68: + + three_points = get_eyes_mouths(pred_points[0]) + sumpoints += three_points + three_points_list.append(three_points) + else: + + print('preprocessing failed') + return False + avg_points = sumpoints / len(preds) + M = get_affine(avg_points) + p_bias = None + for i, img_pth in tqdm.tqdm(enumerate(preds.keys()), desc='affine and save'): + three_points = three_points_list[i] + affined_3landmarks = affine_align_3landmarks(three_points, M) + bias = get_mouth_bias(affined_3landmarks) + if p_bias is None: + bias = bias + else: + bias = p_bias * 0.2 + bias * 0.8 + p_bias = bias + M_i = M.copy() + M_i[:, 2] = M[:, 2] + bias + img = cv2.imread(img_pth) + wrapped = affine_align_img(img, M_i) + img_save_path = os.path.join(folder_save_path, img_pth.split('/')[-1]) + cv2.imwrite(img_save_path, wrapped) + print('cropped files saved at {}'.format(folder_save_path)) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--folder_path', help='the folder which needs processing') + args = parser.parse_args() + + if os.path.isdir(args.folder_path): + home_path = '/'.join(args.folder_path.split('/')[:-1]) + save_img_path = os.path.join(home_path, args.folder_path.split('/')[-1] + '_cropped') + os.makedirs(save_img_path, exist_ok=True) + + align_folder(args.folder_path, save_img_path) + + +if __name__ == '__main__': + main() diff --git a/Geneface_main/GeneFace/utils/commons/dataset_utils.py b/Geneface_main/GeneFace/utils/commons/dataset_utils.py new file mode 100644 index 00000000..44c2ca0c --- /dev/null +++ b/Geneface_main/GeneFace/utils/commons/dataset_utils.py @@ -0,0 +1,247 @@ +import os +import sys +import traceback +import types +from functools import wraps +from itertools import chain +import numpy as np +import torch.utils.data +from torch.utils.data import ConcatDataset +from utils.commons.hparams import hparams + + +def collate_1d_or_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1): + if len(values[0].shape) == 1: + return collate_1d(values, pad_idx, left_pad, shift_right, max_len, shift_id) + else: + return collate_2d(values, pad_idx, left_pad, shift_right, max_len) + + +def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1): + """Convert a list of 1d tensors into a padded 2d tensor.""" + size = max(v.size(0) for v in values) if max_len is None else max_len + res = values[0].new(len(values), size).fill_(pad_idx) + + def copy_tensor(src, dst): + assert dst.numel() == src.numel() + if shift_right: + dst[1:] = src[:-1] + dst[0] = shift_id + else: + dst.copy_(src) + + for i, v in enumerate(values): + copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) + return res + + +def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None): + """Convert a list of 2d tensors into a padded 3d tensor.""" + size = max(v.size(0) for v in values) if max_len is None else max_len + res = values[0].new(len(values), size, values[0].shape[1]).fill_(pad_idx) + + def copy_tensor(src, dst): + assert dst.numel() == src.numel() + if shift_right: + dst[1:] = src[:-1] + else: + dst.copy_(src) + + for i, v in enumerate(values): + copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) + return res + + +def _is_batch_full(batch, num_tokens, max_tokens, max_sentences): + if len(batch) == 0: + return 0 + if len(batch) == max_sentences: + return 1 + if num_tokens > max_tokens: + return 1 + return 0 + + +def batch_by_size( + indices, num_tokens_fn, max_tokens=None, max_sentences=None, + required_batch_size_multiple=1, distributed=False +): + """ + Yield mini-batches of indices bucketed by size. Batches may contain + sequences of different lengths. + + Args: + indices (List[int]): ordered list of dataset indices + num_tokens_fn (callable): function that returns the number of tokens at + a given index + max_tokens (int, optional): max number of tokens in each batch + (default: None). + max_sentences (int, optional): max number of sentences in each + batch (default: None). + required_batch_size_multiple (int, optional): require batch size to + be a multiple of N (default: 1). + """ + max_tokens = max_tokens if max_tokens is not None else sys.maxsize + max_sentences = max_sentences if max_sentences is not None else sys.maxsize + bsz_mult = required_batch_size_multiple + + if isinstance(indices, types.GeneratorType): + indices = np.fromiter(indices, dtype=np.int64, count=-1) + + sample_len = 0 + sample_lens = [] + batch = [] + batches = [] + for i in range(len(indices)): + idx = indices[i] + num_tokens = num_tokens_fn(idx) + sample_lens.append(num_tokens) + sample_len = max(sample_len, num_tokens) + + assert sample_len <= max_tokens, ( + "sentence at index {} of size {} exceeds max_tokens " + "limit of {}!".format(idx, sample_len, max_tokens) + ) + num_tokens = (len(batch) + 1) * sample_len + + if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): + mod_len = max( + bsz_mult * (len(batch) // bsz_mult), + len(batch) % bsz_mult, + ) + batches.append(batch[:mod_len]) + batch = batch[mod_len:] + sample_lens = sample_lens[mod_len:] + sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 + batch.append(idx) + if len(batch) > 0: + batches.append(batch) + return batches + + +def unpack_dict_to_list(samples): + samples_ = [] + bsz = samples.get('outputs').size(0) + for i in range(bsz): + res = {} + for k, v in samples.items(): + try: + res[k] = v[i] + except: + pass + samples_.append(res) + return samples_ + + +def remove_padding(x, padding_idx=0): + if x is None: + return None + assert len(x.shape) in [1, 2] + if len(x.shape) == 2: # [T, H] + return x[np.abs(x).sum(-1) != padding_idx] + elif len(x.shape) == 1: # [T] + return x[x != padding_idx] + + +def data_loader(fn): + """ + Decorator to make any fx with this use the lazy property + :param fn: + :return: + """ + + wraps(fn) + attr_name = '_lazy_' + fn.__name__ + + def _get_data_loader(self): + try: + value = getattr(self, attr_name) + except AttributeError: + try: + value = fn(self) # Lazy evaluation, done only once. + except AttributeError as e: + # Guard against AttributeError suppression. (Issue #142) + traceback.print_exc() + error = f'{fn.__name__}: An AttributeError was encountered: ' + str(e) + raise RuntimeError(error) from e + setattr(self, attr_name, value) # Memoize evaluation. + return value + + return _get_data_loader + + +class BaseDataset(torch.utils.data.Dataset): + def __init__(self, shuffle): + super().__init__() + self.hparams = hparams + self.shuffle = shuffle + self.sort_by_len = hparams['sort_by_len'] + self.sizes = None + + @property + def _sizes(self): + return self.sizes + + def __getitem__(self, index): + raise NotImplementedError + + def collater(self, samples): + raise NotImplementedError + + def __len__(self): + return len(self._sizes) + + def num_tokens(self, index): + return self.size(index) + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return min(self._sizes[index], hparams['max_frames']) + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + indices = np.random.permutation(len(self)) + if self.sort_by_len: + indices = indices[np.argsort(np.array(self._sizes)[indices], kind='mergesort')] + else: + indices = np.arange(len(self)) + return indices + + @property + def num_workers(self): + return int(os.getenv('NUM_WORKERS', hparams['ds_workers'])) + + +class BaseConcatDataset(ConcatDataset): + def collater(self, samples): + return self.datasets[0].collater(samples) + + @property + def _sizes(self): + if not hasattr(self, 'sizes'): + self.sizes = list(chain.from_iterable([d._sizes for d in self.datasets])) + return self.sizes + + def size(self, index): + return min(self._sizes[index], hparams['max_frames']) + + def num_tokens(self, index): + return self.size(index) + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.datasets[0].shuffle: + indices = np.random.permutation(len(self)) + if self.datasets[0].sort_by_len: + indices = indices[np.argsort(np.array(self._sizes)[indices], kind='mergesort')] + else: + indices = np.arange(len(self)) + return indices + + @property + def num_workers(self): + return self.datasets[0].num_workers diff --git a/Geneface_main/GeneFace/utils/commons/ddp_utils.py b/Geneface_main/GeneFace/utils/commons/ddp_utils.py new file mode 100644 index 00000000..4b529198 --- /dev/null +++ b/Geneface_main/GeneFace/utils/commons/ddp_utils.py @@ -0,0 +1,137 @@ +from torch.nn.parallel import DistributedDataParallel +from torch.nn.parallel.distributed import _find_tensors +import torch.optim +import torch.utils.data +import torch +from packaging import version + +class DDP(DistributedDataParallel): + """ + Override the forward call in lightning so it goes to training and validation step respectively + """ + + def forward(self, *inputs, **kwargs): # pragma: no cover + if version.parse(torch.__version__[:6]) < version.parse("1.11"): + self._sync_params() + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + assert len(self.device_ids) == 1 + if self.module.training: + output = self.module.training_step(*inputs[0], **kwargs[0]) + elif self.module.testing: + output = self.module.test_step(*inputs[0], **kwargs[0]) + else: + output = self.module.validation_step(*inputs[0], **kwargs[0]) + if torch.is_grad_enabled(): + # We'll return the output object verbatim since it is a freeform + # object. We need to find any tensors in this object, though, + # because we need to figure out which parameters were used during + # this forward pass, to ensure we short circuit reduction for any + # unused parameters. Only if `find_unused_parameters` is set. + if self.find_unused_parameters: + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + from torch.nn.parallel.distributed import \ + logging, Join, _DDPSink, _tree_flatten_with_rref, _tree_unflatten_with_rref + with torch.autograd.profiler.record_function("DistributedDataParallel.forward"): + if torch.is_grad_enabled() and self.require_backward_grad_sync: + self.logger.set_runtime_stats_and_log() + self.num_iterations += 1 + self.reducer.prepare_for_forward() + + # Notify the join context that this process has not joined, if + # needed + work = Join.notify_join_context(self) + if work: + self.reducer._set_forward_pass_work_handle( + work, self._divide_by_initial_world_size + ) + + # Calling _rebuild_buckets before forward compuation, + # It may allocate new buckets before deallocating old buckets + # inside _rebuild_buckets. To save peak memory usage, + # call _rebuild_buckets before the peak memory usage increases + # during forward computation. + # This should be called only once during whole training period. + if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): + logging.info("Reducer buckets have been rebuilt in this iteration.") + self._has_rebuilt_buckets = True + + # sync params according to location (before/after forward) user + # specified as part of hook, if hook was specified. + buffer_hook_registered = hasattr(self, 'buffer_hook') + if self._check_sync_bufs_pre_fwd(): + self._sync_buffers() + + if self._join_config.enable: + # Notify joined ranks whether they should sync in backwards pass or not. + self._check_global_requires_backward_grad_sync(is_joined_rank=False) + + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if self.module.training: + output = self.module.training_step(*inputs[0], **kwargs[0]) + elif self.module.testing: + output = self.module.test_step(*inputs[0], **kwargs[0]) + else: + output = self.module.validation_step(*inputs[0], **kwargs[0]) + + # sync params according to location (before/after forward) user + # specified as part of hook, if hook was specified. + if self._check_sync_bufs_post_fwd(): + self._sync_buffers() + + if torch.is_grad_enabled() and self.require_backward_grad_sync: + self.require_forward_param_sync = True + # We'll return the output object verbatim since it is a freeform + # object. We need to find any tensors in this object, though, + # because we need to figure out which parameters were used during + # this forward pass, to ensure we short circuit reduction for any + # unused parameters. Only if `find_unused_parameters` is set. + if self.find_unused_parameters and not self.static_graph: + # Do not need to populate this for static graph. + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + self.require_forward_param_sync = False + + # TODO: DDPSink is currently enabled for unused parameter detection and + # static graph training for first iteration. + if (self.find_unused_parameters and not self.static_graph) or ( + self.static_graph and self.num_iterations == 1 + ): + state_dict = { + 'static_graph': self.static_graph, + 'num_iterations': self.num_iterations, + } + + output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref( + output + ) + output_placeholders = [None for _ in range(len(output_tensor_list))] + # Do not touch tensors that have no grad_fn, which can cause issues + # such as https://github.com/pytorch/pytorch/issues/60733 + for i, output in enumerate(output_tensor_list): + if torch.is_tensor(output) and output.grad_fn is None: + output_placeholders[i] = output + + # When find_unused_parameters=True, makes tensors which require grad + # run through the DDPSink backward pass. When not all outputs are + # used in loss, this makes those corresponding tensors receive + # undefined gradient which the reducer then handles to ensure + # param.grad field is not touched and we don't error out. + passthrough_tensor_list = _DDPSink.apply( + self.reducer, + state_dict, + *output_tensor_list, + ) + for i in range(len(output_placeholders)): + if output_placeholders[i] is None: + output_placeholders[i] = passthrough_tensor_list[i] + + # Reconstruct output data structure. + output = _tree_unflatten_with_rref( + output_placeholders, treespec, output_is_rref + ) + return output diff --git a/Geneface_main/GeneFace/utils/commons/euler2rot.py b/Geneface_main/GeneFace/utils/commons/euler2rot.py new file mode 100644 index 00000000..9a9202d6 --- /dev/null +++ b/Geneface_main/GeneFace/utils/commons/euler2rot.py @@ -0,0 +1,37 @@ +import torch +from scipy.spatial.transform import Rotation as R +from utils.commons.tensor_utils import convert_to_tensor + + +def rot2euler(rot, use_radian=True): + r = R.from_matrix(rot) + return r.as_euler('xyz', degrees=not use_radian) + +def euler2rot(euler, use_radian=True): + r = R.from_euler('xyz',euler, degrees=not use_radian) + return r.as_matrix() + +def c2w_to_euler_trans(c2w): + if c2w.ndim == 3: + e = rot2euler(c2w[:, :3, :3]) # [B, 3] + t = c2w[:, :3, 3].reshape([-1, 3]) + else: + e = rot2euler(c2w[:3, :3]) # [B, 3] + t = c2w[:3, 3].reshape([3]) + return e, t # [3+3] + +def euler_trans_2_c2w(euler, trans): + if euler.ndim == 2: + rot = euler2rot(euler) # [b, 3, 3] + bs = trans.shape[0] + trans = trans.reshape([bs, 3, 1]) + rot = convert_to_tensor(rot).float() + trans = convert_to_tensor(trans).float() + c2w = torch.cat([rot, trans], dim=-1) # [b, 3, 4] + else: + rot = euler2rot(euler) # [3, 3] + trans = trans.reshape([3, 1]) + rot = convert_to_tensor(rot).float() + trans = convert_to_tensor(trans).float() + c2w = torch.cat([rot, trans], dim=-1) # [3, 4] + return c2w \ No newline at end of file diff --git a/Geneface_main/GeneFace/utils/commons/face_alignment_utils.py b/Geneface_main/GeneFace/utils/commons/face_alignment_utils.py new file mode 100644 index 00000000..04e0b6b8 --- /dev/null +++ b/Geneface_main/GeneFace/utils/commons/face_alignment_utils.py @@ -0,0 +1,22 @@ +import numpy as np + +yaw_idx_in_mediapipe_mesh = [356, 454, 361, 288, 397, 379, 378, 377, 152, 148, 149, 150, 172,58, 132, 234, 127] +brow_idx_in_mediapipe_mesh = [70, 63, 105, 66, 107, 336, 296, 334, 293, 300] +nose_idx_in_mediapipe_mesh = [6, 5, 1, 2, 129, 240, 2, 460, 358] +eye_idx_in_mediapipe_mesh = [33, 160, 158, 133, 153, 144, 362, 385, 387, 263, 373, 380] +mouth_idx_in_mediapipe_mesh = [61, 40, 37, 0, 267, 270, 291, 321, 314, 17, 84, 91, 78, 81, 13, 311, 308, 402, 14, 178] +lm68_idx_in_mediapipe_mesh = yaw_idx_in_mediapipe_mesh + brow_idx_in_mediapipe_mesh + nose_idx_in_mediapipe_mesh + eye_idx_in_mediapipe_mesh + mouth_idx_in_mediapipe_mesh + +def mediapipe_lm478_to_face_alignment_lm68(lm478, H, W, return_2d=True): + """ + lm478: [B, 478, 3] or [478,3] + """ + lm478[..., 0] *= W + lm478[..., 1] *= H + n_dim = 2 if return_2d else 3 + if lm478.ndim == 2: + return lm478[lm68_idx_in_mediapipe_mesh, :n_dim].astype(np.int16) + elif lm478.ndim == 3: + return lm478[:, lm68_idx_in_mediapipe_mesh, :n_dim].astype(np.int16) + else: + raise ValueError("input lm478 ndim should in 2 or 3!") \ No newline at end of file diff --git a/Geneface_main/GeneFace/utils/commons/hparams.py b/Geneface_main/GeneFace/utils/commons/hparams.py new file mode 100644 index 00000000..49e066f7 --- /dev/null +++ b/Geneface_main/GeneFace/utils/commons/hparams.py @@ -0,0 +1,132 @@ +import argparse +import os +import yaml + +from utils.commons.os_utils import remove_file + +global_print_hparams = True +hparams = {} + + +class Args: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + self.__setattr__(k, v) + + +def override_config(old_config: dict, new_config: dict): + for k, v in new_config.items(): + if isinstance(v, dict) and k in old_config: + override_config(old_config[k], new_config[k]) + else: + old_config[k] = v + + +def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True): + if config == '' and exp_name == '': + parser = argparse.ArgumentParser(description='') + parser.add_argument('--config', type=str, default='', + help='location of the data corpus') + parser.add_argument('--exp_name', type=str, default='', help='exp_name') + parser.add_argument('-hp', '--hparams', type=str, default='', + help='location of the data corpus') + parser.add_argument('--infer', action='store_true', help='infer') + parser.add_argument('--validate', action='store_true', help='validate') + parser.add_argument('--reset', action='store_true', help='reset hparams') + parser.add_argument('--remove', action='store_true', help='remove old ckpt') + parser.add_argument('--debug', action='store_true', help='debug') + args, unknown = parser.parse_known_args() + print("| Unknow hparams: ", unknown) + else: + args = Args(config=config, exp_name=exp_name, hparams=hparams_str, + infer=False, validate=False, reset=False, debug=False, remove=False) + global hparams + assert args.config != '' or args.exp_name != '' + if args.config != '': + assert os.path.exists(args.config) + + config_chains = [] + loaded_config = set() + + def load_config(config_fn): + # deep first inheritance and avoid the second visit of one node + if not os.path.exists(config_fn): + return {} + with open(config_fn) as f: + hparams_ = yaml.safe_load(f) + loaded_config.add(config_fn) + if 'base_config' in hparams_: + ret_hparams = {} + if not isinstance(hparams_['base_config'], list): + hparams_['base_config'] = [hparams_['base_config']] + for c in hparams_['base_config']: + if c.startswith('.'): + c = f'{os.path.dirname(config_fn)}/{c}' + c = os.path.normpath(c) + if c not in loaded_config: + override_config(ret_hparams, load_config(c)) + override_config(ret_hparams, hparams_) + else: + ret_hparams = hparams_ + config_chains.append(config_fn) + return ret_hparams + + saved_hparams = {} + args_work_dir = '' + if args.exp_name != '': + args_work_dir = f'checkpoints/{args.exp_name}' + ckpt_config_path = f'{args_work_dir}/config.yaml' + if os.path.exists(ckpt_config_path): + with open(ckpt_config_path) as f: + saved_hparams_ = yaml.safe_load(f) + if saved_hparams_ is not None: + saved_hparams.update(saved_hparams_) + hparams_ = {} + if args.config != '': + hparams_.update(load_config(args.config)) + if not args.reset: + hparams_.update(saved_hparams) + if args.exp_name != '': + hparams_['work_dir'] = args_work_dir + + # Support config overriding in command line. Support list type config overriding. + # Examples: --hparams="a=1,b.c=2,d=[1 1 1]" + if args.hparams != "": + for new_hparam in args.hparams.split(","): + k, v = new_hparam.split("=") + v = v.strip("\'\" ") + config_node = hparams_ + for k_ in k.split(".")[:-1]: + config_node = config_node[k_] + k = k.split(".")[-1] + if v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]: + if type(config_node[k]) == list: + v = v.replace(" ", ",") + config_node[k] = eval(v) + else: + config_node[k] = type(config_node[k])(v) + if args_work_dir != '' and args.remove: + answer = input("REMOVE old checkpoint? Y/N [Default: N]: ") + if answer.lower() == "y": + remove_file(args_work_dir) + if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer: + os.makedirs(hparams_['work_dir'], exist_ok=True) + with open(ckpt_config_path, 'w') as f: + yaml.safe_dump(hparams_, f) + + hparams_['infer'] = args.infer + hparams_['debug'] = args.debug + hparams_['validate'] = args.validate + hparams_['exp_name'] = args.exp_name + global global_print_hparams + if global_hparams: + hparams.clear() + hparams.update(hparams_) + if print_hparams and global_print_hparams and global_hparams: + print('| Hparams chains: ', config_chains) + print('| Hparams: ') + for i, (k, v) in enumerate(sorted(hparams_.items())): + print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "") + print("") + global_print_hparams = False + return hparams_ diff --git a/Geneface_main/GeneFace/utils/commons/image_utils.py b/Geneface_main/GeneFace/utils/commons/image_utils.py new file mode 100644 index 00000000..6f836246 --- /dev/null +++ b/Geneface_main/GeneFace/utils/commons/image_utils.py @@ -0,0 +1,39 @@ +import numpy as np +import torch +import cv2 +import os +import imageio + + +def to8b(x): + return (255*np.clip(x, 0, 1)).astype(np.uint8) + +def mse2psnr(x): + return -10. * torch.log(x) / torch.log(torch.Tensor([10.])) + +def img2mse(x, y): + return torch.mean((x - y) ** 2) + +def video2images(video_name, out_dir): + cap = cv2.VideoCapture(video_name) + frame_num = 0 + while(True): + _, frame = cap.read() + if frame is None: + break + out_frame_name = os.path.join(out_dir, str(frame_num) + '.jpg') + cv2.imwrite(out_frame_name, frame) + frame_num += + 1 + cap.release() + +def load_image_as_uint8_tensor(fname): + """ + img: (H, W, 3) floatTensor + """ + img = torch.as_tensor(imageio.imread(fname)) + return img + +if __name__ =='__main__': + video2images("test_data/May_val/AD-NeRF.mp4", "test_data/May_val/AD-NeRF") + video2images("test_data/May_val/GeneFace.mp4", "test_data/May_val/GeneFace") + video2images("test_data/May_val/GT.mp4", "test_data/May_val/GT") \ No newline at end of file diff --git a/Geneface_main/GeneFace/utils/commons/indexed_datasets.py b/Geneface_main/GeneFace/utils/commons/indexed_datasets.py new file mode 100644 index 00000000..4b6e1d21 --- /dev/null +++ b/Geneface_main/GeneFace/utils/commons/indexed_datasets.py @@ -0,0 +1,200 @@ +import pickle +from bisect import bisect +from copy import deepcopy +import numpy as np +import gzip + + +def int2bytes(i: int, *, signed: bool = False) -> bytes: + length = ((i + ((i * signed) < 0)).bit_length() + 7 + signed) // 8 + return i.to_bytes(length, byteorder='little', signed=signed) + + +def bytes2int(b: bytes, *, signed: bool = False) -> int: + return int.from_bytes(b, byteorder='little', signed=signed) + + +def load_index_data(data_file): + index_data_size = bytes2int(data_file.read(32)) + index_data = data_file.read(index_data_size) + index_data = pickle.loads(index_data) + data_offsets = deepcopy(index_data['offsets']) + id2pos = deepcopy(index_data.get('id2pos', {})) + meta = deepcopy(index_data.get('meta', {})) + return data_offsets, id2pos, meta + + +class IndexedDataset: + def __init__(self, path, unpickle=True): + self.path = path + self.root_data_file = open(f"{path}.data", 'rb', buffering=-1) + try: + self.byte_offsets, self.id2pos, self.meta = load_index_data(self.root_data_file) + self.data_files = [self.root_data_file] + except: + self.__init__old(path) + self.meta = {} + self.gzip = self.meta.get('gzip', False) + if 'chunk_begin' not in self.meta: + self.meta['chunk_begin'] = [0] + for i in range(len(self.meta['chunk_begin'][1:])): + self.data_files.append(open(f"{self.path}.{i + 1}.data", 'rb')) + self.unpickle = unpickle + + def __init__old(self, path): + self.path = path + index_data = np.load(f"{path}.idx", allow_pickle=True).item() + self.byte_offsets = index_data['offsets'] + self.id2pos = index_data.get('id2pos', {}) + self.data_files = [open(f"{path}.data", 'rb', buffering=-1)] + + def __getitem__(self, i): + if self.id2pos is not None and len(self.id2pos) > 0: + i = self.id2pos[i] + self.check_index(i) + + chunk_id = bisect(self.meta['chunk_begin'][1:], self.byte_offsets[i]) + data_file = open(f"{self.path}.data", 'rb', buffering=-1) + data_file.seek(self.byte_offsets[i] - self.meta['chunk_begin'][chunk_id]) + b = data_file.read(self.byte_offsets[i + 1] - self.byte_offsets[i]) + data_file.close() + + # chunk_id = bisect(self.meta['chunk_begin'][1:], self.byte_offsets[i]) + # data_file = self.data_files[chunk_id] + # data_file.seek(self.byte_offsets[i] - self.meta['chunk_begin'][chunk_id]) + # b = data_file.read(self.byte_offsets[i + 1] - self.byte_offsets[i]) + + unpickle = self.unpickle + if unpickle: + if self.gzip: + b = gzip.decompress(b) + item = pickle.loads(b) + else: + item = b + return item + + def __del__(self): + for data_file in self.data_files: + data_file.close() + + def check_index(self, i): + if i < 0 or i >= len(self.byte_offsets) - 1: + raise IndexError('index out of range') + + def __len__(self): + return len(self.byte_offsets) - 1 + + def __iter__(self): + self.iter_i = 0 + return self + + def __next__(self): + if self.iter_i == len(self): + raise StopIteration + else: + item = self[self.iter_i] + self.iter_i += 1 + return item + + +class IndexedDatasetBuilder: + def __init__(self, path, append=False, max_size=1024 * 1024 * 1024 * 64, + default_idx_size=1024 * 1024 * 16, gzip=False): + self.path = self.root_path = path + self.default_idx_size = default_idx_size + if append: + self.data_file = open(f"{path}.data", 'r+b') + self.data_file.seek(0) + self.byte_offsets, self.id2pos, self.meta = load_index_data(self.data_file) + self.data_file.seek(0) + self.data_file.write(bytes(default_idx_size)) + self.data_file.seek(self.byte_offsets[-1]) + self.gzip = self.meta['gzip'] + else: + self.data_file = open(f"{path}.data", 'wb') + self.data_file.seek(default_idx_size) + self.byte_offsets = [default_idx_size] + self.id2pos = {} + self.meta = {} + self.meta['chunk_begin'] = [0] + self.gzip = self.meta['gzip'] = gzip + self.root_data_file = self.data_file + self.max_size = max_size + self.data_chunk_id = 0 + + def add_item(self, item, id=None, use_pickle=True): + if self.byte_offsets[-1] > self.meta['chunk_begin'][-1] + self.max_size: + if self.data_file != self.root_data_file: + self.data_file.close() + self.data_chunk_id += 1 + self.data_file = open(f"{self.path}.{self.data_chunk_id}.data", 'wb') + self.data_file.seek(0) + self.meta['chunk_begin'].append(self.byte_offsets[-1]) + if not use_pickle: + s = item + else: + s = pickle.dumps(item) + if self.gzip: + s = gzip.compress(s, 1) + bytes = self.data_file.write(s) + if id is not None: + self.id2pos[id] = len(self.byte_offsets) - 1 + self.byte_offsets.append(self.byte_offsets[-1] + bytes) + + def finalize(self): + self.root_data_file.seek(0) + s = pickle.dumps({'offsets': self.byte_offsets, 'id2pos': self.id2pos, 'meta': self.meta}) + assert len(s) < self.default_idx_size, (len(s), self.default_idx_size) + len_bytes = int2bytes(len(s)) + self.root_data_file.write(len_bytes) + self.root_data_file.seek(32) + self.root_data_file.write(s) + self.root_data_file.close() + try: + self.data_file.close() + except: + pass + + +if __name__ == "__main__": + import random + from tqdm import tqdm + + # builder = IndexedDatasetBuilder(ds_path, append=True) + # for i in tqdm(range(size)): + # builder.add_item(items[i], i + size) + # builder.finalize() + # ds = IndexedDataset(ds_path) + # for i in tqdm(range(1000)): + # idx = random.randint(size, 2 * size - 1) + # assert (ds[idx]['a'] == items[idx - size]['a']).all() + # idx = random.randint(0, size - 1) + # assert (ds[idx]['a'] == items[idx]['a']).all() + + ds_path = '/tmp/indexed_ds_example' + size = 100 + items = [{"a": np.random.normal(size=[10000, 10]), + "b": np.random.normal(size=[10000, 10])} for i in range(size)] + builder = IndexedDatasetBuilder(ds_path, max_size=1024 * 1024 * 40) + builder.meta['lengths'] = [1, 2, 3] + for i in tqdm(range(size)): + builder.add_item(pickle.dumps(items[i]), i, use_pickle=False) + builder.finalize() + ds = IndexedDataset(ds_path) + assert ds.meta['lengths'] == [1, 2, 3] + for i in tqdm(range(1000)): + idx = random.randint(0, size - 1) + assert (ds[idx]['a'] == items[idx]['a']).all() + + # builder = IndexedDataset2Builder(ds_path, append=True) + # builder.meta['lengths'] = [1, 2, 3, 5, 6, 7] + # for i in tqdm(range(size)): + # builder.add_item(items[i], i + size) + # builder.finalize() + # ds = IndexedDataset2(ds_path) + # assert ds.meta['lengths'] == [1, 2, 3, 5, 6, 7] + # for i in tqdm(range(1000)): + # idx = random.randint(size, 2 * size - 1) + # assert (ds[idx]['a'] == items[idx - size]['a']).all() + # idx = random.randint(0, size - 1) + # assert (ds[idx]['a'] == items[idx]['a']).all() diff --git a/Geneface_main/GeneFace/utils/commons/meters.py b/Geneface_main/GeneFace/utils/commons/meters.py new file mode 100644 index 00000000..e38790e9 --- /dev/null +++ b/Geneface_main/GeneFace/utils/commons/meters.py @@ -0,0 +1,42 @@ +import time +import torch + + +class AvgrageMeter(object): + + def __init__(self): + self.reset() + + def reset(self): + self.avg = 0 + self.sum = 0 + self.cnt = 0 + + def update(self, val, n=1): + self.sum += val * n + self.cnt += n + self.avg = self.sum / self.cnt + + +class Timer: + timer_map = {} + + def __init__(self, name, enable=False): + if name not in Timer.timer_map: + Timer.timer_map[name] = 0 + self.name = name + self.enable = enable + + def __enter__(self): + if self.enable: + if torch.cuda.is_available(): + torch.cuda.synchronize() + self.t = time.time() + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.enable: + if torch.cuda.is_available(): + torch.cuda.synchronize() + Timer.timer_map[self.name] += time.time() - self.t + if self.enable: + print(f'[Timer] {self.name}: {Timer.timer_map[self.name]}') diff --git a/Geneface_main/GeneFace/utils/commons/multiprocess_utils.py b/Geneface_main/GeneFace/utils/commons/multiprocess_utils.py new file mode 100644 index 00000000..e2773543 --- /dev/null +++ b/Geneface_main/GeneFace/utils/commons/multiprocess_utils.py @@ -0,0 +1,130 @@ +import os +import traceback +from functools import partial +from tqdm import tqdm + + +def chunked_worker(worker_id, args_queue=None, results_queue=None, init_ctx_func=None): + ctx = init_ctx_func(worker_id) if init_ctx_func is not None else None + while True: + args = args_queue.get() + if args == '': + return + job_idx, map_func, arg = args + try: + map_func_ = partial(map_func, ctx=ctx) if ctx is not None else map_func + if isinstance(arg, dict): + res = map_func_(**arg) + elif isinstance(arg, (list, tuple)): + res = map_func_(*arg) + else: + res = map_func_(arg) + results_queue.put((job_idx, res)) + except: + traceback.print_exc() + results_queue.put((job_idx, None)) + + +class MultiprocessManager: + def __init__(self, num_workers=None, init_ctx_func=None, multithread=False, queue_max=-1): + if multithread: + from multiprocessing.dummy import Queue, Process + else: + from multiprocessing import Queue, Process + if num_workers is None: + num_workers = int(os.getenv('N_PROC', os.cpu_count())) + self.num_workers = num_workers + self.results_queue = Queue(maxsize=-1) + self.jobs_pending = [] + self.args_queue = Queue(maxsize=queue_max) + self.workers = [] + self.total_jobs = 0 + self.multithread = multithread + for i in range(num_workers): + if multithread: + p = Process(target=chunked_worker, + args=(i, self.args_queue, self.results_queue, init_ctx_func)) + else: + p = Process(target=chunked_worker, + args=(i, self.args_queue, self.results_queue, init_ctx_func), + daemon=True) + self.workers.append(p) + p.start() + + def add_job(self, func, args): + if not self.args_queue.full(): + self.args_queue.put((self.total_jobs, func, args)) + else: + self.jobs_pending.append((self.total_jobs, func, args)) + self.total_jobs += 1 + + def get_results(self): + self.n_finished = 0 + while self.n_finished < self.total_jobs: + while len(self.jobs_pending) > 0 and not self.args_queue.full(): + self.args_queue.put(self.jobs_pending[0]) + self.jobs_pending = self.jobs_pending[1:] + job_id, res = self.results_queue.get() + yield job_id, res + self.n_finished += 1 + for w in range(self.num_workers): + self.args_queue.put("") + for w in self.workers: + w.join() + + def close(self): + if not self.multithread: + for w in self.workers: + w.terminate() + + def __len__(self): + return self.total_jobs + + +def multiprocess_run_tqdm(map_func, args, num_workers=None, ordered=True, init_ctx_func=None, + multithread=False, queue_max=-1, desc=None): + for i, res in tqdm( + multiprocess_run(map_func, args, num_workers, ordered, init_ctx_func, multithread, + queue_max=queue_max), + total=len(args), desc=desc): + yield i, res + + +def multiprocess_run(map_func, args, num_workers=None, ordered=True, init_ctx_func=None, multithread=False, + queue_max=-1): + """ + Multiprocessing running chunked jobs. + + Examples: + >>> for res in tqdm(multiprocess_run(job_func, args): + >>> print(res) + + :param map_func: + :param args: + :param num_workers: + :param ordered: + :param init_ctx_func: + :param q_max_size: + :param multithread: + :return: + """ + if num_workers is None: + num_workers = int(os.getenv('N_PROC', os.cpu_count())) + # num_workers = 1 + manager = MultiprocessManager(num_workers, init_ctx_func, multithread, queue_max=queue_max) + for arg in args: + manager.add_job(map_func, arg) + if ordered: + n_jobs = len(args) + results = ['' for _ in range(n_jobs)] + i_now = 0 + for job_i, res in manager.get_results(): + results[job_i] = res + while i_now < n_jobs and (not isinstance(results[i_now], str) or results[i_now] != ''): + yield i_now, results[i_now] + results[i_now] = None + i_now += 1 + else: + for job_i, res in manager.get_results(): + yield job_i, res + manager.close() diff --git a/Geneface_main/GeneFace/utils/commons/os_utils.py b/Geneface_main/GeneFace/utils/commons/os_utils.py new file mode 100644 index 00000000..4567d17c --- /dev/null +++ b/Geneface_main/GeneFace/utils/commons/os_utils.py @@ -0,0 +1,20 @@ +import os +import subprocess + + +def link_file(from_file, to_file): + subprocess.check_call( + f'ln -s "`realpath --relative-to="{os.path.dirname(to_file)}" "{from_file}"`" "{to_file}"', shell=True) + + +def move_file(from_file, to_file): + subprocess.check_call(f'mv "{from_file}" "{to_file}"', shell=True) + + +def copy_file(from_file, to_file): + subprocess.check_call(f'cp -r "{from_file}" "{to_file}"', shell=True) + + +def remove_file(*fns): + for f in fns: + subprocess.check_call(f'rm -rf "{f}"', shell=True) diff --git a/Geneface_main/GeneFace/utils/commons/pitch_utils.py b/Geneface_main/GeneFace/utils/commons/pitch_utils.py new file mode 100644 index 00000000..ec0a63ce --- /dev/null +++ b/Geneface_main/GeneFace/utils/commons/pitch_utils.py @@ -0,0 +1,37 @@ +import numpy as np +import torch + +f0_bin = 256 +f0_max = 1100.0 +f0_min = 50.0 +f0_mel_min = 1127 * np.log(1 + f0_min / 700) +f0_mel_max = 1127 * np.log(1 + f0_max / 700) + +def coarse_to_f0(coarse): + uv = coarse == 1 + f0_mel = (coarse - 1) * (f0_mel_max - f0_mel_min) / (f0_bin - 2) + f0_mel_min + f0 = ((f0_mel / 1127).exp() - 1) * 700 + f0[uv] = 0 + return f0 + +def f0_to_coarse(f0): + is_torch = isinstance(f0, torch.Tensor) + f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700) + f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1 + + f0_mel[f0_mel <= 1] = 1 + f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1 + f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int) + assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min(), f0.min(), f0.max()) + return f0_coarse + + +def norm_f0(f0, uv, hparams): + is_torch = isinstance(f0, torch.Tensor) + if hparams['pitch_norm'] == 'standard': + f0 = (f0 - hparams['f0_mean']) / hparams['f0_std'] + if hparams['pitch_norm'] == 'log': + f0 = torch.log2(f0 + 1e-8) if is_torch else np.log2(f0 + 1e-8) + if uv is not None and hparams['use_uv']: + f0[uv > 0] = 0 + return f0 \ No newline at end of file diff --git a/Geneface_main/GeneFace/utils/commons/tensor_utils.py b/Geneface_main/GeneFace/utils/commons/tensor_utils.py new file mode 100644 index 00000000..e2e1c8b1 --- /dev/null +++ b/Geneface_main/GeneFace/utils/commons/tensor_utils.py @@ -0,0 +1,111 @@ +import torch +import torch.distributed as dist +import numpy as np + + +def reduce_tensors(metrics): + new_metrics = {} + for k, v in metrics.items(): + if isinstance(v, torch.Tensor): + dist.all_reduce(v) + v = v / dist.get_world_size() + if type(v) is dict: + v = reduce_tensors(v) + new_metrics[k] = v + return new_metrics + + +def tensors_to_scalars(tensors): + if isinstance(tensors, torch.Tensor): + tensors = tensors.item() + return tensors + elif isinstance(tensors, dict): + new_tensors = {} + for k, v in tensors.items(): + v = tensors_to_scalars(v) + new_tensors[k] = v + return new_tensors + elif isinstance(tensors, list): + return [tensors_to_scalars(v) for v in tensors] + else: + return tensors + + +def convert_to_np(tensors): + if isinstance(tensors, np.ndarray): + return tensors + elif isinstance(tensors, dict): + new_np = {} + for k, v in tensors.items(): + if isinstance(v, torch.Tensor): + v = v.cpu().numpy() + if type(v) is dict: + v = convert_to_np(v) + new_np[k] = v + elif isinstance(tensors, list): + new_np = [] + for v in tensors: + if isinstance(v, torch.Tensor): + v = v.cpu().numpy() + if type(v) is dict: + v = convert_to_np(v) + new_np.append(v) + elif isinstance(tensors, torch.Tensor): + v = tensors + if isinstance(v, torch.Tensor): + v = v.cpu().numpy() + if type(v) is dict: + v = convert_to_np(v) + new_np = v + else: + raise Exception(f'tensors_to_np does not support type {type(tensors)}.') + return new_np + + +def convert_to_tensor(arrays): + if isinstance(arrays, np.ndarray): + v = torch.from_numpy(arrays).float() + ret = v + elif isinstance(arrays, torch.Tensor): + ret = arrays + elif type(arrays) is dict: + ret = {} + for k, v in arrays.items(): + if isinstance(v, np.ndarray): + v = torch.from_numpy(v).float() + if type(v) is dict: + v = convert_to_tensor(v) + ret[k] = v + return ret + +def move_to_cpu(tensors): + ret = {} + for k, v in tensors.items(): + if isinstance(v, torch.Tensor): + v = v.cpu() + if type(v) is dict: + v = move_to_cpu(v) + ret[k] = v + return ret + + +def move_to_cuda(batch, gpu_id=0): + # base case: object can be directly moved using `cuda` or `to` + if callable(getattr(batch, 'cuda', None)): + return batch.cuda(gpu_id, non_blocking=True) + elif callable(getattr(batch, 'to', None)): + return batch.to(torch.device('cuda', gpu_id), non_blocking=True) + elif isinstance(batch, list): + for i, x in enumerate(batch): + batch[i] = move_to_cuda(x, gpu_id) + return batch + elif isinstance(batch, tuple): + batch = list(batch) + for i, x in enumerate(batch): + batch[i] = move_to_cuda(x, gpu_id) + return tuple(batch) + elif isinstance(batch, dict): + for k, v in batch.items(): + batch[k] = move_to_cuda(v, gpu_id) + return batch + return batch diff --git a/Geneface_main/GeneFace/utils/commons/trainer.py b/Geneface_main/GeneFace/utils/commons/trainer.py new file mode 100644 index 00000000..744d9278 --- /dev/null +++ b/Geneface_main/GeneFace/utils/commons/trainer.py @@ -0,0 +1,562 @@ +import random +import subprocess +import traceback +from datetime import datetime + +from torch.cuda.amp import GradScaler, autocast +import numpy as np +import torch.optim +import torch.utils.data +import copy +import logging +import os +import re +import sys +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import tqdm + +from utils.commons.ckpt_utils import get_last_checkpoint, get_all_ckpts +from utils.commons.ddp_utils import DDP +from utils.commons.hparams import hparams +from utils.commons.tensor_utils import move_to_cuda +from utils.commons.os_utils import remove_file + + +class Tee(object): + def __init__(self, name, mode): + self.file = open(name, mode) + self.stdout = sys.stdout + sys.stdout = self + + def __del__(self): + sys.stdout = self.stdout + self.file.close() + + def write(self, data): + self.file.write(data) + self.stdout.write(data) + + def flush(self): + self.file.flush() + + +class Trainer: + def __init__( + self, + work_dir, + default_save_path=None, + accumulate_grad_batches=1, + max_updates=160000, + print_nan_grads=False, + val_check_interval=2000, + num_sanity_val_steps=5, + amp=False, + # tb logger + log_save_interval=100, + tb_log_interval=10, + # checkpoint + monitor_key='val_loss', + monitor_mode='min', + num_ckpt_keep=5, + save_best=True, + resume_from_checkpoint=0, + seed=1234, + debug=False, + ): + os.makedirs(work_dir, exist_ok=True) + self.work_dir = work_dir + self.accumulate_grad_batches = accumulate_grad_batches + self.max_updates = max_updates + self.num_sanity_val_steps = num_sanity_val_steps + self.print_nan_grads = print_nan_grads + self.default_save_path = default_save_path + self.resume_from_checkpoint = resume_from_checkpoint if resume_from_checkpoint > 0 else None + self.seed = seed + self.debug = debug + # model and optm + self.task = None + self.optimizers = [] + + # trainer state + self.testing = False + self.global_step = 0 + self.current_epoch = 0 + self.total_batches = 0 + + # configure checkpoint + self.monitor_key = monitor_key + self.num_ckpt_keep = num_ckpt_keep + self.save_best = save_best + self.monitor_op = np.less if monitor_mode == 'min' else np.greater + self.best_val_results = np.Inf if monitor_mode == 'min' else -np.Inf + self.mode = 'min' + + # allow int, string and gpu list + self.all_gpu_ids = [ + int(x) for x in os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",") if x != ''] + self.num_gpus = len(self.all_gpu_ids) + self.on_gpu = self.num_gpus > 0 + self.root_gpu = 0 + logging.info(f'GPU available: {torch.cuda.is_available()}, GPU used: {self.all_gpu_ids}') + self.use_ddp = self.num_gpus > 1 + self.proc_rank = 0 + # Tensorboard logging + self.log_save_interval = log_save_interval + self.val_check_interval = val_check_interval + self.tb_log_interval = tb_log_interval + self.amp = amp + self.amp_scalar = GradScaler() + + def test(self, task_cls): + self.testing = True + self.fit(task_cls) + + def fit(self, task_cls): + if len(self.all_gpu_ids) > 1: + mp.spawn(self.ddp_run, nprocs=self.num_gpus, args=(task_cls, copy.deepcopy(hparams))) + else: + self.task = task_cls() + self.task.trainer = self + self.run_single_process(self.task) + return 1 + + def ddp_run(self, gpu_idx, task_cls, hparams_): + hparams.update(hparams_) + self.proc_rank = gpu_idx + self.init_ddp_connection(self.proc_rank, self.num_gpus) + if dist.get_rank() != 0 and not self.debug: + sys.stdout = open(os.devnull, "w") + sys.stderr = open(os.devnull, "w") + task = task_cls() + task.trainer = self + torch.cuda.set_device(gpu_idx) + self.root_gpu = gpu_idx + self.task = task + self.run_single_process(task) + + def run_single_process(self, task): + """Sanity check a few things before starting actual training. + + :param task: + """ + # build model, optm and load checkpoint + if self.proc_rank == 0: + self.save_terminal_logs() + if not self.testing: + self.save_codes() + + model = task.build_model() + if model is not None: + task.model = model + checkpoint, _ = get_last_checkpoint(self.work_dir, self.resume_from_checkpoint) + if checkpoint is not None: + self.restore_weights(checkpoint) + elif self.on_gpu: + task.cuda(self.root_gpu) + if not self.testing: + self.optimizers = task.configure_optimizers() + self.fisrt_epoch = True + if checkpoint is not None: + self.restore_opt_state(checkpoint) + del checkpoint + # clear cache after restore + if self.on_gpu: + torch.cuda.empty_cache() + + if self.use_ddp: + self.task = self.configure_ddp(self.task) + dist.barrier() + + task_ref = self.get_task_ref() + task_ref.trainer = self + task_ref.testing = self.testing + # link up experiment object + if self.proc_rank == 0: + task_ref.build_tensorboard(save_dir=self.work_dir, name='tb_logs') + else: + os.makedirs('tmp', exist_ok=True) + task_ref.build_tensorboard(save_dir='tmp', name='tb_tmp') + self.logger = task_ref.logger + try: + if self.testing: + self.run_evaluation(test=True) + else: + self.train() + except KeyboardInterrupt as e: + traceback.print_exc() + task_ref.on_keyboard_interrupt() + + #################### + # valid and test + #################### + def run_evaluation(self, test=False): + eval_results = self.evaluate(self.task, test, tqdm_desc='Valid' if not test else 'test', + max_batches=hparams['eval_max_batches']) + if eval_results is not None and 'tb_log' in eval_results: + tb_log_output = eval_results['tb_log'] + self.log_metrics_to_tb(tb_log_output) + if self.proc_rank == 0 and not test: + self.save_checkpoint(epoch=self.current_epoch, logs=eval_results) + + def evaluate(self, task, test=False, tqdm_desc='Valid', max_batches=None): + if max_batches == -1: + max_batches = None + # enable eval mode + task.zero_grad() + task.eval() + torch.set_grad_enabled(False) + + task_ref = self.get_task_ref() + if test: + ret = task_ref.test_start() + if ret == 'EXIT': + return + else: + task_ref.validation_start() + outputs = [] + dataloader = task_ref.test_dataloader() if test else task_ref.val_dataloader() + pbar = tqdm.tqdm(dataloader, desc=tqdm_desc, total=max_batches, dynamic_ncols=True, unit='step', + disable=self.root_gpu > 0) + # give model a chance to do something with the outputs (and method defined) + for batch_idx, batch in enumerate(pbar): + if batch is None: # pragma: no cover + continue + # stop short when on fast_dev_run (sets max_batch=1) + if max_batches is not None and batch_idx >= max_batches: + break + + # make dataloader_idx arg in validation_step optional + if self.on_gpu: + batch = move_to_cuda(batch, self.root_gpu) + args = [batch, batch_idx] + if self.use_ddp: + output = task(*args) + else: + if test: + output = task_ref.test_step(*args) + else: + output = task_ref.validation_step(*args) + # track outputs for collation + outputs.append(output) + # give model a chance to do something with the outputs (and method defined) + if test: + eval_results = task_ref.test_end(outputs) + else: + eval_results = task_ref.validation_end(outputs) + # enable train mode again + task.train() + torch.set_grad_enabled(True) + return eval_results + + #################### + # train + #################### + def train(self): + task_ref = self.get_task_ref() + task_ref.on_train_start() + if self.num_sanity_val_steps > 0: + # run tiny validation (if validation defined) to make sure program won't crash during val + self.evaluate(self.task, False, 'Sanity Val', max_batches=self.num_sanity_val_steps) + # clear cache before training + if self.on_gpu: + torch.cuda.empty_cache() + dataloader = task_ref.train_dataloader() + epoch = self.current_epoch + # run all epochs + while True: + # set seed for distributed sampler (enables shuffling for each epoch) + if self.use_ddp and hasattr(dataloader.sampler, 'set_epoch'): + dataloader.sampler.set_epoch(epoch) + # update training progress in trainer and model + task_ref.current_epoch = epoch + self.current_epoch = epoch + # total batches includes multiple val checks + self.batch_loss_value = 0 # accumulated grads + # before epoch hook + task_ref.on_epoch_start() + + # run epoch + train_pbar = tqdm.tqdm(dataloader, initial=self.global_step, total=float('inf'), + dynamic_ncols=True, unit='step', disable=self.root_gpu > 0) + for batch_idx, batch in enumerate(train_pbar): + if self.global_step % self.val_check_interval == 0 and not self.fisrt_epoch: + self.run_evaluation() + pbar_metrics, tb_metrics = self.run_training_batch(batch_idx, batch) + train_pbar.set_postfix(**pbar_metrics) + self.fisrt_epoch = False + # when metrics should be logged + if (self.global_step + 1) % self.tb_log_interval == 0: + # logs user requested information to logger + self.log_metrics_to_tb(tb_metrics) + + self.global_step += 1 + task_ref.global_step = self.global_step + if self.global_step > self.max_updates: + print("| Training end..") + break + # epoch end hook + epoch_loss_dict = task_ref.on_epoch_end() + self.log_metrics_to_tb(epoch_loss_dict) + epoch += 1 + if self.global_step > self.max_updates: + break + task_ref.on_train_end() + + def run_training_batch(self, batch_idx, batch): + if batch is None: + return {} + all_progress_bar_metrics = [] + all_log_metrics = [] + task_ref = self.get_task_ref() + for opt_idx, optimizer in enumerate(self.optimizers): + if optimizer is None: + continue + # make sure only the gradients of the current optimizer's paramaters are calculated + # in the training step to prevent dangling gradients in multiple-optimizer setup. + if len(self.optimizers) > 1: + for param in task_ref.parameters(): + param.requires_grad = False + for group in optimizer.param_groups: + for param in group['params']: + param.requires_grad = True + + # forward pass + with autocast(enabled=self.amp): + if self.on_gpu: + batch = move_to_cuda(copy.copy(batch), self.root_gpu) + args = [batch, batch_idx, opt_idx] + if self.use_ddp: + output = self.task(*args) + else: + output = task_ref.training_step(*args) + loss = output['loss'] + if loss is None: + continue + progress_bar_metrics = output['progress_bar'] + log_metrics = output['tb_log'] + # accumulate loss + loss = loss / self.accumulate_grad_batches + + # backward pass + if loss.requires_grad: + if self.amp: + self.amp_scalar.scale(loss).backward() + else: + loss.backward() + + # track progress bar metrics + all_log_metrics.append(log_metrics) + all_progress_bar_metrics.append(progress_bar_metrics) + + if loss is None: + continue + + # nan grads + if self.print_nan_grads: + has_nan_grad = False + for name, param in task_ref.named_parameters(): + if (param.grad is not None) and torch.isnan(param.grad.float()).any(): + print("| NaN params: ", name, param, param.grad) + has_nan_grad = True + if has_nan_grad: + exit(0) + + # gradient update with accumulated gradients + if (self.global_step + 1) % self.accumulate_grad_batches == 0: + grad_norm_dict = task_ref.on_before_optimization(opt_idx) + if grad_norm_dict is not None: + all_log_metrics[-1].update(grad_norm_dict) + if self.amp: + self.amp_scalar.step(optimizer) + self.amp_scalar.update() + else: + optimizer.step() + optimizer.zero_grad() + task_ref.on_after_optimization(self.current_epoch, batch_idx, optimizer, opt_idx) + + # collapse all metrics into one dict + all_progress_bar_metrics = {k: v for d in all_progress_bar_metrics for k, v in d.items()} + all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()} + return all_progress_bar_metrics, all_log_metrics + + #################### + # load and save checkpoint + #################### + def restore_weights(self, checkpoint): + # load model state + task_ref = self.get_task_ref() + + for k, v in checkpoint['state_dict'].items(): + getattr(task_ref, k).load_state_dict(v) + + if self.on_gpu: + task_ref.cuda(self.root_gpu) + # load training state (affects trainer only) + self.best_val_results = checkpoint['checkpoint_callback_best'] + self.global_step = checkpoint['global_step'] + self.current_epoch = checkpoint['epoch'] + task_ref.global_step = self.global_step + + # wait for all models to restore weights + if self.use_ddp: + # wait for all processes to catch up + dist.barrier() + + def restore_opt_state(self, checkpoint): + if self.testing: + return + # restore the optimizers + optimizer_states = checkpoint['optimizer_states'] + for optimizer, opt_state in zip(self.optimizers, optimizer_states): + if optimizer is None: + return + try: + optimizer.load_state_dict(opt_state) + # move optimizer to GPU 1 weight at a time + if self.on_gpu: + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.cuda(self.root_gpu) + except ValueError: + print("| WARMING: optimizer parameters not match !!!") + try: + if dist.is_initialized() and dist.get_rank() > 0: + return + except Exception as e: + print(e) + return + did_restore = True + return did_restore + + def save_checkpoint(self, epoch, logs=None): + monitor_op = np.less + ckpt_path = f'{self.work_dir}/model_ckpt_steps_{self.global_step}.ckpt' + logging.info(f'Epoch {epoch:05d}@{self.global_step}: saving model to {ckpt_path}') + self._atomic_save(ckpt_path) + for old_ckpt in get_all_ckpts(self.work_dir)[self.num_ckpt_keep:]: + remove_file(old_ckpt) + logging.info(f'Delete ckpt: {os.path.basename(old_ckpt)}') + current = None + if logs is not None and self.monitor_key in logs: + current = logs[self.monitor_key] + if current is not None and self.save_best: + if monitor_op(current, self.best_val_results): + best_filepath = f'{self.work_dir}/model_ckpt_best.pt' + self.best_val_results = current + logging.info( + f'Epoch {epoch:05d}@{self.global_step}: {self.monitor_key} reached {current:0.5f}. ' + f'Saving model to {best_filepath}') + self._atomic_save(best_filepath) + + def _atomic_save(self, filepath): + checkpoint = self.dump_checkpoint() + tmp_path = str(filepath) + ".part" + torch.save(checkpoint, tmp_path, _use_new_zipfile_serialization=False) + os.replace(tmp_path, filepath) + + def dump_checkpoint(self): + checkpoint = {'epoch': self.current_epoch, 'global_step': self.global_step, + 'checkpoint_callback_best': self.best_val_results} + # save optimizers + optimizer_states = [] + for i, optimizer in enumerate(self.optimizers): + if optimizer is not None: + optimizer_states.append(optimizer.state_dict()) + + checkpoint['optimizer_states'] = optimizer_states + task_ref = self.get_task_ref() + checkpoint['state_dict'] = { + k: v.state_dict() for k, v in task_ref.named_children() if len(list(v.parameters())) > 0} + return checkpoint + + #################### + # DDP + #################### + def configure_ddp(self, task): + task = DDP(task, device_ids=[self.root_gpu], find_unused_parameters=True) + random.seed(self.seed) + np.random.seed(self.seed) + return task + + def init_ddp_connection(self, proc_rank, world_size): + root_node = '127.0.0.1' + root_node = self.resolve_root_node_address(root_node) + os.environ['MASTER_ADDR'] = root_node + dist.init_process_group('nccl', rank=proc_rank, world_size=world_size) + + def resolve_root_node_address(self, root_node): + if '[' in root_node: + name = root_node.split('[')[0] + number = root_node.split(',')[0] + if '-' in number: + number = number.split('-')[0] + number = re.sub('[^0-9]', '', number) + root_node = name + number + return root_node + + #################### + # utils + #################### + def get_task_ref(self): + from utils.commons.base_task import BaseTask + task: BaseTask = self.task.module if isinstance(self.task, DDP) else self.task + return task + + def log_metrics_to_tb(self, metrics, step=None): + """Logs the metric dict passed in. + + :param metrics: + """ + # turn all tensors to scalars + scalar_metrics = self.metrics_to_scalars(metrics) + + step = step if step is not None else self.global_step + # log actual metrics + if self.proc_rank == 0: + self.log_metrics(self.logger, scalar_metrics, step=step) + + @staticmethod + def log_metrics(logger, metrics, step=None): + for k, v in metrics.items(): + if isinstance(v, torch.Tensor): + v = v.item() + logger.add_scalar(k, v, step) + + def metrics_to_scalars(self, metrics): + new_metrics = {} + for k, v in metrics.items(): + if isinstance(v, torch.Tensor): + v = v.item() + + if type(v) is dict: + v = self.metrics_to_scalars(v) + + new_metrics[k] = v + + return new_metrics + + def save_terminal_logs(self): + t = datetime.now().strftime('%Y%m%d%H%M%S') + os.makedirs(f'{self.work_dir}/terminal_logs', exist_ok=True) + Tee(f'{self.work_dir}/terminal_logs/log_{t}.txt', 'w') + + def save_codes(self): + if len(hparams['save_codes']) > 0: + t = datetime.now().strftime('%Y%m%d%H%M%S') + code_dir = f'{self.work_dir}/codes/{t}' + subprocess.check_call(f'mkdir -p "{code_dir}"', shell=True) + for c in hparams['save_codes']: + if os.path.exists(c): + subprocess.check_call( + f'rsync -aR ' + f'--include="*.py" ' + f'--include="*.yaml" ' + f'--exclude="__pycache__" ' + f'--include="*/" ' + f'--exclude="*" ' + f'"./{c}" "{code_dir}/"', + shell=True) + print(f"| Copied codes to {code_dir}.") diff --git a/Geneface_main/GeneFace/utils/nn/__pycache__/grad.cpython-39.pyc b/Geneface_main/GeneFace/utils/nn/__pycache__/grad.cpython-39.pyc new file mode 100644 index 00000000..94ff8e25 Binary files /dev/null and b/Geneface_main/GeneFace/utils/nn/__pycache__/grad.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/utils/nn/__pycache__/model_utils.cpython-39.pyc b/Geneface_main/GeneFace/utils/nn/__pycache__/model_utils.cpython-39.pyc new file mode 100644 index 00000000..49c659a9 Binary files /dev/null and b/Geneface_main/GeneFace/utils/nn/__pycache__/model_utils.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/utils/nn/__pycache__/schedulers.cpython-39.pyc b/Geneface_main/GeneFace/utils/nn/__pycache__/schedulers.cpython-39.pyc new file mode 100644 index 00000000..33166017 Binary files /dev/null and b/Geneface_main/GeneFace/utils/nn/__pycache__/schedulers.cpython-39.pyc differ diff --git a/Geneface_main/GeneFace/utils/nn/grad.py b/Geneface_main/GeneFace/utils/nn/grad.py new file mode 100644 index 00000000..7a098082 --- /dev/null +++ b/Geneface_main/GeneFace/utils/nn/grad.py @@ -0,0 +1,44 @@ +import torch + +def get_grad_norm(model, l=2): + num_para = 0 + accu_grad = 0 + if isinstance(model, torch.nn.Module): + params = model.parameters() + else: + params = model + for p in params: + if p.grad is None: + continue + num_para += p.numel() + if l == 1: + accu_grad += p.grad.abs(1).sum() + elif l == 2: + accu_grad += p.grad.pow(2).sum() + else: + raise ValueError("Now we only implement l1/l2 norm !") + if l == 2: + accu_grad = accu_grad ** 0.5 + if isinstance(accu_grad, float): + return accu_grad + return accu_grad.item() + +class GradBuffer: + def __init__(self): + self.buffer = {} + + def add(self, model): + for item in model.named_parameters(): + name, param = item + if param.grad is None: + continue + self.buffer[name] = self.buffer.get(name, 0) + param.grad.data + + def apply(self, model): + for item in model.named_parameters(): + name, param = item + if param.grad is None: + continue + if name in self.buffer.keys(): + param.grad.data += self.buffer[name] + self.buffer = {} \ No newline at end of file diff --git a/Geneface_main/GeneFace/utils/nn/model_utils.py b/Geneface_main/GeneFace/utils/nn/model_utils.py new file mode 100644 index 00000000..3585da67 --- /dev/null +++ b/Geneface_main/GeneFace/utils/nn/model_utils.py @@ -0,0 +1,32 @@ +import numpy as np +import torch + + +def print_arch(model, model_name='model'): + print(f"| {model_name} Arch: ", model) + num_params(model, model_name=model_name) + + +def num_params(model, print_out=True, model_name="model"): + parameters = filter(lambda p: p.requires_grad, model.parameters()) + parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 + if print_out: + print(f'| {model_name} Trainable Parameters: %.3fM' % parameters) + return parameters + +def get_device_of_model(model): + return model.parameters().__next__().device + +def requires_grad(model): + if isinstance(model, torch.nn.Module): + for p in model.parameters(): + p.requires_grad = True + else: + model.requires_grad = True + +def not_requires_grad(model): + if isinstance(model, torch.nn.Module): + for p in model.parameters(): + p.requires_grad = False + else: + model.requires_grad = False diff --git a/Geneface_main/GeneFace/utils/nn/schedulers.py b/Geneface_main/GeneFace/utils/nn/schedulers.py new file mode 100644 index 00000000..19f7698b --- /dev/null +++ b/Geneface_main/GeneFace/utils/nn/schedulers.py @@ -0,0 +1,205 @@ +import numpy as np +from utils.commons.hparams import hparams + + +class NoneSchedule(object): + def __init__(self, optimizer, lr): + self.optimizer = optimizer + self.constant_lr = lr + self.step(0) + + def step(self, num_updates): + self.lr = self.constant_lr + for param_group in self.optimizer.param_groups: + param_group['lr'] = self.lr + return self.lr + + def get_lr(self): + return self.optimizer.param_groups[0]['lr'] + + def get_last_lr(self): + return self.get_lr() + + +class RSQRTSchedule(NoneSchedule): + def __init__(self, optimizer, lr, warmup_updates, hidden_size): + self.optimizer = optimizer + self.constant_lr = lr + self.warmup_updates = warmup_updates + self.hidden_size = hidden_size + self.lr = lr + for param_group in optimizer.param_groups: + param_group['lr'] = self.lr + self.step(0) + + def step(self, num_updates): + constant_lr = self.constant_lr + warmup = min(num_updates / self.warmup_updates, 1.0) + rsqrt_decay = max(self.warmup_updates, num_updates) ** -0.5 + rsqrt_hidden = self.hidden_size ** -0.5 + self.lr = max(constant_lr * warmup * rsqrt_decay * rsqrt_hidden, 1e-7) + for param_group in self.optimizer.param_groups: + param_group['lr'] = self.lr + return self.lr + + +class WarmupSchedule(NoneSchedule): + def __init__(self, optimizer, lr, warmup_updates): + self.optimizer = optimizer + self.constant_lr = self.lr = lr + self.warmup_updates = warmup_updates + for param_group in optimizer.param_groups: + param_group['lr'] = self.lr + self.step(0) + + def step(self, num_updates): + constant_lr = self.constant_lr + warmup = min(num_updates / self.warmup_updates, 1.0) + self.lr = max(constant_lr * warmup, 1e-7) + for param_group in self.optimizer.param_groups: + param_group['lr'] = self.lr + return self.lr + + +class ExponentialSchedule(NoneSchedule): + def __init__(self, optimizer, lr, warmup_updates): + self.optimizer = optimizer + self.constant_lr = self.lr = lr + self.warmup_updates = warmup_updates + for param_group in optimizer.param_groups: + param_group['lr'] = self.lr + self.step(0) + + def step(self, num_updates): + constant_lr = self.constant_lr + if self.warmup_updates > 0 and num_updates <= self.warmup_updates: + warmup = min(num_updates / self.warmup_updates, 1.0) + self.lr = max(constant_lr * warmup, 1e-7) + else: + new_lrate = constant_lr * (0.1 ** (num_updates / 250_000)) # decay by 0.1x for every 250k steps + self.lr = max(new_lrate, 1e-7) + for param_group in self.optimizer.param_groups: + param_group['lr'] = self.lr + return self.lr + + +class ExponentialScheduleWithAudattNet(NoneSchedule): + """ + Default Scheduler in AD-NeRF + for audatt net, since it starts at 20_0000 steps, we need to enlarge its lr + in optimizer, we set param_groups[1] to optimize audatt net + """ + def __init__(self, optimizer, lr, warmup_updates=0): + self.optimizer = optimizer + self.constant_lr = self.lr = lr + self.warmup_updates = warmup_updates + optimizer.param_groups[0]['lr'] = self.lr + optimizer.param_groups[1]['lr'] = self.lr * 5 + self.step(0) + + def step(self, num_updates): + constant_lr = self.constant_lr + if self.warmup_updates > 0 and num_updates <= self.warmup_updates: + warmup = min(num_updates / self.warmup_updates, 1.0) + self.lr = max(constant_lr * warmup, 1e-7) + else: + new_lrate = constant_lr * (0.1 ** (num_updates / 250_000)) # decay by 0.1x for every 250k steps + self.lr = max(new_lrate, 1e-7) + + self.optimizer.param_groups[0]['lr'] = self.lr + self.optimizer.param_groups[1]['lr'] = self.lr * 5 + return self.lr + +class ExponentialScheduleForRADNeRF(NoneSchedule): + """ + Default Scheduler in RAD-NeRF + RAD-NeRF has two groups of params with different lr + for tileGrid embedding, the lr=5e-3 + for other network params, the lr=5e-4 + """ + def __init__(self, optimizer, lr, warmup_updates=0): + self.optimizer = optimizer + self.constant_lr = self.lr = lr # 0.0005 + self.warmup_updates = warmup_updates + self.finetune_lips = hparams['finetune_lips'] + self.finetune_lips_start_iter = hparams['finetune_lips_start_iter'] + + optimizer.param_groups[0]['lr'] = self.lr # for Net_params in RAD-NeRF, lr starts from 0.0005 + optimizer.param_groups[1]['lr'] = self.lr * 10 # for tileGrid, lr starts from 0.005 + optimizer.param_groups[2]['lr'] = self.lr * 5 # for Att Net, lr starts from 0.0025 + self.step(0) + + def step(self, num_updates): + constant_lr = self.constant_lr + if self.warmup_updates > 0 and num_updates <= self.warmup_updates: + warmup = min(num_updates / self.warmup_updates, 1.0) + self.lr = max(constant_lr * warmup, 1e-7) + else: + if self.finetune_lips and num_updates > self.finetune_lips_start_iter: + new_lrate = constant_lr * (0.1 ** (num_updates / 250_000)) # decay by 0.05x for every 200k steps + else: + new_lrate = constant_lr * (0.1 ** (num_updates / 250_000)) # decay by 0.1x for every 200k steps + + self.lr = max(new_lrate, 1e-7) + + self.optimizer.param_groups[0]['lr'] = self.lr + self.optimizer.param_groups[1]['lr'] = self.lr * 10 + self.optimizer.param_groups[2]['lr'] = self.lr * 5 + return self.lr + + +class ExponentialScheduleForRADNeRFTorso(NoneSchedule): + """ + Default Scheduler in RAD-NeRF + RAD-NeRF has two groups of params with different lr + for tileGrid embedding, the lr=5e-3 + for other network params, the lr=5e-4 + """ + def __init__(self, optimizer, lr, warmup_updates=0): + self.optimizer = optimizer + self.constant_lr = self.lr = lr # 0.0005 + self.warmup_updates = warmup_updates + + optimizer.param_groups[0]['lr'] = self.lr # for Net_params in RAD-NeRF, lr starts from 0.0005 + optimizer.param_groups[1]['lr'] = self.lr * 10 # for tileGrid, lr starts from 0.005 + self.step(0) + + def step(self, num_updates): + constant_lr = self.constant_lr + if self.warmup_updates > 0 and num_updates <= self.warmup_updates: + warmup = min(num_updates / self.warmup_updates, 1.0) + self.lr = max(constant_lr * warmup, 1e-7) + else: + new_lrate = constant_lr * (0.1 ** (num_updates / 250_000)) # decay by 0.1x for every 200k steps + self.lr = max(new_lrate, 1e-7) + self.optimizer.param_groups[0]['lr'] = self.lr + self.optimizer.param_groups[1]['lr'] = self.lr * 10 + return self.lr + + +class CosineSchedule(NoneSchedule): + def __init__(self, optimizer, lr, warmup_updates, total_updates): + self.optimizer = optimizer + self.constant_lr = lr + self.warmup_updates = warmup_updates + self.total_updates = total_updates + self.lr = lr + self.assign_learning_rate(self.optimizer, self.lr) + self.step(0) + + def assign_learning_rate(self, optimizer, new_lr): + for param_group in optimizer.param_groups: + param_group["lr"] = new_lr + + def _warmup_lr(self, base_lr, warmup_length, step): + return base_lr * (step + 1) / warmup_length + + def step(self, num_updates): + if self.warmup_updates > 0 and num_updates <= self.warmup_updates: + lr = self._warmup_lr(self.lr, self.warmup_updates, num_updates) + else: + e = num_updates - self.warmup_updates + es = self.total_updates - self.warmup_updates + lr = 0.5 * (1 + np.cos(np.pi * e / es)) * self.lr + self.assign_learning_rate(self.optimizer, lr) + return lr diff --git a/Geneface_main/GeneFace/utils/nn/seq_utils.py b/Geneface_main/GeneFace/utils/nn/seq_utils.py new file mode 100644 index 00000000..1308bf7d --- /dev/null +++ b/Geneface_main/GeneFace/utils/nn/seq_utils.py @@ -0,0 +1,305 @@ +from collections import defaultdict +import torch +import torch.nn.functional as F + + +def make_positions(tensor, padding_idx): + """Replace non-padding symbols with their position numbers. + + Position numbers begin at padding_idx+1. Padding symbols are ignored. + """ + # The series of casts and type-conversions here are carefully + # balanced to both work with ONNX export and XLA. In particular XLA + # prefers ints, cumsum defaults to output longs, and ONNX doesn't know + # how to handle the dtype kwarg in cumsum. + mask = tensor.ne(padding_idx).int() + return ( + torch.cumsum(mask, dim=1).type_as(mask) * mask + ).long() + padding_idx + + +def softmax(x, dim): + return F.softmax(x, dim=dim, dtype=torch.float32) + + +def sequence_mask(lengths, maxlen, dtype=torch.bool): + if maxlen is None: + maxlen = lengths.max() + mask = ~(torch.ones((len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths).t() + mask.type(dtype) + return mask + + +def weights_nonzero_speech(target): + # target : B x T x mel + # Assign weight 1.0 to all labels except for padding (id=0). + dim = target.size(-1) + return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim) + + +INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0) + + +def _get_full_incremental_state_key(module_instance, key): + module_name = module_instance.__class__.__name__ + + # assign a unique ID to each module instance, so that incremental state is + # not shared across module instances + if not hasattr(module_instance, '_instance_id'): + INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1 + module_instance._instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name] + + return '{}.{}.{}'.format(module_name, module_instance._instance_id, key) + + +def get_incremental_state(module, incremental_state, key): + """Helper for getting incremental state for an nn.Module.""" + full_key = _get_full_incremental_state_key(module, key) + if incremental_state is None or full_key not in incremental_state: + return None + return incremental_state[full_key] + + +def set_incremental_state(module, incremental_state, key, value): + """Helper for setting incremental state for an nn.Module.""" + if incremental_state is not None: + full_key = _get_full_incremental_state_key(module, key) + incremental_state[full_key] = value + + +def fill_with_neg_inf(t): + """FP16-compatible function that fills a tensor with -inf.""" + return t.float().fill_(float('-inf')).type_as(t) + + +def fill_with_neg_inf2(t): + """FP16-compatible function that fills a tensor with -inf.""" + return t.float().fill_(-1e8).type_as(t) + + +def select_attn(attn_logits, type='best'): + """ + + :param attn_logits: [n_layers, B, n_head, T_sp, T_txt] + :return: + """ + encdec_attn = torch.stack(attn_logits, 0).transpose(1, 2) + # [n_layers * n_head, B, T_sp, T_txt] + encdec_attn = (encdec_attn.reshape([-1, *encdec_attn.shape[2:]])).softmax(-1) + if type == 'best': + indices = encdec_attn.max(-1).values.sum(-1).argmax(0) + encdec_attn = encdec_attn.gather( + 0, indices[None, :, None, None].repeat(1, 1, encdec_attn.size(-2), encdec_attn.size(-1)))[0] + return encdec_attn + elif type == 'mean': + return encdec_attn.mean(0) + + +def make_pad_mask(lengths, xs=None, length_dim=-1): + """Make mask tensor containing indices of padded part. + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. + If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. + See the example. + Returns: + Tensor: Mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + Examples: + With only lengths. + >>> lengths = [5, 3, 2] + >>> make_non_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + With the reference tensor. + >>> xs = torch.zeros((3, 2, 4)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0], + [0, 0, 0, 0]], + [[0, 0, 0, 1], + [0, 0, 0, 1]], + [[0, 0, 1, 1], + [0, 0, 1, 1]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + With the reference tensor and dimension indicator. + >>> xs = torch.zeros((3, 6, 6)) + >>> make_pad_mask(lengths, xs, 1) + tensor([[[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) + >>> make_pad_mask(lengths, xs, 2) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + """ + if length_dim == 0: + raise ValueError("length_dim cannot be 0: {}".format(length_dim)) + + if not isinstance(lengths, list): + lengths = lengths.tolist() + bs = int(len(lengths)) + if xs is None: + maxlen = int(max(lengths)) + else: + maxlen = xs.size(length_dim) + + seq_range = torch.arange(0, maxlen, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) + seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + if xs is not None: + assert xs.size(0) == bs, (xs.size(0), bs) + + if length_dim < 0: + length_dim = xs.dim() + length_dim + # ind = (:, None, ..., None, :, , None, ..., None) + ind = tuple( + slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) + ) + mask = mask[ind].expand_as(xs).to(xs.device) + return mask + + +def make_non_pad_mask(lengths, xs=None, length_dim=-1): + """Make mask tensor containing indices of non-padded part. + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. + If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. + See the example. + Returns: + ByteTensor: mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + Examples: + With only lengths. + >>> lengths = [5, 3, 2] + >>> make_non_pad_mask(lengths) + masks = [[1, 1, 1, 1 ,1], + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 0]] + With the reference tensor. + >>> xs = torch.zeros((3, 2, 4)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1], + [1, 1, 1, 1]], + [[1, 1, 1, 0], + [1, 1, 1, 0]], + [[1, 1, 0, 0], + [1, 1, 0, 0]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) + With the reference tensor and dimension indicator. + >>> xs = torch.zeros((3, 6, 6)) + >>> make_non_pad_mask(lengths, xs, 1) + tensor([[[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) + >>> make_non_pad_mask(lengths, xs, 2) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) + """ + return ~make_pad_mask(lengths, xs, length_dim) + + +def get_mask_from_lengths(lengths): + max_len = torch.max(lengths).item() + ids = torch.arange(0, max_len).to(lengths.device) + mask = (ids < lengths.unsqueeze(1)).bool() + return mask + + +def group_hidden_by_segs(h, seg_ids, max_len): + """ + + :param h: [B, T, H] + :param seg_ids: [B, T] + :return: h_ph: [B, T_ph, H] + """ + B, T, H = h.shape + h_gby_segs = h.new_zeros([B, max_len + 1, H]).scatter_add_(1, seg_ids[:, :, None].repeat([1, 1, H]), h) + all_ones = h.new_ones(h.shape[:2]) + cnt_gby_segs = h.new_zeros([B, max_len + 1]).scatter_add_(1, seg_ids, all_ones).contiguous() + h_gby_segs = h_gby_segs[:, 1:] + cnt_gby_segs = cnt_gby_segs[:, 1:] + h_gby_segs = h_gby_segs / torch.clamp(cnt_gby_segs[:, :, None], min=1) + return h_gby_segs, cnt_gby_segs diff --git a/Geneface_main/GeneFace/utils/visualization/draw_3d_landmark.py b/Geneface_main/GeneFace/utils/visualization/draw_3d_landmark.py new file mode 100644 index 00000000..fe0ead4c --- /dev/null +++ b/Geneface_main/GeneFace/utils/visualization/draw_3d_landmark.py @@ -0,0 +1,364 @@ +import cv2 +import math +import numpy as np +import matplotlib.pyplot as plt +import dearpygui.dearpygui as dpg +from scipy.spatial.transform import Rotation as R +from utils.commons.hparams import set_hparams, hparams +from data_util.face3d_helper import Face3DHelper + +face3d_helper = Face3DHelper(use_gpu=False) + + +set_hparams("egs/datasets/videos/May/radnerf_torso.yaml") + +from tasks.radnerfs.dataset_utils import RADNeRFDataset +dataset = RADNeRFDataset("val") +idexp_lm3d_mean = dataset.idexp_lm3d_mean.reshape([68,3]) +lm3d_mean = idexp_lm3d_mean / 10 + face3d_helper.key_mean_shape +lm3d_mean /= 1.5 # normalize to [-1,1] + +class Landmark3D: + + def __init__(self): + + # init pose [18, 3], in [-1, 1]^3 + self.points3D = np.concatenate([lm3d_mean.numpy(), np.ones([68,1])],axis=1).reshape([68,4]) + + # lines [17, 2] + self.lines = [ + # yaw + [0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5,6], [6,7], [7,8], [8,9], [9,10], [10,11], [11,12], [12,13], [13,14], [14,15], [15,16], + # left brow + [17,18], [18,19], [19,20], [20,21], + # right brow + [22, 23], [23,24], [24,25], [25,26], + # nose + [27,28], [28,29], [29,30], [31,32], [32,33], [33,34], [34,35], + # left eye + [36,37], [37,38], [38,39], [39,40], [40,41], [41,36], + # right eye + [42,43], [43,44], [44,45], [45,46], [46,47], [47,42], + # mouth + [48, 49], [49,50], [50,51], [51,52], [52,53], [53,54], [54,55], [55,56], [56,57], [57,58], [58,59],[59,48], + [48, 60], [60,61], [61,62], [62,63], [63,64], [64,65], [65,66], [66,67], [67,60], [54,64] + ] + # # keypoint color [18, 3] + # self.colors = [[0, 0, 255], [255, 0, 0], [255, 170, 0], [255, 255, 0], [255, 85, 0], [170, 255, 0], + # [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], + # [0, 85, 255], [85, 0, 255], [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + + self.colors = [[0,0,255] for _ in range(36)] + [[0,255,0] for _ in range(12)]+ [[255,0,0] for _ in range(20)] + self.line_colors = [[0,0,255] for _ in range(31)] + [[0,255,0] for _ in range(12)]+ [[255,0,0] for _ in range(22)] + + def draw(self, mvp, H, W): + # mvp: [4, 4] + + canvas = np.zeros((H, W, 3), dtype=np.uint8) + + points2D = self.points3D @ mvp.T # [18, 4] + points2D = points2D[:, :3] / points2D[:, 3:] # NDC in [-1, 1] + + xs = (points2D[:, 0] + 1) / 2 * H # [18] + ys = (points2D[:, 1] + 1) / 2 * W # [18] + + # 18 points + for i in range(len(self.points3D)): + cv2.circle(canvas, (int(xs[i]), int(ys[i])), 4, self.colors[i], thickness=-1) + + # 17 lines + for i in range(len(self.lines)): + cur_canvas = canvas.copy() + X = xs[self.lines[i]] + Y = ys[self.lines[i]] + mY = np.mean(Y) + mX = np.mean(X) + length = ((Y[0] - Y[1]) ** 2 + (X[0] - X[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(Y[0] - Y[1], X[0] - X[1])) + polygon = cv2.ellipse2Poly((int(mX), int(mY)), (int(length / 2), 4), int(angle), 0, 360, 1) + + cv2.fillConvexPoly(cur_canvas, polygon, self.line_colors[i]) + + canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0) + + canvas = canvas.astype(np.float32) / 255 + return canvas, np.stack([xs, ys], axis=1) + + +class OrbitCamera: + def __init__(self, W, H, r=2, fovy=60, near=0.01, far=100): + self.W = W + self.H = H + self.radius = r # camera distance from center + self.fovy = fovy # in degree + self.near = near + self.far = far + self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point + self.rot = R.from_matrix(np.eye(3)) + self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized! + + # pose + @property + def pose(self): + # first move camera to radius + res = np.eye(4, dtype=np.float32) + res[2, 3] = self.radius # opengl convention... + # rotate + rot = np.eye(4, dtype=np.float32) + rot[:3, :3] = self.rot.as_matrix() + res = rot @ res + # translate + res[:3, 3] -= self.center + return res + + # view + @property + def view(self): + return np.linalg.inv(self.pose) + + # intrinsics + @property + def intrinsics(self): + focal = self.H / (2 * np.tan(np.radians(self.fovy) / 2)) + return np.array([focal, focal, self.W // 2, self.H // 2], dtype=np.float32) + + # projection (perspective) + @property + def perspective(self): + y = np.tan(np.radians(self.fovy) / 2) + aspect = self.W / self.H + return np.array([[1/(y*aspect), 0, 0, 0], + [ 0, -1/y, 0, 0], + [ 0, 0, -(self.far+self.near)/(self.far-self.near), -(2*self.far*self.near)/(self.far-self.near)], + [ 0, 0, -1, 0]], dtype=np.float32) + + + def orbit(self, dx, dy): + # rotate along camera up/side axis! + side = self.rot.as_matrix()[:3, 0] # why this is side --> ? # already normalized. + rotvec_x = self.up * np.radians(-0.05 * dx) + rotvec_y = side * np.radians(-0.05 * dy) + self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot + + def scale(self, delta): + self.radius *= 1.1 ** (-delta) + + def pan(self, dx, dy, dz=0): + # pan in camera coordinate system (careful on the sensitivity!) + self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([dx, -dy, dz]) + + +class GUI: + def __init__(self, opt): + self.opt = opt + self.W = opt.W + self.H = opt.H + self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy) + + self.skel = Landmark3D() + + self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32) + self.need_update = True # camera moved, should reset accumulation + + self.save_path = 'pose.png' + self.mouse_loc = np.array([0, 0]) + self.points2D = None # [18, 2] + self.point_idx = 0 + + dpg.create_context() + self.register_dpg() + self.step() + + + def __del__(self): + dpg.destroy_context() + + + def step(self): + + if self.need_update: + + # mvp + mv = self.cam.view # [4, 4] + proj = self.cam.perspective # [4, 4] + mvp = proj @ mv + + # render our openpose image, somehow + self.render_buffer, self.points2D = self.skel.draw(mvp, self.H, self.W) + + self.need_update = False + + dpg.set_value("_texture", self.render_buffer) + + + def register_dpg(self): + + ### register texture + + with dpg.texture_registry(show=False): + dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture") + + ### register window + + # the rendered image, as the primary window + with dpg.window(label="Viewer", tag="_primary_window", width=self.W, height=self.H): + dpg.add_image("_texture") + + dpg.set_primary_window("_primary_window", True) + + # control window + with dpg.window(label="Control", tag="_control_window", width=-1, height=-1): + + # button theme + with dpg.theme() as theme_button: + with dpg.theme_component(dpg.mvButton): + dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18)) + dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47)) + dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83)) + dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5) + dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3) + + def callback_save(sender, app_data): + image = (self.render_buffer * 255).astype(np.uint8) + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + cv2.imwrite(self.save_path, image) + print(f'[INFO] write image to {self.save_path}') + + def callback_set_save_path(sender, app_data): + self.save_path = app_data + + with dpg.group(horizontal=True): + dpg.add_button(label="save image", tag="_button_save", callback=callback_save) + dpg.bind_item_theme("_button_save", theme_button) + + dpg.add_input_text(label="", default_value=self.save_path, callback=callback_set_save_path) + + # fov slider + def callback_set_fovy(sender, app_data): + self.cam.fovy = app_data + self.need_update = True + + dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy) + + + ### register camera handler + + def callback_camera_drag_rotate(sender, app_data): + + if not dpg.is_item_focused("_primary_window"): + return + + # dx = app_data[1] + # dy = app_data[2] + + # self.cam.orbit(dx, dy) + self.need_update = True + + + def callback_camera_wheel_scale(sender, app_data): + + if not dpg.is_item_focused("_primary_window"): + return + + delta = app_data + + self.cam.scale(delta) + self.need_update = True + + + def callback_camera_drag_pan(sender, app_data): + + if not dpg.is_item_focused("_primary_window"): + return + + dx = app_data[1] + dy = app_data[2] + + self.cam.pan(dx, dy) + self.need_update = True + + def callback_set_mouse_loc(sender, app_data): + + if not dpg.is_item_focused("_primary_window"): + return + + # just the pixel coordinate in image + self.mouse_loc = np.array(app_data) + + def callback_skel_select(sender, app_data): + + if not dpg.is_item_focused("_primary_window"): + return + + # determine the selected keypoint from mouse_loc + if self.points2D is None: return # not prepared + + dist = np.linalg.norm(self.points2D - self.mouse_loc, axis=1) # [18] + self.point_idx = np.argmin(dist) + + + def callback_skel_drag(sender, app_data): + + if not dpg.is_item_focused("_primary_window"): + return + + # 2D to 3D delta + dx = app_data[1] + dy = app_data[2] + + self.skel.points3D[self.point_idx, :3] += 0.0002 * self.cam.rot.as_matrix()[:3, :3] @ np.array([dx, -dy, 0]) + self.need_update = True + + + with dpg.handler_registry(): + dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate) + dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) + dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan) + + # for skeleton editing + dpg.add_mouse_move_handler(callback=callback_set_mouse_loc) + dpg.add_mouse_click_handler(button=dpg.mvMouseButton_Right, callback=callback_skel_select) + dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Right, callback=callback_skel_drag) + + + dpg.create_viewport(title='pose viewer', resizable=False, width=self.W, height=self.H) + + ### global theme + with dpg.theme() as theme_no_padding: + with dpg.theme_component(dpg.mvAll): + # set all padding to 0 to avoid scroll bar + dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core) + dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core) + dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core) + + dpg.bind_item_theme("_primary_window", theme_no_padding) + dpg.focus_item("_primary_window") + + dpg.setup_dearpygui() + + #dpg.show_metrics() + + dpg.show_viewport() + + + def render(self): + + while dpg.is_dearpygui_running(): + self.step() + dpg.render_dearpygui_frame() + + +if __name__ == '__main__': + + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--W', type=int, default=512, help="GUI width") + parser.add_argument('--H', type=int, default=512, help="GUI height") + parser.add_argument('--radius', type=float, default=3, help="default GUI camera radius from center") + parser.add_argument('--fovy', type=float, default=25, help="default GUI camera fovy") + + opt = parser.parse_args() + + gui = GUI(opt) + gui.render() \ No newline at end of file diff --git a/Geneface_main/GeneFace/utils/visualization/ffmpeg_utils.py b/Geneface_main/GeneFace/utils/visualization/ffmpeg_utils.py new file mode 100644 index 00000000..8cc8a715 --- /dev/null +++ b/Geneface_main/GeneFace/utils/visualization/ffmpeg_utils.py @@ -0,0 +1,18 @@ +import os + +def imgs_to_video(img_dir, video_path, audio_path=None, verbose=False): + cmd = f"ffmpeg -i {img_dir}/%5d.png " + if audio_path is not None: + cmd += f"-i {audio_path} " + cmd += "-strict -2 " + cmd += "-c:v libx264 -pix_fmt yuv420p -b:v 2000k -y " + if verbose is False: + cmd += " -v quiet " + cmd += f"{video_path} " + + os.system(cmd) + + +if __name__ == '__main__': + imgs_to_video('infer_out/tmp_imgs', 'infer_out/tmp_imgs/out.mp4', 'data/raw/val_wavs/zozo.wav') + imgs_to_video('infer_out/tmp_imgs', 'infer_out/tmp_imgs/out2.mp4', 'data/raw/val_wavs/zozo.wav') \ No newline at end of file diff --git a/Geneface_main/GeneFace/utils/visualization/lm_visualizer.py b/Geneface_main/GeneFace/utils/visualization/lm_visualizer.py new file mode 100644 index 00000000..463f9310 --- /dev/null +++ b/Geneface_main/GeneFace/utils/visualization/lm_visualizer.py @@ -0,0 +1,57 @@ +import numpy as np +import cv2 +from data_util.face3d_helper import Face3DHelper +from utils.visualization.ffmpeg_utils import imgs_to_video +import os + +face3d_helper = Face3DHelper('deep_3drecon/BFM') +# lrs3_stats = np.load('data/binary/lrs3/stats.npy',allow_pickle=True).tolist() +# lrs3_idexp_mean = lrs3_stats['idexp_lm3d_mean'].reshape([1,204]) +# lrs3_idexp_std = lrs3_stats['idexp_lm3d_std'].reshape([1,204]) + + +def render_idexp_npy_to_lm_video(npy_name, out_video_name, audio_name=None): + idexp_lm3d = np.load(npy_name) + lm3d = idexp_lm3d / 10 + face3d_helper.key_mean_shape.squeeze().reshape([1, -1]).cpu().numpy() + lm3d = lm3d.reshape([-1, 68, 3]) + + tmp_img_dir = os.path.join(os.path.dirname(out_video_name), "tmp_lm3d_imgs") + os.makedirs(tmp_img_dir, exist_ok=True) + + WH = 512 + lm3d = (lm3d * WH/2 + WH/2).astype(int) + eye_idx = list(range(36,48)) + mouth_idx = list(range(48,68)) + for i_img in range(len(lm3d)): + lm2d = lm3d[i_img ,:, :2] # [68, 2] + img = np.ones([WH, WH, 3], dtype=np.uint8) * 255 + + for i in range(len(lm2d)): + x, y = lm2d[i] + if i in eye_idx: + color = (0,0,255) + elif i in mouth_idx: + color = (0,255,0) + else: + color = (255,0,0) + img = cv2.circle(img, center=(x,y), radius=3, color=color, thickness=-1) + font = cv2.FONT_HERSHEY_SIMPLEX + img = cv2.flip(img, 0) + for i in range(len(lm2d)): + x, y = lm2d[i] + y = WH - y + img = cv2.putText(img, f"{i}", org=(x,y), fontFace=font, fontScale=0.3, color=(255,0,0)) + + out_name = os.path.join(tmp_img_dir, f'{format(i_img, "05d")}.png') + cv2.imwrite(out_name, img) + imgs_to_video(tmp_img_dir, out_video_name, audio_name) + os.system(f"rm -r {tmp_img_dir}") + +if __name__ == '__main__': + import argparse + argparser = argparse.ArgumentParser() + argparser.add_argument('--npy_name', type=str, default="infer_out/May/pred_lm3d/zozo.npy", help='the path of landmark .npy') + argparser.add_argument('--audio_name', type=str, default="data/raw/val_wavs/zozo.wav", help='the path of audio file') + argparser.add_argument('--out_path', type=str, default="infer_out/May/visualized_lm3d/zozo.mp4", help='the path to save visualization results') + args = argparser.parse_args() + render_idexp_npy_to_lm_video(args.npy_name, args.out_path, audio_name=args.audio_name) \ No newline at end of file diff --git a/Geneface_main/GeneFace/utils/visualization/t-sne.py b/Geneface_main/GeneFace/utils/visualization/t-sne.py new file mode 100644 index 00000000..a0322650 --- /dev/null +++ b/Geneface_main/GeneFace/utils/visualization/t-sne.py @@ -0,0 +1,132 @@ +from openTSNE import TSNE +import numpy as np +import matplotlib +import matplotlib.pyplot as plt +import random + +def visualize( + x, + y, + ax=None, + title=None, + draw_legend=True, + draw_centers=False, + draw_cluster_labels=False, + colors=None, + legend_kwargs=None, + label_order=None, + **kwargs +): + + if ax is None: + _, ax = matplotlib.pyplot.subplots(figsize=(10, 8)) + + if title is not None: + ax.set_title(title) + + plot_params = {"alpha": kwargs.get("alpha", 0.6), "s": kwargs.get("s", 1)} + + # Create main plot + if label_order is not None: + assert all(np.isin(np.unique(y), label_order)) + classes = [l for l in label_order if l in np.unique(y)] + else: + classes = np.unique(y) + if colors is None: + default_colors = matplotlib.rcParams["axes.prop_cycle"] + colors = {k: v["color"] for k, v in zip(classes, default_colors())} + + point_colors = list(map(colors.get, y)) + + ax.scatter(x[:, 0], x[:, 1], c=point_colors, rasterized=True, **plot_params) + + # Plot mediods + if draw_centers: + centers = [] + for yi in classes: + mask = yi == y + centers.append(np.median(x[mask, :2], axis=0)) + centers = np.array(centers) + + center_colors = list(map(colors.get, classes)) + ax.scatter( + centers[:, 0], centers[:, 1], c=center_colors, s=48, alpha=1, edgecolor="k" + ) + + # Draw mediod labels + if draw_cluster_labels: + for idx, label in enumerate(classes): + ax.text( + centers[idx, 0], + centers[idx, 1] + 2.2, + label, + fontsize=kwargs.get("fontsize", 6), + horizontalalignment="center", + ) + + # Hide ticks and axis + ax.set_xticks([]), ax.set_yticks([]), ax.axis("off") + + if draw_legend: + legend_handles = [ + matplotlib.lines.Line2D( + [], + [], + marker="s", + color="w", + markerfacecolor=colors[yi], + ms=10, + alpha=1, + linewidth=0, + label=yi, + markeredgecolor="k", + ) + for yi in classes + ] + legend_kwargs_ = dict(loc="best", bbox_to_anchor=(0.05, 0.5), frameon=False, ) + if legend_kwargs is not None: + legend_kwargs_.update(legend_kwargs) + ax.legend(handles=legend_handles, **legend_kwargs_) + + +tsne = TSNE( + perplexity=30, + metric="euclidean", + n_jobs=8, + random_state=42, + verbose=True, +) + +idexp_lm3d_pred_lrs3 = np.load("infer_out/tmp_npys/lrs3_pred_all.npy") +idx = np.random.choice(np.arange(len(idexp_lm3d_pred_lrs3)), 10000) +idexp_lm3d_pred_lrs3 = idexp_lm3d_pred_lrs3[idx] + +person_ds = np.load("data/binary/videos/May/trainval_dataset.npy", allow_pickle=True).tolist() +person_idexp_mean = person_ds['idexp_lm3d_mean'].reshape([1,204]) +person_idexp_std = person_ds['idexp_lm3d_std'].reshape([1,204]) +person_idexp_lm3d_train = np.stack([s['idexp_lm3d_normalized'].reshape([204,]) for s in person_ds['train_samples']]) +person_idexp_lm3d_val = np.stack([s['idexp_lm3d_normalized'].reshape([204,]) for s in person_ds['val_samples']]) + +lrs3_stats = np.load('/home/yezhenhui/datasets/binary/lrs3_0702/stats.npy',allow_pickle=True).tolist() +lrs3_idexp_mean = lrs3_stats['idexp_lm3d_mean'].reshape([1,204]) +lrs3_idexp_std = lrs3_stats['idexp_lm3d_std'].reshape([1,204]) +person_idexp_lm3d_train = person_idexp_lm3d_train * person_idexp_std + person_idexp_mean +# person_idexp_lm3d_train = (person_idexp_lm3d_train - lrs3_idexp_mean) / lrs3_idexp_std +person_idexp_lm3d_val = person_idexp_lm3d_val * person_idexp_std + person_idexp_mean +# person_idexp_lm3d_val = (person_idexp_lm3d_val - lrs3_idexp_mean) / lrs3_idexp_std +idexp_lm3d_pred_lrs3 = idexp_lm3d_pred_lrs3 * lrs3_idexp_std + lrs3_idexp_mean + + +idexp_lm3d_pred_vae = np.load("infer_out/tmp_npys/pred_exp_0_vae.npy").reshape([-1,204]) +idexp_lm3d_pred_postnet = np.load("infer_out/tmp_npys/pred_exp_0_postnet_hubert.npy").reshape([-1,204]) +# idexp_lm3d_pred_postnet = idexp_lm3d_pred_postnet * lrs3_idexp_std + lrs3_idexp_mean + +idexp_lm3d_all = np.concatenate([idexp_lm3d_pred_lrs3, person_idexp_lm3d_train,idexp_lm3d_pred_vae, idexp_lm3d_pred_postnet]) +idexp_lm3d_all_emb = tsne.fit(idexp_lm3d_all) # array(float64) [B,50]==>[B, 2] +# z_p_emb = tsne.fit(z_p) # array(float64) [B,50]==>[B, 2] +y1 = ["pred_lrs3" for _ in range(len(idexp_lm3d_pred_lrs3))] +y2 = ["person_train" for _ in range(len(person_idexp_lm3d_train))] +y3 = ["vae" for _ in range(len(idexp_lm3d_pred_vae))] +y4 = ["postnet" for _ in range(len(idexp_lm3d_pred_postnet))] +visualize(idexp_lm3d_all_emb, y1+y2+y3+y4) +plt.savefig("infer_out/tmp_npys/lrs3_pred_all_0k.png") \ No newline at end of file diff --git a/Geneface_main/README.md b/Geneface_main/README.md new file mode 100644 index 00000000..6464dd2b --- /dev/null +++ b/Geneface_main/README.md @@ -0,0 +1,15 @@ +# GeneFace-Reproduction + +这是2024年语音识别课程大作业的仓库,用于[GeneFace](https://github.com/yerfor/GeneFace)的复现 + +## 报告 + +- [报告](./report.pdf) +- [复现流程](./C3.md) + +## 获取项目 + +- [校内网盘](https://maru.hana.im/dedfaf/filebrowser/share/w_i9-LQN) +- [校外网盘(使用nlibvpn)](https://nlibvpn.bit.edu.cn/https/77726476706e69737468656265737421f8f64f9d2a317a45301b8dbf821b26201ef36cf7/dedfaf/filebrowser/share/w_i9-LQN) + +其中`Geneface.zip`为项目文件,`Geneface_docker.tar`为本项目封装的docker diff --git a/Geneface_main/general-report.docx b/Geneface_main/general-report.docx new file mode 100644 index 00000000..23e8233d Binary files /dev/null and b/Geneface_main/general-report.docx differ diff --git a/Geneface_main/imgs/Audio2Motion.png b/Geneface_main/imgs/Audio2Motion.png new file mode 100644 index 00000000..edf09d71 Binary files /dev/null and b/Geneface_main/imgs/Audio2Motion.png differ diff --git a/Geneface_main/imgs/SyncNet.png b/Geneface_main/imgs/SyncNet.png new file mode 100644 index 00000000..0aac9630 Binary files /dev/null and b/Geneface_main/imgs/SyncNet.png differ diff --git a/Geneface_main/report.docx b/Geneface_main/report.docx new file mode 100644 index 00000000..48e271e7 Binary files /dev/null and b/Geneface_main/report.docx differ diff --git a/Geneface_main/report.pdf b/Geneface_main/report.pdf new file mode 100644 index 00000000..60a8f129 Binary files /dev/null and b/Geneface_main/report.pdf differ diff --git a/README.md b/README.md index 449c4ea9..8d1c8b69 100644 --- a/README.md +++ b/README.md @@ -1 +1 @@ -# talkingface-kit \ No newline at end of file +