diff --git a/Talking-Face_PC-AVS/.gitignore b/Talking-Face_PC-AVS/.gitignore new file mode 100644 index 00000000..8efc0b64 --- /dev/null +++ b/Talking-Face_PC-AVS/.gitignore @@ -0,0 +1,121 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class +**/*.pyc + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# custom +demo/ +checkpoints/ +results/ +.vscode +.idea +*.pkl +*.pkl.json +*.log.json +work_dirs/ +*.avi + +# Pytorch +*.pth +*.tar diff --git a/Talking-Face_PC-AVS/Dockerfile b/Talking-Face_PC-AVS/Dockerfile new file mode 100644 index 00000000..3d04d452 --- /dev/null +++ b/Talking-Face_PC-AVS/Dockerfile @@ -0,0 +1,39 @@ +#基于python的基础镜像 +FROM python:3.6-slim + +#代码添加到code文件夹 +ADD ./Talking-Face_PC-AVS /code + +#设置code文件夹是工作目录 +WORKDIR /code + +COPY requirements.txt /code/ + +RUN pip install --upgrade pip -i https://pypi.mirrors.ustc.edu.cn/simple/ + +RUN sed -i 's|http://deb.debian.org/debian|http://mirrors.aliyun.com/debian|g' /etc/apt/sources.list + +RUN apt-get update + +RUN apt-get install -y libsm6 libxext6 libxrender-dev + +RUN apt-get install -y libglib2.0-0 + +RUN apt-get install -y libglib2.0-dev + +RUN apt-get install -y libopencv-dev + +RUN apt-get install -y build-essential + +RUN apt-get install -y python3-dev + +RUN apt-get install -y gcc + +RUN apt-get install -y g++ + +RUN pip install numpy==1.19.5 -i https://pypi.mirrors.ustc.edu.cn/simple/ + +#conda安装环境 +RUN pip install -r requirements.txt -i https://pypi.mirrors.ustc.edu.cn/simple/ + +CMD ["/bin/bash", "/code/experiments/demo_vox_new.sh"] diff --git a/Talking-Face_PC-AVS/LICENSE b/Talking-Face_PC-AVS/LICENSE new file mode 100644 index 00000000..a6d7fd36 --- /dev/null +++ b/Talking-Face_PC-AVS/LICENSE @@ -0,0 +1,384 @@ +Attribution 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution 4.0 International Public License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution 4.0 International Public License ("Public License"). To the +extent this Public License may be interpreted as a contract, You are +granted the Licensed Rights in consideration of Your acceptance of +these terms and conditions, and the Licensor grants You such rights in +consideration of benefits the Licensor receives from making the +Licensed Material available under these terms and conditions. + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + +b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + +c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + +d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + +e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + +f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + +g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + +h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + +i. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + +j. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + +k. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + +Section 2 -- Scope. + +a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part; and + + b. produce, reproduce, and Share Adapted Material. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + +b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties. + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + +a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + +a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database; + +b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + +c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + +a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + +b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + +c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + +Section 6 -- Term and Termination. + +a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + +b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + +c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + +d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + +Section 7 -- Other Terms and Conditions. + +a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + +b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + +Section 8 -- Interpretation. + +a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + +b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + +c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + +d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public licenses. +Notwithstanding, Creative Commons may elect to apply one of its public +licenses to material it publishes and in those instances will be +considered the "Licensor." Except for the limited purpose of indicating +that material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the public +licenses. + +Creative Commons may be contacted at creativecommons.org. diff --git a/Talking-Face_PC-AVS/NIQE.py b/Talking-Face_PC-AVS/NIQE.py new file mode 100644 index 00000000..449cf6f5 --- /dev/null +++ b/Talking-Face_PC-AVS/NIQE.py @@ -0,0 +1,267 @@ +import numpy as np +import scipy.misc +import scipy.io +from os.path import dirname +from os.path import join +import scipy +from PIL import Image +import numpy as np +import scipy.ndimage +import numpy as np +import scipy.special +import math +import os + + +gamma_range = np.arange(0.2, 10, 0.001) +a = scipy.special.gamma(2.0/gamma_range) +a *= a +b = scipy.special.gamma(1.0/gamma_range) +c = scipy.special.gamma(3.0/gamma_range) +prec_gammas = a/(b*c) + +def aggd_features(imdata): + #flatten imdata + imdata.shape = (len(imdata.flat),) + imdata2 = imdata*imdata + left_data = imdata2[imdata<0] + right_data = imdata2[imdata>=0] + left_mean_sqrt = 0 + right_mean_sqrt = 0 + if len(left_data) > 0: + left_mean_sqrt = np.sqrt(np.average(left_data)) + if len(right_data) > 0: + right_mean_sqrt = np.sqrt(np.average(right_data)) + + if right_mean_sqrt != 0: + gamma_hat = left_mean_sqrt/right_mean_sqrt + else: + gamma_hat = np.inf + #solve r-hat norm + + imdata2_mean = np.mean(imdata2) + if imdata2_mean != 0: + r_hat = (np.average(np.abs(imdata))**2) / (np.average(imdata2)) + else: + r_hat = np.inf + rhat_norm = r_hat * (((math.pow(gamma_hat, 3) + 1)*(gamma_hat + 1)) / math.pow(math.pow(gamma_hat, 2) + 1, 2)) + + #solve alpha by guessing values that minimize ro + pos = np.argmin((prec_gammas - rhat_norm)**2); + alpha = gamma_range[pos] + + gam1 = scipy.special.gamma(1.0/alpha) + gam2 = scipy.special.gamma(2.0/alpha) + gam3 = scipy.special.gamma(3.0/alpha) + + aggdratio = np.sqrt(gam1) / np.sqrt(gam3) + bl = aggdratio * left_mean_sqrt + br = aggdratio * right_mean_sqrt + + #mean parameter + N = (br - bl)*(gam2 / gam1)#*aggdratio + return (alpha, N, bl, br, left_mean_sqrt, right_mean_sqrt) + +def ggd_features(imdata): + nr_gam = 1/prec_gammas + sigma_sq = np.var(imdata) + E = np.mean(np.abs(imdata)) + rho = sigma_sq/E**2 + pos = np.argmin(np.abs(nr_gam - rho)); + return gamma_range[pos], sigma_sq + +def paired_product(new_im): + shift1 = np.roll(new_im.copy(), 1, axis=1) + shift2 = np.roll(new_im.copy(), 1, axis=0) + shift3 = np.roll(np.roll(new_im.copy(), 1, axis=0), 1, axis=1) + shift4 = np.roll(np.roll(new_im.copy(), 1, axis=0), -1, axis=1) + + H_img = shift1 * new_im + V_img = shift2 * new_im + D1_img = shift3 * new_im + D2_img = shift4 * new_im + + return (H_img, V_img, D1_img, D2_img) + + +def gen_gauss_window(lw, sigma): + sd = np.float32(sigma) + lw = int(lw) + weights = [0.0] * (2 * lw + 1) + weights[lw] = 1.0 + sum = 1.0 + sd *= sd + for ii in range(1, lw + 1): + tmp = np.exp(-0.5 * np.float32(ii * ii) / sd) + weights[lw + ii] = tmp + weights[lw - ii] = tmp + sum += 2.0 * tmp + for ii in range(2 * lw + 1): + weights[ii] /= sum + return weights + +def compute_image_mscn_transform(image, C=1, avg_window=None, extend_mode='constant'): + if avg_window is None: + avg_window = gen_gauss_window(3, 7.0/6.0) + assert len(np.shape(image)) == 2 + h, w = np.shape(image) + mu_image = np.zeros((h, w), dtype=np.float32) + var_image = np.zeros((h, w), dtype=np.float32) + image = np.array(image).astype('float32') + scipy.ndimage.correlate1d(image, avg_window, 0, mu_image, mode=extend_mode) + scipy.ndimage.correlate1d(mu_image, avg_window, 1, mu_image, mode=extend_mode) + scipy.ndimage.correlate1d(image**2, avg_window, 0, var_image, mode=extend_mode) + scipy.ndimage.correlate1d(var_image, avg_window, 1, var_image, mode=extend_mode) + var_image = np.sqrt(np.abs(var_image - mu_image**2)) + return (image - mu_image)/(var_image + C), var_image, mu_image + + +def _niqe_extract_subband_feats(mscncoefs): + # alpha_m, = extract_ggd_features(mscncoefs) + alpha_m, N, bl, br, lsq, rsq = aggd_features(mscncoefs.copy()) + pps1, pps2, pps3, pps4 = paired_product(mscncoefs) + alpha1, N1, bl1, br1, lsq1, rsq1 = aggd_features(pps1) + alpha2, N2, bl2, br2, lsq2, rsq2 = aggd_features(pps2) + alpha3, N3, bl3, br3, lsq3, rsq3 = aggd_features(pps3) + alpha4, N4, bl4, br4, lsq4, rsq4 = aggd_features(pps4) + return np.array([alpha_m, (bl+br)/2.0, + alpha1, N1, bl1, br1, # (V) + alpha2, N2, bl2, br2, # (H) + alpha3, N3, bl3, bl3, # (D1) + alpha4, N4, bl4, bl4, # (D2) + ]) + +def get_patches_train_features(img, patch_size, stride=8): + return _get_patches_generic(img, patch_size, 1, stride) + +def get_patches_test_features(img, patch_size, stride=8): + return _get_patches_generic(img, patch_size, 0, stride) + +def extract_on_patches(img, patch_size): + h, w = img.shape + patch_size = np.int(patch_size) + patches = [] + for j in range(0, h-patch_size+1, patch_size): + for i in range(0, w-patch_size+1, patch_size): + patch = img[j:j+patch_size, i:i+patch_size] + patches.append(patch) + + patches = np.array(patches) + + patch_features = [] + for p in patches: + patch_features.append(_niqe_extract_subband_feats(p)) + patch_features = np.array(patch_features) + + return patch_features + +def _get_patches_generic(img, patch_size, is_train, stride): + h, w = np.shape(img) + if h < patch_size or w < patch_size: + print("Input image is too small") + exit(0) + + # ensure that the patch divides evenly into img + hoffset = (h % patch_size) + woffset = (w % patch_size) + + if hoffset > 0: + img = img[:-hoffset, :] + if woffset > 0: + img = img[:, :-woffset] + + + img = img.astype(np.float32) + img2 = scipy.misc.imresize(img, 0.5, interp='bicubic', mode='F') + + mscn1, var, mu = compute_image_mscn_transform(img) + mscn1 = mscn1.astype(np.float32) + + mscn2, _, _ = compute_image_mscn_transform(img2) + mscn2 = mscn2.astype(np.float32) + + + feats_lvl1 = extract_on_patches(mscn1, patch_size) + feats_lvl2 = extract_on_patches(mscn2, patch_size/2) + + feats = np.hstack((feats_lvl1, feats_lvl2))# feats_lvl3)) + + return feats + +def niqe(inputImgData): + + patch_size = 96 + module_path = dirname(__file__) + + # TODO: memoize + params = scipy.io.loadmat(join(module_path, 'data', 'niqe_image_params.mat')) + pop_mu = np.ravel(params["pop_mu"]) + pop_cov = params["pop_cov"] + + + M, N = inputImgData.shape + + # assert C == 1, "niqe called with videos containing %d channels. Please supply only the luminance channel" % (C,) + assert M > (patch_size*2+1), "niqe called with small frame size, requires > 192x192 resolution video using current training parameters" + assert N > (patch_size*2+1), "niqe called with small frame size, requires > 192x192 resolution video using current training parameters" + + + feats = get_patches_test_features(inputImgData, patch_size) + sample_mu = np.mean(feats, axis=0) + sample_cov = np.cov(feats.T) + + X = sample_mu - pop_mu + covmat = ((pop_cov+sample_cov)/2.0) + pinvmat = scipy.linalg.pinv(covmat) + niqe_score = np.sqrt(np.dot(np.dot(X, pinvmat), X)) + + return niqe_score + + +'''if __name__ == "__main__": + + ref = np.array(Image.open('./test_imgs/bikes.bmp').convert('LA'))[:,:,0] # ref + dis = np.array(Image.open('./test_imgs/bikes_distorted.bmp').convert('LA'))[:,:,0] # dis + + print('NIQE of ref bikes image is: %0.3f'% niqe(ref)) + print('NIQE of dis bikes image is: %0.3f'% niqe(dis)) + + ref = np.array(Image.open('./test_imgs/parrots.bmp').convert('LA'))[:,:,0] # ref + dis = np.array(Image.open('./test_imgs/parrots_distorted.bmp').convert('LA'))[:,:,0] # dis + + print('NIQE of ref parrot image is: %0.3f'% niqe(ref)) + print('NIQE of dis parrot image is: %0.3f'% niqe(dis))''' +'''if __name__ == "__main__": + def read_image_as_grayscale(image_path): + img = Image.open(image_path).convert('L') + return np.array(img) + ref_bikes = read_image_as_grayscale('/home/lanlan/Talking-Face_PC-AVS-main/results/id_00010_pose_00473_audio_741400104/G_Pose_Driven_/G_Pose_Driven_0.jpg') # 注意文件扩展名 + + print('NIQE of the image is: %0.3f' % niqe(ref_bikes))''' + +def read_image_as_grayscale(image_path): + img = Image.open(image_path).convert('L') + return np.array(img) + +def evaluate_niqe_in_folder(folder_path): + # 遍历文件夹中的所有文件 + niqe_values = [] + for filename in os.listdir(folder_path): + # 检查文件扩展名,确保只处理图像文件(这里以.jpg为例) + if filename.endswith(".jpg") or filename.endswith(".jpeg") or filename.endswith(".png"): + # 构建完整的文件路径 + image_path = os.path.join(folder_path, filename) + # 读取图像并转换为灰度图 + img_gray = read_image_as_grayscale(image_path) + # 计算NIQE值 + niqe_value = niqe(img_gray) + niqe_values.append(niqe_value) + # 打印NIQE值和对应的图像文件名 + print(f'NIQE of the image "{filename}" is: %0.3f' % niqe_value) + average_niqe = sum(niqe_values) / len(niqe_values) if niqe_values else 0 + print(f'Average NIQE of all images in the folder is: %0.3f' % average_niqe) + +if __name__ == "__main__": + folder_to_evaluate = 'results/id_May_cropped_pose_Macron_cropped_audio_Shaheen/G_Pose_Driven_' +# 调用函数评估文件夹中的所有图像 + evaluate_niqe_in_folder(folder_to_evaluate) \ No newline at end of file diff --git a/Talking-Face_PC-AVS/PSNR.py b/Talking-Face_PC-AVS/PSNR.py new file mode 100644 index 00000000..769a78ea --- /dev/null +++ b/Talking-Face_PC-AVS/PSNR.py @@ -0,0 +1,27 @@ +import cv2 +import numpy as np +from skimage.metrics import peak_signal_noise_ratio as psnr + + +cap_orig = cv2.VideoCapture('raw/videos/Macron_224.mp4') +cap_gen = cv2.VideoCapture('results/id_May_cropped_pose_Macron_cropped_audio_Shaheen/G_Pose_Driven_.mp4') + + +psnr_values = [] +while cap_orig.isOpened() and cap_gen.isOpened(): + ret_orig, frame_orig = cap_orig.read() + ret_gen, frame_gen = cap_gen.read() + if not ret_orig or not ret_gen: + break + + + gray_orig = cv2.cvtColor(frame_orig, cv2.COLOR_BGR2GRAY) + gray_gen = cv2.cvtColor(frame_gen, cv2.COLOR_BGR2GRAY) + + + psnr_value = psnr(gray_orig, gray_gen) + psnr_values.append(psnr_value) + + +average_psnr = np.mean(psnr_values) +print('Average PSNR:', average_psnr) \ No newline at end of file diff --git a/Talking-Face_PC-AVS/README.md b/Talking-Face_PC-AVS/README.md new file mode 100644 index 00000000..dcf968b0 --- /dev/null +++ b/Talking-Face_PC-AVS/README.md @@ -0,0 +1,39 @@ +(特别注意:由于评估数据集raw文件夹中的视频过大,无法上传到github上,请自行将raw文件夹放到根目录下) +PS:docker镜像下载地址https://drive.google.com/file/d/1H9K5E8GqzK1EDPScpSeiW0j5IksmWORo/view?usp=sharing + +1. 按照以下代码格式对数据集进行预处理,数据集可选用raw文件夹下的任意视频。 +python scripts/prepare_testing_files.py --src_pose_path /path/to/your/pose_path.mp4 --src_audio_path /path/to/your/audio_path.mp4 --src_mouth_frame_path /path/to/your/mouth_path.mp4 --src_input_path /path/to/your/input.mp4 --csv_path raw/metadata.csv + +2. 预处理后会生成文件夹,文件夹包含视频每一帧的图片,按照以下代码格式对图片进行人脸对齐,输入样式是文件夹 +python scripts/align_68_new.py --folder_path /path/to/your/folder + +3. 人脸对齐后会生成一个全新文件夹,修改raw/metadata.csv文件,将文件内原先写入的文件路径改为新生成的文件夹路径,一般情况文件夹名字为xxx_cropped,(注:修改完后找到mouth_source文件夹中对应的xxx_cropped文件夹,打开之后在文件夹中复制粘贴一份最后一张图片,原因是执行脚本会报错,mouth_source中缺少一张图片)随后执行experiments/demo_vox_new.sh脚本 +bash experiments/demo_vox_new.sh + +4. 运行结果会保存在results文件夹内。 + +结果评估: +PSNR:修改PSNR文件如图所示部分,![alt text](image.png)。cap_orig改为原视频的224分辨率视频,cap_gen改为结果文件夹中G_Pose_Driven_.mp4,随后运行文件即可。 +注:224分辨率视频获取方法:ffmpeg -i Jae-in.mp4 -vf "scale=224:224" Jae-in_224.mp4(没有ffmpeg通过以下命令安装:pip install ffmpeg) + +SSIM:方式同上,修改后运行SSIM.py即可 + +NIQE:将NIQE.py文件末尾处folder_to_evaluate修改为被测试路径即可,如图所示![alt text](image-1.png)。修改后执行即可。 + +FID:按照以下形式执行命令即可,修改命令中的文件路径为被测试路径,第一个路径为 人脸对齐 后的路径。 +python -m pytorch_fid /path/to/your/input/folder path/to/your/results/G_Pose_Driven_ --device cuda:0 + + +执行示例如下: +1. 预处理: + python scripts/prepare_testing_files.py --src_pose_path raw/videos/Jae-in.mp4 --src_audio_path raw/videos/May.mp4 --src_mouth_frame_path raw/videos/May.mp4 --src_input_path raw/videos/Lieu.mp4 --csv_path raw/metadata.csv +2. 人脸对齐: +python scripts/align_68_new.py --folder_path Pose_Source/Jae-in +python scripts/align_68_new.py --folder_path Mouth_Source/May +python scripts/align_68_new.py --folder_path Input/Lieu +3. 修改metadata.csv文件,更改文件路径 +4. 执行脚本 : +bash experiments/demo_vox_new.sh +5. 结果评估 + +PS:文件夹中的my_image.tar中封装的是docker镜像,docker镜像中有部分预处理过得文件(人脸对齐完毕),可以直接执行脚本,加载后使用即可。 diff --git a/Talking-Face_PC-AVS/SSIM.py b/Talking-Face_PC-AVS/SSIM.py new file mode 100644 index 00000000..07d6da93 --- /dev/null +++ b/Talking-Face_PC-AVS/SSIM.py @@ -0,0 +1,25 @@ +import cv2 +import numpy as np +from skimage.metrics import structural_similarity as ssim + + +cap_orig = cv2.VideoCapture('raw/videos/Macron_224.mp4') +cap_gen = cv2.VideoCapture('results/id_May_cropped_pose_Macron_cropped_audio_Shaheen/G_Pose_Driven_.mp4') + + +ssim_values = [] +while cap_orig.isOpened() and cap_gen.isOpened(): + ret_orig, frame_orig = cap_orig.read() + ret_gen, frame_gen = cap_gen.read() + if not ret_orig or not ret_gen: + break + + gray_orig = cv2.cvtColor(frame_orig, cv2.COLOR_BGR2GRAY) + gray_gen = cv2.cvtColor(frame_gen, cv2.COLOR_BGR2GRAY) + + + ssim_value, _ = ssim(gray_orig, gray_gen, full=True) + ssim_values.append(ssim_value) + +average_ssim = np.mean(ssim_values) +print('Average SSIM:', average_ssim) \ No newline at end of file diff --git a/Talking-Face_PC-AVS/config/AudioConfig.py b/Talking-Face_PC-AVS/config/AudioConfig.py new file mode 100644 index 00000000..83207139 --- /dev/null +++ b/Talking-Face_PC-AVS/config/AudioConfig.py @@ -0,0 +1,180 @@ +import librosa +import librosa.filters +import numpy as np +from scipy import signal +from scipy.io import wavfile +import lws + + +class AudioConfig: + def __init__(self, frame_rate=25, + sample_rate=16000, + num_mels=80, + fft_size=1280, + hop_size=160, + num_frames_per_clip=5, + save_mel=True + ): + self.frame_rate = frame_rate + self.sample_rate = sample_rate + self.num_bins_per_frame = int(sample_rate / hop_size / frame_rate) + self.num_frames_per_clip = num_frames_per_clip + self.silence_threshold = 2 + self.num_mels = num_mels + self.save_mel = save_mel + self.fmin = 125 + self.fmax = 7600 + self.fft_size = fft_size + self.hop_size = hop_size + self.frame_shift_ms = None + self.min_level_db = -100 + self.ref_level_db = 20 + self.rescaling = True + self.rescaling_max = 0.999 + self.allow_clipping_in_normalization = True + self.log_scale_min = -32.23619130191664 + self.norm_audio = True + self.with_phase = False + + def load_wav(self, path): + return librosa.core.load(path, sr=self.sample_rate)[0] + + def audio_normalize(self, samples, desired_rms=0.1, eps=1e-4): + rms = np.maximum(eps, np.sqrt(np.mean(samples ** 2))) + samples = samples * (desired_rms / rms) + return samples + + def generate_spectrogram_magphase(self, audio): + spectro = librosa.core.stft(audio, hop_length=self.get_hop_size(), n_fft=self.fft_size, center=True) + spectro_mag, spectro_phase = librosa.core.magphase(spectro) + spectro_mag = np.expand_dims(spectro_mag, axis=0) + if self.with_phase: + spectro_phase = np.expand_dims(np.angle(spectro_phase), axis=0) + return spectro_mag, spectro_phase + else: + return spectro_mag + + def save_wav(self, wav, path): + wav *= 32767 / max(0.01, np.max(np.abs(wav))) + wavfile.write(path, self.sample_rate, wav.astype(np.int16)) + + def trim(self, quantized): + start, end = self.start_and_end_indices(quantized, self.silence_threshold) + return quantized[start:end] + + def adjust_time_resolution(self, quantized, mel): + """Adjust time resolution by repeating features + + Args: + quantized (ndarray): (T,) + mel (ndarray): (N, D) + + Returns: + tuple: Tuple of (T,) and (T, D) + """ + assert len(quantized.shape) == 1 + assert len(mel.shape) == 2 + + upsample_factor = quantized.size // mel.shape[0] + mel = np.repeat(mel, upsample_factor, axis=0) + n_pad = quantized.size - mel.shape[0] + if n_pad != 0: + assert n_pad > 0 + mel = np.pad(mel, [(0, n_pad), (0, 0)], mode="constant", constant_values=0) + + # trim + start, end = self.start_and_end_indices(quantized, self.silence_threshold) + + return quantized[start:end], mel[start:end, :] + + adjast_time_resolution = adjust_time_resolution # 'adjust' is correct spelling, this is for compatibility + + def start_and_end_indices(self, quantized, silence_threshold=2): + for start in range(quantized.size): + if abs(quantized[start] - 127) > silence_threshold: + break + for end in range(quantized.size - 1, 1, -1): + if abs(quantized[end] - 127) > silence_threshold: + break + + assert abs(quantized[start] - 127) > silence_threshold + assert abs(quantized[end] - 127) > silence_threshold + + return start, end + + def melspectrogram(self, y): + D = self._lws_processor().stft(y).T + S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db + if not self.allow_clipping_in_normalization: + assert S.max() <= 0 and S.min() - self.min_level_db >= 0 + return self._normalize(S) + + def get_hop_size(self): + hop_size = self.hop_size + if hop_size is None: + assert self.frame_shift_ms is not None + hop_size = int(self.frame_shift_ms / 1000 * self.sample_rate) + return hop_size + + def _lws_processor(self): + return lws.lws(self.fft_size, self.get_hop_size(), mode="speech") + + def lws_num_frames(self, length, fsize, fshift): + """Compute number of time frames of lws spectrogram + """ + pad = (fsize - fshift) + if length % fshift == 0: + M = (length + pad * 2 - fsize) // fshift + 1 + else: + M = (length + pad * 2 - fsize) // fshift + 2 + return M + + def lws_pad_lr(self, x, fsize, fshift): + """Compute left and right padding lws internally uses + """ + M = self.lws_num_frames(len(x), fsize, fshift) + pad = (fsize - fshift) + T = len(x) + 2 * pad + r = (M - 1) * fshift + fsize - T + return pad, pad + r + + + def _linear_to_mel(self, spectrogram): + global _mel_basis + _mel_basis = self._build_mel_basis() + return np.dot(_mel_basis, spectrogram) + + def _build_mel_basis(self): + assert self.fmax <= self.sample_rate // 2 + return librosa.filters.mel(self.sample_rate, self.fft_size, + fmin=self.fmin, fmax=self.fmax, + n_mels=self.num_mels) + + def _amp_to_db(self, x): + min_level = np.exp(self.min_level_db / 20 * np.log(10)) + return 20 * np.log10(np.maximum(min_level, x)) + + def _db_to_amp(self, x): + return np.power(10.0, x * 0.05) + + def _normalize(self, S): + return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1) + + def _denormalize(self, S): + return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db + + def read_audio(self, audio_path): + wav = self.load_wav(audio_path) + if self.norm_audio: + wav = self.audio_normalize(wav) + else: + wav = wav / np.abs(wav).max() + + return wav + + def audio_to_spectrogram(self, wav): + if self.save_mel: + spectrogram = self.melspectrogram(wav).astype(np.float32).T + else: + spectrogram = self.generate_spectrogram_magphase(wav) + return spectrogram diff --git a/Talking-Face_PC-AVS/data/__init__.py b/Talking-Face_PC-AVS/data/__init__.py new file mode 100644 index 00000000..946a7485 --- /dev/null +++ b/Talking-Face_PC-AVS/data/__init__.py @@ -0,0 +1,79 @@ +import importlib +import torch.utils.data +from data.base_dataset import BaseDataset + + +def find_dataset_using_name(dataset_name): + # Given the option --dataset [datasetname], + # the file "datasets/datasetname_dataset.py" + # will be imported. + dataset_filename = "data." + dataset_name + "_dataset" + datasetlib = importlib.import_module(dataset_filename) + + # In the file, the class called DatasetNameDataset() will + # be instantiated. It has to be a subclass of BaseDataset, + # and it is case-insensitive. + dataset = None + target_dataset_name = dataset_name.replace('_', '') + 'dataset' + for name, cls in datasetlib.__dict__.items(): + if name.lower() == target_dataset_name.lower() \ + and issubclass(cls, BaseDataset): + dataset = cls + + if dataset is None: + raise ValueError("In %s.py, there should be a subclass of BaseDataset " + "with class name that matches %s in lowercase." % + (dataset_filename, target_dataset_name)) + + return dataset + + +def get_option_setter(dataset_name): + dataset_class = find_dataset_using_name(dataset_name) + return dataset_class.modify_commandline_options + + +def create_dataloader(opt): + dataset_modes = opt.dataset_mode.split(',') + if len(dataset_modes) == 1: + dataset = find_dataset_using_name(opt.dataset_mode) + instance = dataset() + instance.initialize(opt) + print("dataset [%s] of size %d was created" % + (type(instance).__name__, len(instance))) + if not opt.isTrain: + shuffle = False + else: + shuffle = True + dataloader = torch.utils.data.DataLoader( + instance, + batch_size=opt.batchSize, + shuffle=shuffle, + num_workers=int(opt.nThreads), + drop_last=opt.isTrain + ) + return dataloader + + else: + dataloader_dict = {} + for dataset_mode in dataset_modes: + dataset = find_dataset_using_name(dataset_mode) + instance = dataset() + instance.initialize(opt) + print("dataset [%s] of size %d was created" % + (type(instance).__name__, len(instance))) + if not opt.isTrain: + shuffle = not opt.defined_driven + else: + shuffle = True + dataloader = torch.utils.data.DataLoader( + instance, + batch_size=opt.batchSize, + shuffle=shuffle, + num_workers=int(opt.nThreads), + drop_last=opt.isTrain + ) + dataloader_dict[dataset_mode] = dataloader + return dataloader_dict + + diff --git a/Talking-Face_PC-AVS/data/base_dataset.py b/Talking-Face_PC-AVS/data/base_dataset.py new file mode 100644 index 00000000..889522bf --- /dev/null +++ b/Talking-Face_PC-AVS/data/base_dataset.py @@ -0,0 +1,100 @@ +import torch.utils.data as data +import torch +import torchvision.transforms as transforms +import numpy as np +import cv2 + + +class BaseDataset(data.Dataset): + def __init__(self): + super(BaseDataset, self).__init__() + + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + def initialize(self, opt): + pass + + def to_Tensor(self, img): + if img.ndim == 3: + wrapped_img = img.transpose(2, 0, 1) / 255.0 + elif img.ndim == 4: + wrapped_img = img.transpose(0, 3, 1, 2) / 255.0 + else: + wrapped_img = img / 255.0 + wrapped_img = torch.from_numpy(wrapped_img).float() + + return wrapped_img * 2 - 1 + + def face_augmentation(self, img, crop_size): + img = self._color_transfer(img) + img = self._reshape(img, crop_size) + img = self._blur_and_sharp(img) + return img + + def _blur_and_sharp(self, img): + blur = np.random.randint(0, 2) + img2 = img.copy() + output = [] + for i in range(len(img2)): + if blur: + ksize = np.random.choice([3, 5, 7, 9]) + output.append(cv2.medianBlur(img2[i], ksize)) + else: + kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]]) + output.append(cv2.filter2D(img2[i], -1, kernel)) + output = np.stack(output) + return output + + def _color_transfer(self, img): + + transfer_c = np.random.uniform(0.3, 1.6) + + start_channel = np.random.randint(0, 2) + end_channel = np.random.randint(start_channel + 1, 4) + + img2 = img.copy() + + img2[:, :, :, start_channel:end_channel] = np.minimum(np.maximum(img[:, :, :, start_channel:end_channel] * transfer_c, np.zeros(img[:, :, :, start_channel:end_channel].shape)), + np.ones(img[:, :, :, start_channel:end_channel].shape) * 255) + return img2 + + def perspective_transform(self, img, crop_size=224, pers_size=10, enlarge_size=-10): + h, w, c = img.shape + dst = np.array([ + [-enlarge_size, -enlarge_size], + [-enlarge_size + pers_size, w + enlarge_size], + [h + enlarge_size, -enlarge_size], + [h + enlarge_size - pers_size, w + enlarge_size],], dtype=np.float32) + src = np.array([[-enlarge_size, -enlarge_size], [-enlarge_size, w + enlarge_size], + [h + enlarge_size, -enlarge_size], [h + enlarge_size, w + enlarge_size]]).astype(np.float32()) + M = cv2.getPerspectiveTransform(src, dst) + warped = cv2.warpPerspective(img, M, (crop_size, crop_size), borderMode=cv2.BORDER_REPLICATE) + return warped, M + + def _reshape(self, img, crop_size): + reshape = np.random.randint(0, 2) + reshape_size = np.random.randint(15, 25) + extra_padding_size = np.random.randint(0, reshape_size // 2) + pers_size = np.random.randint(20, 30) * pow(-1, np.random.randint(2)) + + enlarge_size = np.random.randint(20, 40) * pow(-1, np.random.randint(2)) + shape = img[0].shape + img2 = img.copy() + output = [] + for i in range(len(img2)): + if reshape: + im = cv2.resize(img2[i], (shape[0] - reshape_size*2, shape[1] + reshape_size*2)) + im = cv2.copyMakeBorder(im, 0, 0, reshape_size + extra_padding_size, reshape_size + extra_padding_size, cv2.cv2.BORDER_REFLECT) + im = im[reshape_size - extra_padding_size:shape[0] + reshape_size + extra_padding_size, :, :] + im, _ = self.perspective_transform(im, crop_size=crop_size, pers_size=pers_size, enlarge_size=enlarge_size) + output.append(im) + else: + im = cv2.resize(img2[i], (shape[0] + reshape_size*2, shape[1] - reshape_size*2)) + im = cv2.copyMakeBorder(im, reshape_size + extra_padding_size, reshape_size + extra_padding_size, 0, 0, cv2.cv2.BORDER_REFLECT) + im = im[:, reshape_size - extra_padding_size:shape[0] + reshape_size + extra_padding_size, :] + im, _ = self.perspective_transform(im, crop_size=crop_size, pers_size=pers_size, enlarge_size=enlarge_size) + output.append(im) + output = np.stack(output) + return output \ No newline at end of file diff --git a/Talking-Face_PC-AVS/data/niqe_image_params.mat b/Talking-Face_PC-AVS/data/niqe_image_params.mat new file mode 100644 index 00000000..53df0998 Binary files /dev/null and b/Talking-Face_PC-AVS/data/niqe_image_params.mat differ diff --git "a/Talking-Face_PC-AVS/data/niqe_image_params.mat\357\200\272Zone.Identifier" "b/Talking-Face_PC-AVS/data/niqe_image_params.mat\357\200\272Zone.Identifier" new file mode 100644 index 00000000..43a1a01e --- /dev/null +++ "b/Talking-Face_PC-AVS/data/niqe_image_params.mat\357\200\272Zone.Identifier" @@ -0,0 +1,3 @@ +[ZoneTransfer] +ZoneId=3 +HostUrl=https://github.com/ diff --git a/Talking-Face_PC-AVS/data/voxtest_dataset.py b/Talking-Face_PC-AVS/data/voxtest_dataset.py new file mode 100644 index 00000000..7c9a38d9 --- /dev/null +++ b/Talking-Face_PC-AVS/data/voxtest_dataset.py @@ -0,0 +1,200 @@ +import os +import math +import numpy as np +from config import AudioConfig +import shutil +import cv2 +import glob +import random +import torch +from data.base_dataset import BaseDataset +import util.util as util + + +class VOXTestDataset(BaseDataset): + + @staticmethod + def modify_commandline_options(parser, is_train): + parser.add_argument('--no_pairing_check', action='store_true', + help='If specified, skip sanity check of correct label-image file pairing') + return parser + + def cv2_loader(self, img_str): + img_array = np.frombuffer(img_str, dtype=np.uint8) + return cv2.imdecode(img_array, cv2.IMREAD_COLOR) + + def load_img(self, image_path, M=None, crop=True, crop_len=16): + img = cv2.imread(image_path) + + if img is None: + print(f"error:The path isssssssssssssssssssssssssssssss {image_path}") + raise Exception('None Image') + + if M is not None: + img = cv2.warpAffine(img, M, (self.opt.crop_size, self.opt.crop_size), borderMode=cv2.BORDER_REPLICATE) + + if crop: + img = img[:self.opt.crop_size - crop_len*2, crop_len:self.opt.crop_size - crop_len] + if self.opt.target_crop_len > 0: + img = img[self.opt.target_crop_len:self.opt.crop_size - self.opt.target_crop_len, self.opt.target_crop_len:self.opt.crop_size - self.opt.target_crop_len] + img = cv2.resize(img, (self.opt.crop_size, self.opt.crop_size)) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + def fill_list(self, tmp_list): + length = len(tmp_list) + if length % self.opt.batchSize != 0: + end = math.ceil(length / self.opt.batchSize) * self.opt.batchSize + tmp_list = tmp_list + tmp_list[-1 * (end - length) :] + return tmp_list + + def frame2audio_indexs(self, frame_inds): + start_frame_ind = frame_inds - self.audio.num_frames_per_clip // 2 + + start_audio_inds = start_frame_ind * self.audio.num_bins_per_frame + return start_audio_inds + + def initialize(self, opt): + self.opt = opt + self.path_label = opt.path_label + self.clip_len = opt.clip_len + self.frame_interval = opt.frame_interval + self.num_clips = opt.num_clips + self.frame_rate = opt.frame_rate + self.num_inputs = opt.num_inputs + self.filename_tmpl = opt.filename_tmpl + + self.mouth_num_frames = None + self.mouth_frame_path = None + self.pose_num_frames = None + + self.audio = AudioConfig.AudioConfig(num_frames_per_clip=opt.num_frames_per_clip, hop_size=opt.hop_size) + self.num_audio_bins = self.audio.num_frames_per_clip * self.audio.num_bins_per_frame + + + assert len(opt.path_label.split()) == 8, opt.path_label + id_path, ref_num, \ + pose_frame_path, pose_num_frames, \ + audio_path, mouth_frame_path, mouth_num_frames, spectrogram_path = opt.path_label.split() + + + id_idx, mouth_idx = id_path.split('/')[-1], audio_path.split('/')[-1].split('.')[0] + if not os.path.isdir(pose_frame_path): + pose_frame_path = id_path + pose_num_frames = 1 + + pose_idx = pose_frame_path.split('/')[-1] + id_idx, pose_idx, mouth_idx = str(id_idx), str(pose_idx), str(mouth_idx) + + self.processed_file_savepath = os.path.join('results', 'id_' + id_idx + '_pose_' + pose_idx + + '_audio_' + os.path.basename(audio_path)[:-4]) + if not os.path.exists(self.processed_file_savepath): os.makedirs(self.processed_file_savepath) + + + if not os.path.isfile(spectrogram_path): + wav = self.audio.read_audio(audio_path) + self.spectrogram = self.audio.audio_to_spectrogram(wav) + + else: + self.spectrogram = np.load(spectrogram_path) + + if os.path.isdir(mouth_frame_path): + self.mouth_frame_path = mouth_frame_path + self.mouth_num_frames = mouth_num_frames + + self.pose_num_frames = int(pose_num_frames) + + self.target_frame_inds = np.arange(2, len(self.spectrogram) // self.audio.num_bins_per_frame - 2) + self.audio_inds = self.frame2audio_indexs(self.target_frame_inds) + + self.dataset_size = len(self.target_frame_inds) + + id_img_paths = glob.glob(os.path.join(id_path, '*.jpg')) + glob.glob(os.path.join(id_path, '*.png')) + random.shuffle(id_img_paths) + opt.num_inputs = min(len(id_img_paths), opt.num_inputs) + id_img_tensors = [] + + for i, image_path in enumerate(id_img_paths): + id_img_tensor = self.to_Tensor(self.load_img(image_path)) + id_img_tensors += [id_img_tensor] + shutil.copyfile(image_path, os.path.join(self.processed_file_savepath, 'ref_id_{}.jpg'.format(i))) + if i == (opt.num_inputs - 1): + break + self.id_img_tensor = torch.stack(id_img_tensors) + self.pose_frame_path = pose_frame_path + self.audio_path = audio_path + self.id_path = id_path + self.mouth_frame_path = mouth_frame_path + self.initialized = False + + + def paths_match(self, path1, path2): + filename1_without_ext = os.path.splitext(os.path.basename(path1)[-10:])[0] + filename2_without_ext = os.path.splitext(os.path.basename(path2)[-10:])[0] + return filename1_without_ext == filename2_without_ext + + def load_one_frame(self, frame_ind, video_path, M=None, crop=True): + filepath = os.path.join(video_path, self.filename_tmpl.format(frame_ind)) + img = self.load_img(filepath, M=M, crop=crop) + img = self.to_Tensor(img) + return img + + def load_spectrogram(self, audio_ind): + mel_shape = self.spectrogram.shape + + if (audio_ind + self.num_audio_bins) <= mel_shape[0] and audio_ind >= 0: + spectrogram = np.array(self.spectrogram[audio_ind:audio_ind + self.num_audio_bins, :]).astype('float32') + else: + print('(audio_ind {} + opt.num_audio_bins {}) > mel_shape[0] {} '.format(audio_ind, self.num_audio_bins, + mel_shape[0])) + if audio_ind > 0: + spectrogram = np.array(self.spectrogram[audio_ind:audio_ind + self.num_audio_bins, :]).astype('float32') + else: + spectrogram = np.zeros((self.num_audio_bins, mel_shape[1])).astype(np.float16).astype(np.float32) + + spectrogram = torch.from_numpy(spectrogram) + spectrogram = spectrogram.unsqueeze(0) + + spectrogram = spectrogram.transpose(-2, -1) + return spectrogram + + def __getitem__(self, index): + + img_index = self.target_frame_inds[index] + mel_index = self.audio_inds[index] + + pose_index = util.calc_loop_idx(img_index, self.pose_num_frames) + + pose_frame = self.load_one_frame(pose_index, self.pose_frame_path) + + if os.path.isdir(self.mouth_frame_path): + mouth_frame = self.load_one_frame(img_index, self.mouth_frame_path) + else: + mouth_frame = torch.zeros_like(pose_frame) + + spectrograms = self.load_spectrogram(mel_index) + + input_dict = { + 'input': self.id_img_tensor, + 'target': mouth_frame, + 'driving_pose_frames': pose_frame, + 'augmented': pose_frame, + 'label': torch.zeros(1), + } + if self.opt.use_audio: + input_dict['spectrograms'] = spectrograms + + # Give subclasses a chance to modify the final output + self.postprocess(input_dict) + + return input_dict + + def postprocess(self, input_dict): + return input_dict + + def __len__(self): + return self.dataset_size + + def get_processed_file_savepath(self): + return self.processed_file_savepath diff --git a/Talking-Face_PC-AVS/evaluate/NIQE.py b/Talking-Face_PC-AVS/evaluate/NIQE.py new file mode 100644 index 00000000..8ec438d6 --- /dev/null +++ b/Talking-Face_PC-AVS/evaluate/NIQE.py @@ -0,0 +1,267 @@ +import numpy as np +import scipy.misc +import scipy.io +from os.path import dirname +from os.path import join +import scipy +from PIL import Image +import numpy as np +import scipy.ndimage +import numpy as np +import scipy.special +import math +import os + + +gamma_range = np.arange(0.2, 10, 0.001) +a = scipy.special.gamma(2.0/gamma_range) +a *= a +b = scipy.special.gamma(1.0/gamma_range) +c = scipy.special.gamma(3.0/gamma_range) +prec_gammas = a/(b*c) + +def aggd_features(imdata): + #flatten imdata + imdata.shape = (len(imdata.flat),) + imdata2 = imdata*imdata + left_data = imdata2[imdata<0] + right_data = imdata2[imdata>=0] + left_mean_sqrt = 0 + right_mean_sqrt = 0 + if len(left_data) > 0: + left_mean_sqrt = np.sqrt(np.average(left_data)) + if len(right_data) > 0: + right_mean_sqrt = np.sqrt(np.average(right_data)) + + if right_mean_sqrt != 0: + gamma_hat = left_mean_sqrt/right_mean_sqrt + else: + gamma_hat = np.inf + #solve r-hat norm + + imdata2_mean = np.mean(imdata2) + if imdata2_mean != 0: + r_hat = (np.average(np.abs(imdata))**2) / (np.average(imdata2)) + else: + r_hat = np.inf + rhat_norm = r_hat * (((math.pow(gamma_hat, 3) + 1)*(gamma_hat + 1)) / math.pow(math.pow(gamma_hat, 2) + 1, 2)) + + #solve alpha by guessing values that minimize ro + pos = np.argmin((prec_gammas - rhat_norm)**2); + alpha = gamma_range[pos] + + gam1 = scipy.special.gamma(1.0/alpha) + gam2 = scipy.special.gamma(2.0/alpha) + gam3 = scipy.special.gamma(3.0/alpha) + + aggdratio = np.sqrt(gam1) / np.sqrt(gam3) + bl = aggdratio * left_mean_sqrt + br = aggdratio * right_mean_sqrt + + #mean parameter + N = (br - bl)*(gam2 / gam1)#*aggdratio + return (alpha, N, bl, br, left_mean_sqrt, right_mean_sqrt) + +def ggd_features(imdata): + nr_gam = 1/prec_gammas + sigma_sq = np.var(imdata) + E = np.mean(np.abs(imdata)) + rho = sigma_sq/E**2 + pos = np.argmin(np.abs(nr_gam - rho)); + return gamma_range[pos], sigma_sq + +def paired_product(new_im): + shift1 = np.roll(new_im.copy(), 1, axis=1) + shift2 = np.roll(new_im.copy(), 1, axis=0) + shift3 = np.roll(np.roll(new_im.copy(), 1, axis=0), 1, axis=1) + shift4 = np.roll(np.roll(new_im.copy(), 1, axis=0), -1, axis=1) + + H_img = shift1 * new_im + V_img = shift2 * new_im + D1_img = shift3 * new_im + D2_img = shift4 * new_im + + return (H_img, V_img, D1_img, D2_img) + + +def gen_gauss_window(lw, sigma): + sd = np.float32(sigma) + lw = int(lw) + weights = [0.0] * (2 * lw + 1) + weights[lw] = 1.0 + sum = 1.0 + sd *= sd + for ii in range(1, lw + 1): + tmp = np.exp(-0.5 * np.float32(ii * ii) / sd) + weights[lw + ii] = tmp + weights[lw - ii] = tmp + sum += 2.0 * tmp + for ii in range(2 * lw + 1): + weights[ii] /= sum + return weights + +def compute_image_mscn_transform(image, C=1, avg_window=None, extend_mode='constant'): + if avg_window is None: + avg_window = gen_gauss_window(3, 7.0/6.0) + assert len(np.shape(image)) == 2 + h, w = np.shape(image) + mu_image = np.zeros((h, w), dtype=np.float32) + var_image = np.zeros((h, w), dtype=np.float32) + image = np.array(image).astype('float32') + scipy.ndimage.correlate1d(image, avg_window, 0, mu_image, mode=extend_mode) + scipy.ndimage.correlate1d(mu_image, avg_window, 1, mu_image, mode=extend_mode) + scipy.ndimage.correlate1d(image**2, avg_window, 0, var_image, mode=extend_mode) + scipy.ndimage.correlate1d(var_image, avg_window, 1, var_image, mode=extend_mode) + var_image = np.sqrt(np.abs(var_image - mu_image**2)) + return (image - mu_image)/(var_image + C), var_image, mu_image + + +def _niqe_extract_subband_feats(mscncoefs): + # alpha_m, = extract_ggd_features(mscncoefs) + alpha_m, N, bl, br, lsq, rsq = aggd_features(mscncoefs.copy()) + pps1, pps2, pps3, pps4 = paired_product(mscncoefs) + alpha1, N1, bl1, br1, lsq1, rsq1 = aggd_features(pps1) + alpha2, N2, bl2, br2, lsq2, rsq2 = aggd_features(pps2) + alpha3, N3, bl3, br3, lsq3, rsq3 = aggd_features(pps3) + alpha4, N4, bl4, br4, lsq4, rsq4 = aggd_features(pps4) + return np.array([alpha_m, (bl+br)/2.0, + alpha1, N1, bl1, br1, # (V) + alpha2, N2, bl2, br2, # (H) + alpha3, N3, bl3, bl3, # (D1) + alpha4, N4, bl4, bl4, # (D2) + ]) + +def get_patches_train_features(img, patch_size, stride=8): + return _get_patches_generic(img, patch_size, 1, stride) + +def get_patches_test_features(img, patch_size, stride=8): + return _get_patches_generic(img, patch_size, 0, stride) + +def extract_on_patches(img, patch_size): + h, w = img.shape + patch_size = np.int(patch_size) + patches = [] + for j in range(0, h-patch_size+1, patch_size): + for i in range(0, w-patch_size+1, patch_size): + patch = img[j:j+patch_size, i:i+patch_size] + patches.append(patch) + + patches = np.array(patches) + + patch_features = [] + for p in patches: + patch_features.append(_niqe_extract_subband_feats(p)) + patch_features = np.array(patch_features) + + return patch_features + +def _get_patches_generic(img, patch_size, is_train, stride): + h, w = np.shape(img) + if h < patch_size or w < patch_size: + print("Input image is too small") + exit(0) + + # ensure that the patch divides evenly into img + hoffset = (h % patch_size) + woffset = (w % patch_size) + + if hoffset > 0: + img = img[:-hoffset, :] + if woffset > 0: + img = img[:, :-woffset] + + + img = img.astype(np.float32) + img2 = scipy.misc.imresize(img, 0.5, interp='bicubic', mode='F') + + mscn1, var, mu = compute_image_mscn_transform(img) + mscn1 = mscn1.astype(np.float32) + + mscn2, _, _ = compute_image_mscn_transform(img2) + mscn2 = mscn2.astype(np.float32) + + + feats_lvl1 = extract_on_patches(mscn1, patch_size) + feats_lvl2 = extract_on_patches(mscn2, patch_size/2) + + feats = np.hstack((feats_lvl1, feats_lvl2))# feats_lvl3)) + + return feats + +def niqe(inputImgData): + + patch_size = 96 + module_path = dirname(__file__) + + # TODO: memoize + params = scipy.io.loadmat(join(module_path, 'data', 'niqe_image_params.mat')) + pop_mu = np.ravel(params["pop_mu"]) + pop_cov = params["pop_cov"] + + + M, N = inputImgData.shape + + # assert C == 1, "niqe called with videos containing %d channels. Please supply only the luminance channel" % (C,) + assert M > (patch_size*2+1), "niqe called with small frame size, requires > 192x192 resolution video using current training parameters" + assert N > (patch_size*2+1), "niqe called with small frame size, requires > 192x192 resolution video using current training parameters" + + + feats = get_patches_test_features(inputImgData, patch_size) + sample_mu = np.mean(feats, axis=0) + sample_cov = np.cov(feats.T) + + X = sample_mu - pop_mu + covmat = ((pop_cov+sample_cov)/2.0) + pinvmat = scipy.linalg.pinv(covmat) + niqe_score = np.sqrt(np.dot(np.dot(X, pinvmat), X)) + + return niqe_score + + +'''if __name__ == "__main__": + + ref = np.array(Image.open('./test_imgs/bikes.bmp').convert('LA'))[:,:,0] # ref + dis = np.array(Image.open('./test_imgs/bikes_distorted.bmp').convert('LA'))[:,:,0] # dis + + print('NIQE of ref bikes image is: %0.3f'% niqe(ref)) + print('NIQE of dis bikes image is: %0.3f'% niqe(dis)) + + ref = np.array(Image.open('./test_imgs/parrots.bmp').convert('LA'))[:,:,0] # ref + dis = np.array(Image.open('./test_imgs/parrots_distorted.bmp').convert('LA'))[:,:,0] # dis + + print('NIQE of ref parrot image is: %0.3f'% niqe(ref)) + print('NIQE of dis parrot image is: %0.3f'% niqe(dis))''' +'''if __name__ == "__main__": + def read_image_as_grayscale(image_path): + img = Image.open(image_path).convert('L') + return np.array(img) + ref_bikes = read_image_as_grayscale('/home/lanlan/Talking-Face_PC-AVS-main/results/id_00010_pose_00473_audio_741400104/G_Pose_Driven_/G_Pose_Driven_0.jpg') # 注意文件扩展名 + + print('NIQE of the image is: %0.3f' % niqe(ref_bikes))''' + +def read_image_as_grayscale(image_path): + img = Image.open(image_path).convert('L') + return np.array(img) + +def evaluate_niqe_in_folder(folder_path): + # 遍历文件夹中的所有文件 + niqe_values = [] + for filename in os.listdir(folder_path): + # 检查文件扩展名,确保只处理图像文件(这里以.jpg为例) + if filename.endswith(".jpg") or filename.endswith(".jpeg") or filename.endswith(".png"): + # 构建完整的文件路径 + image_path = os.path.join(folder_path, filename) + # 读取图像并转换为灰度图 + img_gray = read_image_as_grayscale(image_path) + # 计算NIQE值 + niqe_value = niqe(img_gray) + niqe_values.append(niqe_value) + # 打印NIQE值和对应的图像文件名 + print(f'NIQE of the image "{filename}" is: %0.3f' % niqe_value) + average_niqe = sum(niqe_values) / len(niqe_values) if niqe_values else 0 + print(f'Average NIQE of all images in the folder is: %0.3f' % average_niqe) + +if __name__ == "__main__": + folder_to_evaluate = '/home/lanlan/Talking-Face_PC-AVS-main/results/id_00010_pose_00473_audio_741400104/G_Pose_Driven_' +# 调用函数评估文件夹中的所有图像 + evaluate_niqe_in_folder(folder_to_evaluate) \ No newline at end of file diff --git a/Talking-Face_PC-AVS/evaluate/data/niqe_image_params.mat b/Talking-Face_PC-AVS/evaluate/data/niqe_image_params.mat new file mode 100644 index 00000000..53df0998 Binary files /dev/null and b/Talking-Face_PC-AVS/evaluate/data/niqe_image_params.mat differ diff --git "a/Talking-Face_PC-AVS/evaluate/data/niqe_image_params.mat\357\200\272Zone.Identifier" "b/Talking-Face_PC-AVS/evaluate/data/niqe_image_params.mat\357\200\272Zone.Identifier" new file mode 100644 index 00000000..43a1a01e --- /dev/null +++ "b/Talking-Face_PC-AVS/evaluate/data/niqe_image_params.mat\357\200\272Zone.Identifier" @@ -0,0 +1,3 @@ +[ZoneTransfer] +ZoneId=3 +HostUrl=https://github.com/ diff --git a/Talking-Face_PC-AVS/experiments/demo_vox.sh b/Talking-Face_PC-AVS/experiments/demo_vox.sh new file mode 100644 index 00000000..7b4b41c1 --- /dev/null +++ b/Talking-Face_PC-AVS/experiments/demo_vox.sh @@ -0,0 +1,26 @@ +meta_path_vox='./misc/demo.csv' + +python -u inference.py \ + --name demo \ + --meta_path_vox ${meta_path_vox} \ + --dataset_mode voxtest \ + --netG modulate \ + --netA resseaudio \ + --netA_sync ressesync \ + --netD multiscale \ + --netV resnext \ + --netE fan \ + --model av \ + --gpu_ids 0 \ + --clip_len 1 \ + --batchSize 16 \ + --style_dim 2560 \ + --nThreads 4 \ + --input_id_feature \ + --generate_interval 1 \ + --style_feature_loss \ + --use_audio 1 \ + --noise_pose \ + --driving_pose \ + --gen_video \ + --generate_from_audio_only \ diff --git a/Talking-Face_PC-AVS/experiments/demo_vox_new.sh b/Talking-Face_PC-AVS/experiments/demo_vox_new.sh new file mode 100644 index 00000000..3a92030f --- /dev/null +++ b/Talking-Face_PC-AVS/experiments/demo_vox_new.sh @@ -0,0 +1,26 @@ +meta_path_vox='./raw/metadata.csv' + +python -u inference.py \ + --name demo \ + --meta_path_vox ${meta_path_vox} \ + --dataset_mode voxtest \ + --netG modulate \ + --netA resseaudio \ + --netA_sync ressesync \ + --netD multiscale \ + --netV resnext \ + --netE fan \ + --model av \ + --gpu_ids 0 \ + --clip_len 1 \ + --batchSize 16 \ + --style_dim 2560 \ + --nThreads 4 \ + --input_id_feature \ + --generate_interval 1 \ + --style_feature_loss \ + --use_audio 1 \ + --noise_pose \ + --driving_pose \ + --gen_video \ + --generate_from_audio_only \ diff --git a/Talking-Face_PC-AVS/image-1.png b/Talking-Face_PC-AVS/image-1.png new file mode 100644 index 00000000..f5a45c75 Binary files /dev/null and b/Talking-Face_PC-AVS/image-1.png differ diff --git a/Talking-Face_PC-AVS/image.png b/Talking-Face_PC-AVS/image.png new file mode 100644 index 00000000..19db1651 Binary files /dev/null and b/Talking-Face_PC-AVS/image.png differ diff --git a/Talking-Face_PC-AVS/inference.py b/Talking-Face_PC-AVS/inference.py new file mode 100644 index 00000000..224cbbb5 --- /dev/null +++ b/Talking-Face_PC-AVS/inference.py @@ -0,0 +1,117 @@ +import os +import sys +sys.path.append('..') +from options.test_options import TestOptions +import torch +from models import create_model +import data +import util.util as util +from tqdm import tqdm + + +def video_concat(processed_file_savepath, name, video_names, audio_path): + cmd = ['ffmpeg'] + num_inputs = len(video_names) + for video_name in video_names: + cmd += ['-i', '\'' + str(os.path.join(processed_file_savepath, video_name + '.mp4'))+'\'',] + + cmd += ['-filter_complex hstack=inputs=' + str(num_inputs), + '\'' + str(os.path.join(processed_file_savepath, name+'.mp4')) + '\'', '-loglevel error -y'] + cmd = ' '.join(cmd) + os.system(cmd) + + video_add_audio(name, audio_path, processed_file_savepath) + + +def video_add_audio(name, audio_path, processed_file_savepath): + os.system('cp {} {}'.format(audio_path, processed_file_savepath)) + cmd = ['ffmpeg', '-i', '\'' + os.path.join(processed_file_savepath, name + '.mp4') + '\'', + '-i', audio_path, + '-q:v 0', + '-strict -2', + '\'' + os.path.join(processed_file_savepath, 'av' + name + '.mp4') + '\'', + '-loglevel error -y'] + cmd = ' '.join(cmd) + os.system(cmd) + + +def img2video(dst_path, prefix, video_path): + cmd = ['ffmpeg', '-i', '\'' + video_path + '/' + prefix + '%d.jpg' + + '\'', '-q:v 0', '\'' + dst_path + '/' + prefix + '.mp4' + '\'', '-loglevel error -y'] + cmd = ' '.join(cmd) + os.system(cmd) + + +def inference_single_audio(opt, path_label, model): + # + opt.path_label = path_label + dataloader = data.create_dataloader(opt) + processed_file_savepath = dataloader.dataset.get_processed_file_savepath() + + idx = 0 + if opt.driving_pose: + video_names = ['Input_', 'G_Pose_Driven_', 'Pose_Source_', 'Mouth_Source_'] + else: + video_names = ['Input_', 'G_Fix_Pose_', 'Mouth_Source_'] + is_mouth_frame = os.path.isdir(dataloader.dataset.mouth_frame_path) + if not is_mouth_frame: + video_names.pop() + save_paths = [] + for name in video_names: + save_path = os.path.join(processed_file_savepath, name) + util.mkdir(save_path) + save_paths.append(save_path) + for data_i in tqdm(dataloader): + # print('==============', i, '===============') + fake_image_original_pose_a, fake_image_driven_pose_a = model.forward(data_i, mode='inference') + + for num in range(len(fake_image_driven_pose_a)): + util.save_torch_img(data_i['input'][num], os.path.join(save_paths[0], video_names[0] + str(idx) + '.jpg')) + if opt.driving_pose: + util.save_torch_img(fake_image_driven_pose_a[num], + os.path.join(save_paths[1], video_names[1] + str(idx) + '.jpg')) + util.save_torch_img(data_i['driving_pose_frames'][num], + os.path.join(save_paths[2], video_names[2] + str(idx) + '.jpg')) + else: + util.save_torch_img(fake_image_original_pose_a[num], + os.path.join(save_paths[1], video_names[1] + str(idx) + '.jpg')) + if is_mouth_frame: + util.save_torch_img(data_i['target'][num], os.path.join(save_paths[-1], video_names[-1] + str(idx) + '.jpg')) + idx += 1 + + if opt.gen_video: + for i, video_name in enumerate(video_names): + img2video(processed_file_savepath, video_name, save_paths[i]) + video_concat(processed_file_savepath, 'concat', video_names, dataloader.dataset.audio_path) + + print('results saved...' + processed_file_savepath) + del dataloader + return + + +def main(): + + opt = TestOptions().parse() + opt.isTrain = False + torch.manual_seed(0) + model = create_model(opt).cuda() + model.eval() + + with open(opt.meta_path_vox, 'r') as f: + lines = f.read().splitlines() + + for clip_idx, path_label in enumerate(lines): + try: + assert len(path_label.split()) == 8, path_label + + inference_single_audio(opt, path_label, model) + + except Exception as ex: + import traceback + traceback.print_exc() + print(path_label + '\n') + print(str(ex)) + + +if __name__ == '__main__': + main() diff --git a/Talking-Face_PC-AVS/misc/Audio_Source.zip b/Talking-Face_PC-AVS/misc/Audio_Source.zip new file mode 100644 index 00000000..842b8aff Binary files /dev/null and b/Talking-Face_PC-AVS/misc/Audio_Source.zip differ diff --git a/Talking-Face_PC-AVS/misc/Input.zip b/Talking-Face_PC-AVS/misc/Input.zip new file mode 100644 index 00000000..2598533f Binary files /dev/null and b/Talking-Face_PC-AVS/misc/Input.zip differ diff --git a/Talking-Face_PC-AVS/misc/Mouth_Source.zip b/Talking-Face_PC-AVS/misc/Mouth_Source.zip new file mode 100644 index 00000000..15ca9a53 Binary files /dev/null and b/Talking-Face_PC-AVS/misc/Mouth_Source.zip differ diff --git a/Talking-Face_PC-AVS/misc/Pose_Source.zip b/Talking-Face_PC-AVS/misc/Pose_Source.zip new file mode 100644 index 00000000..bd9fc98e Binary files /dev/null and b/Talking-Face_PC-AVS/misc/Pose_Source.zip differ diff --git a/Talking-Face_PC-AVS/misc/demo.csv b/Talking-Face_PC-AVS/misc/demo.csv new file mode 100644 index 00000000..6524efbf --- /dev/null +++ b/Talking-Face_PC-AVS/misc/demo.csv @@ -0,0 +1 @@ +misc/Input/517600055 1 misc/Pose_Source/517600078 160 misc/Audio_Source/681600002.mp3 misc/Mouth_Source/681600002 363 dummy diff --git a/Talking-Face_PC-AVS/misc/demo.gif b/Talking-Face_PC-AVS/misc/demo.gif new file mode 100644 index 00000000..b3a72d3b Binary files /dev/null and b/Talking-Face_PC-AVS/misc/demo.gif differ diff --git a/Talking-Face_PC-AVS/misc/demo_id.gif b/Talking-Face_PC-AVS/misc/demo_id.gif new file mode 100644 index 00000000..257a80e4 Binary files /dev/null and b/Talking-Face_PC-AVS/misc/demo_id.gif differ diff --git a/Talking-Face_PC-AVS/misc/method.png b/Talking-Face_PC-AVS/misc/method.png new file mode 100644 index 00000000..bbc27bd7 Binary files /dev/null and b/Talking-Face_PC-AVS/misc/method.png differ diff --git a/Talking-Face_PC-AVS/misc/output.gif b/Talking-Face_PC-AVS/misc/output.gif new file mode 100644 index 00000000..bd4e1f5f Binary files /dev/null and b/Talking-Face_PC-AVS/misc/output.gif differ diff --git a/Talking-Face_PC-AVS/models/__init__.py b/Talking-Face_PC-AVS/models/__init__.py new file mode 100644 index 00000000..908658cf --- /dev/null +++ b/Talking-Face_PC-AVS/models/__init__.py @@ -0,0 +1,36 @@ +import importlib + +def find_model_using_name(model_name): + # Given the option --model [modelname], + # the file "models/modelname_model.py" + # will be imported. + model_filename = "models." + model_name + "_model" + modellib = importlib.import_module(model_filename) + + # In the file, the class called ModelNameModel() will + # be instantiated. It has to be a subclass of torch.nn.Module, + # and it is case-insensitive. + model = None + target_model_name = model_name.replace('_', '') + 'model' + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower(): + model = cls + + if model is None: + print("In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase." % (model_filename, target_model_name)) + exit(0) + + return model + + +def get_option_setter(model_name): + model_class = find_model_using_name(model_name) + return model_class.modify_commandline_options + + +def create_model(opt): + model = find_model_using_name(opt.model) + instance = model(opt) + print("model [%s] was created" % (type(instance).__name__)) + + return instance diff --git a/Talking-Face_PC-AVS/models/av_model.py b/Talking-Face_PC-AVS/models/av_model.py new file mode 100644 index 00000000..8f656104 --- /dev/null +++ b/Talking-Face_PC-AVS/models/av_model.py @@ -0,0 +1,804 @@ +import torch +import models.networks as networks +from models.networks.architecture import VGGFace19 +import util.util as util +from models.networks.loss import CrossEntropyLoss +import os + + +class AvModel(torch.nn.Module): + @staticmethod + def modify_commandline_options(parser, is_train): + networks.modify_commandline_options(parser, is_train) + return parser + + def __init__(self, opt): + super(AvModel, self).__init__() + self.opt = opt + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) + self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \ + else torch.FloatTensor + self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \ + else torch.ByteTensor + self.netG, self.netD, self.netA, self.netA_sync, self.netV, self.netE = \ + self.initialize_networks(opt) + + # set loss functions + if opt.isTrain: + self.loss_cls = CrossEntropyLoss() + self.criterionFeat = torch.nn.L1Loss() + + if opt.softmax_contrastive: + self.criterionSoftmaxContrastive = networks.SoftmaxContrastiveLoss() + if opt.train_recognition or opt.train_sync: + pass + + else: + self.criterionGAN = networks.GANLoss( + opt.gan_mode, tensor=self.FloatTensor, opt=self.opt) + + if not opt.no_vgg_loss: + self.criterionVGG = networks.VGGLoss(self.opt) + + if opt.vgg_face: + self.VGGFace = VGGFace19(self.opt) + self.criterionVGGFace = networks.VGGLoss(self.opt, self.VGGFace) + + if opt.disentangle: + self.criterionLogSoftmax = networks.L2SoftmaxLoss() + + # Entry point for all calls involving forward pass + # of deep networks. We used this approach since DataParallel module + # can't parallelize custom functions, we branch to different + # routines based on |mode|. + # |data|: dictionary of the input data + def preprocessing(self, data): + target_images = data['target'].cuda() + input_image = data['input'].cuda() + augmented = data['augmented'].cuda() + spectrogram = data['spectrograms'].cuda() if self.opt.use_audio else None + + target_images = target_images.view(-1, self.opt.output_nc, self.opt.crop_size, self.opt.crop_size) + augmented = augmented.view(-1, self.opt.output_nc, self.opt.crop_size, self.opt.crop_size) + + return input_image, target_images, augmented, spectrogram + + def forward(self, data, mode): + labels = data['label'] + input_image, target_images, augmentated, spectrogram = self.preprocessing(data) + if mode == 'generator': + g_loss, generated, id_scores = self.compute_generator_loss( + input_image, target_images, augmentated, spectrogram, + netD=self.netD, labels=labels, no_ganFeat_loss=self.opt.no_ganFeat_loss, + no_vgg_loss=self.opt.no_vgg_loss, lambda_D=self.opt.lambda_D) + return g_loss, generated, id_scores + if mode == 'encoder': + g_loss, cls_score = self.compute_encoder_loss( + input_image, target_images, spectrogram, labels) + return g_loss, cls_score + if mode == 'sync': + g_loss = self.sync(augmentated, spectrogram) + return g_loss + if mode == 'sync_D': + d_loss = self.sync_D(spectrogram, labels) + return d_loss + elif mode == 'discriminator': + d_loss = self.compute_discriminator_loss( + input_image, target_images, augmentated, spectrogram, netD=self.netD, labels=labels, lambda_D=self.opt.lambda_D) + return d_loss + elif mode == 'inference': + assert self.opt.use_audio, 'must use audio driven strategy.' + driving_pose_frames = data['driving_pose_frames'].cuda() + with torch.no_grad(): + fake_image_ref_pose_a, fake_image_driven_pose_a = self.inference(input_image, spectrogram, + driving_pose_frames) + return fake_image_ref_pose_a, fake_image_driven_pose_a + else: + raise ValueError("|mode| is invalid") + + def create_optimizers(self, opt): + optimizer_D = None + if opt.no_TTUR: + beta1, beta2 = opt.beta1, opt.beta2 + G_lr, D_lr = opt.lr, opt.lr + else: + beta1, beta2 = 0, 0.9 + G_lr, D_lr = opt.lr / 2, opt.lr * 2 + + if opt.train_recognition: + + util.freeze_model(self.netV) + for param in self.netV.fc.parameters(): + param.requires_grad = True + netV_params = list(self.netV.fc.parameters()) + netA_params = list(self.netA.parameters()) + G_params = netV_params + netA_params + + elif opt.train_sync: + + netA_sync_params = list(self.netA_sync.model.parameters()) + # netE_params = list(self.netE.model.parameters()) + netE_mouth_params = list(self.netE.to_mouth.parameters()) + G_params = netA_sync_params + netE_mouth_params + + D_params = list(self.netA_sync.fc.parameters()) + list(self.netE.classifier.parameters()) + optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2)) + + elif opt.train_dis_pose: + netE_pure_pose_params = list(self.netE.pure_pose.parameters())+list(self.netE.headpose_embed.parameters()) + netG_params = list(self.netG.parameters()) + netV_params = list(self.netV.parameters()) + netE_params = list(self.netE.model.parameters()) + netA_sync_params = list(self.netA_sync.parameters()) if self.opt.use_audio else None + netE_mouth_all_params = list(self.netE.to_mouth.parameters()) + list(self.netE.mouth_fc.parameters()) + + G_params = [] + + if not opt.fix_netE_mouth: + G_params = G_params + netE_mouth_all_params + else: + util.freeze_model(self.netE.to_mouth) + util.freeze_model(self.netE.mouth_fc) + + if not opt.fix_netE_headpose: + G_params = G_params + netE_pure_pose_params + else: + util.freeze_model(self.netE.pure_pose) + util.freeze_model(self.netE.headpose_embed) + + if not opt.fix_netG: + G_params = G_params + netG_params + else: + util.freeze_model(self.netG) + + if not opt.fix_netV: + G_params = G_params + netV_params + else: + util.freeze_model(self.netV) + + if not opt.fix_netE: + G_params = G_params + netE_params + else: + util.freeze_model(self.netE.model) + + if self.opt.use_audio: + if not opt.fix_netA_sync: + G_params = G_params + netA_sync_params + else: + util.freeze_model(self.netA_sync) + + if opt.isTrain: + D_params = list(self.netD.parameters()) + + if opt.disentangle: + + if not opt.fix_netE_headpose: + D_params = list(self.netE.headpose_fc.parameters()) + D_params + else: + util.freeze_model(self.netE.headpose_fc) + + if not opt.fix_netD: + optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2)) + else: + util.freeze_model(self.netD) + + else: + netG_params = list(self.netG.parameters()) + netA_sync_params = list(self.netA_sync.model.parameters()) if opt.use_audio else 0 + netE_mouth_params = list(self.netE.to_mouth.parameters()) + netV_params = list(self.netV.parameters()) + netE_params = list(self.netE.model.parameters()) + + G_params = netA_sync_params + netE_mouth_params + if not opt.fix_netV: + G_params = G_params + netV_params + else: + util.freeze_model(self.netV) + + if not opt.fix_netE: + G_params = G_params + netE_params + else: + util.freeze_model(self.netE) + + if not opt.fix_netG: + G_params = G_params + netG_params + else: + util.freeze_model(self.netG) + + if opt.isTrain: + D_params = list(self.netD.parameters()) + + if opt.disentangle: + D_params = list(self.netE.classifier.parameters()) + D_params + + if not opt.fix_netD: + optimizer_D = torch.optim.Adam(D_params, lr=D_lr, betas=(beta1, beta2)) + else: + util.freeze_model(self.netD) + + if opt.optimizer == 'sgd': + optimizer_G = torch.optim.SGD(G_params, lr=G_lr, momentum=0.9, nesterov=True) + else: + optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2), amsgrad=True) + + return optimizer_G, optimizer_D + + def save(self, epoch): + if self.opt.train_recognition: + util.save_network(self.netV, 'V', epoch, self.opt) + elif self.opt.train_sync: + util.save_network(self.netE, 'E', epoch, self.opt) + if self.opt.use_audio: + util.save_network(self.netA_sync, 'A_sync', epoch, self.opt) + else: + util.save_network(self.netG, 'G', epoch, self.opt) + # util.save_network(self.netD, 'D', epoch, self.opt) + if self.opt.use_audio: + if self.opt.use_audio_id: + util.save_network(self.netA, 'A', epoch, self.opt) + util.save_network(self.netA_sync, 'A_sync', epoch, self.opt) + util.save_network(self.netV, 'V', epoch, self.opt) + util.save_network(self.netE, 'E', epoch, self.opt) + + ############################################################################ + # Private helper methods + ############################################################################ + + + def initialize_networks(self, opt): + netG = None + netD = None + netE = None + netV = None + netA = None + netA_sync = None + if opt.train_recognition: + netV = networks.define_V(opt) + elif opt.train_sync: + netA_sync = networks.define_A_sync(opt) if opt.use_audio else None + netE = networks.define_E(opt) + else: + + netG = networks.define_G(opt) + netA = networks.define_A(opt) if opt.use_audio and opt.use_audio_id else None + netA_sync = networks.define_A_sync(opt) if opt.use_audio else None + netE = networks.define_E(opt) + netV = networks.define_V(opt) + + if opt.isTrain: + netD = networks.define_D(opt) + + if not opt.isTrain or opt.continue_train: + self.load_network(netG, 'G', opt.which_epoch) + self.load_network(netV, 'V', opt.which_epoch) + self.load_network(netE, 'E', opt.which_epoch) + if opt.use_audio: + if opt.use_audio_id: + self.load_network(netA, 'A', opt.which_epoch) + self.load_network(netA_sync, 'A_sync', opt.which_epoch) + + if opt.isTrain and not opt.noload_D: + self.load_network(netD, 'D', opt.which_epoch) + # self.load_network(netD_rotate, 'D_rotate', opt.which_epoch, pretrained_path) + + else: + if self.opt.pretrain: + if opt.netE == 'fan': + netE.load_pretrain() + netV.load_pretrain() + if opt.load_separately: + netG = self.load_separately(netG, 'G', opt) + netA = self.load_separately(netA, 'A', opt) if opt.use_audio and opt.use_audio_id else None + netA_sync = self.load_separately(netA_sync, 'A_sync', opt) if opt.use_audio else None + netV = self.load_separately(netV, 'V', opt) + netE = self.load_separately(netE, 'E', opt) + if not opt.noload_D: + netD = self.load_separately(netD, 'D', opt) + return netG, netD, netA, netA_sync, netV, netE + + def compute_encoder_loss(self, input_img, real_image, spectrogram, labels): + G_losses = {} + real_image = real_image.view(-1, self.opt.clip_len, self.opt.output_nc, self.opt.crop_size, self.opt.crop_size) + + [image_feature, net_V_feature], cls_score_V = self.netV.forward(real_image) + audio_feature, cls_score_A_2 = self.netA.forward(spectrogram) + audio_feature = audio_feature.view(-1, self.opt.clip_len, audio_feature.shape[-1]) + audio_feature = torch.mean(audio_feature, 1) + + G_losses['loss_cls_V'] = self.loss_cls(cls_score_V, labels) + cls_score_A = self.netV.fc.forward(audio_feature) + G_losses['loss_cls_A'] = self.loss_cls(cls_score_A, labels) + # G_losses['loss_cls_A_2'] = self.loss_cls(cls_score_A_2, labels) + if not self.opt.no_cross_modal: + G_losses['CrossModal'] = self.criterionFeat(image_feature.detach(), audio_feature) * self.opt.lambda_crossmodal + + if self.opt.softmax_contrastive: + G_losses['SoftmaxContrastive'] = self.criterionSoftmaxContrastive(image_feature.detach(), audio_feature) * self.opt.lambda_contrastive + + return G_losses, cls_score_A + + def sync_D(self, spectrogram, labels): + D_losses = {} + with torch.no_grad(): + audio_content_feature = self.netA_sync.forward_feature(spectrogram) + audio_content_feature = audio_content_feature.detach() + audio_content_feature.requires_grad_() + cls_score_A = self.netA_sync.fc.forward(audio_content_feature) + labels = labels.unsqueeze(1) + labels_expand = labels.expand(-1, self.opt.clip_len) + labels_expand = labels_expand.contiguous().view(-1) + D_losses['loss_cls_A'] = self.loss_cls(cls_score_A, labels_expand) + return D_losses + + + def encode_audiosync_feature(self, spectrogram): + + audio_content_feature = self.netA_sync.forward_feature(spectrogram) + + audio_content_feature = audio_content_feature.view(-1, self.opt.clip_len, audio_content_feature.shape[-1]) + return audio_content_feature + + def sync(self, augmented, spectrogram): + G_losses = {} + pose_feature = self.encode_noid_feature(augmented) + + audio_content_feature = self.encode_audiosync_feature(spectrogram) + + G_losses = self.compute_sync_loss(pose_feature, audio_content_feature, G_losses) + return G_losses + + def compute_sync_loss(self, image_content_feature, audio_content_feature, G_losses, name=''): + + audio_content_feature_all = audio_content_feature.view(audio_content_feature.shape[0], -1) + image_content_feature_all = image_content_feature.view(image_content_feature.shape[0], -1) + + if not self.opt.no_cross_modal: + G_losses['CrossModal{}'.format(name)] = self.criterionFeat(image_content_feature_all.detach(), + audio_content_feature_all) * self.opt.lambda_crossmodal + + if self.opt.softmax_contrastive: + G_losses['SoftmaxContrastive{}'.format(name)] = self.criterionSoftmaxContrastive(image_content_feature_all.detach(), audio_content_feature_all) * self.opt.lambda_contrastive + G_losses['SoftmaxContrastive_v2a'] = self.criterionSoftmaxContrastive(audio_content_feature_all.detach(), image_content_feature_all) * self.opt.lambda_contrastive + + return G_losses + + def audio_identity_feature(self, id_mel, no_grad=True): + id_mel = id_mel.view(-1, 1, id_mel.shape[-2], id_mel.shape[-1]) + if no_grad: + with torch.no_grad(): + id_feature, id_scores = self.netA(id_mel) + else: + id_feature, id_scores = self.netA(id_mel) + return id_feature, id_scores + + def encode_identity_feature(self, input_img): + + input_img = input_img.view(-1, self.opt.output_nc, self.opt.crop_size, self.opt.crop_size) + if not self.opt.isTrain or self.opt.fix_netV: + with torch.no_grad(): + id_feature, id_scores = self.netV(input_img) + else: + id_feature, id_scores = self.netV(input_img) + + id_feature[0] = id_feature[0].unsqueeze(1).repeat(1, self.opt.clip_len, 1).view(-1, *id_feature[0].shape[1:]) + id_feature[1] = id_feature[1].unsqueeze(1).repeat(1, self.opt.clip_len, 1, 1, 1).view(-1, *id_feature[1].shape[1:]) + + return id_feature, id_scores + + def encode_ref_noid(self, input_img): + input_img = input_img.view(-1, self.opt.output_nc, self.opt.crop_size, self.opt.crop_size) + with torch.no_grad(): + ref_noid_feature = self.netE.forward_feature(input_img) + ref_noid_feature = ref_noid_feature.view(-1, self.opt.num_inputs, ref_noid_feature.shape[-1]) + ref_noid_feature = ref_noid_feature.mean(1).unsqueeze(1).repeat(1, self.opt.clip_len, 1) + return ref_noid_feature + + def compute_pose_diff(self, pose_feature, ref_noid_feature): + pose_feature = pose_feature.view(-1, self.opt.clip_len, pose_feature.shape[-1]) + pose_differences = pose_feature - ref_noid_feature + return pose_differences + + def compute_diff_loss(self, input_img, pose_feature, pose_feature_audio, G_losses): + + pose_feature_audio = pose_feature_audio.view(-1, self.opt.clip_len, pose_feature_audio.shape[-1]) + ref_noid_feature = self.encode_ref_noid(input_img) + pose_differences = self.compute_pose_diff(pose_feature, ref_noid_feature) + + self.compute_sync_loss(pose_differences, pose_feature_audio, G_losses) + + pose_feature_audio = ref_noid_feature + pose_feature_audio + + return pose_feature_audio + + def encode_noid_feature(self, augmented): + augmented = augmented.view(-1, self.opt.output_nc, self.opt.crop_size, self.opt.crop_size) + if (not self.opt.isTrain) or self.opt.train_sync or self.opt.fix_netE: + with torch.no_grad(): + noid_feature = self.netE.forward_feature(augmented) + else: + noid_feature = self.netE.forward_feature(augmented) + + noid_feature = noid_feature.view(-1, self.opt.clip_len, noid_feature.shape[-1]) + return noid_feature + + def select_frames(self, in_obj_ts): + if len(in_obj_ts.shape) == 2: + obj_ts = in_obj_ts.view(-1, self.opt.clip_len, in_obj_ts.shape[-1]) + obj_ts = obj_ts[:, ::self.opt.generate_interval, :].contiguous() + obj_ts = obj_ts.view(-1, obj_ts.shape[-1]) + elif len(in_obj_ts.shape) == 3: + obj_ts = in_obj_ts[:, ::self.opt.generate_interval, :].contiguous() + elif len(in_obj_ts.shape) == 4: + obj_ts = in_obj_ts.view(-1, self.opt.clip_len, *in_obj_ts.shape[1:]) + obj_ts = obj_ts[:, ::self.opt.generate_interval, :].contiguous() + obj_ts = obj_ts.view(-1, *obj_ts.shape[2:]) + elif len(in_obj_ts.shape) == 5: + obj_ts = in_obj_ts[:, ::self.opt.generate_interval, :].contiguous() + else: + raise ValueError + return obj_ts + + def generate_fake(self, id_feature, pose_feature): + pose_feature = pose_feature.view(-1, pose_feature.shape[-1]) + style = torch.cat([id_feature[0], pose_feature], 1) + style = [style] + if self.opt.input_id_feature: + fake_image, style_rgb = self.netG(style, identity_style=id_feature[1]) + else: + fake_image, style_rgb = self.netG(style) + + fake_image = fake_image.view(-1, self.opt.output_nc, self.opt.crop_size, self.opt.crop_size) + + return fake_image, style_rgb + + def merge_mouthpose(self, mouth_feature, headpose_feature, embed_headpose=False): + + mouth_feature = self.netE.mouth_embed(mouth_feature) + if not embed_headpose: + headpose_feature = self.netE.headpose_embed(headpose_feature) + pose_feature = torch.cat((mouth_feature, headpose_feature), dim=2) + + return pose_feature + + def inference(self, input_img, spectrogram, + driving_pose_frames, mouth_feature_weight=1.2): + + ##### ***************** encode image feature and generate ****************************** + id_feature, _ = self.encode_identity_feature(input_img) + + fake_image_pose_driven_a = None + if self.opt.generate_from_audio_only: + assert self.opt.use_audio, 'must use audio in this case' + + A_mouth_feature = self.encode_audiosync_feature(spectrogram) + A_mouth_feature = A_mouth_feature * mouth_feature_weight + + sel_id_feature = [] + sel_id_feature.append(self.select_frames(id_feature[0])) + sel_id_feature.append(self.select_frames(id_feature[1])) + + V_noid_ref_feature = self.encode_ref_noid(input_img) + V_headpose_ref_feature = self.netE.to_headpose(V_noid_ref_feature) + + ref_merge_feature_a = self.select_frames(self.merge_mouthpose(A_mouth_feature, V_headpose_ref_feature)) + fake_image_ref_pose_a, _ = self.generate_fake(sel_id_feature, ref_merge_feature_a) + if self.opt.driving_pose: + V_noid_driving_feature = self.encode_noid_feature(driving_pose_frames) + V_headpose_feature = self.netE.to_headpose(V_noid_driving_feature) + driven_merge_feature_a = self.merge_mouthpose(A_mouth_feature, V_headpose_feature) + sel_driven_pose_feature_a = self.select_frames(driven_merge_feature_a) + fake_image_pose_driven_a, _ = self.generate_fake(sel_id_feature, sel_driven_pose_feature_a) + + return fake_image_ref_pose_a, fake_image_pose_driven_a + + def compute_generator_loss(self, input_img, real_image, augmented, spectrogram, + netD, labels, no_ganFeat_loss=False, no_vgg_loss=False, lambda_D=1): + + G_losses = {} + + real_image = real_image.view(-1, self.opt.output_nc, self.opt.crop_size, self.opt.crop_size) + + ##### ***************** encode image feature and generate ****************************** + + V_noid_feature = self.encode_noid_feature(augmented) + + V_mouth_feature = self.netE.to_mouth(V_noid_feature) + V_headpose_feature = self.netE.to_headpose(V_noid_feature) + id_feature, id_scores = self.encode_identity_feature(input_img) + + sel_id_feature = [] + sel_id_feature.append(self.select_frames(id_feature[0])) + sel_id_feature.append(self.select_frames(id_feature[1])) + + sel_real_image = self.select_frames(real_image) + + fake_image_A, fake_image_V = None, None + + if self.opt.generate_from_audio_only: + assert self.opt.use_audio, 'must use audio in this case' + + V_merge_feature = self.merge_mouthpose(V_mouth_feature, V_headpose_feature) + + sel_V_merge_feature = self.select_frames(V_merge_feature) + if self.opt.use_audio: # use audio pose feature + + A_mouth_feature = self.encode_audiosync_feature(spectrogram) + self.compute_sync_loss(V_mouth_feature, A_mouth_feature, G_losses) + + A_merge_feature = self.merge_mouthpose(A_mouth_feature, V_headpose_feature) + sel_A_merge_feature = self.select_frames(A_merge_feature) + fake_image_A, style_rgb_a = self.generate_fake(sel_id_feature, sel_A_merge_feature) + pred_fake_audio = self.discriminate_single(fake_image_A, netD) + + if not self.opt.generate_from_audio_only: # use both audio and image pose feature + fake_image_V, style_rgb_v = self.generate_fake(sel_id_feature, sel_V_merge_feature) + + else: # only use image pose feature + fake_image_V, style_rgb_v = self.generate_fake(sel_id_feature, sel_V_merge_feature) + + pred_real = self.discriminate_single(sel_real_image, netD) + + ##### **************************************************************************** + + if (not self.opt.generate_from_audio_only) or (not self.opt.use_audio): + pred_fake = self.discriminate_single(fake_image_V, netD) + + if not no_ganFeat_loss: + if not self.opt.generate_from_audio_only: + G_losses['GAN_Feat'] = self.compute_GAN_Feat_loss(pred_fake, pred_real) + if self.opt.use_audio: + G_losses['GAN_Feat_audio'] = self.compute_GAN_Feat_loss(pred_fake_audio, pred_real) + + if not self.opt.fix_netD: + if not self.opt.generate_from_audio_only: + G_losses['GANv'] = self.criterionGAN(pred_fake, True, + for_discriminator=False) * lambda_D + if self.opt.use_audio: + G_losses['GANa'] = self.criterionGAN(pred_fake_audio, True, + for_discriminator=False) * lambda_D + + if not no_vgg_loss: + if not self.opt.generate_from_audio_only: + G_losses['VGGv'] = self.criterionVGG(fake_image_V, sel_real_image) \ + * self.opt.lambda_vgg + if self.opt.use_audio: + G_losses['VGGa'] = self.criterionVGG(fake_image_A, sel_real_image) \ + * self.opt.lambda_vgg + + if self.opt.vgg_face: + if not self.opt.generate_from_audio_only: + G_losses['VGGFace_v'] = self.criterionVGGFace(fake_image_V, sel_real_image, layer=2) \ + * self.opt.lambda_vggface + + if self.opt.use_audio: + G_losses['VGGFace_a'] = self.criterionVGGFace(fake_image_A, sel_real_image, layer=2) \ + * self.opt.lambda_vggface + + + if not self.opt.no_id_loss or not self.fix_netV: + G_losses['loss_cls'] = self.loss_cls(id_scores, labels) + + if self.opt.disentangle and self.opt.clip_len*self.opt.frame_interval >= 20: + V_headpose_embed = self.netE.headpose_embed(V_headpose_feature) + with torch.no_grad(): + V_all_headpose_embed = V_headpose_embed.view(-1, self.opt.clip_len * V_headpose_embed.shape[-1]) + headpose_word_scores = self.netE.headpose_fc(V_all_headpose_embed) + G_losses['logSoftmax_v'] = self.criterionLogSoftmax(headpose_word_scores) * self.opt.lambda_softmax + + return G_losses, [sel_real_image, fake_image_V, fake_image_A, + ], id_scores + + + # Given fake and real image, return the prediction of discriminator + # for each fake and real image. + + def compute_GAN_Feat_loss(self, pred_fake, pred_real): + num_D = len(pred_fake) + GAN_Feat_loss = self.FloatTensor(1).fill_(0) + for i in range(num_D): # for each discriminator + # last output is the final prediction, so we exclude it + num_intermediate_outputs = len(pred_fake[i]) - 1 + for j in range(num_intermediate_outputs): # for each layer output + unweighted_loss = self.criterionFeat( + pred_fake[i][j], pred_real[i][j].detach()) + if j == 0: + unweighted_loss *= self.opt.lambda_image + GAN_Feat_loss += unweighted_loss * self.opt.lambda_feat / num_D + return GAN_Feat_loss + + def compute_discriminator_loss(self, input_img, real_image, augmented, spectrogram, netD, labels, lambda_D=1): + D_losses = {} + with torch.no_grad(): + ##### ***************** encode feature and generate ****************************** + + id_feature, _ = self.encode_identity_feature(input_img) + sel_id_feature = [] + sel_id_feature.append(self.select_frames(id_feature[0])) + sel_id_feature.append(self.select_frames(id_feature[1])) + + sel_real_image = self.select_frames(real_image) + sel_input_img = self.select_frames(input_img) + + V_noid_feature = self.encode_noid_feature(augmented) + V_noid_feature = V_noid_feature.detach() + V_noid_feature.requires_grad_() + + V_mouth_feature = self.netE.to_mouth(V_noid_feature) + V_headpose_feature = self.netE.to_headpose(V_noid_feature) + + fake_image_audio, fake_image = None, None + + if self.opt.generate_from_audio_only: + assert self.opt.use_audio, 'must use audio in this case' + + if not self.opt.generate_from_audio_only: + V_merge_feature = self.merge_mouthpose(V_mouth_feature, V_headpose_feature) + + sel_V_merge_feature = self.select_frames(V_merge_feature) + if self.opt.use_audio: + + A_mouth_feature = self.encode_audiosync_feature(spectrogram) + A_pose_feature = self.merge_mouthpose(A_mouth_feature, V_headpose_feature) + sel_A_pose_feature = self.select_frames(A_pose_feature) + fake_image_audio, style_rgb_a = self.generate_fake(sel_id_feature, sel_A_pose_feature) + fake_image = fake_image_audio + + if not self.opt.generate_from_audio_only: # use both audio and image pose feature + fake_image, style_rgb_v = self.generate_fake(sel_id_feature, sel_V_merge_feature) + fake_image = torch.cat([fake_image_audio, fake_image], 0) + + else: # only use image pose feature + fake_image, style_rgb_v = self.generate_fake(sel_id_feature, sel_V_merge_feature) + + sel_real_image = torch.cat([sel_real_image,]*(len(fake_image)//len(sel_real_image)), 0) + sel_input_img = torch.cat([sel_input_img,]*(len(fake_image)//len(sel_input_img)), 0) + + if fake_image is not None: + fake_image = fake_image.detach() + fake_image.requires_grad_() + if fake_image_audio is not None: + fake_image_audio = fake_image_audio.detach() + fake_image_audio.requires_grad_() + + if self.opt.disentangle: + V_headpose_embed = self.netE.headpose_embed(V_headpose_feature) + V_headpose_embed = V_headpose_embed.detach() + V_headpose_embed.requires_grad_() + + pred_fake, pred_real = self.discriminate( + sel_input_img, fake_image, sel_real_image, netD) + + if self.opt.stylegan_D: + pred_fake_styleGAN, pred_real_styleGAN = self.discriminate( + sel_input_img, fake_image, sel_real_image, self.net_styleGAN_D) + if type(pred_fake) == list and type(pred_real) == list: + pred_fake.append(pred_fake_styleGAN) + pred_real.append(pred_real_styleGAN) + else: + pred_fake = [pred_fake] + pred_fake.append(pred_fake_styleGAN) + pred_real = [pred_real] + pred_real.append(pred_real_styleGAN) + + D_losses['D_Fake'] = self.criterionGAN(pred_fake, False, + for_discriminator=True) * lambda_D + + D_losses['D_real'] = self.criterionGAN(pred_real, True, + for_discriminator=True) * lambda_D + + if self.opt.disentangle and self.opt.clip_len*self.opt.frame_interval >= 20: + V_all_headpose_embed = V_headpose_embed.view(-1, self.opt.clip_len * V_headpose_embed.shape[-1]) + headpose_word_scores = self.netE.headpose_fc(V_all_headpose_embed) + D_losses['headpose_feature_cls'] = self.loss_cls(headpose_word_scores, labels) + + return D_losses + + def discriminate(self, input, fake_image, real_image, netD): + if self.opt.D_input == "concat": + fake_concat = torch.cat([input, fake_image], dim=1) + real_concat = torch.cat([input, real_image], dim=1) + else: + fake_concat = fake_image + real_concat = real_image + + fake_and_real = torch.cat([fake_concat, real_concat], dim=0) + + discriminator_out = netD(fake_and_real) + + pred_fake, pred_real = self.divide_pred(discriminator_out) + + return pred_fake, pred_real + + def discriminate_single(self, single_image, netD): + + if single_image.dim() == 5: + single_image = single_image.view(-1, self.opt.output_nc, self.opt.crop_size, self.opt.crop_size) + + pred_single = netD(single_image) + + return pred_single + + # Take the prediction of fake and real images from the combined batch + def divide_pred(self, pred): + # the prediction contains the intermediate outputs of multiscale GAN, + # so it's usually a list + if type(pred) == list: + fake = [] + real = [] + for p in pred: + fake.append([tensor[:tensor.size(0) // 2] for tensor in p]) + real.append([tensor[tensor.size(0) // 2:] for tensor in p]) + else: + fake = pred[:pred.size(0) // 2] + # rotate_fake = pred[pred.size(0) // 3: pred.size(0) * 2 // 3] + real = pred[pred.size(0)//2 :] + + return fake, real + + def load_separately(self, network, network_label, opt): + load_path = None + if network_label == 'G': + load_path = opt.G_pretrain_path + elif network_label == 'D': + + load_path = opt.D_pretrain_path + elif network_label == 'D_rotate': + load_path = opt.D_rotate_pretrain_path + elif network_label == 'E': + load_path = opt.E_pretrain_path + elif network_label == 'A': + load_path = opt.A_pretrain_path + elif network_label == 'A_sync': + load_path = opt.A_sync_pretrain_path + elif network_label == 'V': + load_path = opt.V_pretrain_path + + if load_path is not None: + if os.path.isfile(load_path): + print("=> loading checkpoint '{}'".format(load_path)) + checkpoint = torch.load(load_path) + util.copy_state_dict(checkpoint, network, strip='MobileNet', replace='model') + else: + print("no load_path") + return network + + def load_network(self, network, network_label, epoch_label): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_dir = self.save_dir + save_path = os.path.join(save_dir, save_filename) + if not os.path.isfile(save_path): + if not self.opt.train_recognition: + print('%s not exists yet!' % save_path) + if network_label == 'G': + raise ('Generator must exist!') + else: + # network.load_state_dict(torch.load(save_path)) + try: + network.load_state_dict(torch.load(save_path)) + except: + pretrained_dict = torch.load(save_path) + model_dict = network.state_dict() + try: + + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} + network.load_state_dict(pretrained_dict) + if self.opt.verbose: + print( + 'Pretrained network %s has excessive layers; Only loading layers that are used' % network_label) + except: + print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label) + for k, v in pretrained_dict.items(): + if v.size() == model_dict[k].size(): + model_dict[k] = v + + not_initialized = set() + + for k, v in model_dict.items(): + if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): + not_initialized.add(k.split('.')[0]) + + print(sorted(not_initialized)) + network.load_state_dict(model_dict) + + def use_gpu(self): + return len(self.opt.gpu_ids) > 0 diff --git a/Talking-Face_PC-AVS/models/networks/FAN_feature_extractor.py b/Talking-Face_PC-AVS/models/networks/FAN_feature_extractor.py new file mode 100644 index 00000000..aafb331e --- /dev/null +++ b/Talking-Face_PC-AVS/models/networks/FAN_feature_extractor.py @@ -0,0 +1,163 @@ +import torch +import torch.nn as nn +from util import util +import torch.nn.functional as F + + +def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, + stride=strd, padding=padding, bias=bias) + + +class ConvBlock(nn.Module): + def __init__(self, in_planes, out_planes): + super(ConvBlock, self).__init__() + self.bn1 = nn.BatchNorm2d(in_planes) + self.conv1 = conv3x3(in_planes, int(out_planes / 2)) + self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) + self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4)) + self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) + self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4)) + + if in_planes != out_planes: + self.downsample = nn.Sequential( + nn.BatchNorm2d(in_planes), + nn.ReLU(True), + nn.Conv2d(in_planes, out_planes, + kernel_size=1, stride=1, bias=False), + ) + else: + self.downsample = None + + def forward(self, x): + residual = x + + out1 = self.bn1(x) + out1 = F.relu(out1, True) + out1 = self.conv1(out1) + + out2 = self.bn2(out1) + out2 = F.relu(out2, True) + out2 = self.conv2(out2) + + out3 = self.bn3(out2) + out3 = F.relu(out3, True) + out3 = self.conv3(out3) + + out3 = torch.cat((out1, out2, out3), 1) + + if self.downsample is not None: + residual = self.downsample(residual) + + out3 += residual + + return out3 + + +class HourGlass(nn.Module): + def __init__(self, num_modules, depth, num_features): + super(HourGlass, self).__init__() + self.num_modules = num_modules + self.depth = depth + self.features = num_features + self.dropout = nn.Dropout(0.5) + + self._generate_network(self.depth) + + def _generate_network(self, level): + self.add_module('b1_' + str(level), ConvBlock(256, 256)) + + self.add_module('b2_' + str(level), ConvBlock(256, 256)) + + if level > 1: + self._generate_network(level - 1) + else: + self.add_module('b2_plus_' + str(level), ConvBlock(256, 256)) + + self.add_module('b3_' + str(level), ConvBlock(256, 256)) + + def _forward(self, level, inp): + # Upper branch + up1 = inp + up1 = self._modules['b1_' + str(level)](up1) + up1 = self.dropout(up1) + # Lower branch + low1 = F.max_pool2d(inp, 2, stride=2) + low1 = self._modules['b2_' + str(level)](low1) + + if level > 1: + low2 = self._forward(level - 1, low1) + else: + low2 = low1 + low2 = self._modules['b2_plus_' + str(level)](low2) + + low3 = low2 + low3 = self._modules['b3_' + str(level)](low3) + up1size = up1.size() + rescale_size = (up1size[2], up1size[3]) + up2 = F.upsample(low3, size=rescale_size, mode='bilinear') + + return up1 + up2 + + def forward(self, x): + return self._forward(self.depth, x) + + +class FAN_use(nn.Module): + def __init__(self): + super(FAN_use, self).__init__() + self.num_modules = 1 + + # Base part + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.bn1 = nn.BatchNorm2d(64) + self.conv2 = ConvBlock(64, 128) + self.conv3 = ConvBlock(128, 128) + self.conv4 = ConvBlock(128, 256) + + # Stacking part + hg_module = 0 + self.add_module('m' + str(hg_module), HourGlass(1, 4, 256)) + self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256)) + self.add_module('conv_last' + str(hg_module), + nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) + self.add_module('l' + str(hg_module), nn.Conv2d(256, + 68, kernel_size=1, stride=1, padding=0)) + self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256)) + + if hg_module < self.num_modules - 1: + self.add_module( + 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) + self.add_module('al' + str(hg_module), nn.Conv2d(68, + 256, kernel_size=1, stride=1, padding=0)) + + self.avgpool = nn.MaxPool2d((2, 2), 2) + self.conv6 = nn.Conv2d(68, 1, 3, 2, 1) + self.fc = nn.Linear(28 * 28, 512) + self.bn5 = nn.BatchNorm2d(68) + self.relu = nn.ReLU(True) + + def forward(self, x): + x = F.relu(self.bn1(self.conv1(x)), True) + x = F.max_pool2d(self.conv2(x), 2) + x = self.conv3(x) + x = self.conv4(x) + + previous = x + + i = 0 + hg = self._modules['m' + str(i)](previous) + + ll = hg + ll = self._modules['top_m_' + str(i)](ll) + + ll = self._modules['bn_end' + str(i)](self._modules['conv_last' + str(i)](ll)) + tmp_out = self._modules['l' + str(i)](F.relu(ll)) + + net = self.relu(self.bn5(tmp_out)) + net = self.conv6(net) + net = net.view(-1, net.shape[-2] * net.shape[-1]) + net = self.relu(net) + net = self.fc(net) + return net diff --git a/Talking-Face_PC-AVS/models/networks/__init__.py b/Talking-Face_PC-AVS/models/networks/__init__.py new file mode 100644 index 00000000..38ea8c93 --- /dev/null +++ b/Talking-Face_PC-AVS/models/networks/__init__.py @@ -0,0 +1,87 @@ +import torch +from models.networks.base_network import BaseNetwork +from models.networks.loss import * +from models.networks.discriminator import MultiscaleDiscriminator, ImageDiscriminator +from models.networks.generator import ModulateGenerator +from models.networks.encoder import ResSEAudioEncoder, ResNeXtEncoder, ResSESyncEncoder, FanEncoder +import util.util as util + + +def find_network_using_name(target_network_name, filename): + target_class_name = target_network_name + filename + module_name = 'models.networks.' + filename + network = util.find_class_in_module(target_class_name, module_name) + + assert issubclass(network, BaseNetwork), \ + "Class %s should be a subclass of BaseNetwork" % network + + return network + + +def modify_commandline_options(parser, is_train): + opt, _ = parser.parse_known_args() + + netG_cls = find_network_using_name(opt.netG, 'generator') + parser = netG_cls.modify_commandline_options(parser, is_train) + if is_train: + netD_cls = find_network_using_name(opt.netD, 'discriminator') + parser = netD_cls.modify_commandline_options(parser, is_train) + netA_cls = find_network_using_name(opt.netA, 'encoder') + parser = netA_cls.modify_commandline_options(parser, is_train) + # parser = netA_sync_cls.modify_commandline_options(parser, is_train) + + return parser + + +def create_network(cls, opt): + net = cls(opt) + net.print_network() + if len(opt.gpu_ids) > 0: + assert(torch.cuda.is_available()) + net.cuda() + net.init_weights(opt.init_type, opt.init_variance) + return net + + +def define_networks(opt, name, type): + netG_cls = find_network_using_name(name, type) + return create_network(netG_cls, opt) + +def define_G(opt): + netG_cls = find_network_using_name(opt.netG, 'generator') + return create_network(netG_cls, opt) + + +def define_D(opt): + netD_cls = find_network_using_name(opt.netD, 'discriminator') + return create_network(netD_cls, opt) + +def define_A(opt): + netA_cls = find_network_using_name(opt.netA, 'encoder') + return create_network(netA_cls, opt) + +def define_A_sync(opt): + netA_cls = find_network_using_name(opt.netA_sync, 'encoder') + return create_network(netA_cls, opt) + + +def define_E(opt): + # there exists only one encoder type + netE_cls = find_network_using_name(opt.netE, 'encoder') + return create_network(netE_cls, opt) + + +def define_V(opt): + # there exists only one encoder type + netV_cls = find_network_using_name(opt.netV, 'encoder') + return create_network(netV_cls, opt) + + +def define_P(opt): + netP_cls = find_network_using_name(opt.netP, 'encoder') + return create_network(netP_cls, opt) + + +def define_F_rec(opt): + netF_rec_cls = find_network_using_name(opt.netF_rec, 'encoder') + return create_network(netF_rec_cls, opt) diff --git a/Talking-Face_PC-AVS/models/networks/architecture.py b/Talking-Face_PC-AVS/models/networks/architecture.py new file mode 100644 index 00000000..60a8b294 --- /dev/null +++ b/Talking-Face_PC-AVS/models/networks/architecture.py @@ -0,0 +1,128 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +from models.networks.encoder import VGGEncoder +from util import util +from models.networks.sync_batchnorm import SynchronizedBatchNorm2d +import torch.nn.utils.spectral_norm as spectral_norm + + +# VGG architecter, used for the perceptual loss using a pretrained VGG network +class VGG19(torch.nn.Module): + def __init__(self, requires_grad=False): + super(VGG19, self).__init__() + vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out + + +class VGGFace19(torch.nn.Module): + def __init__(self, opt, requires_grad=False): + super(VGGFace19, self).__init__() + self.model = VGGEncoder(opt) + self.opt = opt + ckpt = torch.load(opt.VGGFace_pretrain_path) + print("=> loading checkpoint '{}'".format(opt.VGGFace_pretrain_path)) + util.copy_state_dict(ckpt, self.model) + vgg_pretrained_features = self.model.model.features + len_features = len(self.model.model.features) + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.slice6 = torch.nn.Sequential() + + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + for x in range(30, len_features): + self.slice6.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + X = X.view(-1, self.opt.output_nc, self.opt.crop_size, self.opt.crop_size) + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + h_relu6 = self.slice6(h_relu5) + out = [h_relu3, h_relu4, h_relu5, h_relu6, h_relu6] + return out + + +# Returns a function that creates a normalization function +# that does not condition on semantic map +def get_nonspade_norm_layer(opt, norm_type='instance'): + # helper function to get # output channels of the previous layer + def get_out_channel(layer): + if hasattr(layer, 'out_channels'): + return getattr(layer, 'out_channels') + return layer.weight.size(0) + + # this function will be returned + def add_norm_layer(layer): + nonlocal norm_type + if norm_type.startswith('spectral'): + layer = spectral_norm(layer) + subnorm_type = norm_type[len('spectral'):] + else: + subnorm_type = norm_type + + if subnorm_type == 'none' or len(subnorm_type) == 0: + return layer + + # remove bias in the previous layer, which is meaningless + # since it has no effect after normalization + if getattr(layer, 'bias', None) is not None: + delattr(layer, 'bias') + layer.register_parameter('bias', None) + + if subnorm_type == 'batch': + norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) + elif subnorm_type == 'syncbatch': + norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True) + elif subnorm_type == 'instance': + norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) + else: + raise ValueError('normalization layer %s is not recognized' % subnorm_type) + + return nn.Sequential(layer, norm_layer) + + return add_norm_layer diff --git a/Talking-Face_PC-AVS/models/networks/audio_network.py b/Talking-Face_PC-AVS/models/networks/audio_network.py new file mode 100644 index 00000000..e1ebb28d --- /dev/null +++ b/Talking-Face_PC-AVS/models/networks/audio_network.py @@ -0,0 +1,199 @@ +import torch +import torch.nn as nn + + +class ResNetSE(nn.Module): + def __init__(self, block, layers, num_filters, nOut, encoder_type='SAP', n_mels=80, n_mel_T=1, log_input=True, **kwargs): + super(ResNetSE, self).__init__() + + print('Embedding size is %d, encoder %s.' % (nOut, encoder_type)) + + self.inplanes = num_filters[0] + self.encoder_type = encoder_type + self.n_mels = n_mels + self.log_input = log_input + + self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1) + self.relu = nn.ReLU(inplace=True) + self.bn1 = nn.BatchNorm2d(num_filters[0]) + + self.layer1 = self._make_layer(block, num_filters[0], layers[0]) + self.layer2 = self._make_layer(block, num_filters[1], layers[1], stride=(2, 2)) + self.layer3 = self._make_layer(block, num_filters[2], layers[2], stride=(2, 2)) + self.layer4 = self._make_layer(block, num_filters[3], layers[3], stride=(2, 2)) + + self.instancenorm = nn.InstanceNorm1d(n_mels) + + outmap_size = int(self.n_mels * n_mel_T / 8) + + self.attention = nn.Sequential( + nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1), + nn.Softmax(dim=2), + ) + + if self.encoder_type == "SAP": + out_dim = num_filters[3] * outmap_size + elif self.encoder_type == "ASP": + out_dim = num_filters[3] * outmap_size * 2 + else: + raise ValueError('Undefined encoder') + + self.fc = nn.Linear(out_dim, nOut) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def new_parameter(self, *size): + out = nn.Parameter(torch.FloatTensor(*size)) + nn.init.xavier_normal_(out) + return out + + def forward(self, x): + + # with torch.no_grad(): + # x = self.torchfb(x) + 1e-6 + # if self.log_input: x = x.log() + # x = self.instancenorm(x).unsqueeze(1) + + x = self.conv1(x) + x = self.relu(x) + x = self.bn1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = x.reshape(x.size()[0], -1, x.size()[-1]) + + w = self.attention(x) + + if self.encoder_type == "SAP": + x = torch.sum(x * w, dim=2) + elif self.encoder_type == "ASP": + mu = torch.sum(x * w, dim=2) + sg = torch.sqrt((torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-5)) + x = torch.cat((mu, sg), 1) + + x = x.view(x.size()[0], -1) + x = self.fc(x) + + return x + + + + +class SEBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): + super(SEBasicBlock, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.se = SELayer(planes, reduction) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.relu(out) + out = self.bn1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + return out + + +class SEBottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): + super(SEBottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se = SELayer(planes * 4, reduction) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class SELayer(nn.Module): + def __init__(self, channel, reduction=8): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y diff --git a/Talking-Face_PC-AVS/models/networks/base_network.py b/Talking-Face_PC-AVS/models/networks/base_network.py new file mode 100644 index 00000000..2eecf7d6 --- /dev/null +++ b/Talking-Face_PC-AVS/models/networks/base_network.py @@ -0,0 +1,54 @@ +import torch.nn as nn +from torch.nn import init + + +class BaseNetwork(nn.Module): + def __init__(self): + super(BaseNetwork, self).__init__() + + @staticmethod + def modify_commandline_options(parser, is_train): + return parser + + def print_network(self): + if isinstance(self, list): + self = self[0] + num_params = 0 + for param in self.parameters(): + num_params += param.numel() + print('Network [%s] was created. Total number of parameters: %.1f million. ' + 'To see the architecture, do print(network).' + % (type(self).__name__, num_params / 1000000)) + + def init_weights(self, init_type='normal', gain=0.02): + def init_func(m): + classname = m.__class__.__name__ + if classname.find('BatchNorm2d') != -1: + if hasattr(m, 'weight') and m.weight is not None: + init.normal_(m.weight.data, 1.0, gain) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'xavier_uniform': + init.xavier_uniform_(m.weight.data, gain=1.0) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=gain) + elif init_type == 'none': # uses pytorch's default init method + m.reset_parameters() + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + + self.apply(init_func) + + # propagate to children + for m in self.children(): + if hasattr(m, 'init_weights'): + m.init_weights(init_type, gain) diff --git a/Talking-Face_PC-AVS/models/networks/discriminator.py b/Talking-Face_PC-AVS/models/networks/discriminator.py new file mode 100644 index 00000000..d99f472e --- /dev/null +++ b/Talking-Face_PC-AVS/models/networks/discriminator.py @@ -0,0 +1,214 @@ +import torch.nn as nn +import numpy as np +from models.networks.base_network import BaseNetwork +import util.util as util +import torch +from models.networks.architecture import get_nonspade_norm_layer +import torch.nn.functional as F + + +class MultiscaleDiscriminator(BaseNetwork): + @staticmethod + def modify_commandline_options(parser, is_train): + parser.add_argument('--netD_subarch', type=str, default='n_layer', + help='architecture of each discriminator') + parser.add_argument('--num_D', type=int, default=2, + help='number of discriminators to be used in multiscale') + opt, _ = parser.parse_known_args() + + # define properties of each discriminator of the multiscale discriminator + subnetD = util.find_class_in_module(opt.netD_subarch + 'discriminator', + 'models.networks.discriminator') + subnetD.modify_commandline_options(parser, is_train) + + return parser + + def __init__(self, opt): + super(MultiscaleDiscriminator, self).__init__() + self.opt = opt + + for i in range(opt.num_D): + subnetD = self.create_single_discriminator(opt) + self.add_module('discriminator_%d' % i, subnetD) + + def create_single_discriminator(self, opt): + subarch = opt.netD_subarch + if subarch == 'n_layer': + netD = NLayerDiscriminator(opt) + else: + raise ValueError('unrecognized discriminator subarchitecture %s' % subarch) + return netD + + def downsample(self, input): + return F.avg_pool2d(input, kernel_size=3, + stride=2, padding=[1, 1], + count_include_pad=False) + + # Returns list of lists of discriminator outputs. + # The final result is of size opt.num_D x opt.n_layers_D + def forward(self, input): + result = [] + get_intermediate_features = not self.opt.no_ganFeat_loss + for name, D in self.named_children(): + out = D(input) + if not get_intermediate_features: + out = [out] + result.append(out) + input = self.downsample(input) + + return result + + +# Defines the PatchGAN discriminator with the specified arguments. +class NLayerDiscriminator(BaseNetwork): + @staticmethod + def modify_commandline_options(parser, is_train): + parser.add_argument('--n_layers_D', type=int, default=4, + help='# layers in each discriminator') + return parser + + def __init__(self, opt): + + super(NLayerDiscriminator, self).__init__() + self.opt = opt + + kw = 4 + padw = int(np.ceil((kw - 1.0) / 2)) + nf = opt.ndf + input_nc = self.compute_D_input_nc(opt) + + norm_layer = get_nonspade_norm_layer(opt, opt.norm_D) + sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, False)]] + + for n in range(1, opt.n_layers_D): + nf_prev = nf + nf = min(nf * 2, 512) + stride = 1 if n == opt.n_layers_D - 1 else 2 + sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, + stride=stride, padding=padw)), + nn.LeakyReLU(0.2, False) + ]] + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] + + # We divide the layers into groups to extract intermediate layer outputs + for n in range(len(sequence)): + self.add_module('model' + str(n), nn.Sequential(*sequence[n])) + + def compute_D_input_nc(self, opt): + if opt.D_input == "concat": + input_nc = opt.label_nc + opt.output_nc + if opt.contain_dontcare_label: + input_nc += 1 + if not opt.no_instance: + input_nc += 1 + else: + input_nc = 3 + return input_nc + + def forward(self, input): + results = [input] + for submodel in self.children(): + + # intermediate_output = checkpoint(submodel, results[-1]) + intermediate_output = submodel(results[-1]) + results.append(intermediate_output) + + get_intermediate_features = not self.opt.no_ganFeat_loss + if get_intermediate_features: + return results[0:] + else: + return results[-1] + + +class AudioSubDiscriminator(BaseNetwork): + def __init__(self, opt, nc, audio_nc): + super(AudioSubDiscriminator, self).__init__() + norm_layer = get_nonspade_norm_layer(opt, opt.norm_D) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + sequence = [] + sequence += [norm_layer(nn.Conv1d(nc, nc, 3, 2, 1)), + nn.ReLU() + ] + sequence += [norm_layer(nn.Conv1d(nc, audio_nc, 3, 2, 1)), + nn.ReLU() + ] + + self.conv = nn.Sequential(*sequence) + self.cosine = nn.CosineSimilarity() + self.mapping = nn.Linear(audio_nc, audio_nc) + + def forward(self, result, audio): + region = result[result.shape[3] // 2:result.shape[3] - 2, result.shape[4] // 3: 2 * result.shape[4] // 3] + visual = self.avgpool(region) + cos = self.cosine(visual, self.mapping(audio)) + return cos + + +class ImageDiscriminator(BaseNetwork): + """Defines a PatchGAN discriminator""" + def modify_commandline_options(parser, is_train): + parser.add_argument('--n_layers_D', type=int, default=4, + help='# layers in each discriminator') + return parser + + def __init__(self, opt, n_layers=3, norm_layer=nn.BatchNorm2d): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(ImageDiscriminator, self).__init__() + use_bias = norm_layer == nn.InstanceNorm2d + if opt.D_input == "concat": + input_nc = opt.label_nc + opt.output_nc + else: + input_nc = opt.label_nc + ndf = 64 + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.model = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.model(input) + + +class FeatureDiscriminator(BaseNetwork): + def __init__(self, opt): + super(FeatureDiscriminator, self).__init__() + self.opt = opt + self.fc = nn.Linear(512, opt.num_labels) + self.dropout = nn.Dropout(0.5) + + def forward(self, x): + x0 = x.view(-1, 512) + net = self.dropout(x0) + net = self.fc(net) + return net + + diff --git a/Talking-Face_PC-AVS/models/networks/encoder.py b/Talking-Face_PC-AVS/models/networks/encoder.py new file mode 100644 index 00000000..9e68fea4 --- /dev/null +++ b/Talking-Face_PC-AVS/models/networks/encoder.py @@ -0,0 +1,90 @@ +import torch.nn as nn +import numpy as np +import torch.nn.functional as F +from models.networks.base_network import BaseNetwork +import torchvision.models.mobilenet +from util import util +from models.networks.audio_network import ResNetSE, SEBasicBlock +import torch +from models.networks.FAN_feature_extractor import FAN_use +from torchvision.models.vgg import vgg19_bn +from models.networks.vision_network import ResNeXt50 + + +class ResSEAudioEncoder(BaseNetwork): + def __init__(self, opt, nOut=2048, n_mel_T=None): + super(ResSEAudioEncoder, self).__init__() + self.nOut = nOut + # Number of filters + num_filters = [32, 64, 128, 256] + if n_mel_T is None: # use it when use audio identity + n_mel_T = opt.n_mel_T + self.model = ResNetSE(SEBasicBlock, [3, 4, 6, 3], num_filters, self.nOut, n_mel_T=n_mel_T) + self.fc = nn.Linear(self.nOut, opt.num_classes) + + def forward_feature(self, x): + + input_size = x.size() + if len(input_size) == 5: + bz, clip_len, c, f, t = input_size + x = x.view(bz * clip_len, c, f, t) + out = self.model(x) + return out + + def forward(self, x): + out = self.forward_feature(x) + score = self.fc(out) + return out, score + + +class ResSESyncEncoder(ResSEAudioEncoder): + def __init__(self, opt): + super(ResSESyncEncoder, self).__init__(opt, nOut=512, n_mel_T=1) + + +class ResNeXtEncoder(ResNeXt50): + def __init__(self, opt): + super(ResNeXtEncoder, self).__init__(opt) + + +class VGGEncoder(BaseNetwork): + def __init__(self, opt): + super(VGGEncoder, self).__init__() + self.model = vgg19_bn(num_classes=opt.num_classes) + + def forward(self, x): + return self.model(x) + + +class FanEncoder(BaseNetwork): + def __init__(self, opt): + super(FanEncoder, self).__init__() + self.opt = opt + pose_dim = self.opt.pose_dim + self.model = FAN_use() + self.classifier = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, opt.num_classes)) + + # mapper to mouth subspace + self.to_mouth = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 512)) + self.mouth_embed = nn.Sequential(nn.ReLU(), nn.Linear(512, 512-pose_dim)) + self.mouth_fc = nn.Sequential(nn.ReLU(), nn.Linear(512*opt.clip_len, opt.num_classes)) + + # mapper to head pose subspace + self.to_headpose = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 512)) + self.headpose_embed = nn.Sequential(nn.ReLU(), nn.Linear(512, pose_dim)) + self.headpose_fc = nn.Sequential(nn.ReLU(), nn.Linear(pose_dim*opt.clip_len, opt.num_classes)) + + def load_pretrain(self): + check_point = torch.load(self.opt.FAN_pretrain_path) + print("=> loading checkpoint '{}'".format(self.opt.FAN_pretrain_path)) + util.copy_state_dict(check_point, self.model) + + def forward_feature(self, x): + net = self.model(x) + return net + + def forward(self, x): + x0 = x.view(-1, self.opt.output_nc, self.opt.crop_size, self.opt.crop_size) + net = self.forward_feature(x0) + scores = self.classifier(net.view(-1, self.opt.num_clips, 512).mean(1)) + return net, scores diff --git a/Talking-Face_PC-AVS/models/networks/generator.py b/Talking-Face_PC-AVS/models/networks/generator.py new file mode 100644 index 00000000..a4f405fd --- /dev/null +++ b/Talking-Face_PC-AVS/models/networks/generator.py @@ -0,0 +1,681 @@ +import math +import random +from models.networks import BaseNetwork +import torch +from torch import nn +from torch.nn import functional as F + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + return F.leaky_relu(input + bias, negative_slope) * scale + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1), requires_grad=True) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + # print("FusedLeakyReLU: ", input.abs().mean()) + out = fused_leaky_relu(input, self.bias, + self.negative_slope, + self.scale) + # print("FusedLeakyReLU: ", out.abs().mean()) + return out + + +def upfirdn2d_native( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 +): + _, minor, in_h, in_w = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, minor, in_h, 1, in_w, 1) + out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0]) + out = out.view(-1, minor, in_h * up_y, in_w * up_x) + + out = F.pad( + out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out[ + :, + :, + max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0), + max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), + ] + + # out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + # out = out.permute(0, 2, 3, 1) + + return out[:, :, ::down_y, ::down_x] + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + + +class PixelNorm(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + + if k.ndim == 1: + k = k[None, :] * k[:, None] + + k /= k.sum() + + return k + + +class Upsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor ** 2) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) + + return out + + +class Downsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) + + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + out = upfirdn2d(input, self.kernel, pad=self.pad) + + return out + + +class EqualConv2d(nn.Module): + def __init__( + self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True + ): + super().__init__() + + self.weight = nn.Parameter( + torch.randn(out_channel, in_channel, kernel_size, kernel_size) + ) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + + else: + self.bias = None + + def forward(self, input): + out = F.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__( + self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None + ): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear( + input, self.weight * self.scale, bias=self.bias * self.lr_mul + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' + ) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + out = F.leaky_relu(input, negative_slope=self.negative_slope) + + return out * math.sqrt(2) + + +class ModulatedConv2d(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size ** 2 + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter( + torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) + ) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' + f'upsample={self.upsample}, downsample={self.downsample})' + ) + + def forward(self, input, style): + batch, in_channel, height, width = input.shape + + style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) + + weight = weight.view( + batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + + if self.upsample: + input = input.view(1, batch * in_channel, height, width) + weight = weight.view( + batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size + ) + weight = weight.transpose(1, 2).reshape( + batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size + ) + out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + else: + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + return out, style + + +class NoiseInjection(nn.Module): + def __init__(self): + super().__init__() + + self.weight = nn.Parameter(torch.zeros(1)) + + def forward(self, image, noise=None): + if noise is None: + batch, _, height, width = image.shape + noise = image.new_empty(batch, 1, height, width).normal_() + + return image + self.weight * noise + + +class ConstantInput(nn.Module): + def __init__(self, channel, size=7): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + + +class StyledConv(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.noise = NoiseInjection() + # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + # self.activate = ScaledLeakyReLU(0.2) + self.activate = FusedLeakyReLU(out_channel) + + def forward(self, input, style, noise=None): + out, _ = self.conv(input, style) + out = self.noise(out, noise=noise) + # out = out + self.bias + out = self.activate(out) + + return out + + +class ToRGB(nn.Module): + def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, style, skip=None): + out, style = self.conv(input, style) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + + out = out + skip + + return out, style + + +class StyleGAN2Generator(BaseNetwork): + def __init__( + self, + opt, + style_dim=2580, + n_mlp=8, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + input_is_latent=True, + ): + super().__init__() + + self.size = opt.crop_size + + self.feature_encoded_dim = opt.feature_encoded_dim + + self.style_dim = style_dim + + self.input_is_latent = input_is_latent + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear( + self.feature_encoded_dim, self.style_dim, lr_mul=lr_mlp, activation='fused_lrelu' + ) + ) + + self.style = nn.Sequential(*layers) + self.init_size = 4 + if self.size % 7 == 0: + self.channels = { + 7: 512, + 14: 512, + 28: 512, + 56: 256 * channel_multiplier, + 112: 128 * channel_multiplier, + 224: 64 * channel_multiplier, + 448: 32 * channel_multiplier, + } + self.init_size = 7 + else: + self.channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + self.input = ConstantInput(self.channels[self.init_size], size=self.init_size) + self.conv1 = StyledConv( + self.channels[self.init_size], self.channels[self.init_size], 3, self.style_dim, blur_kernel=blur_kernel + ) + self.to_rgb1 = ToRGB(self.channels[self.init_size], self.style_dim, upsample=False) + + self.log_size = int(math.log(self.size // self.init_size, 2)) + self.num_layers = self.log_size * 2 + 1 + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + self.return_middle = opt.style_feature_loss + + in_channel = self.channels[self.init_size] + + for layer_idx in range(self.num_layers): + res = (layer_idx + 1) // 2 + shape = [1, 1, self.init_size * 2 ** res, self.init_size * 2 ** res] + self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape)) + + for i in range(1, self.log_size + 1): + out_channel = self.channels[self.init_size * 2 ** i] + + self.convs.append( + StyledConv( + in_channel, + out_channel, + 3, + self.style_dim, + upsample=True, + blur_kernel=blur_kernel, + ) + ) + + self.convs.append( + StyledConv( + out_channel, out_channel, 3, self.style_dim, blur_kernel=blur_kernel + ) + ) + + self.to_rgbs.append(ToRGB(out_channel, self.style_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 + 2 + self.tanh = nn.Tanh() + + def forward( + self, + styles, + identity_style=None, + return_latents=False, + inject_index=None, + truncation=1, + truncation_latent=None, + noise=None, + randomize_noise=True, + ): + + Style_RGB = [] + if not self.input_is_latent: + styles = [self.style(s) for s in styles] + + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers + else: + noise = [ + getattr(self.noises, f'noise_{i}') for i in range(self.num_layers) + ] + + if truncation < 1: + style_t = [] + + for style in styles: + style_t.append( + truncation_latent + truncation * (style - truncation_latent) + ) + + styles = style_t + + if len(styles) < 2: + inject_index = self.n_latent + + if styles[0].ndim < 3: + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + + else: + latent = styles[0] + + else: + if inject_index is None: + inject_index = random.randint(1, self.n_latent - 1) + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) + + latent = torch.cat([latent, latent2], 1) + + if identity_style is not None: + out = identity_style + else: + out = self.input(latent) + out = self.conv1(out, latent[:, 0], noise=noise[0]) + + skip, style_rgb = self.to_rgb1(out, latent[:, 1]) + Style_RGB.append(style_rgb) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs + ): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip, style_rgb = to_rgb(out, latent[:, i + 2], skip) + Style_RGB.append(style_rgb) + i += 2 + + image = skip + image = self.tanh(image) + + if return_latents: + return image, latent + elif self.return_middle: + return image, Style_RGB + + else: + return image, None + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append( + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + ) + ) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer( + in_channel, out_channel, 1, downsample=True, activate=False, bias=False + ) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + +class ModulateGenerator(StyleGAN2Generator): + def __init__(self, opt): + super(ModulateGenerator, self).__init__(opt, style_dim=opt.style_dim) \ No newline at end of file diff --git a/Talking-Face_PC-AVS/models/networks/loss.py b/Talking-Face_PC-AVS/models/networks/loss.py new file mode 100644 index 00000000..cf5b64c6 --- /dev/null +++ b/Talking-Face_PC-AVS/models/networks/loss.py @@ -0,0 +1,196 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from models.networks.architecture import VGG19, VGGFace19 + + +# Defines the GAN loss which uses either LSGAN or the regular GAN. +# When LSGAN is used, it is basically same as MSELoss, +# but it abstracts away the need to create the target label tensor +# that has the same size as the input +class GANLoss(nn.Module): + def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, + tensor=torch.FloatTensor, opt=None): + super(GANLoss, self).__init__() + self.real_label = target_real_label + self.fake_label = target_fake_label + self.real_label_tensor = None + self.fake_label_tensor = None + self.zero_tensor = None + self.Tensor = tensor + self.gan_mode = gan_mode + self.opt = opt + if gan_mode == 'ls': + pass + elif gan_mode == 'original': + pass + elif gan_mode == 'w': + pass + elif gan_mode == 'hinge': + pass + else: + raise ValueError('Unexpected gan_mode {}'.format(gan_mode)) + + def get_target_tensor(self, input, target_is_real): + if target_is_real: + if self.real_label_tensor is None: + self.real_label_tensor = self.Tensor(1).fill_(self.real_label) + self.real_label_tensor.requires_grad_(False) + return self.real_label_tensor.expand_as(input) + else: + if self.fake_label_tensor is None: + self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label) + self.fake_label_tensor.requires_grad_(False) + return self.fake_label_tensor.expand_as(input) + + def get_zero_tensor(self, input): + if self.zero_tensor is None: + self.zero_tensor = self.Tensor(1).fill_(0) + self.zero_tensor.requires_grad_(False) + return self.zero_tensor.expand_as(input) + + def loss(self, input, target_is_real, for_discriminator=True): + if self.gan_mode == 'original': # cross entropy loss + target_tensor = self.get_target_tensor(input, target_is_real) + loss = F.binary_cross_entropy_with_logits(input, target_tensor) + return loss + elif self.gan_mode == 'ls': + target_tensor = self.get_target_tensor(input, target_is_real) + return F.mse_loss(input, target_tensor) + elif self.gan_mode == 'hinge': + if for_discriminator: + if target_is_real: + minval = torch.min(input - 1, self.get_zero_tensor(input)) + loss = -torch.mean(minval) + else: + minval = torch.min(-input - 1, self.get_zero_tensor(input)) + loss = -torch.mean(minval) + else: + assert target_is_real, "The generator's hinge loss must be aiming for real" + loss = -torch.mean(input) + return loss + else: + # wgan + if target_is_real: + return -input.mean() + else: + return input.mean() + + def __call__(self, input, target_is_real, for_discriminator=True): + # computing loss is a bit complicated because |input| may not be + # a tensor, but list of tensors in case of multiscale discriminator + if isinstance(input, list): + loss = 0 + for pred_i in input: + if isinstance(pred_i, list): + pred_i = pred_i[-1] + loss_tensor = self.loss(pred_i, target_is_real, for_discriminator) + bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) + new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) + loss += new_loss + return loss / len(input) + else: + return self.loss(input, target_is_real, for_discriminator) + + +# Perceptual loss that uses a pretrained VGG network +class VGGLoss(nn.Module): + def __init__(self, opt, vgg=VGG19()): + super(VGGLoss, self).__init__() + self.vgg = vgg.cuda() + self.criterion = nn.L1Loss() + self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] + + def forward(self, x, y, layer=0): + x_vgg, y_vgg = self.vgg(x), self.vgg(y) + loss = 0 + for i in range(len(x_vgg)): + if i >= layer: + loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) + return loss + + +# KL Divergence loss used in VAE with an image encoder +class KLDLoss(nn.Module): + def forward(self, mu, logvar): + return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + + +class CrossEntropyLoss(nn.Module): + """Cross Entropy Loss + + It will calculate cross_entropy loss given cls_score and label. + """ + + def forward(self, cls_score, label): + loss_cls = F.cross_entropy(cls_score, label) + return loss_cls + + +class SumLogSoftmaxLoss(nn.Module): + + def forward(self, x): + out = F.log_softmax(x, dim=1) + loss = - torch.mean(out) + torch.mean(F.log_softmax(torch.ones_like(out), dim=1) ) + return loss + + +class L2SoftmaxLoss(nn.Module): + def __init__(self): + super(L2SoftmaxLoss, self).__init__() + self.softmax = nn.Softmax() + self.L2loss = nn.MSELoss() + self.label = None + + def forward(self, x): + out = self.softmax(x) + self.label = (torch.ones(out.size()).float() * (1 / x.size(1))).cuda() + loss = self.L2loss(out, self.label) + return loss + + +class SoftmaxContrastiveLoss(nn.Module): + def __init__(self): + super(SoftmaxContrastiveLoss, self).__init__() + self.cross_ent = nn.CrossEntropyLoss() + + def l2_norm(self, x): + x_norm = F.normalize(x, p=2, dim=1) + return x_norm + + def l2_sim(self, feature1, feature2): + Feature = feature1.expand(feature1.size(0), feature1.size(0), feature1.size(1)).transpose(0, 1) + return torch.norm(Feature - feature2, p=2, dim=2) + + @torch.no_grad() + def evaluate(self, face_feat, audio_feat, mode='max'): + assert mode in 'max' or 'confusion', '{} must be in max or confusion'.format(mode) + face_feat = self.l2_norm(face_feat) + audio_feat = self.l2_norm(audio_feat) + cross_dist = 1.0 / self.l2_sim(face_feat, audio_feat) + + print(cross_dist) + if mode == 'max': + label = torch.arange(face_feat.size(0)).to(cross_dist.device) + max_idx = torch.argmax(cross_dist, dim=1) + # print(max_idx, label) + acc = torch.sum(label == max_idx) * 1.0 / label.size(0) + else: + raise ValueError + + return acc + + def forward(self, face_feat, audio_feat, mode='max'): + assert mode in 'max' or 'confusion', '{} must be in max or confusion'.format(mode) + + face_feat = self.l2_norm(face_feat) + audio_feat = self.l2_norm(audio_feat) + + cross_dist = 1.0 / self.l2_sim(face_feat, audio_feat) + + if mode == 'max': + label = torch.arange(face_feat.size(0)).to(cross_dist.device) + loss = F.cross_entropy(cross_dist, label) + else: + raise ValueError + return loss diff --git a/Talking-Face_PC-AVS/models/networks/sync_batchnorm/__init__.py b/Talking-Face_PC-AVS/models/networks/sync_batchnorm/__init__.py new file mode 100644 index 00000000..5459114b --- /dev/null +++ b/Talking-Face_PC-AVS/models/networks/sync_batchnorm/__init__.py @@ -0,0 +1,3 @@ +from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d +from .batchnorm import patch_sync_batchnorm, convert_model +from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/Talking-Face_PC-AVS/models/networks/sync_batchnorm/batchnorm.py b/Talking-Face_PC-AVS/models/networks/sync_batchnorm/batchnorm.py new file mode 100644 index 00000000..be9ef149 --- /dev/null +++ b/Talking-Face_PC-AVS/models/networks/sync_batchnorm/batchnorm.py @@ -0,0 +1,384 @@ +import collections +import contextlib + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm + +try: + from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast +except ImportError: + ReduceAddCoalesced = Broadcast = None + +try: + from jactorch.parallel.comm import SyncMaster + from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback +except ImportError: + from .comm import SyncMaster + from .replicate import DataParallelWithCallback + +__all__ = [ + 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d', + 'patch_sync_batchnorm', 'convert_model' +] + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dimensions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): + assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' + + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + # Resize the input to (B, C, -1). + input_shape = input.size() + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + if hasattr(torch, 'no_grad'): + with torch.no_grad(): + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + else: + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + + return mean, bias_var.clamp(self.eps) ** -0.5 + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm1d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm2d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape:: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm3d, self)._check_input_dim(input) + + +@contextlib.contextmanager +def patch_sync_batchnorm(): + import torch.nn as nn + + backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d + + nn.BatchNorm1d = SynchronizedBatchNorm1d + nn.BatchNorm2d = SynchronizedBatchNorm2d + nn.BatchNorm3d = SynchronizedBatchNorm3d + + yield + + nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup + + +def convert_model(module): + """Traverse the input module and its child recursively + and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d + to SynchronizedBatchNorm*N*d + + Args: + module: the input module needs to be convert to SyncBN model + + Examples: + >>> import torch.nn as nn + >>> import torchvision + >>> # m is a standard pytorch model + >>> m = torchvision.models.resnet18(True) + >>> m = nn.DataParallel(m) + >>> # after convert, m is using SyncBN + >>> m = convert_model(m) + """ + if isinstance(module, torch.nn.DataParallel): + mod = module.module + mod = convert_model(mod) + mod = DataParallelWithCallback(mod) + return mod + + mod = module + for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d, + torch.nn.modules.batchnorm.BatchNorm2d, + torch.nn.modules.batchnorm.BatchNorm3d], + [SynchronizedBatchNorm1d, + SynchronizedBatchNorm2d, + SynchronizedBatchNorm3d]): + if isinstance(module, pth_module): + mod = sync_module(module.num_features, module.eps, module.momentum, module.affine) + mod.running_mean = module.running_mean + mod.running_var = module.running_var + if module.affine: + mod.weight.data = module.weight.data.clone().detach() + mod.bias.data = module.bias.data.clone().detach() + + for name, child in module.named_children(): + mod.add_module(name, convert_model(child)) + + return mod diff --git a/Talking-Face_PC-AVS/models/networks/sync_batchnorm/batchnorm_reimpl.py b/Talking-Face_PC-AVS/models/networks/sync_batchnorm/batchnorm_reimpl.py new file mode 100644 index 00000000..31a8d08c --- /dev/null +++ b/Talking-Face_PC-AVS/models/networks/sync_batchnorm/batchnorm_reimpl.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn +import torch.nn.init as init + +__all__ = ['BatchNorm2dReimpl'] + + +class BatchNorm2dReimpl(nn.Module): + """ + A re-implementation of batch normalization, used for testing the numerical + stability. + + Author: acgtyrant + See also: + https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 + """ + def __init__(self, num_features, eps=1e-5, momentum=0.1): + super().__init__() + + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.weight = nn.Parameter(torch.empty(num_features)) + self.bias = nn.Parameter(torch.empty(num_features)) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_running_stats(self): + self.running_mean.zero_() + self.running_var.fill_(1) + + def reset_parameters(self): + self.reset_running_stats() + init.uniform_(self.weight) + init.zeros_(self.bias) + + def forward(self, input_): + batchsize, channels, height, width = input_.size() + numel = batchsize * height * width + input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) + sum_ = input_.sum(1) + sum_of_square = input_.pow(2).sum(1) + mean = sum_ / numel + sumvar = sum_of_square - sum_ * mean + + self.running_mean = ( + (1 - self.momentum) * self.running_mean + + self.momentum * mean.detach() + ) + unbias_var = sumvar / (numel - 1) + self.running_var = ( + (1 - self.momentum) * self.running_var + + self.momentum * unbias_var.detach() + ) + + bias_var = sumvar / numel + inv_std = 1 / (bias_var + self.eps).pow(0.5) + output = ( + (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * + self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) + + return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() + diff --git a/Talking-Face_PC-AVS/models/networks/sync_batchnorm/comm.py b/Talking-Face_PC-AVS/models/networks/sync_batchnorm/comm.py new file mode 100644 index 00000000..0e159b3f --- /dev/null +++ b/Talking-Face_PC-AVS/models/networks/sync_batchnorm/comm.py @@ -0,0 +1,127 @@ +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/Talking-Face_PC-AVS/models/networks/sync_batchnorm/replicate.py b/Talking-Face_PC-AVS/models/networks/sync_batchnorm/replicate.py new file mode 100644 index 00000000..367dd99f --- /dev/null +++ b/Talking-Face_PC-AVS/models/networks/sync_batchnorm/replicate.py @@ -0,0 +1,120 @@ +import functools +import torch + +from torch.nn.parallel.data_parallel import DataParallel +from .scatter_gather import scatter_kwargs + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + def __init__(self, module, device_ids=None, output_device=None, dim=0, chunk_size=None): + super(DataParallelWithCallback, self).__init__(module) + + if not torch.cuda.is_available(): + self.module = module + self.device_ids = [] + return + + if device_ids is None: + device_ids = list(range(torch.cuda.device_count())) + if output_device is None: + output_device = device_ids[0] + self.dim = dim + self.module = module + self.device_ids = device_ids + self.output_device = output_device + self.chunk_size = chunk_size + + if len(self.device_ids) == 1: + self.module.cuda(device_ids[0]) + + def forward(self, *inputs, **kwargs): + if not self.device_ids: + return self.module(*inputs, **kwargs) + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids, self.chunk_size) + if len(self.device_ids) == 1: + return self.module(*inputs[0], **kwargs[0]) + replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) + outputs = self.parallel_apply(replicas, inputs, kwargs) + return self.gather(outputs, self.output_device) + + def scatter(self, inputs, kwargs, device_ids, chunk_size): + return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim, chunk_size=self.chunk_size) + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/Talking-Face_PC-AVS/models/networks/sync_batchnorm/scatter_gather.py b/Talking-Face_PC-AVS/models/networks/sync_batchnorm/scatter_gather.py new file mode 100644 index 00000000..f6629c94 --- /dev/null +++ b/Talking-Face_PC-AVS/models/networks/sync_batchnorm/scatter_gather.py @@ -0,0 +1,44 @@ +import torch +from torch.nn.parallel._functions import Scatter, Gather + + +def scatter(inputs, target_gpus, dim=0, chunk_size=None): + r""" + Slices tensors into approximately equal chunks and + distributes them across given GPUs. Duplicates + references to objects that are not tensors. + """ + def scatter_map(obj): + if isinstance(obj, torch.Tensor): + return Scatter.apply(target_gpus, chunk_size, dim, obj) + if isinstance(obj, tuple) and len(obj) > 0: + return list(zip(*map(scatter_map, obj))) + if isinstance(obj, list) and len(obj) > 0: + return list(map(list, zip(*map(scatter_map, obj)))) + if isinstance(obj, dict) and len(obj) > 0: + return list(map(type(obj), zip(*map(scatter_map, obj.items())))) + return [obj for targets in target_gpus] + + # After scatter_map is called, a scatter_map cell will exist. This cell + # has a reference to the actual function scatter_map, which has references + # to a closure that has a reference to the scatter_map cell (because the + # fn is recursive). To avoid this reference cycle, we set the function to + # None, clearing the cell + try: + res = scatter_map(inputs) + finally: + scatter_map = None + return res + + +def scatter_kwargs(inputs, kwargs, target_gpus, dim=0, chunk_size=None): + r"""Scatter with support for kwargs dictionary""" + inputs = scatter(inputs, target_gpus, dim, chunk_size) if inputs else [] + kwargs = scatter(kwargs, target_gpus, dim, chunk_size) if kwargs else [] + if len(inputs) < len(kwargs): + inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) + elif len(kwargs) < len(inputs): + kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) + inputs = tuple(inputs) + kwargs = tuple(kwargs) + return inputs, kwargs diff --git a/Talking-Face_PC-AVS/models/networks/sync_batchnorm/unittest.py b/Talking-Face_PC-AVS/models/networks/sync_batchnorm/unittest.py new file mode 100644 index 00000000..bdf38472 --- /dev/null +++ b/Talking-Face_PC-AVS/models/networks/sync_batchnorm/unittest.py @@ -0,0 +1,19 @@ +import unittest +import torch + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, x, y): + adiff = float((x - y).abs().max()) + if (y == 0).all(): + rdiff = 'NaN' + else: + rdiff = float((adiff / y).abs().max()) + + message = ( + 'Tensor close check failed\n' + 'adiff={}\n' + 'rdiff={}\n' + ).format(adiff, rdiff) + self.assertTrue(torch.allclose(x, y), message) + diff --git a/Talking-Face_PC-AVS/models/networks/util.py b/Talking-Face_PC-AVS/models/networks/util.py new file mode 100644 index 00000000..d0fe6d83 --- /dev/null +++ b/Talking-Face_PC-AVS/models/networks/util.py @@ -0,0 +1,172 @@ +"""This module contains simple helper functions """ +from __future__ import print_function +import torch +import numpy as np +from PIL import Image +import os +from math import * + +def P2sRt(P): + ''' decompositing camera matrix P. + Args: + P: (3, 4). Affine Camera Matrix. + Returns: + s: scale factor. + R: (3, 3). rotation matrix. + t2d: (2,). 2d translation. + ''' + t3d = P[:, 3] + R1 = P[0:1, :3] + R2 = P[1:2, :3] + s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2.0 + r1 = R1 / np.linalg.norm(R1) + r2 = R2 / np.linalg.norm(R2) + r3 = np.cross(r1, r2) + + R = np.concatenate((r1, r2, r3), 0) + return s, R, t3d + +def matrix2angle(R): + ''' compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf + Args: + R: (3,3). rotation matrix + Returns: + x: yaw + y: pitch + z: roll + ''' + # assert(isRotationMatrix(R)) + + if R[2, 0] != 1 and R[2, 0] != -1: + x = -asin(max(-1, min(R[2, 0], 1))) + y = atan2(R[2, 1] / cos(x), R[2, 2] / cos(x)) + z = atan2(R[1, 0] / cos(x), R[0, 0] / cos(x)) + + else: # Gimbal lock + z = 0 # can be anything + if R[2, 0] == -1: + x = np.pi / 2 + y = z + atan2(R[0, 1], R[0, 2]) + else: + x = -np.pi / 2 + y = -z + atan2(-R[0, 1], -R[0, 2]) + + return [x, y, z] + +def angle2matrix(angles): + ''' get rotation matrix from three rotation angles(radian). The same as in 3DDFA. + Args: + angles: [3,]. x, y, z angles + x: yaw. + y: pitch. + z: roll. + Returns: + R: 3x3. rotation matrix. + ''' + # x, y, z = np.deg2rad(angles[0]), np.deg2rad(angles[1]), np.deg2rad(angles[2]) + # x, y, z = angles[0], angles[1], angles[2] + y, x, z = angles[0], angles[1], angles[2] + + # x + Rx=np.array([[1, 0, 0], + [0, cos(x), -sin(x)], + [0, sin(x), cos(x)]]) + # y + Ry=np.array([[ cos(y), 0, sin(y)], + [ 0, 1, 0], + [-sin(y), 0, cos(y)]]) + # z + Rz=np.array([[cos(z), -sin(z), 0], + [sin(z), cos(z), 0], + [ 0, 0, 1]]) + R = Rz.dot(Ry).dot(Rx) + return R.astype(np.float32) + +def tensor2im(input_image, imtype=np.uint8): + """"Converts a Tensor array into a numpy image array. + + Parameters: + input_image (tensor) -- the input image tensor array + imtype (type) -- the desired type of the converted numpy array + """ + if not isinstance(input_image, np.ndarray): + if isinstance(input_image, torch.Tensor): # get the data from a variable + image_tensor = input_image.data + else: + return input_image + image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array + if image_numpy.shape[0] == 1: # grayscale to RGB + image_numpy = np.tile(image_numpy, (3, 1, 1)) + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling + else: # if it is a numpy array, do nothing + image_numpy = input_image + return image_numpy.astype(imtype) + + +def diagnose_network(net, name='network'): + """Calculate and print the mean of average absolute(gradients) + + Parameters: + net (torch network) -- Torch network + name (str) -- the name of the network + """ + mean = 0.0 + count = 0 + for param in net.parameters(): + if param.grad is not None: + mean += torch.mean(torch.abs(param.grad.data)) + count += 1 + if count > 0: + mean = mean / count + print(name) + print(mean) + + +def save_image(image_numpy, image_path): + """Save a numpy image to the disk + + Parameters: + image_numpy (numpy array) -- input numpy array + image_path (str) -- the path of the image + """ + image_pil = Image.fromarray(image_numpy) + image_pil.save(image_path) + + +def print_numpy(x, val=True, shp=False): + """Print the mean, min, max, median, std, and size of a numpy array + + Parameters: + val (bool) -- if print the values of the numpy array + shp (bool) -- if print the shape of the numpy array + """ + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( + np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + + +def mkdirs(paths): + """create empty directories if they don't exist + + Parameters: + paths (str list) -- a list of directory paths + """ + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + """create a single empty directory if it didn't exist + + Parameters: + path (str) -- a single directory path + """ + if not os.path.exists(path): + os.makedirs(path) diff --git a/Talking-Face_PC-AVS/models/networks/vision_network.py b/Talking-Face_PC-AVS/models/networks/vision_network.py new file mode 100644 index 00000000..31ee2040 --- /dev/null +++ b/Talking-Face_PC-AVS/models/networks/vision_network.py @@ -0,0 +1,54 @@ +import torch.nn as nn +import torch.nn.functional as F +from models.networks.base_network import BaseNetwork +from torchvision.models.resnet import ResNet, Bottleneck +from util import util +import torch + +model_urls = { + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', +} + + +class ResNeXt50(BaseNetwork): + def __init__(self, opt): + super(ResNeXt50, self).__init__() + self.model = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4) + self.opt = opt + # self.reduced_id_dim = opt.reduced_id_dim + self.conv1x1 = nn.Conv2d(512 * Bottleneck.expansion, 512, kernel_size=1, padding=0) + self.fc = nn.Linear(512 * Bottleneck.expansion, opt.num_classes) + # self.fc_pre = nn.Sequential(nn.Linear(512 * Bottleneck.expansion, self.reduced_id_dim), nn.ReLU()) + + + def load_pretrain(self): + check_point = torch.load(model_urls['resnext50_32x4d']) + util.copy_state_dict(check_point, self.model) + + def forward_feature(self, input): + x = self.model.conv1(input) + x = self.model.bn1(x) + x = self.model.relu(x) + x = self.model.maxpool(x) + + x = self.model.layer1(x) + x = self.model.layer2(x) + x = self.model.layer3(x) + x = self.model.layer4(x) + net = self.model.avgpool(x) + net = torch.flatten(net, 1) + x = self.conv1x1(x) + # x = self.fc_pre(x) + return net, x + + def forward(self, input): + input_batch = input.view(-1, self.opt.output_nc, self.opt.crop_size, self.opt.crop_size) + net, x = self.forward_feature(input_batch) + net = net.view(-1, self.opt.num_inputs, 512 * Bottleneck.expansion) + x = F.adaptive_avg_pool2d(x, (7, 7)) + x = x.view(-1, self.opt.num_inputs, 512, 7, 7) + net = torch.mean(net, 1) + x = torch.mean(x, 1) + cls_scores = self.fc(net) + + return [net, x], cls_scores diff --git a/Talking-Face_PC-AVS/options/__init__.py b/Talking-Face_PC-AVS/options/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Talking-Face_PC-AVS/options/base_options.py b/Talking-Face_PC-AVS/options/base_options.py new file mode 100644 index 00000000..9807d98a --- /dev/null +++ b/Talking-Face_PC-AVS/options/base_options.py @@ -0,0 +1,255 @@ +import sys +import argparse +import math +import os +from util import util +import torch +import models +import data +import pickle + + +class BaseOptions(): + def __init__(self): + self.initialized = False + + def initialize(self, parser): + # experiment specifics + parser.add_argument('--name', type=str, default='demo', help='name of the experiment. It decides where to store samples and models') + parser.add_argument('--filename_tmpl', type=str, default='{:06}.jpg', help='name of the experiment. It decides where to store samples and models') + parser.add_argument('--data_path', type=str, default='/home/SENSETIME/zhouhang1/Downloads/VoxCeleb2/voxceleb2_train.csv', help='where to load voxceleb train data') + parser.add_argument('--lrw_data_path', type=str, + default='/home/SENSETIME/zhouhang1/Downloads/VoxCeleb2/voxceleb2_train.csv', + help='where to load lrw train data') + + parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids') + parser.add_argument('--num_classes', type=int, default=5830, help='num classes') + parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') + parser.add_argument('--model', type=str, default='av', help='which model to use, rotate|rotatespade') + parser.add_argument('--trainer', type=str, default='audio', help='which trainer to use, rotate|rotatespade') + parser.add_argument('--norm_G', type=str, default='spectralinstance', help='instance normalization or batch normalization') + parser.add_argument('--norm_D', type=str, default='spectralinstance', help='instance normalization or batch normalization') + parser.add_argument('--norm_E', type=str, default='spectralinstance', help='instance normalization or batch normalization') + parser.add_argument('--norm_A', type=str, default='spectralinstance', help='instance normalization or batch normalization') + parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + # input/output sizes + parser.add_argument('--batchSize', type=int, default=2, help='input batch size') + parser.add_argument('--preprocess_mode', type=str, default='resize_and_crop', help='scaling and cropping of images at load time.', choices=("resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "scale_shortside", "scale_shortside_and_crop", "fixed", "none")) + parser.add_argument('--crop_size', type=int, default=224, help='Crop to the width of crop_size (after initially scaling the images to load_size.)') + parser.add_argument('--crop_len', type=int, default=16, help='Crop len') + parser.add_argument('--target_crop_len', type=int, default=0, help='Crop len') + parser.add_argument('--crop', action='store_true', help='whether to crop the image') + parser.add_argument('--clip_len', type=int, default=1, help='num of imgs to process') + parser.add_argument('--pose_dim', type=int, default=12, help='num of imgs to process') + parser.add_argument('--frame_interval', type=int, default=1, help='the interval of frams') + parser.add_argument('--num_clips', type=int, default=1, help='num of clips to process') + parser.add_argument('--num_inputs', type=int, default=1, help='num of inputs to the network') + parser.add_argument('--feature_encoded_dim', type=int, default=2560, help='dim of reduced id feature') + + parser.add_argument('--aspect_ratio', type=float, default=1.0, help='The ratio width/height. The final height of the load image will be crop_size/aspect_ratio') + parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') + parser.add_argument('--audio_nc', type=int, default=256, help='# of output audio channels') + parser.add_argument('--frame_rate', type=int, default=25, help='fps') + parser.add_argument('--num_frames_per_clip', type=int, default=5, help='num of frames one audio bin') + parser.add_argument('--hop_size', type=int, default=160, help='audio hop size') + parser.add_argument('--generate_interval', type=int, default=1, help='select frames to generate') + parser.add_argument('--dis_feat_rec', action='store_true', help='select frames to generate') + + parser.add_argument('--train_recognition', action='store_true', help='train recognition only') + parser.add_argument('--train_sync', action='store_true', help='train sync only') + parser.add_argument('--train_word', action='store_true', help='train word only') + parser.add_argument('--train_dis_pose', action='store_true', help='train dis pose') + parser.add_argument('--generate_from_audio_only', action='store_true', help='if specified, generate only from audio features') + parser.add_argument('--noise_pose', action='store_true', help='noise pose to generate a talking face') + parser.add_argument('--style_feature_loss', action='store_true', help='style_feature_loss') + + # for setting inputsf + parser.add_argument('--dataset_mode', type=str, default='voxtest') + parser.add_argument('--landmark_align', action='store_true', help='wether there is landmark_align') + parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation') + parser.add_argument('--nThreads', default=1, type=int, help='# threads for loading data') + parser.add_argument('--n_mel_T', default=4, type=int, help='# threads for loading data') + parser.add_argument('--num_bins_per_frame', type=int, default=4, help='n_melT') + + parser.add_argument('--max_dataset_size', type=int, default=sys.maxsize, help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + parser.add_argument('--load_from_opt_file', action='store_true', help='load the options from checkpoints and use that as default') + parser.add_argument('--use_audio', type=int, default=1, help='use audio as driven input') + parser.add_argument('--use_audio_id', type=int, default=0, help='use audio id') + parser.add_argument('--augment_target', action='store_true', help='whether to use checkpoint') + parser.add_argument('--verbose', action='store_true', help='just add') + + parser.add_argument('--display_winsize', type=int, default=224, help='display window size') + + # for generator + parser.add_argument('--netG', type=str, default='modulate', help='selects model to use for netG (modulate)') + parser.add_argument('--netA', type=str, default='resseaudio', help='selects model to use for netA (audio | spade)') + parser.add_argument('--netA_sync', type=str, default='ressesync', help='selects model to use for netA (audio | spade)') + parser.add_argument('--netV', type=str, default='resnext', help='selects model to use for netV (mobile | id)') + parser.add_argument('--netE', type=str, default='fan', help='selects model to use for netV (mobile | fan)') + parser.add_argument('--netD', type=str, default='multiscale', help='(n_layers|multiscale|image|projection)') + parser.add_argument('--D_input', type=str, default='single', help='(concat|single|hinge)') + parser.add_argument('--driven_type', type=str, default='face', help='selects model to use for netV (heatmap | face)') + parser.add_argument('--landmark_type', type=str, default='min', help='selects model to use for netV (mobile | fan)') + parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') + parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]') + parser.add_argument('--feature_fusion', type=str, default='concat', help='style fusion method') + parser.add_argument('--init_variance', type=float, default=0.02, help='variance of the initialization distribution') + + # for instance-wise features + parser.add_argument('--no_instance', action='store_true', help='if specified, do *not* add instance map as input') + parser.add_argument('--input_id_feature', action='store_true', help='if specified, use id feature as style gan input') + parser.add_argument('--load_landmark', action='store_true', help='if specified, load landmarks') + parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer') + parser.add_argument('--style_dim', type=int, default=2580, help='# of encoder filters in the first conv layer') + + ####################### weight settings ################################################################### + + parser.add_argument('--vgg_face', action='store_true', help='if specified, use VGG feature matching loss') + + parser.add_argument('--VGGFace_pretrain_path', type=str, default='', help='VGGFace pretrain path') + parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') + parser.add_argument('--lambda_image', type=float, default=1.0, help='weight for image reconstruction') + parser.add_argument('--lambda_vgg', type=float, default=10.0, help='weight for vgg loss') + parser.add_argument('--lambda_vggface', type=float, default=5.0, help='weight for vggface loss') + parser.add_argument('--lambda_rotate_D', type=float, default='0.1', + help='rotated D loss weight') + parser.add_argument('--lambda_D', type=float, default=1, + help='D loss weight') + parser.add_argument('--lambda_softmax', type=float, default=1000000, help='weight for softmax loss') + parser.add_argument('--lambda_crossmodal', type=float, default=1, help='weight for softmax loss') + + parser.add_argument('--lambda_contrastive', type=float, default=100, help='if specified, use contrastive loss for img and audio embed') + parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') + + parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss') + parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss') + parser.add_argument('--no_id_loss', action='store_true', help='if specified, do *not* use cls loss') + parser.add_argument('--word_loss', action='store_true', help='if specified, do *not* use cls loss') + parser.add_argument('--no_spectrogram', action='store_true', help='if specified, do *not* use mel spectrogram, use mfcc') + + parser.add_argument('--gan_mode', type=str, default='hinge', help='(ls|original|hinge)') + parser.add_argument('--no_TTUR', action='store_true', help='Use TTUR training scheme') + + ############################## optimizer ############################# + parser.add_argument('--optimizer', type=str, default='adam') + parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') + parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam') + parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate for adam') + + parser.add_argument('--no_gaussian_landmark', action='store_true', help='whether to use no_gaussian_landmark (1.0 landmark) for rotatespade model') + parser.add_argument('--label_mask', action='store_true', help='whether to use face mask') + parser.add_argument('--positional_encode', action='store_true', help='whether to use positional encode') + parser.add_argument('--use_transformer', action='store_true', help='whether to use transformer') + parser.add_argument('--has_mask', action='store_true', help='whether to use mask in transformer') + parser.add_argument('--heatmap_size', type=float, default=3, help='the size of the heatmap, used in rotatespade model') + + self.initialized = True + return parser + + def gather_options(self): + # initialize parser with basic options + if not self.initialized: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = self.initialize(parser) + + # get the basic options + opt, unknown = parser.parse_known_args() + + # modify model-related parser options + model_name = opt.model + model_option_setter = models.get_option_setter(model_name) + parser = model_option_setter(parser, self.isTrain) + + # modify dataset-related parser options + dataset_mode = opt.dataset_mode + dataset_modes = opt.dataset_mode.split(',') + + if len(dataset_modes) == 1: + dataset_option_setter = data.get_option_setter(dataset_mode) + parser = dataset_option_setter(parser, self.isTrain) + else: + for dm in dataset_modes: + dataset_option_setter = data.get_option_setter(dm) + parser = dataset_option_setter(parser, self.isTrain) + + opt, unknown = parser.parse_known_args() + + # if there is opt_file, load it. + # lt options will be overwritten + if opt.load_from_opt_file: + parser = self.update_options_from_file(parser, opt) + + opt = parser.parse_args() + self.parser = parser + return opt + + def print_options(self, opt): + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + def option_file_path(self, opt, makedir=False): + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + if makedir: + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, 'opt') + return file_name + + def save_options(self, opt): + file_name = self.option_file_path(opt, makedir=True) + with open(file_name + '.txt', 'wt') as opt_file: + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)) + + with open(file_name + '.pkl', 'wb') as opt_file: + pickle.dump(opt, opt_file) + + def update_options_from_file(self, parser, opt): + new_opt = self.load_options(opt) + for k, v in sorted(vars(opt).items()): + if hasattr(new_opt, k) and v != getattr(new_opt, k): + new_val = getattr(new_opt, k) + parser.set_defaults(**{k: new_val}) + return parser + + def load_options(self, opt): + file_name = self.option_file_path(opt, makedir=False) + new_opt = pickle.load(open(file_name + '.pkl', 'rb')) + return new_opt + + def parse(self, save=False): + + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + + self.print_options(opt) + if opt.isTrain: + self.save_options(opt) + # Set semantic_nc based on the option. + # This will be convenient in many places + # set gpu ids + str_ids = opt.gpu_ids.split(',') + opt.gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + opt.gpu_ids.append(id) + if len(opt.gpu_ids) > 0: + torch.cuda.set_device(opt.gpu_ids[0]) + + + self.opt = opt + return self.opt diff --git a/Talking-Face_PC-AVS/options/test_options.py b/Talking-Face_PC-AVS/options/test_options.py new file mode 100644 index 00000000..cd8a79b5 --- /dev/null +++ b/Talking-Face_PC-AVS/options/test_options.py @@ -0,0 +1,30 @@ +from .base_options import BaseOptions + + +class TestOptions(BaseOptions): + def initialize(self, parser): + BaseOptions.initialize(self, parser) + parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') + parser.add_argument('--input_path', type=str, default='./checkpoints/results/input_path', help='defined input path.') + parser.add_argument('--meta_path_vox', type=str, default='./misc/demo.csv', help='the meta data path') + parser.add_argument('--driving_pose', action='store_true', help='driven pose to generate a talking face') + parser.add_argument('--list_num', type=int, default=0, help='list num') + parser.add_argument('--fitting_iterations', type=int, default=10, help='The iterarions for fit testing') + parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + parser.add_argument('--how_many', type=int, default=float("inf"), help='how many test images to run') + parser.add_argument('--start_ind', type=int, default=0, help='the start id for defined driven') + parser.add_argument('--list_start', type=int, default=0, help='which num in the list to start') + parser.add_argument('--list_end', type=int, default=float("inf"), help='how many test images to run') + parser.add_argument('--save_path', type=str, default='./results/', help='where to save data') + parser.add_argument('--multi_gpu', action='store_true', help='whether to use multi gpus') + parser.add_argument('--defined_driven', action='store_true', help='whether to use defined driven') + parser.add_argument('--gen_video', action='store_true', help='whether to generate videos') + parser.add_argument('--onnx', action='store_true', help='for tddfa') + parser.add_argument('--mode', type=str, default='cpu', help='gpu or cpu mode') + + # parser.set_defaults(preprocess_mode='scale_width_and_crop', crop_size=256, load_size=256, display_winsize=256) + # parser.set_defaults(serial_batches=True) + parser.set_defaults(no_flip=True) + parser.set_defaults(phase='test') + self.isTrain = False + return parser diff --git a/Talking-Face_PC-AVS/options/train_options.py b/Talking-Face_PC-AVS/options/train_options.py new file mode 100644 index 00000000..63d3be24 --- /dev/null +++ b/Talking-Face_PC-AVS/options/train_options.py @@ -0,0 +1,56 @@ +from .base_options import BaseOptions + + +class TrainOptions(BaseOptions): + def initialize(self, parser): + BaseOptions.initialize(self, parser) + # for displays + parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') + parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') + parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') + parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') + parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') + parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration') + parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed') + parser.add_argument('--tensorboard', default=True, help='if specified, use tensorboard logging. Requires tensorflow installed') + parser.add_argument('--load_pretrain', type=str, default='', + help='load the pretrained model from the specified location') + + # for training + parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') + parser.add_argument('--recognition', action='store_true', help='train only recognition') + parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + parser.add_argument('--noload_D', action='store_true', help='whether to load D when continue training') + parser.add_argument('--pose_noise', action='store_true', help='whether to use pose noise training') + parser.add_argument('--load_separately', action='store_true', help='whether to continue train by loading separate models') + parser.add_argument('--niter', type=int, default=10, help='# of iter at starting learning rate. This is NOT the total #epochs. Totla #epochs is niter + niter_decay') + parser.add_argument('--niter_decay', type=int, default=1000, help='# of iter to linearly decay learning rate to zero') + parser.add_argument('--D_steps_per_G', type=int, default=1, help='number of discriminator iterations per generator iterations.') + + parser.add_argument('--G_pretrain_path', type=str, default='./checkpoints/100_net_G.pth', help='G pretrain path') + parser.add_argument('--D_pretrain_path', type=str, default='', help='D pretrain path') + parser.add_argument('--E_pretrain_path', type=str, default='', help='E pretrain path') + parser.add_argument('--V_pretrain_path', type=str, default='', help='V pretrain path') + parser.add_argument('--A_pretrain_path', type=str, default='', help='E pretrain path') + parser.add_argument('--A_sync_pretrain_path', type=str, default='', help='E pretrain path') + parser.add_argument('--netE_pretrain_path', type=str, default='', help='E pretrain path') + + parser.add_argument('--fix_netV', action='store_true', help='if specified, fix net V') + parser.add_argument('--fix_netE', action='store_true', help='if specified, fix net E') + parser.add_argument('--fix_netE_mouth', action='store_true', help='if specified, fix net E mapper, fc and mapper') + parser.add_argument('--fix_netE_mouth_embed', action='store_true', help='if specified, fix net E mapper, fc and mapper') + parser.add_argument('--fix_netE_headpose', action='store_true', help='if specified, fix net E headpose') + parser.add_argument('--fix_netA_sync', action='store_true', help='if specified fix net A_sync') + parser.add_argument('--fix_netG', action='store_true', help='if specified, fix net G') + parser.add_argument('--fix_netD', action='store_true', help='if specified, fix net D') + parser.add_argument('--no_cross_modal', action='store_true', help='if specified, do *not* use cls loss') + parser.add_argument('--softmax_contrastive', action='store_true', help='if specified, use contrastive loss for img and audio embed') + # for discriminators + + parser.add_argument('--baseline_sync', action='store_true', help='train baseline sync') + parser.add_argument('--style_feature_loss', action='store_true', help='to use style feature loss') + # parser.add_argument('--vggface_checkpoint', type=str, default='', help='pth to vggface ckpt') + parser.add_argument('--pretrain', action='store_true', help='Use outsider pretrain') + parser.add_argument('--disentangle', action='store_true', help='whether to use disentangle loss') + self.isTrain = True + return parser diff --git a/Talking-Face_PC-AVS/requirements.txt b/Talking-Face_PC-AVS/requirements.txt new file mode 100644 index 00000000..163860bc --- /dev/null +++ b/Talking-Face_PC-AVS/requirements.txt @@ -0,0 +1,23 @@ +brotlipy==0.7.0 +certifi==2021.5.30 +cytoolz==0.11.0 +dill +numpy==1.19.5 +face-alignment==1.2.0 +librosa +llvmlite==0.36.0 +lws +numba +opencv-python==3.4.9.33 +package-name==0.1 +pandas==1.1.5 +Pillow==8.4.0 +pytorch-fid==0.3.0 +PyYAML==5.4.1 +scikit-image==0.17.2 +scipy==1.2.0 +six==1.17.0 +some-package==0.1 +torch==1.4.0 +torchvision==0.5.0 +tqdm diff --git a/Talking-Face_PC-AVS/scripts/align_68.py b/Talking-Face_PC-AVS/scripts/align_68.py new file mode 100644 index 00000000..0ae4ac4a --- /dev/null +++ b/Talking-Face_PC-AVS/scripts/align_68.py @@ -0,0 +1,108 @@ +import face_alignment +import os +import cv2 +import skimage.transform as trans +import argparse +import torch +import numpy as np + +device = 'cuda' if torch.cuda.is_available() else 'cpu' + + +def get_affine(src): + dst = np.array([[87, 59], + [137, 59], + [112, 120]], dtype=np.float32) + tform = trans.SimilarityTransform() + tform.estimate(src, dst) + M = tform.params[0:2, :] + return M + + +def affine_align_img(img, M, crop_size=224): + warped = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0) + return warped + + +def affine_align_3landmarks(landmarks, M): + new_landmarks = np.concatenate([landmarks, np.ones((3, 1))], 1) + affined_landmarks = np.matmul(new_landmarks, M.transpose()) + return affined_landmarks + + +def get_eyes_mouths(landmark): + three_points = np.zeros((3, 2)) + three_points[0] = landmark[36:42].mean(0) + three_points[1] = landmark[42:48].mean(0) + three_points[2] = landmark[60:68].mean(0) + return three_points + + +def get_mouth_bias(three_points): + bias = np.array([112, 120]) - three_points[2] + return bias + + +def align_folder(folder_path, folder_save_path): + + fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device=device) + preds = fa.get_landmarks_from_directory(folder_path) + flag = True + sumpoints = 0 + three_points_list = [] + + for img in preds.keys(): + pred_points = np.array(preds[img]) + + if pred_points is None or len(pred_points.shape) != 3: + print('preprocessing failed111') + return False + else: + num_faces, size, _ = pred_points.shape + if num_faces == 1 and size == 68: + + three_points = get_eyes_mouths(pred_points[0]) + sumpoints += three_points + three_points_list.append(three_points) + else: + print(f"Image: {img}, number of faces detected: {pred_points.shape[0]}") + print('preprocessing failed222') + flag = False + if flag == False: + return False + avg_points = sumpoints / len(preds) + M = get_affine(avg_points) + p_bias = None + for i, img_pth in enumerate(preds.keys()): + three_points = three_points_list[i] + affined_3landmarks = affine_align_3landmarks(three_points, M) + bias = get_mouth_bias(affined_3landmarks) + if p_bias is None: + bias = bias + else: + bias = p_bias * 0.2 + bias * 0.8 + p_bias = bias + M_i = M.copy() + M_i[:, 2] = M[:, 2] + bias + img = cv2.imread(img_pth) + wrapped = affine_align_img(img, M_i) + img_save_path = os.path.join(folder_save_path, img_pth.split('/')[-1]) + cv2.imwrite(img_save_path, wrapped) + print('cropped files saved at {}'.format(folder_save_path)) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--folder_path', help='the folder which needs processing') + args = parser.parse_args() + + if os.path.isdir(args.folder_path): + home_path = '/'.join(args.folder_path.split('/')[:-1]) + save_img_path = os.path.join(home_path, args.folder_path.split('/')[-1] + '_cropped') + os.makedirs(save_img_path, exist_ok=True) + + align_folder(args.folder_path, save_img_path) + + +if __name__ == '__main__': + main() diff --git a/Talking-Face_PC-AVS/scripts/align_68_new.py b/Talking-Face_PC-AVS/scripts/align_68_new.py new file mode 100644 index 00000000..15702c48 --- /dev/null +++ b/Talking-Face_PC-AVS/scripts/align_68_new.py @@ -0,0 +1,110 @@ +import face_alignment +import os +import cv2 +import skimage.transform as trans +import argparse +import torch +import numpy as np + +device = 'cuda' if torch.cuda.is_available() else 'cpu' + + +def get_affine(src): + dst = np.array([[87, 59], + [137, 59], + [112, 120]], dtype=np.float32) + tform = trans.SimilarityTransform() + tform.estimate(src, dst) + M = tform.params[0:2, :] + return M + + +def affine_align_img(img, M, crop_size=224): + warped = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0) + return warped + + +def affine_align_3landmarks(landmarks, M): + new_landmarks = np.concatenate([landmarks, np.ones((3, 1))], 1) + affined_landmarks = np.matmul(new_landmarks, M.transpose()) + return affined_landmarks + + +def get_eyes_mouths(landmark): + three_points = np.zeros((3, 2)) + three_points[0] = landmark[36:42].mean(0) + three_points[1] = landmark[42:48].mean(0) + three_points[2] = landmark[60:68].mean(0) + return three_points + + +def get_mouth_bias(three_points): + bias = np.array([112, 120]) - three_points[2] + return bias + + +def align_folder(folder_path, folder_save_path): + + fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device=device) + preds = fa.get_landmarks_from_directory(folder_path) + flag = True + sumpoints = 0 + three_points_list = [] + + for img in preds.keys(): + pred_points = np.array(preds[img]) + + if pred_points is None or len(pred_points.shape) != 3: + print('preprocessing failed111') + return False + else: + num_faces, size, _ = pred_points.shape + if num_faces == 1 and size == 68: + + three_points = get_eyes_mouths(pred_points[0]) + sumpoints += three_points + three_points_list.append(three_points) + else: + print(f"Image: {img}, number of faces detected: {pred_points.shape[0]}") + print('preprocessing failed222') + three_points = get_eyes_mouths(pred_points[0]) + sumpoints += three_points + three_points_list.append(three_points) + if flag == False: + return False + avg_points = sumpoints / len(preds) + M = get_affine(avg_points) + p_bias = None + for i, img_pth in enumerate(preds.keys()): + three_points = three_points_list[i] + affined_3landmarks = affine_align_3landmarks(three_points, M) + bias = get_mouth_bias(affined_3landmarks) + if p_bias is None: + bias = bias + else: + bias = p_bias * 0.2 + bias * 0.8 + p_bias = bias + M_i = M.copy() + M_i[:, 2] = M[:, 2] + bias + img = cv2.imread(img_pth) + wrapped = affine_align_img(img, M_i) + img_save_path = os.path.join(folder_save_path, img_pth.split('/')[-1]) + cv2.imwrite(img_save_path, wrapped) + print('cropped files saved at {}'.format(folder_save_path)) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--folder_path', help='the folder which needs processing') + args = parser.parse_args() + + if os.path.isdir(args.folder_path): + home_path = '/'.join(args.folder_path.split('/')[:-1]) + save_img_path = os.path.join(home_path, args.folder_path.split('/')[-1] + '_cropped') + os.makedirs(save_img_path, exist_ok=True) + + align_folder(args.folder_path, save_img_path) + + +if __name__ == '__main__': + main() diff --git a/Talking-Face_PC-AVS/scripts/prepare_testing_files.py b/Talking-Face_PC-AVS/scripts/prepare_testing_files.py new file mode 100644 index 00000000..8a4c7418 --- /dev/null +++ b/Talking-Face_PC-AVS/scripts/prepare_testing_files.py @@ -0,0 +1,117 @@ +import sys +import os +sys.path.append(os.path.dirname(os.path.dirname(__file__))) +import argparse +import glob +import csv +import numpy as np +from config.AudioConfig import AudioConfig + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def proc_frames(src_path, dst_path): + cmd = 'ffmpeg -i \"{}\" -start_number 0 -qscale:v 2 \"{}\"/%06d.jpg -loglevel error -y'.format(src_path, dst_path) + os.system(cmd) + frames = glob.glob(os.path.join(dst_path, '*.jpg')) + return len(frames) + + +def proc_audio(src_mouth_path, dst_audio_path): + audio_command = 'ffmpeg -i \"{}\" -loglevel error -y -f wav -acodec pcm_s16le ' \ + '-ar 16000 \"{}\"'.format(src_mouth_path, dst_audio_path) + os.system(audio_command) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # parser.add_argument('--dst_dir_path', default='/mnt/lustre/DATAshare3/VoxCeleb2', + # help="dst file position") + parser.add_argument('--dir_path', default='./misc', + help="dst file position") + parser.add_argument('--src_pose_path', default='./misc/Pose_Source/00473.mp4', + help="pose source file position, this could be an mp4 or a folder") + parser.add_argument('--src_audio_path', default='./misc/Audio_Source/00015.mp4', + help="audio source file position, it could be an mp3 file or an mp4 video with audio") + parser.add_argument('--src_mouth_frame_path', default=None, + help="mouth frame file position, the video frames synced with audios") + parser.add_argument('--src_input_path', default='./misc/Input/00098.mp4', + help="input file position, it could be a folder with frames, a jpg or an mp4") + parser.add_argument('--csv_path', default='./misc/demo2.csv', + help="path to output index files") + parser.add_argument('--convert_spectrogram', action='store_true', help='whether to convert audio to spectrogram') + + args = parser.parse_args() + dir_path = args.dir_path + mkdir(dir_path) + + # ===================== process input ======================================================= + input_save_path = os.path.join(dir_path, 'Input') + mkdir(input_save_path) + input_name = args.src_input_path.split('/')[-1].split('.')[0] + num_inputs = 1 + dst_input_path = os.path.join(input_save_path, input_name) + mkdir(dst_input_path) + if args.src_input_path.split('/')[-1].split('.')[-1] == 'mp4': + num_inputs = proc_frames(args.src_input_path, dst_input_path) + elif os.path.isdir(args.src_input_path): + dst_input_path = args.src_input_path + else: + os.system('cp {} {}'.format(args.src_input_path, os.path.join(dst_input_path, args.src_input_path.split('/')[-1]))) + + + # ===================== process audio ======================================================= + audio_source_save_path = os.path.join(dir_path, 'Audio_Source') + mkdir(audio_source_save_path) + audio_name = args.src_audio_path.split('/')[-1].split('.')[0] + spec_dir = 'None' + dst_audio_path = os.path.join(audio_source_save_path, audio_name + '.mp3') + + if args.src_audio_path.split('/')[-1].split('.')[-1] == 'mp3': + os.system('cp {} {}'.format(args.src_audio_path, dst_audio_path)) + if args.src_mouth_frame_path and os.path.isdir(args.src_mouth_frame_path): + dst_mouth_frame_path = args.src_mouth_frame_path + num_mouth_frames = len(glob.glob(os.path.join(args.src_mouth_frame_path, '*.jpg')) + glob.glob(os.path.join(args.src_mouth_frame_path, '*.png'))) + else: + dst_mouth_frame_path = 'None' + num_mouth_frames = 0 + else: + mouth_source_save_path = os.path.join(dir_path, 'Mouth_Source') + mkdir(mouth_source_save_path) + dst_mouth_frame_path = os.path.join(mouth_source_save_path, audio_name) + mkdir(dst_mouth_frame_path) + proc_audio(args.src_audio_path, dst_audio_path) + num_mouth_frames = proc_frames(args.src_audio_path, dst_mouth_frame_path) + + if args.convert_spectrogram: + audio = AudioConfig(fft_size=1280, hop_size=160) + wav = audio.read_audio(dst_audio_path) + spectrogram = audio.audio_to_spectrogram(wav) + spec_dir = os.path.join(audio_source_save_path, audio_name + '.npy') + np.save(spec_dir, + spectrogram.astype(np.float32), allow_pickle=False) + + # ===================== process pose ======================================================= + if os.path.isdir(args.src_pose_path): + num_pose_frames = len(glob.glob(os.path.join(args.src_pose_path, '*.jpg')) + glob.glob(os.path.join(args.src_pose_path, '*.png'))) + dst_pose_frame_path = args.src_pose_path + else: + pose_source_save_path = os.path.join(dir_path, 'Pose_Source') + mkdir(pose_source_save_path) + pose_name = args.src_pose_path.split('/')[-1].split('.')[0] + dst_pose_frame_path = os.path.join(pose_source_save_path, pose_name) + mkdir(dst_pose_frame_path) + num_pose_frames = proc_frames(args.src_pose_path, dst_pose_frame_path) + + # ===================== form csv ======================================================= + + with open(args.csv_path, 'w', newline='') as csvfile: + writer = csv.writer(csvfile, delimiter=' ', quoting=csv.QUOTE_MINIMAL) + writer.writerows([[dst_input_path, str(num_inputs), dst_pose_frame_path, str(num_pose_frames), + dst_audio_path, dst_mouth_frame_path, str(num_mouth_frames), spec_dir]]) + print('meta-info saved at ' + args.csv_path) + + csvfile.close() \ No newline at end of file diff --git a/Talking-Face_PC-AVS/util/__init__.py b/Talking-Face_PC-AVS/util/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/Talking-Face_PC-AVS/util/__init__.py @@ -0,0 +1 @@ + diff --git a/Talking-Face_PC-AVS/util/html.py b/Talking-Face_PC-AVS/util/html.py new file mode 100644 index 00000000..50f67c74 --- /dev/null +++ b/Talking-Face_PC-AVS/util/html.py @@ -0,0 +1,71 @@ +import datetime +import dominate +from dominate.tags import * +import os + + +class HTML: + def __init__(self, web_dir, title, refresh=0): + if web_dir.endswith('.html'): + web_dir, html_name = os.path.split(web_dir) + else: + web_dir, html_name = web_dir, 'index.html' + self.title = title + self.web_dir = web_dir + self.html_name = html_name + self.img_dir = os.path.join(self.web_dir, 'images') + if len(self.web_dir) > 0 and not os.path.exists(self.web_dir): + os.makedirs(self.web_dir) + if len(self.web_dir) > 0 and not os.path.exists(self.img_dir): + os.makedirs(self.img_dir) + + self.doc = dominate.document(title=title) + with self.doc: + h1(datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")) + if refresh > 0: + with self.doc.head: + meta(http_equiv="refresh", content=str(refresh)) + + def get_image_dir(self): + return self.img_dir + + def add_header(self, str): + with self.doc: + h3(str) + + def add_table(self, border=1): + self.t = table(border=border, style="table-layout: fixed;") + self.doc.add(self.t) + + def add_images(self, ims, txts, links, width=512): + self.add_table() + with self.t: + with tr(): + for im, txt, link in zip(ims, txts, links): + with td(style="word-wrap: break-word;", halign="center", valign="top"): + with p(): + with a(href=os.path.join('images', link)): + img(style="width:%dpx" % (width), src=os.path.join('images', im)) + br() + p(txt.encode('utf-8')) + + def save(self): + html_file = os.path.join(self.web_dir, self.html_name) + f = open(html_file, 'wt') + f.write(self.doc.render()) + f.close() + + +if __name__ == '__main__': + html = HTML('web/', 'test_html') + html.add_header('hello world') + + ims = [] + txts = [] + links = [] + for n in range(4): + ims.append('image_%d.jpg' % n) + txts.append('text_%d' % n) + links.append('image_%d.jpg' % n) + html.add_images(ims, txts, links) + html.save() diff --git a/Talking-Face_PC-AVS/util/iter_counter.py b/Talking-Face_PC-AVS/util/iter_counter.py new file mode 100644 index 00000000..4cdae637 --- /dev/null +++ b/Talking-Face_PC-AVS/util/iter_counter.py @@ -0,0 +1,69 @@ +import os +import time +import numpy as np + + +# Helper class that keeps track of training iterations +class IterationCounter(): + def __init__(self, opt, dataset_size): + self.opt = opt + self.dataset_size = dataset_size + + self.first_epoch = 1 + self.total_epochs = opt.niter + opt.niter_decay if opt.isTrain else 1 + self.epoch_iter = 0 # iter number within each epoch + self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, 'iter.txt') + if opt.isTrain and opt.continue_train: + try: + self.first_epoch, self.epoch_iter = np.loadtxt( + self.iter_record_path, delimiter=',', dtype=int) + print('Resuming from epoch %d at iteration %d' % (self.first_epoch, self.epoch_iter)) + except: + print('Could not load iteration record at %s. Starting from beginning.' % + self.iter_record_path) + + self.total_steps_so_far = (self.first_epoch - 1) * dataset_size + self.epoch_iter + + # return the iterator of epochs for the training + def training_epochs(self): + return range(self.first_epoch, self.total_epochs + 1) + + def record_epoch_start(self, epoch): + self.epoch_start_time = time.time() + self.epoch_iter = 0 + self.last_iter_time = time.time() + self.current_epoch = epoch + + def record_one_iteration(self): + current_time = time.time() + + # the last remaining batch is dropped (see data/__init__.py), + # so we can assume batch size is always opt.batchSize + self.time_per_iter = (current_time - self.last_iter_time) / self.opt.batchSize + self.last_iter_time = current_time + self.total_steps_so_far += self.opt.batchSize + self.epoch_iter += self.opt.batchSize + + def record_epoch_end(self): + current_time = time.time() + self.time_per_epoch = current_time - self.epoch_start_time + print('End of epoch %d / %d \t Time Taken: %d sec' % + (self.current_epoch, self.total_epochs, self.time_per_epoch)) + if self.current_epoch % self.opt.save_epoch_freq == 0: + np.savetxt(self.iter_record_path, (self.current_epoch + 1, 0), + delimiter=',', fmt='%d') + print('Saved current iteration count at %s.' % self.iter_record_path) + + def record_current_iter(self): + np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter), + delimiter=',', fmt='%d') + print('Saved current iteration count at %s.' % self.iter_record_path) + + def needs_saving(self): + return (self.total_steps_so_far % self.opt.save_latest_freq) < self.opt.batchSize + + def needs_printing(self): + return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize + + def needs_displaying(self): + return (self.total_steps_so_far % self.opt.display_freq) < self.opt.batchSize diff --git a/Talking-Face_PC-AVS/util/util.py b/Talking-Face_PC-AVS/util/util.py new file mode 100644 index 00000000..6b54c905 --- /dev/null +++ b/Talking-Face_PC-AVS/util/util.py @@ -0,0 +1,264 @@ +import re +import importlib +import torch +from argparse import Namespace +import numpy as np +from PIL import Image +import os +import argparse +import dill as pickle +import skimage.transform as trans +import cv2 + + +def save_obj(obj, name): + with open(name, 'wb') as f: + pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) + + +def load_obj(name): + with open(name, 'rb') as f: + return pickle.load(f) + +# returns a configuration for creating a generator +# |default_opt| should be the opt of the current experiment +# |**kwargs|: if any configuration should be overriden, it can be specified here + + +def copyconf(default_opt, **kwargs): + conf = argparse.Namespace(**vars(default_opt)) + for key in kwargs: + print(key, kwargs[key]) + setattr(conf, key, kwargs[key]) + return conf + + +def tile_images(imgs, picturesPerRow=4): + """ Code borrowed from + https://stackoverflow.com/questions/26521365/cleanly-tile-numpy-array-of-images-stored-in-a-flattened-1d-format/26521997 + """ + + # Padding + if imgs.shape[0] % picturesPerRow == 0: + rowPadding = 0 + else: + rowPadding = picturesPerRow - imgs.shape[0] % picturesPerRow + if rowPadding > 0: + imgs = np.concatenate([imgs, np.zeros((rowPadding, *imgs.shape[1:]), dtype=imgs.dtype)], axis=0) + + # Tiling Loop (The conditionals are not necessary anymore) + tiled = [] + for i in range(0, imgs.shape[0], picturesPerRow): + tiled.append(np.concatenate([imgs[j] for j in range(i, i + picturesPerRow)], axis=1)) + + tiled = np.concatenate(tiled, axis=0) + return tiled + + +# Converts a Tensor into a Numpy array +# |imtype|: the desired type of the converted numpy array +def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=True): + if isinstance(image_tensor, list): + image_numpy = [] + for i in range(len(image_tensor)): + image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) + return image_numpy + + if image_tensor.dim() == 4: + # transform each image in the batch + images_np = [] + for b in range(image_tensor.size(0)): + one_image = image_tensor[b] + one_image_np = tensor2im(one_image) + images_np.append(one_image_np.reshape(1, *one_image_np.shape)) + images_np = np.concatenate(images_np, axis=0) + if tile: + images_tiled = tile_images(images_np) + return images_tiled + else: + if len(images_np.shape) == 4 and images_np.shape[0] == 1: + images_np = images_np[0] + return images_np + + if image_tensor.dim() == 2: + image_tensor = image_tensor.unsqueeze(0) + image_numpy = image_tensor.detach().cpu().float().numpy() + if normalize: + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + else: + image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 + image_numpy = np.clip(image_numpy, 0, 255) + if image_numpy.shape[2] == 1: + image_numpy = image_numpy[:, :, 0] + return image_numpy.astype(imtype) + + + +def save_image(image_numpy, image_path, create_dir=False): + if create_dir: + os.makedirs(os.path.dirname(image_path), exist_ok=True) + if len(image_numpy.shape) == 4: + image_numpy = image_numpy[0] + if len(image_numpy.shape) == 2: + image_numpy = np.expand_dims(image_numpy, axis=2) + if image_numpy.shape[2] == 1: + image_numpy = np.repeat(image_numpy, 3, 2) + image_pil = Image.fromarray(image_numpy) + + # save to png + image_pil.save(image_path) + # image_pil.save(image_path.replace('.jpg', '.png')) + + +def save_torch_img(img, save_path): + image_numpy = tensor2im(img,tile=False) + save_image(image_numpy, save_path, create_dir=True) + return image_numpy + + + +def mkdirs(paths): + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split('(\d+)', text)] + + +def natural_sort(items): + items.sort(key=natural_keys) + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def find_class_in_module(target_cls_name, module): + target_cls_name = target_cls_name.replace('_', '').lower() + clslib = importlib.import_module(module) + cls = None + for name, clsobj in clslib.__dict__.items(): + if name.lower() == target_cls_name: + cls = clsobj + + if cls is None: + print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)) + exit(0) + + return cls + + +def save_network(net, label, epoch, opt): + save_filename = '%s_net_%s.pth' % (epoch, label) + save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename) + torch.save(net.cpu().state_dict(), save_path) + if len(opt.gpu_ids) and torch.cuda.is_available(): + net.cuda() + + +def load_network(net, label, epoch, opt): + save_filename = '%s_net_%s.pth' % (epoch, label) + save_dir = os.path.join(opt.checkpoints_dir, opt.name) + save_path = os.path.join(save_dir, save_filename) + weights = torch.load(save_path) + net.load_state_dict(weights) + return net + + +def copy_state_dict(state_dict, model, strip=None, replace=None): + tgt_state = model.state_dict() + copied_names = set() + for name, param in state_dict.items(): + if strip is not None and replace is None and name.startswith(strip): + name = name[len(strip):] + if strip is not None and replace is not None: + name = name.replace(strip, replace) + if name not in tgt_state: + continue + if isinstance(param, torch.nn.Parameter): + param = param.data + if param.size() != tgt_state[name].size(): + print('mismatch:', name, param.size(), tgt_state[name].size()) + continue + tgt_state[name].copy_(param) + copied_names.add(name) + + missing = set(tgt_state.keys()) - copied_names + if len(missing) > 0: + print("missing keys in state_dict:", missing) + + + +def freeze_model(net): + for param in net.parameters(): + param.requires_grad = False +############################################################################### +# Code from +# https://github.com/ycszen/pytorch-seg/blob/master/transform.py +# Modified so it complies with the Citscape label map colors +############################################################################### +def uint82bin(n, count=8): + """returns the binary of integer n, count refers to amount of bits""" + return ''.join([str((n >> y) & 1) for y in range(count - 1, -1, -1)]) + +def build_landmark_dict(ldmk_path): + with open(ldmk_path) as f: + lines = f.readlines() + ldmk_dict = {} + paths = [] + for line in lines: + info = line.strip().split() + key = info[-1] + if "/" in key: + key = key.split("/")[-1] + # key = int(key.split(".")[0]) + value = info[:-1] + paths.append(key) + value = [float(it) for it in value] + if len(info) == 106 * 2 + 1: # landmark+name + value = [float(it) for it in info[:106 * 2]] + elif len(info) == 106 * 2 + 1 + 6: # affmat+landmark+name + value = [float(it) for it in info[6:106 * 2 + 6]] + elif len(info) == 20 * 2 + 2: # mouth landmark+name + value = [float(it) for it in info[:-1]] + ldmk_dict[key] = value + return ldmk_dict, paths + + +def get_affine(src, dst): + tform = trans.SimilarityTransform() + tform.estimate(src, dst) + M = tform.params[0:2, :] + return M + +def affine_align_img(img, M, crop_size=224): + warped = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0) + return warped + +def calc_loop_idx(idx, loop_num): + flag = -1 * ((idx // loop_num % 2) * 2 - 1) + new_idx = -flag * (flag - 1) // 2 + flag * (idx % loop_num) + return (new_idx + loop_num) % loop_num diff --git a/Talking-Face_PC-AVS/util/visualizer.py b/Talking-Face_PC-AVS/util/visualizer.py new file mode 100644 index 00000000..a77cc228 --- /dev/null +++ b/Talking-Face_PC-AVS/util/visualizer.py @@ -0,0 +1,187 @@ +import os +import ntpath +import time +from . import util +from . import html +import scipy.misc +import torch +import torchvision.utils as vutils +from torch.utils.tensorboard import SummaryWriter +try: + from StringIO import StringIO # Python 2.7 +except ImportError: + from io import BytesIO # Python 3.x + +class Visualizer(): + def __init__(self, opt): + self.opt = opt + self.tf_log = opt.isTrain and opt.tf_log + self.tensorboard = opt.isTrain and opt.tensorboard + self.use_html = opt.isTrain and not opt.no_html + self.win_size = opt.display_winsize + self.name = opt.name + if self.tf_log: + import tensorflow as tf + self.tf = tf + self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs') + self.writer = tf.summary.FileWriter(self.log_dir) + + if self.tensorboard: + self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs') + self.writer = SummaryWriter(self.log_dir, comment=opt.name) + + if self.use_html: + self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') + self.img_dir = os.path.join(self.web_dir, 'images') + print('create web directory %s...' % self.web_dir) + util.mkdirs([self.web_dir, self.img_dir]) + if opt.isTrain: + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + # |visuals|: dictionary of images to display or save + def display_current_results(self, visuals, epoch, step): + + ## convert tensors to numpy arrays + + + if self.tf_log: # show images in tensorboard output + img_summaries = [] + visuals = self.convert_visuals_to_numpy(visuals) + for label, image_numpy in visuals.items(): + # Write the image to a string + try: + s = StringIO() + except: + s = BytesIO() + if len(image_numpy.shape) >= 4: + image_numpy = image_numpy[0] + scipy.misc.toimage(image_numpy).save(s, format="jpeg") + # Create an Image object + img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1]) + # Create a Summary value + img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum)) + + # Create and write Summary + summary = self.tf.Summary(value=img_summaries) + self.writer.add_summary(summary, step) + + if self.tensorboard: # show images in tensorboard output + img_summaries = [] + for label, image_numpy in visuals.items(): + # Write the image to a string + try: + s = StringIO() + except: + s = BytesIO() + # if len(image_numpy.shape) >= 4: + # image_numpy = image_numpy[0] + # scipy.misc.toimage(image_numpy).save(s, format="jpeg") + # Create an Image object + # self.writer.add_image(tag=label, img_tensor=image_numpy, global_step=step, dataformats='HWC') + # Create a Summary value + batch_size = image_numpy.size(0) + x = vutils.make_grid(image_numpy[:min(batch_size, 16)], normalize=True, scale_each=True) + self.writer.add_image(label, x, step) + + + if self.use_html: # save images to a html file + for label, image_numpy in visuals.items(): + if isinstance(image_numpy, list): + for i in range(len(image_numpy)): + img_path = os.path.join(self.img_dir, 'epoch%.3d_iter%.3d_%s_%d.png' % (epoch, step, label, i)) + util.save_image(image_numpy[i], img_path) + else: + img_path = os.path.join(self.img_dir, 'epoch%.3d_iter%.3d_%s.png' % (epoch, step, label)) + if len(image_numpy.shape) >= 4: + image_numpy = image_numpy[0] + util.save_image(image_numpy, img_path) + + # update website + webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=5) + for n in range(epoch, 0, -1): + webpage.add_header('epoch [%d]' % n) + ims = [] + txts = [] + links = [] + + for label, image_numpy in visuals.items(): + if isinstance(image_numpy, list): + for i in range(len(image_numpy)): + img_path = 'epoch%.3d_iter%.3d_%s_%d.png' % (n, step, label, i) + ims.append(img_path) + txts.append(label+str(i)) + links.append(img_path) + else: + img_path = 'epoch%.3d_iter%.3d_%s.png' % (n, step, label) + ims.append(img_path) + txts.append(label) + links.append(img_path) + if len(ims) < 10: + webpage.add_images(ims, txts, links, width=self.win_size) + else: + num = int(round(len(ims)/2.0)) + webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size) + webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size) + webpage.save() + + # errors: dictionary of error labels and values + def plot_current_errors(self, errors, step): + if self.tf_log: + for tag, value in errors.items(): + value = value.mean().float() + summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) + self.writer.add_summary(summary, step) + + if self.tensorboard: + for tag, value in errors.items(): + value = value.mean().float() + self.writer.add_scalar(tag=tag, scalar_value=value, global_step=step) + + # errors: same format as |errors| of plotCurrentErrors + def print_current_errors(self, opt, epoch, i, errors, t): + message = opt.name + ' (epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) + for k, v in errors.items(): + #print(v) + #if v != 0: + v = v.mean().float() + message += '%s: %.3f ' % (k, v) + + print(message) + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) + + def convert_visuals_to_numpy(self, visuals): + for key, t in visuals.items(): + tile = self.opt.batchSize > 8 + if 'input_label' == key: + t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile) + else: + t = util.tensor2im(t, tile=tile) + visuals[key] = t + return visuals + + # save image to the disk + def save_images(self, webpage, visuals, image_path): + visuals = self.convert_visuals_to_numpy(visuals) + + image_dir = webpage.get_image_dir() + short_path = ntpath.basename(image_path[0]) + name = os.path.splitext(short_path)[0] + + webpage.add_header(name) + ims = [] + txts = [] + links = [] + + for label, image_numpy in visuals.items(): + image_name = os.path.join(label, '%s.png' % (name)) + save_path = os.path.join(image_dir, image_name) + util.save_image(image_numpy, save_path, create_dir=True) + + ims.append(image_name) + txts.append(label) + links.append(image_name) + webpage.add_images(ims, txts, links, width=self.win_size)