From 3ac6c2c5ab904b49f16faedc64e5c97d2c421d45 Mon Sep 17 00:00:00 2001 From: Jiusheng Chen Date: Tue, 17 Nov 2020 21:54:44 -0800 Subject: [PATCH 01/12] Fix prophenet dict loading. (#58) * Fix prophenet dict loading. * Use logger. * Fix import. --- fastseq/models/prophetnet_fs/bert_dictionary.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/fastseq/models/prophetnet_fs/bert_dictionary.py b/fastseq/models/prophetnet_fs/bert_dictionary.py index c63bc1a8..b4d0a372 100644 --- a/fastseq/models/prophetnet_fs/bert_dictionary.py +++ b/fastseq/models/prophetnet_fs/bert_dictionary.py @@ -7,14 +7,16 @@ from collections import Counter from multiprocessing import Pool +import logging import os - import torch from fairseq.tokenizer import tokenize_line from fairseq.binarizer import safe_readline from fairseq.data import data_utils, Dictionary +from fastseq.logging import get_logger +logger = get_logger(__name__, logging.INFO) class BertDictionary(Dictionary): """A mapping from symbols to consecutive integers""" @@ -37,11 +39,17 @@ def load_from_file(cls, filename): d.count = [] d.indices = {} + line_cnt = 0 with open( filename, 'r', encoding='utf-8', errors='ignore') as input_file: for line in input_file: - k, v = line.split() - d.add_symbol(k) + line_cnt += 1 + try: + k, v = line.split(" ") + d.add_symbol(k) + except: + logger.error("Bad line at line: %d (1-based), content: '%s'." % (line_cnt, line)) + raise d.unk_word = '[UNK]' d.pad_word = '[PAD]' From 62b66573911030fd6b4d918510887846b7a5d5bb Mon Sep 17 00:00:00 2001 From: NickNickGo <66033489+NickNickGo@users.noreply.github.com> Date: Fri, 20 Nov 2020 10:10:39 -0800 Subject: [PATCH 02/12] made ngram op device agnostic, unit test cleaned (#61) --- .../fairseq/beam_search_optimizer.py | 51 ++++++++++++++++++- .../transformers/beam_search_optimizer.py | 14 ++++- tests/run_transformers_tests.py | 9 ++-- 3 files changed, 67 insertions(+), 7 deletions(-) diff --git a/fastseq/optimizer/fairseq/beam_search_optimizer.py b/fastseq/optimizer/fairseq/beam_search_optimizer.py index 69fc1788..e4a84c73 100644 --- a/fastseq/optimizer/fairseq/beam_search_optimizer.py +++ b/fastseq/optimizer/fairseq/beam_search_optimizer.py @@ -494,6 +494,49 @@ def is_finished(sent, step, unfin_idx): return True return False + def apply_no_repeat_ngram_cpu(self, tokens,lprobs, bsz,step, + beam_size, no_repeat_ngram_size): + """ Fairseq implementation of blocking + repeated ngrams + """ + banned_list = [[] for bbsz_idx in range(bsz * beam_size)] + cpu_tokens = tokens.cpu()[:, :step + 1].numpy() + check_start_pos = step + 2 - no_repeat_ngram_size + for bbsz_idx in range(bsz * beam_size): + for i in range(check_start_pos): + is_banned = True + for k in range(no_repeat_ngram_size - 1): + if cpu_tokens[bbsz_idx, i + k] != cpu_tokens[ + bbsz_idx, check_start_pos + k]: + is_banned = False + break + if is_banned: + banned_list[bbsz_idx].append( + cpu_tokens[bbsz_idx, + i + no_repeat_ngram_size - 1]) + + def calculate_banned_tokens(bbsz_idx): + """before decoding the next token, prevent decoding + of ngrams that have already appeared + """ + banned_tokens_per_sample = [ + (bbsz_idx, t) for t in banned_list[bbsz_idx] + ] + return banned_tokens_per_sample + + banned_tokens = [] + if step + 2 - no_repeat_ngram_size >= 0: + for bbsz_idx in range(bsz * beam_size): + banned_tokens.extend(calculate_banned_tokens(bbsz_idx)) + + if banned_tokens: + banned_tokens = torch.LongTensor(banned_tokens) + lprobs.index_put_( + tuple(banned_tokens.t()), + lprobs.new_tensor([-math.inf] * len(banned_tokens))) + + return lprobs + def finalize_hypos(step, bbsz_idx, eos_scores): """ Finalize the given hypotheses at this step, while keeping the total @@ -658,8 +701,12 @@ def replicate_first_beam(tensor, mask): if self.no_repeat_ngram_size > 0: #Applying Cuda Op for NGram repeat Blocking - lprobs = self.no_repeat_ngram_op(tokens,lprobs, bsz, step, - beam_size, self.no_repeat_ngram_size) + if (tokens.is_cuda and lprobs.is_cuda): + lprobs = self.no_repeat_ngram_op(tokens,lprobs, bsz, step, + beam_size, self.no_repeat_ngram_size) + else: + lprobs = apply_no_repeat_ngram_cpu(tokens, lprobs, bsz, + step, beam_size, self.ngram_repeat_block_size) cand_scores, cand_indices, cand_beams = self.search.step( step, diff --git a/fastseq/optimizer/transformers/beam_search_optimizer.py b/fastseq/optimizer/transformers/beam_search_optimizer.py index b9102801..df8061bc 100644 --- a/fastseq/optimizer/transformers/beam_search_optimizer.py +++ b/fastseq/optimizer/transformers/beam_search_optimizer.py @@ -650,8 +650,18 @@ def _update_scores(banned_tokens): cpu_input_ids = input_ids.cpu() if no_repeat_ngram_size > 0: #custom op for Ngram repeat blocking - scores = self.no_repeat_ngram_op(input_ids,scores.float(), - batch_size, cur_len-1, num_beams, no_repeat_ngram_size) + if (input_ids.is_cuda and scores.is_cuda): + scores = self.no_repeat_ngram_op(input_ids,scores.float(), + batch_size, cur_len-1, num_beams, no_repeat_ngram_size) + else: + num_batch_hypotheses = batch_size * num_beams + banned_ngram_tokens = calc_banned_ngram_tokens_v2( + cpu_input_ids, + num_batch_hypotheses, + no_repeat_ngram_size, + cur_len, + self.config.pad_token_id) + _update_scores(banned_ngram_tokens) if bad_words_ids is not None: # calculate a list of banned tokens according to bad words diff --git a/tests/run_transformers_tests.py b/tests/run_transformers_tests.py index cc23d8b0..f2ba15ab 100644 --- a/tests/run_transformers_tests.py +++ b/tests/run_transformers_tests.py @@ -44,7 +44,10 @@ def clone_and_build_transformers(self, repo, version): 'testcase_name': 'Normal', 'without_fastseq_opt': False, 'transformers_version': 'v3.0.2', - 'blocked_tests': ['test_modeling_reformer.py'] + 'blocked_tests': ['modeling_reformer', + 'multigpu', + 'HfApiEndpoints' + ] }) def test_suites(self, without_fastseq_opt, transformers_version, blocked_tests): @@ -56,8 +59,8 @@ def test_suites(self, without_fastseq_opt, transformers_version, import pytest #pylint: disable=import-outside-toplevel self.prepare_env() os.chdir(TRANSFORMERS_PATH) - blocked_tests_string = (' not '+ - ' not '.join([test[5:-3] for test in blocked_tests])) + blocked_tests_string = ( + ' and '.join([' not '+ test for test in blocked_tests])) exit_code = pytest.main(['-sv', '-k'+blocked_tests_string, './tests/']) assert str(exit_code).strip() == 'ExitCode.OK' From 4c58e5a0f04fb0d7057a40e77f4072ba884edf77 Mon Sep 17 00:00:00 2001 From: Fei Hu Date: Fri, 20 Nov 2020 16:40:53 -0800 Subject: [PATCH 03/12] Generate the XML log file for each fastseq unit test (#56) * Generate the XML log file for each unit tests * run all fastseq unit tests * Add Nikhil's changes on pipeline to publish XML * Just use a small unit test to test pipeline * Change the xml folder path * Add more tests * Add env var for xml log dir and test the failures * Enable all fastseq unit tests * Enable all tests * Generate xml files for fairseq and transformers unit tests * Fix an issue in pytest command * Trigger the CI pipeline --- azure-pipelines.yml | 11 ++++- fastseq/config.py | 2 + fastseq/utils/test_utils.py | 17 +++++++- tests/models/test_prophetnet_fs.py | 4 +- .../fairseq/benchmark_fairseq_optimizer.py | 4 +- .../fairseq/test_fairseq_optimizer.py | 5 ++- .../transformers/test_bart_optimizer.py | 4 +- .../transformers/test_t5_optimizer.py | 4 +- tests/run_fairseq_tests.py | 43 +++++++++++++------ tests/run_fairseq_tests.sh | 5 ++- tests/run_fastseq_tests.sh | 7 +++ tests/run_transformers_tests.py | 35 +++++++++++---- tests/run_transformers_tests.sh | 5 ++- tests/utils/test_api_decorator.py | 4 +- tests/utils/test_file_utils.py | 4 +- 15 files changed, 111 insertions(+), 43 deletions(-) create mode 100755 tests/run_fastseq_tests.sh diff --git a/azure-pipelines.yml b/azure-pipelines.yml index aff30f15..a34f270d 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -32,11 +32,20 @@ jobs: echo "******* Running fairseq unittests *******" bash tests/run_fairseq_tests.sh + echo "******* Running transformers unittests *******" bash tests/run_transformers_tests.sh + echo "******* Running fastseq unittests *******" pip install pytorch-transformers==1.0.0 - python -m unittest discover -s tests/ -p 'test_*.py' -v + bash tests/run_fastseq_tests.sh + #cd benchmarks/ #bash run_all_benchmarks.sh displayName: 'run fastseq unit tests' + - task: PublishTestResults@2 + condition: succeededOrFailed() + inputs: + testRunTitle: 'Publish test results for Python $(python.version)' + testResultsFiles: 'tests/log_xml/*.xml' + failTaskOnFailedTests: true diff --git a/fastseq/config.py b/fastseq/config.py index 412712cc..7a1dfb86 100644 --- a/fastseq/config.py +++ b/fastseq/config.py @@ -9,6 +9,8 @@ FASTSEQ_DEFAULT_LOG_LEVEL = 'INFO' FASTSEQ_LOG_LEVEL = os.getenv('FASTSEQ_LOG_LEVEL', FASTSEQ_DEFAULT_LOG_LEVEL) FASTSEQ_CACHE_DIR = os.getenv('FASTSEQ_CACHE_DIR', os.path.join(os.sep, 'tmp')) +FASTSEQ_UNITTEST_LOG_XML_DIR = os.getenv( + 'FASTSEQ_UNITTEST_LOG_XML_DIR', os.path.join('tests', 'log_xml')) FASTSEQ_LOG_FORMAT = ( '%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s') diff --git a/fastseq/utils/test_utils.py b/fastseq/utils/test_utils.py index 2ef9d5a6..0a44a3ea 100644 --- a/fastseq/utils/test_utils.py +++ b/fastseq/utils/test_utils.py @@ -3,18 +3,31 @@ """Utilities to make it easy to add unit tests""" +from inspect import getframeinfo, stack import os from statistics import mean, stdev import time -from absl.testing import parameterized +from absl import flags +from absl.testing import absltest, parameterized -from fastseq.config import FASTSEQ_CACHE_DIR +from fastseq.config import FASTSEQ_CACHE_DIR, FASTSEQ_UNITTEST_LOG_XML_DIR from fastseq.logging import get_logger from fastseq.utils.api_decorator import get_class logger = get_logger(__name__) +FLAGS = flags.FLAGS + +def fastseq_test_main(): + caller = getframeinfo(stack()[1][0]) + suffix = '_' + time.strftime("%Y%m%d%H%M%S") + '.xml' + log_xml_file = caller.filename.replace(os.sep, '_').replace('.py', suffix) + log_xml_file = os.path.join(FASTSEQ_UNITTEST_LOG_XML_DIR, log_xml_file) + FLAGS.xml_output_file = log_xml_file + logger.info(f"Fastseq unit test log output filepath: {log_xml_file}") + absltest.main() + class TestCaseBase(parameterized.TestCase): """Base class used for unittest.""" diff --git a/tests/models/test_prophetnet_fs.py b/tests/models/test_prophetnet_fs.py index 7e02f155..5eed8c82 100644 --- a/tests/models/test_prophetnet_fs.py +++ b/tests/models/test_prophetnet_fs.py @@ -19,7 +19,7 @@ from fastseq.utils.file_utils import decompress_file, make_dirs, wget from fastseq.utils.test_utils import (PROPHETNET_MODEL_URLS, CACHED_PROPHETNET_MODEL_PATHS, - TestCaseBase) + fastseq_test_main, TestCaseBase) logger = get_logger(__name__) @@ -136,4 +136,4 @@ def test_beam_search_optimizer(self, beam_size, batch_size, need_attn, self.assertEqual(output, self.expected_outputs[i]) if __name__ == "__main__": - absltest.main() + fastseq_test_main() diff --git a/tests/optimizer/fairseq/benchmark_fairseq_optimizer.py b/tests/optimizer/fairseq/benchmark_fairseq_optimizer.py index 98975574..28f78b00 100644 --- a/tests/optimizer/fairseq/benchmark_fairseq_optimizer.py +++ b/tests/optimizer/fairseq/benchmark_fairseq_optimizer.py @@ -15,7 +15,7 @@ from fastseq.utils.file_utils import decompress_file, make_dirs, wget from fastseq.utils.test_utils import (BART_MODEL_URLS, CACHED_BART_MODEL_DIR, CACHED_BART_MODEL_PATHS, BenchmarkBase, - benchmark) + benchmark, fastseq_test_main) logger = get_logger(__name__) @@ -128,4 +128,4 @@ def test_beam_search_optimizer(self, beam_size, batch_size, need_attn, if __name__ == "__main__": - absltest.main() + fastseq_test_main() diff --git a/tests/optimizer/fairseq/test_fairseq_optimizer.py b/tests/optimizer/fairseq/test_fairseq_optimizer.py index e476a927..0fb73bb3 100644 --- a/tests/optimizer/fairseq/test_fairseq_optimizer.py +++ b/tests/optimizer/fairseq/test_fairseq_optimizer.py @@ -16,7 +16,8 @@ from fastseq.logging import get_logger from fastseq.utils.file_utils import decompress_file, make_dirs, wget from fastseq.utils.test_utils import (BART_MODEL_URLS, CACHED_BART_MODEL_DIR, - CACHED_BART_MODEL_PATHS, TestCaseBase) + CACHED_BART_MODEL_PATHS, + fastseq_test_main, TestCaseBase) logger = get_logger(__name__) @@ -117,4 +118,4 @@ def test_beam_search_optimizer(self, beam_size, batch_size, need_attn, if __name__ == "__main__": - absltest.main() + fastseq_test_main() diff --git a/tests/optimizer/transformers/test_bart_optimizer.py b/tests/optimizer/transformers/test_bart_optimizer.py index bb4b809a..e5d4ed0c 100644 --- a/tests/optimizer/transformers/test_bart_optimizer.py +++ b/tests/optimizer/transformers/test_bart_optimizer.py @@ -14,7 +14,7 @@ from absl.testing import absltest, parameterized from transformers import BartForConditionalGeneration, BartTokenizer -from fastseq.utils.test_utils import TestCaseBase +from fastseq.utils.test_utils import fastseq_test_main, TestCaseBase class BARTOptimizerTest(TestCaseBase): @@ -183,4 +183,4 @@ def test_beam_search_optimizer(self, if __name__ == "__main__": - absltest.main() + fastseq_test_main() diff --git a/tests/optimizer/transformers/test_t5_optimizer.py b/tests/optimizer/transformers/test_t5_optimizer.py index 9d38a913..43b155ab 100644 --- a/tests/optimizer/transformers/test_t5_optimizer.py +++ b/tests/optimizer/transformers/test_t5_optimizer.py @@ -12,7 +12,7 @@ import fastseq from fastseq.logging import get_logger -from fastseq.utils.test_utils import TestCaseBase +from fastseq.utils.test_utils import fastseq_test_main, TestCaseBase from transformers import (T5ForConditionalGeneration, T5Tokenizer) @@ -184,4 +184,4 @@ def test_beam_search_optimizer(self, if __name__ == "__main__": - absltest.main() + fastseq_test_main() diff --git a/tests/run_fairseq_tests.py b/tests/run_fairseq_tests.py index 980d87b6..2747c9f0 100644 --- a/tests/run_fairseq_tests.py +++ b/tests/run_fairseq_tests.py @@ -3,15 +3,19 @@ """ script for importing fairseq tests """ import glob -import sys -import os -import argparse +import io import logging +import os import shutil +import sys +import time import unittest + +import xmlrunner +from absl.testing import parameterized from git import Repo -from absl.testing import absltest, parameterized from pip._internal import main as pipmain +from xmlrunner.extra.xunit_plugin import transform FASTSEQ_PATH = os.sep.join(os.path.realpath(__file__).split('/')[0:-2]) FAIRSEQ_PATH = '/tmp/fairseq/' @@ -32,7 +36,7 @@ def clone_and_build_fairseq(self, repo, version): if os.path.isdir(FAIRSEQ_PATH): shutil.rmtree(FAIRSEQ_PATH) Repo.clone_from(FAIRSEQ_GIT_URL, FAIRSEQ_PATH, branch=version) - pipmain(['install', 'git+https://github.com/pytorch/fairseq.git@' + + pipmain(['install', 'git+https://github.com/pytorch/fairseq.git@' + version]) original_pythonpath = os.environ[ 'PYTHONPATH'] if 'PYTHONPATH' in os.environ else '' @@ -54,12 +58,9 @@ def get_test_suites(self, test_files_path, blocked_tests): return suites @parameterized.named_parameters({ - 'testcase_name': - 'Normal', - 'without_fastseq_opt': - False, - 'fairseq_version': - 'v0.9.0', + 'testcase_name': 'Normal', + 'without_fastseq_opt': False, + 'fairseq_version': 'v0.9.0', 'blocked_tests': [ 'test_binaries.py', 'test_bmuf.py', 'test_reproducibility.py'] }) @@ -67,14 +68,28 @@ def test_suites(self, without_fastseq_opt, fairseq_version, blocked_tests): """"run test suites""" self.clone_and_build_fairseq(FAIRSEQ_GIT_URL, fairseq_version) if not without_fastseq_opt: - import fastseq #pylint: disable=import-outside-toplevel + import fastseq # pylint: disable=import-outside-toplevel self.prepare_env() test_files_path = FAIRSEQ_PATH + '/tests/test_*.py' suites = self.get_test_suites(test_files_path, blocked_tests) test_suite = unittest.TestSuite(suites) test_runner = unittest.TextTestRunner() test_result = test_runner.run(test_suite) - assert len(test_result.errors) == 0 + assert len(test_result.errors) == 0 if __name__ == "__main__": - absltest.main() + log_xml_dir = os.getenv( + 'FASTSEQ_UNITTEST_LOG_XML_DIR', + os.path.join(os.getcwd(), 'tests', 'log_xml')) + os.makedirs(log_xml_dir, exist_ok=True) + suffix = '_' + time.strftime("%Y%m%d%H%M%S") + '.xml' + log_xml_file = __file__.replace(os.sep, '_').replace('.py', suffix) + log_xml_file = os.path.join(log_xml_dir, log_xml_file) + + out = io.BytesIO() + unittest.main( + testRunner=xmlrunner.XMLTestRunner(output=out), + failfast=False, buffer=False, catchbreak=False, exit=False) + with open(log_xml_file, 'wb') as report: + report.write(transform(out.getvalue())) + print("Save the log of fairseq unit tests into %s" % (log_xml_file)) diff --git a/tests/run_fairseq_tests.sh b/tests/run_fairseq_tests.sh index 23b12d88..697c6ae2 100755 --- a/tests/run_fairseq_tests.sh +++ b/tests/run_fairseq_tests.sh @@ -6,9 +6,10 @@ source ${ENV_PATH}/testing_env/bin/activate pip install gitpython pip install absl-py pip install packaging +pip install unittest-xml-reporting +pip install lxml cd ${FASTSEQ_TEST_PATH}/../ pip install torch==1.5.0+cu101 torchvision==0.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html pip install --editable . -cd tests -python run_fairseq_tests.py +python tests/run_fairseq_tests.py deactivate diff --git a/tests/run_fastseq_tests.sh b/tests/run_fastseq_tests.sh new file mode 100755 index 00000000..aff0ddc3 --- /dev/null +++ b/tests/run_fastseq_tests.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +for fastseq_py_test_file in $(find tests/ -name "test_*.py") +do + echo "Running $fastseq_py_test_file" + python $fastseq_py_test_file +done diff --git a/tests/run_transformers_tests.py b/tests/run_transformers_tests.py index f2ba15ab..ffa7e679 100644 --- a/tests/run_transformers_tests.py +++ b/tests/run_transformers_tests.py @@ -2,15 +2,18 @@ # Licensed under the MIT License. """ script for importing transformers tests """ -import glob -import sys +import io import os -import argparse -import logging import shutil +import sys +import time +import unittest + +import xmlrunner +from absl.testing import parameterized from git import Repo -from absl.testing import absltest, parameterized from pip._internal import main as pipmain +from xmlrunner.extra.xunit_plugin import transform FASTSEQ_PATH = os.sep.join(os.path.realpath(__file__).split('/')[0:-2]) TRANSFORMERS_PATH = '/tmp/transformers/' @@ -60,9 +63,25 @@ def test_suites(self, without_fastseq_opt, transformers_version, self.prepare_env() os.chdir(TRANSFORMERS_PATH) blocked_tests_string = ( - ' and '.join([' not '+ test for test in blocked_tests])) - exit_code = pytest.main(['-sv', '-k'+blocked_tests_string, './tests/']) + ' and '.join([' not '+ test for test in blocked_tests])) + exit_code = pytest.main( + ['-sv', '-k' + blocked_tests_string, './tests/']) assert str(exit_code).strip() == 'ExitCode.OK' if __name__ == "__main__": - absltest.main() + log_xml_dir = os.getenv( + 'FASTSEQ_UNITTEST_LOG_XML_DIR', + os.path.join(os.getcwd(), 'tests', 'log_xml')) + os.makedirs(log_xml_dir, exist_ok=True) + suffix = '_' + time.strftime("%Y%m%d%H%M%S") + '.xml' + log_xml_file = __file__.replace(os.sep, '_').replace('.py', suffix) + log_xml_file = os.path.join(log_xml_dir, log_xml_file) + + out = io.BytesIO() + unittest.main( + testRunner=xmlrunner.XMLTestRunner(output=out), + failfast=False, buffer=False, catchbreak=False, exit=False) + with open(log_xml_file, 'wb') as report: + report.write(transform(out.getvalue())) + print( + "Save the log of transformers unit tests into %s" % (log_xml_file)) diff --git a/tests/run_transformers_tests.sh b/tests/run_transformers_tests.sh index be5c9db9..6667081c 100644 --- a/tests/run_transformers_tests.sh +++ b/tests/run_transformers_tests.sh @@ -9,8 +9,9 @@ pip install packaging pip install pytest pip install timeout-decorator pip install torch torchvision +pip install unittest-xml-reporting +pip install lxml cd ${FASTSEQ_TEST_PATH}/../ pip install --editable . -cd tests -python run_transformers_tests.py +python tests/run_transformers_tests.py deactivate diff --git a/tests/utils/test_api_decorator.py b/tests/utils/test_api_decorator.py index a4d2a238..f07a250f 100644 --- a/tests/utils/test_api_decorator.py +++ b/tests/utils/test_api_decorator.py @@ -5,7 +5,7 @@ from absl.testing import absltest, parameterized from fastseq.utils.api_decorator import get_class, override_method, add_method, export_api, replace -from fastseq.utils.test_utils import TestCaseBase +from fastseq.utils.test_utils import fastseq_test_main, TestCaseBase class A: @@ -152,4 +152,4 @@ def name(self): if __name__ == "__main__": - absltest.main() + fastseq_test_main() diff --git a/tests/utils/test_file_utils.py b/tests/utils/test_file_utils.py index 2b48a305..bc94c1fd 100644 --- a/tests/utils/test_file_utils.py +++ b/tests/utils/test_file_utils.py @@ -8,7 +8,7 @@ from absl.testing import absltest, parameterized from fastseq.utils.file_utils import decompress_file, get_temp_dir, make_dirs, wget -from fastseq.utils.test_utils import TestCaseBase +from fastseq.utils.test_utils import fastseq_test_main, TestCaseBase class FileUtilsTest(TestCaseBase): @@ -90,4 +90,4 @@ def test_wget_and_decompress_file(self, tar_file_url, tar_file_name, if __name__ == "__main__": - absltest.main() + fastseq_test_main() From 01e7d492be71ef41e246802774da064dbead49f5 Mon Sep 17 00:00:00 2001 From: Fei Hu Date: Fri, 20 Nov 2020 22:14:25 -0800 Subject: [PATCH 04/12] Update install_requires and enable fairseq to work with torch 1.6&1.7 (#59) * Update install_requires and enable fairseq to work with torch 1.6&1.7 * Better error message and address some warnings in torch1.7 * Raise the error if fairseq/transformers are installed but the optmizations can not be applied * Move transformers/fairseq to extra_require * Remove the out-of-dated build files for ngram cuda op * Run fastseq units before transformers and fairseq --- README.md | 10 +++++ azure-pipelines.yml | 15 +++---- fastseq/config.py | 2 +- fastseq/optimizer/fairseq/__init__.py | 19 ++++++--- .../fairseq/beam_search_optimizer.py | 41 +++++++++++++++---- fastseq/optimizer/transformers/__init__.py | 19 ++++++--- setup.py | 28 +++++++++---- tests/run_fairseq_tests.sh | 2 + tests/run_transformers_tests.sh | 2 + 9 files changed, 102 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 10b91c2a..7f5f9127 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,17 @@ If you use fairseq or transformers, you only need to install one of them. If you `fastseq` Python package can be directly installed with pip using ```bash +# when fairseq and/or transformers has been installed $ pip install fastseq + +# install fastseq + transformers +$ pip install fastseq[transformers] + +# install fastseq + fairseq +$ pip install fastseq[fairseq] + +# install fastseq + transformers + fairseq +$ pip install fastseq[transformers,fairseq] ``` ### Install from the source diff --git a/azure-pipelines.yml b/azure-pipelines.yml index a34f270d..328e01c8 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -21,7 +21,14 @@ jobs: steps: - script: | #install fastseq - pip install --editable . + pip install --editable .[transformers,fairseq] + + echo "******* Running fastseq unittests *******" + pip install pytorch-transformers==1.0.0 + bash tests/run_fastseq_tests.sh + + #cd benchmarks/ + #bash run_all_benchmarks.sh #show environment which python @@ -36,12 +43,6 @@ jobs: echo "******* Running transformers unittests *******" bash tests/run_transformers_tests.sh - echo "******* Running fastseq unittests *******" - pip install pytorch-transformers==1.0.0 - bash tests/run_fastseq_tests.sh - - #cd benchmarks/ - #bash run_all_benchmarks.sh displayName: 'run fastseq unit tests' - task: PublishTestResults@2 condition: succeededOrFailed() diff --git a/fastseq/config.py b/fastseq/config.py index 7a1dfb86..7d5f269d 100644 --- a/fastseq/config.py +++ b/fastseq/config.py @@ -13,7 +13,7 @@ 'FASTSEQ_UNITTEST_LOG_XML_DIR', os.path.join('tests', 'log_xml')) FASTSEQ_LOG_FORMAT = ( - '%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s') + '%(levelname)s %(asctime)s %(pathname)s:%(lineno)d] %(message)s') FASTSEQ_VERSION = '0.0.4' diff --git a/fastseq/optimizer/fairseq/__init__.py b/fastseq/optimizer/fairseq/__init__.py index 96f56b99..05f635d1 100644 --- a/fastseq/optimizer/fairseq/__init__.py +++ b/fastseq/optimizer/fairseq/__init__.py @@ -40,7 +40,9 @@ def apply_fairseq_optimization(): f"fairseq(v{fairseq.__version__}) is not supported by fastseq(v" f"{FASTSEQ_VERSION}) yet, please change fairseq to " f"v{MIN_FAIRSEQ_VERSION} ~ v{MAX_FAIRSEQ_VERSION}, or check other " - "versions of fastseq.") + "versions of fastseq. Currently, no optimization in fastseq has " + "been applied. Please ignore this warning if you are not using " + "fairseq") return import fastseq.optimizer.fairseq.beam_search_optimizer # pylint: disable=import-outside-toplevel @@ -68,15 +70,20 @@ def _update_fairseq_model_registration(): "Update the register model arch {} from {} to {}".format( arch_name, model_class, OPTIMIZED_CLASSES[model_class])) +is_fairseq_installed = True try: import fairseq # pylint: disable=ungrouped-imports from fairseq.models import ARCH_MODEL_REGISTRY, MODEL_REGISTRY # pylint: disable=ungrouped-imports from fairseq.sequence_generator import SequenceGenerator # pylint: disable=ungrouped-imports - apply_fairseq_optimization() except ImportError as error: + is_fairseq_installed = False logger.warning('fairseq can not be imported. Please ignore this warning if ' - 'you are not using fairseq') -except: - logger.error("Unexpected error: {}".format(sys.exc_info()[0])) - raise + 'you are not using fairseq: {}'.format(error)) + +if is_fairseq_installed: + try: + apply_fairseq_optimization() + except: + logger.error("Unexpected error: {}".format(sys.exc_info()[0])) + raise diff --git a/fastseq/optimizer/fairseq/beam_search_optimizer.py b/fastseq/optimizer/fairseq/beam_search_optimizer.py index e4a84c73..f984f175 100644 --- a/fastseq/optimizer/fairseq/beam_search_optimizer.py +++ b/fastseq/optimizer/fairseq/beam_search_optimizer.py @@ -13,10 +13,40 @@ from fairseq import utils from fairseq.models.transformer import TransformerEncoder, TransformerModel from fairseq.modules.multihead_attention import MultiheadAttention +from fairseq.search import BeamSearch from fairseq.sequence_generator import SequenceGenerator from fastseq.ops.ngram_repeat_block import NGramRepeatBlock from fastseq.utils.api_decorator import replace +@replace(BeamSearch) +class BeamSearchV2(BeamSearch): + + def step(self, step, lprobs, scores): + super()._init_buffers(lprobs) + bsz, beam_size, vocab_size = lprobs.size() + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + lprobs = lprobs[:, ::beam_size, :].contiguous() + else: + # make probs contain cumulative scores for each hypothesis + lprobs.add_(scores[:, :, step - 1].unsqueeze(-1)) + + torch.topk( + lprobs.view(bsz, -1), + k=min( + # Take the best 2 x beam_size predictions. We'll choose the first + # beam_size of these which don't predict eos to continue with. + beam_size * 2, + lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad + ), + out=(self.scores_buf, self.indices_buf), + ) + self.beams_buf = torch.floor_divide(self.indices_buf, vocab_size) + self.indices_buf.fmod_(vocab_size) + return self.scores_buf, self.indices_buf, self.beams_buf + @replace(TransformerEncoder) class TransformerEncoderV2(TransformerEncoder): """ @@ -725,18 +755,16 @@ def replicate_first_beam(tensor, mask): eos_mask[:, :beam_size][blacklist] = 0 # only consider eos when it's among the top beam_size indices - torch.masked_select( + eos_bbsz_idx = torch.masked_select( cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size], - out=eos_bbsz_idx, ) finalized_sents = set() if eos_bbsz_idx.numel() > 0: - torch.masked_select( + eos_scores = torch.masked_select( cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size], - out=eos_scores, ) finalized_sents = finalize_hypos(step, eos_bbsz_idx, eos_scores) @@ -753,7 +781,7 @@ def replicate_first_beam(tensor, mask): # construct batch_idxs which holds indices of batches to keep for the next pass batch_mask = cand_indices.new_ones(bsz) batch_mask[cand_indices.new(finalized_sents)] = 0 - batch_idxs = batch_mask.nonzero().squeeze(-1) + batch_idxs = torch.nonzero(batch_mask).squeeze(-1) eos_mask = eos_mask[batch_idxs] cand_beams = cand_beams[batch_idxs] @@ -786,10 +814,9 @@ def replicate_first_beam(tensor, mask): # candidate active hypos. active_mask = buffer('active_mask') eos_mask[:, :beam_size] |= blacklist - torch.add( + active_mask = torch.add( eos_mask.type_as(cand_offsets) * cand_size, cand_offsets[:eos_mask.size(1)], - out=active_mask, ) # get the top beam_size active hypotheses, which are just the hypos diff --git a/fastseq/optimizer/transformers/__init__.py b/fastseq/optimizer/transformers/__init__.py index 38dc845e..e34c19ab 100644 --- a/fastseq/optimizer/transformers/__init__.py +++ b/fastseq/optimizer/transformers/__init__.py @@ -7,6 +7,7 @@ """ from packaging import version +import sys from fastseq.config import MIN_TRANSFORMERS_VERSION, MAX_TRANSFORMER_VERSION from fastseq.logging import get_logger @@ -36,7 +37,9 @@ def apply_transformers_optimization(): logger.warning( f"transformers == {v} is not supported yet, please change it to " f"v{MIN_TRANSFORMERS_VERSION} to v{MAX_TRANSFORMER_VERSION}, or try" - f" other versions of fastseq.") + f" other versions of fastseq. Currently, no optimization provided " + "by fastseq has been applied. Please ignore this warning if you are" + " not using transformers") return import fastseq.optimizer.transformers.modeling_bart_optimizer # pylint: disable=import-outside-toplevel @@ -45,13 +48,17 @@ def apply_transformers_optimization(): logger.debug(f"transformers == {v} has been optimized.") - +is_transformers_installed = True try: import transformers - apply_transformers_optimization() except ImportError as error: + is_transformers_installed = False logger.warning('transformers can not be imported. Please ignore this ' 'warning if you are not using transformers') -except: - logger.error("Unexpected error: {}".format(sys.exc_info()[0])) - raise + +if is_transformers_installed: + try: + apply_transformers_optimization() + except: + logger.error("Unexpected error: {}".format(sys.exc_info()[0])) + raise diff --git a/setup.py b/setup.py index 6f80569b..45ebef29 100644 --- a/setup.py +++ b/setup.py @@ -4,23 +4,29 @@ from setuptools import find_packages, setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension -from fastseq.config import FASTSEQ_VERSION +FASTSEQ_VERSION = '0.0.4' +MIN_FAIRSEQ_VERSION = '0.9.0' +MAX_FAIRSEQ_VERSION = '0.9.0' +MIN_TRANSFORMERS_VERSION = '3.0.2' +MAX_TRANSFORMER_VERSION = '3.0.2' def get_fastseq_version(): return FASTSEQ_VERSION extras = {} -extras["torch"] = ["torch>=1.4.0"] -extras["fairseq"] = ["fairseq>=0.9.0"] -extras["transformers"] = ["transformers>=3.0.2"] +extras["transformers"] = ["transformers >= {}, <= {}".format( + MIN_TRANSFORMERS_VERSION, MAX_TRANSFORMER_VERSION)] +extras["fairseq"] = ["fairseq >= {}, <= {}".format( + MIN_FAIRSEQ_VERSION, MAX_FAIRSEQ_VERSION)] +extras["gitpython"] = ["gitpython>=3.1.7"] +extras["editdistance"] = ["editdistance>=0.5.3"] extensions = [ - CUDAExtension('ngram_repeat_block_cuda', [ - 'fastseq/clib/cuda/ngram_repeat_block_cuda.cpp', - 'fastseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu', - ]), - ] + CUDAExtension('ngram_repeat_block_cuda', [ + 'fastseq/clib/cuda/ngram_repeat_block_cuda.cpp', + 'fastseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu',]), +] setup( name="fastseq", @@ -41,9 +47,13 @@ def get_fastseq_version(): ], install_requires=[ "absl-py", + "filelock", "numpy", "requests", + "rouge-score>=0.0.4", "packaging", + "torch>=1.4.0", + "pytorch-transformers==1.0.0", ], extras_require=extras, python_requires=">=3.6.0", diff --git a/tests/run_fairseq_tests.sh b/tests/run_fairseq_tests.sh index 697c6ae2..bf4b4369 100755 --- a/tests/run_fairseq_tests.sh +++ b/tests/run_fairseq_tests.sh @@ -10,6 +10,8 @@ pip install unittest-xml-reporting pip install lxml cd ${FASTSEQ_TEST_PATH}/../ pip install torch==1.5.0+cu101 torchvision==0.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html +rm -rf build/ +rm ngram_repeat_block_cuda*.so pip install --editable . python tests/run_fairseq_tests.py deactivate diff --git a/tests/run_transformers_tests.sh b/tests/run_transformers_tests.sh index 6667081c..200635df 100644 --- a/tests/run_transformers_tests.sh +++ b/tests/run_transformers_tests.sh @@ -12,6 +12,8 @@ pip install torch torchvision pip install unittest-xml-reporting pip install lxml cd ${FASTSEQ_TEST_PATH}/../ +rm -rf build/ +rm ngram_repeat_block_cuda*.so pip install --editable . python tests/run_transformers_tests.py deactivate From 7558a5c59c93596a5d0113c93819db3bd8b48f47 Mon Sep 17 00:00:00 2001 From: Fei Hu Date: Sat, 21 Nov 2020 10:35:00 -0800 Subject: [PATCH 05/12] Add missing init files (#62) --- fastseq/ops/__init__.py | 2 ++ tests/optimizer/transformers/__init__.py | 2 ++ 2 files changed, 4 insertions(+) create mode 100644 fastseq/ops/__init__.py create mode 100644 tests/optimizer/transformers/__init__.py diff --git a/fastseq/ops/__init__.py b/fastseq/ops/__init__.py new file mode 100644 index 00000000..59e481eb --- /dev/null +++ b/fastseq/ops/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/tests/optimizer/transformers/__init__.py b/tests/optimizer/transformers/__init__.py new file mode 100644 index 00000000..59e481eb --- /dev/null +++ b/tests/optimizer/transformers/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. From a55e9ded447470389c976400cdf318cdf7fa8270 Mon Sep 17 00:00:00 2001 From: Fei Hu Date: Wed, 25 Nov 2020 13:19:50 -0800 Subject: [PATCH 06/12] Update the instructions for installation (#64) --- README.md | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 7f5f9127..42569b9e 100644 --- a/README.md +++ b/README.md @@ -37,30 +37,20 @@ We developped a wide range of speedup techniques, including improving beam searc If you use fairseq or transformers, you only need to install one of them. If you use both, you need to install both. -### Install from PIP package - -`fastseq` Python package can be directly installed with pip using +### Install from the source ```bash # when fairseq and/or transformers has been installed -$ pip install fastseq +$ pip install git+https://github.com/microsoft/fastseq.git # install fastseq + transformers -$ pip install fastseq[transformers] +$ pip install git+https://github.com/microsoft/fastseq.git#egg=project[transformers] # install fastseq + fairseq -$ pip install fastseq[fairseq] +$ pip install git+https://github.com/microsoft/fastseq.git#egg=project[fairseq] # install fastseq + transformers + fairseq -$ pip install fastseq[transformers,fairseq] -``` - -### Install from the source - -```bash -$ git clone https://github.com/microsoft/fastseq -$ cd fastseq -$ pip install --editable ./ +$ pip install git+https://github.com/microsoft/fastseq.git#egg=project[transformers,fairseq] ``` ## Usage From f9d5a82bfa1e358a0ddd3a8e12d4dc55ada032ad Mon Sep 17 00:00:00 2001 From: Fei Hu Date: Wed, 11 Nov 2020 21:32:49 +0000 Subject: [PATCH 07/12] Init code to support TorchScript and graph rewrite --- fastseq/optimizer/jit/__init__.py | 2 + fastseq/optimizer/jit/einsum_rewriter.py | 52 ++++ fastseq/optimizer/jit/graph_rewriter.py | 11 + fastseq/optimizer/jit/utils.py | 20 ++ fastseq/optimizer/transformers/__init__.py | 1 + .../transformers/modeling_bart_optimizer.py | 283 +++++++++++++++--- tests/optimizer/jit/__init__.py | 2 + tests/optimizer/jit/test_einsum_rewriter.py | 37 +++ 8 files changed, 363 insertions(+), 45 deletions(-) create mode 100644 fastseq/optimizer/jit/__init__.py create mode 100644 fastseq/optimizer/jit/einsum_rewriter.py create mode 100644 fastseq/optimizer/jit/graph_rewriter.py create mode 100644 fastseq/optimizer/jit/utils.py create mode 100644 tests/optimizer/jit/__init__.py create mode 100644 tests/optimizer/jit/test_einsum_rewriter.py diff --git a/fastseq/optimizer/jit/__init__.py b/fastseq/optimizer/jit/__init__.py new file mode 100644 index 00000000..59e481eb --- /dev/null +++ b/fastseq/optimizer/jit/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/fastseq/optimizer/jit/einsum_rewriter.py b/fastseq/optimizer/jit/einsum_rewriter.py new file mode 100644 index 00000000..5a245253 --- /dev/null +++ b/fastseq/optimizer/jit/einsum_rewriter.py @@ -0,0 +1,52 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Optmize einsum operation in the graph""" + +from typing import List + +import torch +from torch import Tensor + +from fastseq.optimizer.jit.utils import graph_pattern, rewrite_graph + +@graph_pattern +def einsum_pattern_0(t0: str, t1: List[Tensor]): + r = torch.einsum(t0, t1) + return r + +@graph_pattern +def einsum_rewrite_pattern_0(equation: str, operands: List[Tensor]): + if equation == "bmhtd,bnhsd->bmhts": + t0 = operands[0] + t1 = operands[1] + expand_shape = list(t1.shape) + expand_shape[1] = t0.size(1) + result_shape = list(t0.shape) + result_shape[4] = expand_shape[3] + t1 = t1.expand(expand_shape).transpose(3, 4).contiguous() + t1 = t1.view(-1, t1.size(3), t1.size(4)) + t0 = t0.view(-1, t0.size(3), t0.size(4)) + r = torch.bmm(t0, t1).view(result_shape) + return r + + if equation == "bmhts,bnhsd->bmhtd": + t0 = operands[0] + t1 = operands[1] + expand_shape = list(t1.shape) + expand_shape[1] = t0.size(1) + result_shape = list(t0.shape) + result_shape[4] = expand_shape[4] + t0 = t0.view(-1, t0.size(3), t0.size(4)) + t1 = t1.expand(expand_shape).contiguous() + t1 = t1.view(-1, t1.size(3), t1.size(4)) + r = torch.bmm(t0, t1).view(result_shape) + return r + + return torch.einsum(equation, operands) + +EINSUM_PATTERN_STR = einsum_pattern_0() +EINSUM_REWRITE_PATTERN_STR = einsum_rewrite_pattern_0() + +def rewrite_einsum(input_graph: torch._C.Graph): + rewrite_graph(EINSUM_PATTERN_STR, EINSUM_REWRITE_PATTERN_STR, input_graph) diff --git a/fastseq/optimizer/jit/graph_rewriter.py b/fastseq/optimizer/jit/graph_rewriter.py new file mode 100644 index 00000000..c77c6ad1 --- /dev/null +++ b/fastseq/optimizer/jit/graph_rewriter.py @@ -0,0 +1,11 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Load and apply the registered graph rewrite patterns""" + +import torch + +from fastseq.optimizer.jit.einsum_rewriter import rewrite_einsum + +def optimize_graph(input_graph: torch._C.Graph): + rewrite_einsum(input_graph) diff --git a/fastseq/optimizer/jit/utils.py b/fastseq/optimizer/jit/utils.py new file mode 100644 index 00000000..958113d7 --- /dev/null +++ b/fastseq/optimizer/jit/utils.py @@ -0,0 +1,20 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Utilities for graph rewriting""" + +import torch + +def rewrite_graph(pattern: str, + rewrite_pattern: str, + input_graph: torch._C.Graph): + torch._C._jit_pass_custom_pattern_based_rewrite_graph( + pattern, rewrite_pattern, input_graph) + + +def graph_pattern(obj): + def convert_to_graph_pattern(): + script = torch.jit.script(obj) + return script.graph.str() + + return convert_to_graph_pattern diff --git a/fastseq/optimizer/transformers/__init__.py b/fastseq/optimizer/transformers/__init__.py index e34c19ab..ed4d131f 100644 --- a/fastseq/optimizer/transformers/__init__.py +++ b/fastseq/optimizer/transformers/__init__.py @@ -5,6 +5,7 @@ Automatically apply the optimizations if the supported versions of transformers are detected. """ +import sys from packaging import version import sys diff --git a/fastseq/optimizer/transformers/modeling_bart_optimizer.py b/fastseq/optimizer/transformers/modeling_bart_optimizer.py index 851bf46b..4f3b11ec 100755 --- a/fastseq/optimizer/transformers/modeling_bart_optimizer.py +++ b/fastseq/optimizer/transformers/modeling_bart_optimizer.py @@ -6,21 +6,24 @@ from typing import Dict, Optional, Tuple import torch -from torch import Tensor +from torch import Tensor, nn from torch.nn import functional as F +from transformers.activations import ACT2FN from transformers.configuration_auto import BartConfig from transformers.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING from transformers.modeling_bart import (BartForConditionalGeneration, + DecoderLayer, EncoderLayer, LayerNorm, SelfAttention, _reorder_buffer) from fastseq.logging import get_logger from fastseq.utils.api_decorator import replace +from fastseq.optimizer.jit.graph_rewriter import optimize_graph logger = get_logger(__name__) @replace(SelfAttention) -class SelfAttentionV2(SelfAttention): +class SelfAttentionV2(nn.Module): """" The BART Model with a language modeling head. Can be used for summarization. """ @@ -34,52 +37,68 @@ def __init__( encoder_decoder_attention=False, # otherwise self_attention num_beams=1, ): - super().__init__( - embed_dim, num_heads, dropout, bias, encoder_decoder_attention) + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, ( + "embed_dim must be divisible by num_heads") + self.scaling = self.head_dim ** -0.5 + + self.encoder_decoder_attention: bool = encoder_decoder_attention + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.cache_key = "encoder_decoder" if ( + self.encoder_decoder_attention) else "self" self.num_beams = num_beams + def _shape(self, tensor: Tensor, dim_0: int, bsz: int) -> Tensor: + return tensor.contiguous().view( + dim_0, bsz * self.num_heads, self.head_dim).transpose(0, 1) + def forward( self, - query, + query: Tensor, key: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, - layer_state: Optional[Dict[str, Optional[Tensor]]] = None, + layer_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, attn_mask: Optional[Tensor] = None, - output_attentions=False, - ) -> Tuple[Tensor, Optional[Tensor]]: + output_attentions: bool=False, + ) -> Tuple[Tensor, + Optional[Tensor], + Optional[Dict[str, Dict[str, Optional[Tensor]]]]]: """Input shape: Time(SeqLen) x Batch x Channel""" static_kv: bool = self.encoder_decoder_attention tgt_len, bsz, embed_dim = query.size() assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] + saved_state: Dict[str, Optional[Tensor]] = {} # get here for encoder decoder cause of static_kv if layer_state is not None: # reuse k,v and encoder_padding_mask - saved_state = layer_state.get(self.cache_key, {}) - if "prev_key" in saved_state and static_kv: + if self.cache_key in layer_state: + tmp_saved_state = layer_state.get(self.cache_key) + assert tmp_saved_state is not None + saved_state = tmp_saved_state + if static_kv and "prev_key" in saved_state: # previous time steps are cached - no need to recompute key and # value if they are static key = None - else: - saved_state = None - layer_state = {} - q = self.q_proj(query) * self.scaling - if static_kv: - if key is None: - k = v = None - else: - k = self.k_proj(key) - v = self.v_proj(key) - else: - k = self.k_proj(query) - v = self.v_proj(query) + q = self.q_proj(query) * self.scaling q = self._shape(q, tgt_len, bsz) - if k is not None: + + k: Optional[Tensor] = None + v: Optional[Tensor] = None + if key is not None: + k = self.k_proj(key) k = self._shape(k, -1, bsz) - if v is not None: + v = self.v_proj(key) v = self._shape(v, -1, bsz) - if saved_state is not None: + if len(saved_state) > 0: k, v, key_padding_mask = self._use_saved_state( k, v, saved_state, key_padding_mask, static_kv, bsz) @@ -87,35 +106,46 @@ def forward( cache_bsz = (bsz // self.num_beams if self.encoder_decoder_attention else bsz) + assert k is not None + assert v is not None if self.encoder_decoder_attention and ("prev_key" not in saved_state): cache_shape = ( cache_bsz, self.num_beams, self.num_heads, -1, self.head_dim) k = k.view(cache_shape)[:, 0 : 1, :, :, :].contiguous() v = v.view(cache_shape)[:, 0 : 1, :, :, :].contiguous() + prev_k: Optional[Tensor] = k + prev_v: Optional[Tensor] = v + prev_key_padding_mask: Optional[Tensor] = None if ( + static_kv) else key_padding_mask + assert layer_state is not None layer_state[self.cache_key] = { - "prev_key": k, - "prev_value": v, - "prev_key_padding_mask": - key_padding_mask if not static_kv else None, + "prev_key": prev_k, + "prev_value": prev_v, + "prev_key_padding_mask": prev_key_padding_mask, } - if not self.encoder_decoder_attention: + + if not self.encoder_decoder_attention and layer_state is not None: cache_shape = (bsz, self.num_heads, -1, self.head_dim) + prev_k: Optional[Tensor] = k.view(cache_shape) + prev_v: Optional[Tensor] = v.view(cache_shape) + prev_key_padding_mask: Optional[Tensor] = None if ( + static_kv) else key_padding_mask + assert layer_state is not None layer_state[self.cache_key] = { - "prev_key": k.view(cache_shape), - "prev_value": v.view(cache_shape), - "prev_key_padding_mask": - key_padding_mask if not static_kv else None, + "prev_key": prev_k, + "prev_value": prev_v, + "prev_key_padding_mask": prev_key_padding_mask, } - assert k is not None + # assert q is not None if self.encoder_decoder_attention: q = q.view(cache_bsz, self.num_beams, self.num_heads, tgt_len, self.head_dim) src_len = k.size(3) - attn_weights = torch.einsum("bmhtd,bnhsd->bmhts", q, - k).reshape(-1, tgt_len, src_len) - assert attn_weights.size() == (bsz * self.num_heads, tgt_len, - src_len) + attn_weights = torch.einsum("bmhtd,bnhsd->bmhts", q, k).reshape( + -1, tgt_len, src_len) + assert attn_weights.size() == ( + bsz * self.num_heads, tgt_len, src_len) else: src_len = k.size(1) attn_weights = torch.bmm(q, k.transpose(1, 2)) @@ -163,15 +193,23 @@ def forward( attn_output = attn_output.transpose(0, 1).contiguous().view( tgt_len, bsz, embed_dim) attn_output = self.out_proj(attn_output) + if output_attentions: attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + return attn_output, attn_weights, layer_state else: - attn_weights = None - return attn_output, attn_weights + return attn_output, None, layer_state - def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, - bsz): + def _use_saved_state( + self, + k: Optional[Tensor], + v: Optional[Tensor], + saved_state: Dict[str, Optional[Tensor]], + key_padding_mask: Optional[Tensor], + static_kv: bool, + bsz: int) -> Tuple[ + Optional[Tensor], Optional[Tensor], Optional[Tensor]]: # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) # note that for self-attn, bsz=input_bsz * beam_size; for # encoder-decoder-attn, bsz=input_bsz. @@ -199,11 +237,13 @@ def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, assert k is not None and v is not None prev_key_padding_mask: Optional[Tensor] = saved_state.get( - "prev_key_padding_mask", None) + "prev_key_padding_mask") if prev_key_padding_mask is not None: if static_kv: new_key_padding_mask = prev_key_padding_mask else: + assert prev_key_padding_mask is not None + assert key_padding_mask is not None new_key_padding_mask = torch.cat( [prev_key_padding_mask, key_padding_mask], dim=1) else: @@ -211,6 +251,159 @@ def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, return k, v, new_key_padding_mask +@replace(EncoderLayer) +class EncoderLayerV2(nn.Module): + def __init__(self, config: BartConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = SelfAttention( + self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, + ) + self.self_attn = torch.jit.script(self.self_attn) + optimize_graph(self.self_attn.graph) + self.normalize_before = config.normalize_before + self.self_attn_layer_norm = LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = LayerNorm(self.embed_dim) + + def forward(self, x, encoder_padding_mask, output_attentions=False): + """ + Args: + x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_padding_mask (ByteTensor): binary ByteTensor of shape + `(batch, src_len)` where padding elements are indicated by ``1``. + for t_tgt, t_src is excluded (or masked out), =0 means it is + included in attention + + Returns: + encoded output of shape `(seq_len, batch, embed_dim)` + """ + residual = x + if self.normalize_before: + x = self.self_attn_layer_norm(x) + x, attn_weights, layer_state = self.self_attn( + query=x, + key=x, + key_padding_mask=encoder_padding_mask, + output_attentions=output_attentions, + ) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + if not self.normalize_before: + x = self.self_attn_layer_norm(x) + + residual = x + if self.normalize_before: + x = self.final_layer_norm(x) + x = self.activation_fn(self.fc1(x)) + x = F.dropout(x, p=self.activation_dropout, training=self.training) + x = self.fc2(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + if not self.normalize_before: + x = self.final_layer_norm(x) + return x, attn_weights + + +@replace(DecoderLayer) +class DecoderLayerV2(nn.Module): + def __init__(self, config: BartConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = SelfAttention( + embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, + ) + self.self_attn = torch.jit.script(self.self_attn) + optimize_graph(self.self_attn.graph) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.normalize_before = config.normalize_before + + self.self_attn_layer_norm = LayerNorm(self.embed_dim) + self.encoder_attn = SelfAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + encoder_decoder_attention=True, + ) + self.encoder_attn = torch.jit.script(self.encoder_attn) + optimize_graph(self.encoder_attn.graph) + self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = LayerNorm(self.embed_dim) + + def forward( + self, + x, + encoder_hidden_states, + encoder_attn_mask=None, + layer_state=None, + causal_mask=None, + decoder_padding_mask=None, + output_attentions=False, + ): + residual = x + + if layer_state is None: + layer_state = {} + if self.normalize_before: + x = self.self_attn_layer_norm(x) + # Self Attention + + x, self_attn_weights, layer_state = self.self_attn( + query=x, + key=x, + layer_state=layer_state, # adds keys to layer state + key_padding_mask=decoder_padding_mask, + attn_mask=causal_mask, + output_attentions=output_attentions, + ) + + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + if not self.normalize_before: + x = self.self_attn_layer_norm(x) + + # Cross attention + residual = x + assert self.encoder_attn.cache_key != self.self_attn.cache_key + if self.normalize_before: + x = self.encoder_attn_layer_norm(x) + x, attn_weights, layer_state = self.encoder_attn( + query=x, + key=encoder_hidden_states, + key_padding_mask=encoder_attn_mask, + layer_state=layer_state, # mutates layer state + ) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + if not self.normalize_before: + x = self.encoder_attn_layer_norm(x) + + # Fully Connected + residual = x + if self.normalize_before: + x = self.final_layer_norm(x) + x = self.activation_fn(self.fc1(x)) + x = F.dropout(x, p=self.activation_dropout, training=self.training) + x = self.fc2(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + if not self.normalize_before: + x = self.final_layer_norm(x) + return ( + x, + self_attn_weights, + layer_state, + ) # just self_attn weights for now, following t5, layer_state = cache for decoding + + @replace(BartForConditionalGeneration) class BartForConditionalGenerationV2(BartForConditionalGeneration): """ diff --git a/tests/optimizer/jit/__init__.py b/tests/optimizer/jit/__init__.py new file mode 100644 index 00000000..59e481eb --- /dev/null +++ b/tests/optimizer/jit/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/tests/optimizer/jit/test_einsum_rewriter.py b/tests/optimizer/jit/test_einsum_rewriter.py new file mode 100644 index 00000000..bc80e833 --- /dev/null +++ b/tests/optimizer/jit/test_einsum_rewriter.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import List + +from absl.testing import absltest +import torch +from torch import Tensor + +from fastseq.optimizer.jit.einsum_rewriter import rewrite_einsum +from fastseq.utils.test_utils import TestCaseBase + +class EinsumRewriterTest(TestCaseBase): + + def test_einsum_rewriter(self): + + def run_einsum(t0: Tensor, t1: Tensor): + r = torch.einsum("bmhtd,bnhsd->bmhts", t0, t1) + r = r + 2.0 + return r + + t0 = torch.randn(10, 3, 4, 3, 9, dtype=torch.float32) + t1 = torch.randn(10, 1, 4, 7, 9, dtype=torch.float32) + + r0 = run_einsum(t0, t1) + + script_run_einsum = torch.jit.script(run_einsum) + r1 = script_run_einsum(t0, t1) + + rewrite_einsum(script_run_einsum.graph) + r2 = script_run_einsum(t0, t1) + + self.assertTrue(torch.allclose(r0, r1)) + self.assertTrue(torch.allclose(r1, r2)) + +if __name__ == "__main__": + absltest.main() From 510f7f1f84e5030ed7957cf718e44aaaec24b63a Mon Sep 17 00:00:00 2001 From: Fei Hu Date: Thu, 12 Nov 2020 19:17:59 +0000 Subject: [PATCH 08/12] Improve einsum_rewrite_pattern --- fastseq/optimizer/jit/einsum_rewriter.py | 40 ++++++++++++--------- tests/optimizer/jit/test_einsum_rewriter.py | 7 ++-- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/fastseq/optimizer/jit/einsum_rewriter.py b/fastseq/optimizer/jit/einsum_rewriter.py index 5a245253..cd7c97d5 100644 --- a/fastseq/optimizer/jit/einsum_rewriter.py +++ b/fastseq/optimizer/jit/einsum_rewriter.py @@ -20,27 +20,35 @@ def einsum_rewrite_pattern_0(equation: str, operands: List[Tensor]): if equation == "bmhtd,bnhsd->bmhts": t0 = operands[0] t1 = operands[1] - expand_shape = list(t1.shape) - expand_shape[1] = t0.size(1) - result_shape = list(t0.shape) - result_shape[4] = expand_shape[3] - t1 = t1.expand(expand_shape).transpose(3, 4).contiguous() - t1 = t1.view(-1, t1.size(3), t1.size(4)) - t0 = t0.view(-1, t0.size(3), t0.size(4)) - r = torch.bmm(t0, t1).view(result_shape) + b = t0.size(0) + m = t0.size(1) + h = t0.size(2) + t = t0.size(3) + d = t0.size(4) + n = t1.size(1) + s = t1.size(3) + t0 = t0.permute(0, 2, 1, 3, 4) # (b, h, m, t, d) + t1 = t1.permute(0, 2, 4, 1, 3) # (b, h, d, n, s) + t0 = t0.reshape(b*h, m*t, d) + t1 = t1.reshape(b*h, d, n*s) # TODO: add a check: assert n == 1 + r = torch.bmm(t0, t1).view(b, h, m, t, n*s).permute(0, 2, 1, 3, 4) return r if equation == "bmhts,bnhsd->bmhtd": t0 = operands[0] t1 = operands[1] - expand_shape = list(t1.shape) - expand_shape[1] = t0.size(1) - result_shape = list(t0.shape) - result_shape[4] = expand_shape[4] - t0 = t0.view(-1, t0.size(3), t0.size(4)) - t1 = t1.expand(expand_shape).contiguous() - t1 = t1.view(-1, t1.size(3), t1.size(4)) - r = torch.bmm(t0, t1).view(result_shape) + b = t0.size(0) + m = t0.size(1) + h = t0.size(2) + t = t0.size(3) + s = t0.size(4) + n = t1.size(1) + d = t1.size(4) + t0 = t0.permute(0, 2, 1, 3, 4) # (b, h, m, t, s) + t1 = t1.permute(0, 2, 3, 1, 4) # (b, h, s, n, d) + t0 = t0.reshape(b*h, m*t, s) + t1 = t1.reshape(b*h, s, n*d) # TODO: add a check: assert n == 1 + r = torch.bmm(t0, t1).view(b, h, m, t, n*d).permute(0, 2, 1, 3, 4) return r return torch.einsum(equation, operands) diff --git a/tests/optimizer/jit/test_einsum_rewriter.py b/tests/optimizer/jit/test_einsum_rewriter.py index bc80e833..737c8cb5 100644 --- a/tests/optimizer/jit/test_einsum_rewriter.py +++ b/tests/optimizer/jit/test_einsum_rewriter.py @@ -25,13 +25,10 @@ def run_einsum(t0: Tensor, t1: Tensor): r0 = run_einsum(t0, t1) script_run_einsum = torch.jit.script(run_einsum) - r1 = script_run_einsum(t0, t1) - rewrite_einsum(script_run_einsum.graph) - r2 = script_run_einsum(t0, t1) + r1 = script_run_einsum(t0, t1) - self.assertTrue(torch.allclose(r0, r1)) - self.assertTrue(torch.allclose(r1, r2)) + self.assertTrue(torch.equal(r0, r1)) if __name__ == "__main__": absltest.main() From 83154f10acbacecd8bcd2524f066211252369aed Mon Sep 17 00:00:00 2001 From: Fei Hu Date: Fri, 13 Nov 2020 05:46:26 +0000 Subject: [PATCH 09/12] Make einsum_rewrite_pattern more general --- fastseq/optimizer/jit/einsum_rewriter.py | 30 ++++++++++++++++-------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/fastseq/optimizer/jit/einsum_rewriter.py b/fastseq/optimizer/jit/einsum_rewriter.py index cd7c97d5..177868aa 100644 --- a/fastseq/optimizer/jit/einsum_rewriter.py +++ b/fastseq/optimizer/jit/einsum_rewriter.py @@ -16,8 +16,11 @@ def einsum_pattern_0(t0: str, t1: List[Tensor]): return r @graph_pattern -def einsum_rewrite_pattern_0(equation: str, operands: List[Tensor]): - if equation == "bmhtd,bnhsd->bmhts": +def einsum_rewrite_pattern_0(eqn: str, operands: List[Tensor]): + # for cases like "bmhtd,bnhsd->bmhts" + if (len(eqn) == 18 and eqn[0:3] == eqn[13:16] and eqn[0] == eqn[6] and + eqn[2] == eqn[8] and eqn[4] == eqn[10] and eqn[3] == eqn[16] and + eqn[9] == eqn[17]): t0 = operands[0] t1 = operands[1] b = t0.size(0) @@ -26,15 +29,20 @@ def einsum_rewrite_pattern_0(equation: str, operands: List[Tensor]): t = t0.size(3) d = t0.size(4) n = t1.size(1) + if n > 1: + t1 = t1.sum(dim=1, keepdim=True) # (b, 1, h, d, s) s = t1.size(3) t0 = t0.permute(0, 2, 1, 3, 4) # (b, h, m, t, d) - t1 = t1.permute(0, 2, 4, 1, 3) # (b, h, d, n, s) + t1 = t1.permute(0, 2, 1, 4, 3) # (b, h, 1, d, s) t0 = t0.reshape(b*h, m*t, d) - t1 = t1.reshape(b*h, d, n*s) # TODO: add a check: assert n == 1 - r = torch.bmm(t0, t1).view(b, h, m, t, n*s).permute(0, 2, 1, 3, 4) + t1 = t1.reshape(b*h, d, s) + r = torch.bmm(t0, t1).view(b, h, m, t, s).permute(0, 2, 1, 3, 4) return r - if equation == "bmhts,bnhsd->bmhtd": + # for cases like "bmhts,bnhsd->bmhtd" + if (len(eqn) == 18 and eqn[0:3] == eqn[13:16] and eqn[0] == eqn[6] and + eqn[2] == eqn[8] and eqn[4] == eqn[9] and eqn[3] == eqn[16] and + eqn[10] == eqn[17]): t0 = operands[0] t1 = operands[1] b = t0.size(0) @@ -43,15 +51,17 @@ def einsum_rewrite_pattern_0(equation: str, operands: List[Tensor]): t = t0.size(3) s = t0.size(4) n = t1.size(1) + if n > 1: + t1 = t1.sum(dim=1, keepdim=True) # (b, 1, h, s, d) d = t1.size(4) t0 = t0.permute(0, 2, 1, 3, 4) # (b, h, m, t, s) - t1 = t1.permute(0, 2, 3, 1, 4) # (b, h, s, n, d) + t1 = t1.permute(0, 2, 1, 3, 4) # (b, h, 1, s, d) t0 = t0.reshape(b*h, m*t, s) - t1 = t1.reshape(b*h, s, n*d) # TODO: add a check: assert n == 1 - r = torch.bmm(t0, t1).view(b, h, m, t, n*d).permute(0, 2, 1, 3, 4) + t1 = t1.reshape(b*h, s, d) + r = torch.bmm(t0, t1).view(b, h, m, t, d).permute(0, 2, 1, 3, 4) return r - return torch.einsum(equation, operands) + return torch.einsum(eqn, operands) EINSUM_PATTERN_STR = einsum_pattern_0() EINSUM_REWRITE_PATTERN_STR = einsum_rewrite_pattern_0() From 69e5410f0abc1b87ca1aa0b7291a8571297d160e Mon Sep 17 00:00:00 2001 From: Fei Hu Date: Sat, 14 Nov 2020 08:50:57 +0000 Subject: [PATCH 10/12] Enhance rewrite pattern and tests --- fastseq/optimizer/jit/einsum_rewriter.py | 49 +++++++++----------- tests/optimizer/jit/test_einsum_rewriter.py | 51 +++++++++++++++++---- 2 files changed, 62 insertions(+), 38 deletions(-) diff --git a/fastseq/optimizer/jit/einsum_rewriter.py b/fastseq/optimizer/jit/einsum_rewriter.py index 177868aa..d46094cf 100644 --- a/fastseq/optimizer/jit/einsum_rewriter.py +++ b/fastseq/optimizer/jit/einsum_rewriter.py @@ -10,61 +10,54 @@ from fastseq.optimizer.jit.utils import graph_pattern, rewrite_graph -@graph_pattern def einsum_pattern_0(t0: str, t1: List[Tensor]): r = torch.einsum(t0, t1) return r -@graph_pattern def einsum_rewrite_pattern_0(eqn: str, operands: List[Tensor]): + # eqn = eqn.replace(' ', '') # TODO: fix the issue: ValueError: stoll # for cases like "bmhtd,bnhsd->bmhts" - if (len(eqn) == 18 and eqn[0:3] == eqn[13:16] and eqn[0] == eqn[6] and - eqn[2] == eqn[8] and eqn[4] == eqn[10] and eqn[3] == eqn[16] and - eqn[9] == eqn[17]): + if (len(eqn) == 18 and eqn[0:4] == eqn[13:17] and eqn[0] == eqn[6] and + eqn[2] == eqn[8] and eqn[4] == eqn[10] and eqn[9] == eqn[17]): t0 = operands[0] t1 = operands[1] - b = t0.size(0) - m = t0.size(1) - h = t0.size(2) - t = t0.size(3) - d = t0.size(4) + b, m, h, t, d = t0.shape + s = t1.size(3) n = t1.size(1) + t1 = t1.permute(0, 2, 3, 4, 1) # (b, h, s, d, n) if n > 1: - t1 = t1.sum(dim=1, keepdim=True) # (b, 1, h, d, s) - s = t1.size(3) + t1 = t1.sum(dim=4, keepdim=True) # (b, h, s, d, 1) + t0 = t0.permute(0, 2, 1, 3, 4) # (b, h, m, t, d) - t1 = t1.permute(0, 2, 1, 4, 3) # (b, h, 1, d, s) + t1 = t1.permute(0, 1, 3, 4, 2) # (b, h, d, 1, s) t0 = t0.reshape(b*h, m*t, d) - t1 = t1.reshape(b*h, d, s) + t1 = t1.view(b*h, d, s) r = torch.bmm(t0, t1).view(b, h, m, t, s).permute(0, 2, 1, 3, 4) return r # for cases like "bmhts,bnhsd->bmhtd" - if (len(eqn) == 18 and eqn[0:3] == eqn[13:16] and eqn[0] == eqn[6] and - eqn[2] == eqn[8] and eqn[4] == eqn[9] and eqn[3] == eqn[16] and - eqn[10] == eqn[17]): + if (len(eqn) == 18 and eqn[0:4] == eqn[13:17] and eqn[0] == eqn[6] and + eqn[2] == eqn[8] and eqn[4] == eqn[9] and eqn[10] == eqn[17]): t0 = operands[0] t1 = operands[1] - b = t0.size(0) - m = t0.size(1) - h = t0.size(2) - t = t0.size(3) - s = t0.size(4) + b, m, h, t, s = t0.shape n = t1.size(1) - if n > 1: - t1 = t1.sum(dim=1, keepdim=True) # (b, 1, h, s, d) d = t1.size(4) + t1 = t1.permute(0, 2, 4, 3, 1) # (b, h, d, s, n) + if n > 1: + t1 = t1.sum(dim=4, keepdim=True) # (b, h, d, s, 1) + # t1 = t1.squeeze(1) # (b, h, s, d) t0 = t0.permute(0, 2, 1, 3, 4) # (b, h, m, t, s) - t1 = t1.permute(0, 2, 1, 3, 4) # (b, h, 1, s, d) + t1 = t1.permute(0, 1, 3, 4, 2) # (b, h, s, 1, d) t0 = t0.reshape(b*h, m*t, s) - t1 = t1.reshape(b*h, s, d) + t1 = t1.view(b*h, s, d) r = torch.bmm(t0, t1).view(b, h, m, t, d).permute(0, 2, 1, 3, 4) return r return torch.einsum(eqn, operands) -EINSUM_PATTERN_STR = einsum_pattern_0() -EINSUM_REWRITE_PATTERN_STR = einsum_rewrite_pattern_0() +EINSUM_PATTERN_STR = graph_pattern(einsum_pattern_0)() +EINSUM_REWRITE_PATTERN_STR = graph_pattern(einsum_rewrite_pattern_0)() def rewrite_einsum(input_graph: torch._C.Graph): rewrite_graph(EINSUM_PATTERN_STR, EINSUM_REWRITE_PATTERN_STR, input_graph) diff --git a/tests/optimizer/jit/test_einsum_rewriter.py b/tests/optimizer/jit/test_einsum_rewriter.py index 737c8cb5..5b0a7581 100644 --- a/tests/optimizer/jit/test_einsum_rewriter.py +++ b/tests/optimizer/jit/test_einsum_rewriter.py @@ -1,34 +1,65 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import List +import functools +import logging +import timeit -from absl.testing import absltest +from absl.testing import absltest, parameterized import torch from torch import Tensor +from fastseq.logging import get_logger from fastseq.optimizer.jit.einsum_rewriter import rewrite_einsum from fastseq.utils.test_utils import TestCaseBase +logger = get_logger(__name__, logging.INFO) + class EinsumRewriterTest(TestCaseBase): - def test_einsum_rewriter(self): + @parameterized.parameters( + {'eqn': "bmhtd,bnhsd->bmhts", + 'shape0': [128, 4, 16, 5, 64], + 'shape1': [128, 2, 16, 1024, 64]}, + {'eqn': "kmijd,knisd->kmijs", + 'shape0': [128, 4, 16, 1, 64], + 'shape1': [128, 2, 16, 1024, 64]}, + {'eqn': "bmhts,bnhsd->bmhtd", + 'shape0': [128, 4, 16, 3, 64], + 'shape1': [128, 2, 16, 64, 7]}, + {'eqn': "impts,inpsw->imptw", + 'shape0': [128, 4, 16, 3, 64], + 'shape1': [128, 2, 16, 64, 7]}, + ) + def test_einsum_rewriter(self, eqn, shape0, shape1): - def run_einsum(t0: Tensor, t1: Tensor): - r = torch.einsum("bmhtd,bnhsd->bmhts", t0, t1) - r = r + 2.0 + def run_einsum(eqn: str, t0: Tensor, t1: Tensor): + r = torch.einsum(eqn, t0, t1) return r - t0 = torch.randn(10, 3, 4, 3, 9, dtype=torch.float32) - t1 = torch.randn(10, 1, 4, 7, 9, dtype=torch.float32) + t0 = torch.randn(shape0, dtype=torch.float32).cuda() + t1 = torch.randn(shape1, dtype=torch.float32).cuda() + repeat_times = 1000 - r0 = run_einsum(t0, t1) + r0 = run_einsum(eqn, t0, t1) + time0 = timeit.Timer(functools.partial(run_einsum, eqn, t0, t1)) + s0 = time0.timeit(repeat_times) script_run_einsum = torch.jit.script(run_einsum) + logger.debug(f"Original graph: \n{script_run_einsum.graph.str()}") rewrite_einsum(script_run_einsum.graph) - r1 = script_run_einsum(t0, t1) + logger.debug(f"Optimized graph: \n{script_run_einsum.graph.str()}") + self.assertTrue('bmm' in script_run_einsum.graph.str()) + + r1 = script_run_einsum(eqn, t0, t1) + time1 = timeit.Timer( + functools.partial(script_run_einsum, eqn, t0, t1)) + s1 = time1.timeit(repeat_times) self.assertTrue(torch.equal(r0, r1)) + logger.info(f"einsum took: {s0}; optimized einsum torchscript took: " + f"{s1};") + if __name__ == "__main__": absltest.main() From 699fba74fb8c965a4ab246faa04fbb258d724e63 Mon Sep 17 00:00:00 2001 From: Fei Hu Date: Tue, 17 Nov 2020 06:07:25 +0000 Subject: [PATCH 11/12] Add synchronize and check for contiguous --- tests/optimizer/jit/test_einsum_rewriter.py | 44 +++++++++++++++------ 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/tests/optimizer/jit/test_einsum_rewriter.py b/tests/optimizer/jit/test_einsum_rewriter.py index 5b0a7581..79d63b1e 100644 --- a/tests/optimizer/jit/test_einsum_rewriter.py +++ b/tests/optimizer/jit/test_einsum_rewriter.py @@ -1,16 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import functools import logging -import timeit +import time from absl.testing import absltest, parameterized import torch from torch import Tensor from fastseq.logging import get_logger -from fastseq.optimizer.jit.einsum_rewriter import rewrite_einsum +from fastseq.optimizer.jit.einsum_rewriter import rewrite_einsum, einsum_rewrite_pattern_0 from fastseq.utils.test_utils import TestCaseBase logger = get_logger(__name__, logging.INFO) @@ -25,8 +24,8 @@ class EinsumRewriterTest(TestCaseBase): 'shape0': [128, 4, 16, 1, 64], 'shape1': [128, 2, 16, 1024, 64]}, {'eqn': "bmhts,bnhsd->bmhtd", - 'shape0': [128, 4, 16, 3, 64], - 'shape1': [128, 2, 16, 64, 7]}, + 'shape0': [128, 4, 16, 5, 64], + 'shape1': [128, 2, 16, 64, 1024]}, {'eqn': "impts,inpsw->imptw", 'shape0': [128, 4, 16, 3, 64], 'shape1': [128, 2, 16, 64, 7]}, @@ -39,11 +38,15 @@ def run_einsum(eqn: str, t0: Tensor, t1: Tensor): t0 = torch.randn(shape0, dtype=torch.float32).cuda() t1 = torch.randn(shape1, dtype=torch.float32).cuda() - repeat_times = 1000 + repeat_times = 1024 r0 = run_einsum(eqn, t0, t1) - time0 = timeit.Timer(functools.partial(run_einsum, eqn, t0, t1)) - s0 = time0.timeit(repeat_times) + torch.cuda.synchronize() + start0 = time.time() + for _ in range(repeat_times): + run_einsum(eqn, t0, t1) + torch.cuda.synchronize() + end0 = time.time() script_run_einsum = torch.jit.script(run_einsum) logger.debug(f"Original graph: \n{script_run_einsum.graph.str()}") @@ -52,13 +55,28 @@ def run_einsum(eqn: str, t0: Tensor, t1: Tensor): self.assertTrue('bmm' in script_run_einsum.graph.str()) r1 = script_run_einsum(eqn, t0, t1) - time1 = timeit.Timer( - functools.partial(script_run_einsum, eqn, t0, t1)) - s1 = time1.timeit(repeat_times) + torch.cuda.synchronize() + start1 = time.time() + for _ in range(repeat_times): + script_run_einsum(eqn, t0, t1) + torch.cuda.synchronize() + end1 = time.time() + + r2 = einsum_rewrite_pattern_0(eqn, [t0, t1]) + torch.cuda.synchronize() + start2 = time.time() + for _ in range(repeat_times): + einsum_rewrite_pattern_0(eqn, [t0, t1]) + torch.cuda.synchronize() + end2 = time.time() self.assertTrue(torch.equal(r0, r1)) - logger.info(f"einsum took: {s0}; optimized einsum torchscript took: " - f"{s1};") + self.assertTrue(torch.equal(r0, r2)) + self.assertEqual( + r0.is_contiguous(), r1.is_contiguous(), r2.is_contiguous()) + logger.info(f"einsum took: {end0 - start0};" + f"optimized einsum torchscript took: {end1 - start1};" + f"optimized einsum python took: {end2 - start2};") if __name__ == "__main__": From 19fcdb1b8cb8998b0012d713bda69adcabccc494 Mon Sep 17 00:00:00 2001 From: Fei Hu Date: Wed, 2 Dec 2020 20:04:03 +0000 Subject: [PATCH 12/12] Test commit for public fastseq --- benchmarks/models/fast_test.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/models/fast_test.sh b/benchmarks/models/fast_test.sh index e02f2d3d..3de68d67 100755 --- a/benchmarks/models/fast_test.sh +++ b/benchmarks/models/fast_test.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Run it at its parent folder, and check result at ../perf. -# USAGE -./benchmark.sh +# Run it at its parent folder, and check result at ../perf. +# USAGE -./benchmark.sh # [fairseq|fairseq+fastseq|transformers|transformers+fastseq] # #