From eb3c4e0c22fe8ba6136549a7e695844781bdbe41 Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Thu, 2 May 2024 19:31:41 +0000 Subject: [PATCH 01/22] Simplify build system. --- .bazelrc | 5 ----- .github/workflows/wheel.yaml | 5 +++-- BUILD | 9 ++++++--- README.md | 4 ++-- tf_shell/BUILD | 27 +++++++++++---------------- third_party/tensorflow/BUILD | 9 ++++++--- 6 files changed, 28 insertions(+), 31 deletions(-) diff --git a/.bazelrc b/.bazelrc index a39e2d8..cb4f578 100644 --- a/.bazelrc +++ b/.bazelrc @@ -5,10 +5,6 @@ build -c opt build --cxxopt='-std=c++17' build --cxxopt='-D_GLIBCXX_USE_CXX11_ABI=1' -build:test --cxxopt='-DPYBIND11_ABSEIL_STATUS_MODULE_PATH=pybind11_abseil.pybind11_abseil.status' - -build:release --cxxopt='-DPYBIND11_ABSEIL_STATUS_MODULE_PATH=pybind11_abseil.status' - # If there are compilation issues with asan and absl, try making the changes # described by: # https://github.com/abseil/abseil-cpp/pull/1399/files#diff-32cf2e2d37473ed6eb8f8b7e1126983fcca9a5fe02885098094c9ed4ceda8a6f @@ -16,7 +12,6 @@ build:asan --strip=never build:asan --copt -O1 build:asan --copt -g build:asan --copt -fno-omit-frame-pointer -build:asan --cxxopt='-DPYBIND11_ABSEIL_STATUS_MODULE_PATH=pybind11_abseil.pybind11_abseil.status' # Sanitizers don't work with absl::string at version 20230802.1. #build:asan --copt -fsanitize=address #build:asan --linkopt -fsanitize=address diff --git a/.github/workflows/wheel.yaml b/.github/workflows/wheel.yaml index 652e576..f9663fc 100644 --- a/.github/workflows/wheel.yaml +++ b/.github/workflows/wheel.yaml @@ -16,7 +16,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - #os: [ubuntu-latest, windows-2019, macOS-11] + #os: [ubuntu-latest, windows-latest, macOS-12] os: [ubuntu-latest] python-version: ["3.9", "3.10", "3.11"] # Also see MODULE.bazel @@ -24,10 +24,11 @@ jobs: - uses: actions/checkout@v4 - name: Build wheel + shell: bash run: | sed -i 's/DEFAULT_PYTHON = "3.10"/DEFAULT_PYTHON = "${{ matrix.python-version }}"/' ./MODULE.bazel sed -i 's/DEFAULT_PYTHON = "3.10"/DEFAULT_PYTHON = "${{ matrix.python-version }}"/' ./BUILD - bazelisk build --config release //:wheel + bazelisk build //:wheel bazelisk run //:wheel_rename - uses: actions/upload-artifact@v3 diff --git a/BUILD b/BUILD index 9899cd6..19e09a0 100644 --- a/BUILD +++ b/BUILD @@ -92,7 +92,7 @@ py_wheel( name = "wheel", abi = "ABI", author = "Google Inc.", - author_email = "jchoncholas@google.com", + author_email = "jchoncholas@gmail.com", classifiers = [ "Topic :: Security :: Cryptography", "Topic :: Scientific/Engineering :: Artificial Intelligence", @@ -115,10 +115,13 @@ py_wheel( "@bazel_tools//src/conditions:linux_x86_64": "LINUX_x86_64", "@bazel_tools//src/conditions:linux_aarch64": "LINUX_aarch64", }), - python_requires = "==" + DEFAULT_PYTHON + ".*", + python_requires = ">=3.9", python_tag = "INTERPRETER", requires = ["tensorflow-cpu==2.16.1"], # See also: requirements.in. - summary = "TF-Shell: Privacy preserving machine learning with Tensorflow and the SHELL encryption library", + # The summary is tailored for each python version because PyPI prevents + # wheel uploads for different versions which have the same contents. + # Changing the summary is sufficient to allow re-uploads. + summary = "TF-Shell: Privacy preserving machine learning with Tensorflow and the SHELL encryption library, built for python " + DEFAULT_PYTHON + ".", version = module_version(), deps = [ "//tf_shell:tf_shell_pkg", diff --git a/README.md b/README.md index 9f76390..0307119 100644 --- a/README.md +++ b/README.md @@ -51,13 +51,13 @@ the labels. 2. Run the tests. ```bash - bazelisk test --config test ... + bazelisk test ... ``` 3. Build the code. ```bash - bazelisk build --config release //:wheel + bazelisk build //:wheel bazelisk run //:wheel_rename ``` diff --git a/tf_shell/BUILD b/tf_shell/BUILD index 222b677..7ce3bb5 100644 --- a/tf_shell/BUILD +++ b/tf_shell/BUILD @@ -1,23 +1,18 @@ load("@rules_python//python:packaging.bzl", "py_package") -cc_binary( - name = "python/ops/_shell_ops.so", - srcs = [ - "cc/kernels/add_kernels.cc", - "cc/kernels/context_kernels.cc", - "cc/kernels/context_variant.h", - "cc/kernels/mod_switch_kernels.cc", - "cc/kernels/mul_kernels.cc", - "cc/kernels/polynomial_kernels.cc", - "cc/kernels/polynomial_variant.h", - "cc/kernels/rotation_kernels.cc", - "cc/kernels/rotation_variants.h", - "cc/kernels/shape_kernels.cc", - "cc/kernels/symmetric_kernels.cc", - "cc/kernels/symmetric_variants.h", - "cc/kernels/utils.h", +filegroup( + name = "shell_ops_src", + srcs = glob([ + "cc/kernels/*.cc", + "cc/kernels/*.h", + ]) + [ "cc/ops/shell_ops.cc", ], +) + +cc_binary( + name = "python/ops/_shell_ops.so", + srcs = [":shell_ops_src"], copts = [ "-pthread", "-fPIC", diff --git a/third_party/tensorflow/BUILD b/third_party/tensorflow/BUILD index 285c5c3..f576af2 100644 --- a/third_party/tensorflow/BUILD +++ b/third_party/tensorflow/BUILD @@ -26,7 +26,7 @@ dynamic_genrule( genrule( name = "hermetic_tf_lib", - outs = ["libtensorflow_framework.so.2"], + outs = ["libtensorflow_framework.so"], cmd = "$(location :extract_tf_lib) $@", tools = [":extract_tf_lib"], ) @@ -45,14 +45,17 @@ cc_library( deps = [":hermetic_tf_lib"], ) +# At runtime, libtensorflow_framework.so will be provided by pip, so this rule +# uses system_provided = 1 to indicate that the library is only needed to build +# but can be skipped when packaging. cc_import( name = "hermetic_tf", hdrs = [":tf_headers"], - shared_library = "libtensorflow_framework.so.2", + interface_library = "libtensorflow_framework.so", + system_provided = 1, visibility = ["//visibility:public"], deps = [ ":tf_headers", ":tf_lib", ], - alwayslink = 1, ) From a6af78f2a8fd2d2ce6f8705bffb23429080fe323 Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Thu, 2 May 2024 19:32:53 +0000 Subject: [PATCH 02/22] Remove unused includes. --- tf_shell/cc/kernels/add_kernels.cc | 1 - tf_shell/cc/kernels/context_variant.h | 1 - tf_shell/cc/kernels/utils.h | 1 - 3 files changed, 3 deletions(-) diff --git a/tf_shell/cc/kernels/add_kernels.cc b/tf_shell/cc/kernels/add_kernels.cc index e3709e0..928ddb4 100644 --- a/tf_shell/cc/kernels/add_kernels.cc +++ b/tf_shell/cc/kernels/add_kernels.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "absl/status/status.h" #include "context_variant.h" #include "polynomial_variant.h" #include "shell_encryption/context.h" diff --git a/tf_shell/cc/kernels/context_variant.h b/tf_shell/cc/kernels/context_variant.h index 3e0e287..dcc6881 100644 --- a/tf_shell/cc/kernels/context_variant.h +++ b/tf_shell/cc/kernels/context_variant.h @@ -17,7 +17,6 @@ #pragma once #include -#include "absl/status/status.h" #include "shell_encryption/prng/hkdf_prng.h" #include "shell_encryption/rns/coefficient_encoder.h" #include "shell_encryption/rns/finite_field_encoder.h" diff --git a/tf_shell/cc/kernels/utils.h b/tf_shell/cc/kernels/utils.h index da008b2..53875b2 100644 --- a/tf_shell/cc/kernels/utils.h +++ b/tf_shell/cc/kernels/utils.h @@ -18,7 +18,6 @@ #include -#include "absl/status/status.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/variant.h" From 6263e2880dc8f9e17b379427e3587482a0770a7a Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Thu, 2 May 2024 20:00:59 +0000 Subject: [PATCH 03/22] Support python 3.12 --- .github/workflows/wheel.yaml | 2 +- BUILD | 10 + MODULE.bazel | 1 + README.md | 2 +- requirements_3_12.txt | 613 +++++++++++++++++++++++++++++++++++ 5 files changed, 626 insertions(+), 2 deletions(-) create mode 100644 requirements_3_12.txt diff --git a/.github/workflows/wheel.yaml b/.github/workflows/wheel.yaml index f9663fc..cd4cc2f 100644 --- a/.github/workflows/wheel.yaml +++ b/.github/workflows/wheel.yaml @@ -18,7 +18,7 @@ jobs: matrix: #os: [ubuntu-latest, windows-latest, macOS-12] os: [ubuntu-latest] - python-version: ["3.9", "3.10", "3.11"] # Also see MODULE.bazel + python-version: ["3.9", "3.10", "3.11", "3.12"] # Also see MODULE.bazel steps: - uses: actions/checkout@v4 diff --git a/BUILD b/BUILD index 19e09a0..503ca20 100644 --- a/BUILD +++ b/BUILD @@ -2,6 +2,7 @@ load("@buildifier_prebuilt//:rules.bzl", "buildifier") load("@pip//:requirements.bzl", "requirement") load("@python_versions//3.10:defs.bzl", compile_pip_requirements_3_10 = "compile_pip_requirements") load("@python_versions//3.11:defs.bzl", compile_pip_requirements_3_11 = "compile_pip_requirements") +load("@python_versions//3.12:defs.bzl", compile_pip_requirements_3_12 = "compile_pip_requirements") load("@python_versions//3.9:defs.bzl", compile_pip_requirements_3_9 = "compile_pip_requirements") load("@rules_python//python:defs.bzl", "py_binary") load("@rules_python//python:packaging.bzl", "py_wheel") @@ -15,6 +16,7 @@ exports_files([ "requirements_3_9.txt", "requirements_3_10.txt", "requirements_3_11.txt", + "requirements_3_12.txt", "README.md", "DESCRIPTION.md", ]) @@ -43,6 +45,14 @@ compile_pip_requirements_3_11( visibility = ["//visibility:public"], ) +compile_pip_requirements_3_12( + name = "requirements_3_12", + extra_args = ["--allow-unsafe"], # need setuptools + requirements_in = "//:requirements.in", + requirements_txt = "//:requirements_3_12.txt", + visibility = ["//visibility:public"], +) + buildifier( name = "bazel_formatter", exclude_patterns = [ diff --git a/MODULE.bazel b/MODULE.bazel index 951ad1e..5d9a89c 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -7,6 +7,7 @@ SUPPORTED_PYTHON_VERSIONS = [ "3.9", "3.10", "3.11", + "3.12", ] # Also see ./.github/workflows/wheel.yaml DEFAULT_PYTHON = "3.10" # Also see ./BUILD diff --git a/README.md b/README.md index 0307119..11bf228 100644 --- a/README.md +++ b/README.md @@ -95,7 +95,7 @@ Update requirements.in and run the following to update the requirements files for each python version. ```bash -for ver in 3_9 3_10 3_11; do +for ver in 3_9 3_10 3_11 3_12; do touch requirements_${ver}.txt bazelisk run //:requirements_${ver}.update done diff --git a/requirements_3_12.txt b/requirements_3_12.txt new file mode 100644 index 0000000..99cc837 --- /dev/null +++ b/requirements_3_12.txt @@ -0,0 +1,613 @@ +# +# This file is autogenerated by pip-compile with Python 3.12 +# by the following command: +# +# bazel run //:requirements_3_12.update +# +absl-py==2.1.0 \ + --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ + --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff + # via + # keras + # tensorboard + # tensorflow-cpu +astunparse==1.6.3 \ + --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ + --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 + # via tensorflow-cpu +black==24.4.0 \ + --hash=sha256:1bb9ca06e556a09f7f7177bc7cb604e5ed2d2df1e9119e4f7d2f1f7071c32e5d \ + --hash=sha256:21f9407063ec71c5580b8ad975653c66508d6a9f57bd008bb8691d273705adcd \ + --hash=sha256:4396ca365a4310beef84d446ca5016f671b10f07abdba3e4e4304218d2c71d33 \ + --hash=sha256:44d99dfdf37a2a00a6f7a8dcbd19edf361d056ee51093b2445de7ca09adac965 \ + --hash=sha256:5cd5b4f76056cecce3e69b0d4c228326d2595f506797f40b9233424e2524c070 \ + --hash=sha256:64578cf99b6b46a6301bc28bdb89f9d6f9b592b1c5837818a177c98525dbe397 \ + --hash=sha256:64e60a7edd71fd542a10a9643bf369bfd2644de95ec71e86790b063aa02ff745 \ + --hash=sha256:652e55bb722ca026299eb74e53880ee2315b181dfdd44dca98e43448620ddec1 \ + --hash=sha256:6644f97a7ef6f401a150cca551a1ff97e03c25d8519ee0bbc9b0058772882665 \ + --hash=sha256:6ad001a9ddd9b8dfd1b434d566be39b1cd502802c8d38bbb1ba612afda2ef436 \ + --hash=sha256:71d998b73c957444fb7c52096c3843875f4b6b47a54972598741fe9a7f737fcb \ + --hash=sha256:74eb9b5420e26b42c00a3ff470dc0cd144b80a766128b1771d07643165e08d0e \ + --hash=sha256:75a2d0b4f5eb81f7eebc31f788f9830a6ce10a68c91fbe0fade34fff7a2836e6 \ + --hash=sha256:7852b05d02b5b9a8c893ab95863ef8986e4dda29af80bbbda94d7aee1abf8702 \ + --hash=sha256:7f2966b9b2b3b7104fca9d75b2ee856fe3fdd7ed9e47c753a4bb1a675f2caab8 \ + --hash=sha256:8e5537f456a22cf5cfcb2707803431d2feeb82ab3748ade280d6ccd0b40ed2e8 \ + --hash=sha256:d4e71cdebdc8efeb6deaf5f2deb28325f8614d48426bed118ecc2dcaefb9ebf3 \ + --hash=sha256:dae79397f367ac8d7adb6c779813328f6d690943f64b32983e896bcccd18cbad \ + --hash=sha256:e3a3a092b8b756c643fe45f4624dbd5a389f770a4ac294cf4d0fce6af86addaf \ + --hash=sha256:eb949f56a63c5e134dfdca12091e98ffb5fd446293ebae123d10fc1abad00b9e \ + --hash=sha256:f07b69fda20578367eaebbd670ff8fc653ab181e1ff95d84497f9fa20e7d0641 \ + --hash=sha256:f95cece33329dc4aa3b0e1a771c41075812e46cf3d6e3f1dfe3d91ff09826ed2 + # via -r requirements.in +certifi==2024.2.2 \ + --hash=sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f \ + --hash=sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1 + # via requests +charset-normalizer==3.3.2 \ + --hash=sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027 \ + --hash=sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087 \ + --hash=sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786 \ + --hash=sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8 \ + --hash=sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09 \ + --hash=sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185 \ + --hash=sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574 \ + --hash=sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e \ + --hash=sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519 \ + --hash=sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898 \ + --hash=sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269 \ + --hash=sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3 \ + --hash=sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f \ + --hash=sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6 \ + --hash=sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8 \ + --hash=sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a \ + --hash=sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73 \ + --hash=sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc \ + --hash=sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714 \ + --hash=sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2 \ + --hash=sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc \ + --hash=sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce \ + --hash=sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d \ + --hash=sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e \ + --hash=sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6 \ + --hash=sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269 \ + --hash=sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96 \ + --hash=sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d \ + --hash=sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a \ + --hash=sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4 \ + --hash=sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77 \ + --hash=sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d \ + --hash=sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0 \ + --hash=sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed \ + --hash=sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068 \ + --hash=sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac \ + --hash=sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25 \ + --hash=sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8 \ + --hash=sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab \ + --hash=sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26 \ + --hash=sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2 \ + --hash=sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db \ + --hash=sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f \ + --hash=sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5 \ + --hash=sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99 \ + --hash=sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c \ + --hash=sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d \ + --hash=sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811 \ + --hash=sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa \ + --hash=sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a \ + --hash=sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03 \ + --hash=sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b \ + --hash=sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04 \ + --hash=sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c \ + --hash=sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001 \ + --hash=sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458 \ + --hash=sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389 \ + --hash=sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99 \ + --hash=sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985 \ + --hash=sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537 \ + --hash=sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238 \ + --hash=sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f \ + --hash=sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d \ + --hash=sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796 \ + --hash=sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a \ + --hash=sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143 \ + --hash=sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8 \ + --hash=sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c \ + --hash=sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5 \ + --hash=sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5 \ + --hash=sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711 \ + --hash=sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4 \ + --hash=sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6 \ + --hash=sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c \ + --hash=sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7 \ + --hash=sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4 \ + --hash=sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b \ + --hash=sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae \ + --hash=sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12 \ + --hash=sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c \ + --hash=sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae \ + --hash=sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8 \ + --hash=sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887 \ + --hash=sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b \ + --hash=sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4 \ + --hash=sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f \ + --hash=sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5 \ + --hash=sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33 \ + --hash=sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519 \ + --hash=sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561 + # via requests +click==8.1.7 \ + --hash=sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28 \ + --hash=sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de + # via black +flatbuffers==24.3.25 \ + --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ + --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 + # via tensorflow-cpu +gast==0.5.4 \ + --hash=sha256:6fc4fa5fa10b72fb8aab4ae58bcb023058386e67b6fa2e3e34cec5c769360316 \ + --hash=sha256:9c270fe5f4b130969b54174de7db4e764b09b4f7f67ccfc32480e29f78348d97 + # via tensorflow-cpu +google-pasta==0.2.0 \ + --hash=sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954 \ + --hash=sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed \ + --hash=sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e + # via tensorflow-cpu +grpcio==1.63.0 \ + --hash=sha256:01799e8649f9e94ba7db1aeb3452188048b0019dc37696b0f5ce212c87c560c3 \ + --hash=sha256:0697563d1d84d6985e40ec5ec596ff41b52abb3fd91ec240e8cb44a63b895094 \ + --hash=sha256:08e1559fd3b3b4468486b26b0af64a3904a8dbc78d8d936af9c1cf9636eb3e8b \ + --hash=sha256:166e5c460e5d7d4656ff9e63b13e1f6029b122104c1633d5f37eaea348d7356d \ + --hash=sha256:1ff737cf29b5b801619f10e59b581869e32f400159e8b12d7a97e7e3bdeee6a2 \ + --hash=sha256:219bb1848cd2c90348c79ed0a6b0ea51866bc7e72fa6e205e459fedab5770172 \ + --hash=sha256:259e11932230d70ef24a21b9fb5bb947eb4703f57865a404054400ee92f42f5d \ + --hash=sha256:2e93aca840c29d4ab5db93f94ed0a0ca899e241f2e8aec6334ab3575dc46125c \ + --hash=sha256:3a6d1f9ea965e750db7b4ee6f9fdef5fdf135abe8a249e75d84b0a3e0c668a1b \ + --hash=sha256:50344663068041b34a992c19c600236e7abb42d6ec32567916b87b4c8b8833b3 \ + --hash=sha256:56cdf96ff82e3cc90dbe8bac260352993f23e8e256e063c327b6cf9c88daf7a9 \ + --hash=sha256:5c039ef01516039fa39da8a8a43a95b64e288f79f42a17e6c2904a02a319b357 \ + --hash=sha256:6426e1fb92d006e47476d42b8f240c1d916a6d4423c5258ccc5b105e43438f61 \ + --hash=sha256:65bf975639a1f93bee63ca60d2e4951f1b543f498d581869922910a476ead2f5 \ + --hash=sha256:6a1a3642d76f887aa4009d92f71eb37809abceb3b7b5a1eec9c554a246f20e3a \ + --hash=sha256:6ef0ad92873672a2a3767cb827b64741c363ebaa27e7f21659e4e31f4d750280 \ + --hash=sha256:756fed02dacd24e8f488f295a913f250b56b98fb793f41d5b2de6c44fb762434 \ + --hash=sha256:75f701ff645858a2b16bc8c9fc68af215a8bb2d5a9b647448129de6e85d52bce \ + --hash=sha256:8064d986d3a64ba21e498b9a376cbc5d6ab2e8ab0e288d39f266f0fca169b90d \ + --hash=sha256:878b1d88d0137df60e6b09b74cdb73db123f9579232c8456f53e9abc4f62eb3c \ + --hash=sha256:8f3f6883ce54a7a5f47db43289a0a4c776487912de1a0e2cc83fdaec9685cc9f \ + --hash=sha256:91b73d3f1340fefa1e1716c8c1ec9930c676d6b10a3513ab6c26004cb02d8b3f \ + --hash=sha256:93a46794cc96c3a674cdfb59ef9ce84d46185fe9421baf2268ccb556f8f81f57 \ + --hash=sha256:93f45f27f516548e23e4ec3fbab21b060416007dbe768a111fc4611464cc773f \ + --hash=sha256:9e350cb096e5c67832e9b6e018cf8a0d2a53b2a958f6251615173165269a91b0 \ + --hash=sha256:a2d60cd1d58817bc5985fae6168d8b5655c4981d448d0f5b6194bbcc038090d2 \ + --hash=sha256:a3abfe0b0f6798dedd2e9e92e881d9acd0fdb62ae27dcbbfa7654a57e24060c0 \ + --hash=sha256:a44624aad77bf8ca198c55af811fd28f2b3eaf0a50ec5b57b06c034416ef2d0a \ + --hash=sha256:a7b19dfc74d0be7032ca1eda0ed545e582ee46cd65c162f9e9fc6b26ef827dc6 \ + --hash=sha256:ad2ac8903b2eae071055a927ef74121ed52d69468e91d9bcbd028bd0e554be6d \ + --hash=sha256:b005292369d9c1f80bf70c1db1c17c6c342da7576f1c689e8eee4fb0c256af85 \ + --hash=sha256:b2e44f59316716532a993ca2966636df6fbe7be4ab6f099de6815570ebe4383a \ + --hash=sha256:b3afbd9d6827fa6f475a4f91db55e441113f6d3eb9b7ebb8fb806e5bb6d6bd0d \ + --hash=sha256:b416252ac5588d9dfb8a30a191451adbf534e9ce5f56bb02cd193f12d8845b7f \ + --hash=sha256:b5194775fec7dc3dbd6a935102bb156cd2c35efe1685b0a46c67b927c74f0cfb \ + --hash=sha256:cacdef0348a08e475a721967f48206a2254a1b26ee7637638d9e081761a5ba86 \ + --hash=sha256:cd1e68776262dd44dedd7381b1a0ad09d9930ffb405f737d64f505eb7f77d6c7 \ + --hash=sha256:cdcda1156dcc41e042d1e899ba1f5c2e9f3cd7625b3d6ebfa619806a4c1aadda \ + --hash=sha256:cf8dae9cc0412cb86c8de5a8f3be395c5119a370f3ce2e69c8b7d46bb9872c8d \ + --hash=sha256:d2497769895bb03efe3187fb1888fc20e98a5f18b3d14b606167dacda5789434 \ + --hash=sha256:e3b77eaefc74d7eb861d3ffbdf91b50a1bb1639514ebe764c47773b833fa2d91 \ + --hash=sha256:e48cee31bc5f5a31fb2f3b573764bd563aaa5472342860edcc7039525b53e46a \ + --hash=sha256:e4cbb2100ee46d024c45920d16e888ee5d3cf47c66e316210bc236d5bebc42b3 \ + --hash=sha256:f28f8b2db7b86c77916829d64ab21ff49a9d8289ea1564a2b2a3a8ed9ffcccd3 \ + --hash=sha256:f3023e14805c61bc439fb40ca545ac3d5740ce66120a678a3c6c2c55b70343d1 \ + --hash=sha256:fdf348ae69c6ff484402cfdb14e18c1b0054ac2420079d575c53a60b9b2853ae + # via + # tensorboard + # tensorflow-cpu +h5py==3.11.0 \ + --hash=sha256:083e0329ae534a264940d6513f47f5ada617da536d8dccbafc3026aefc33c90e \ + --hash=sha256:1625fd24ad6cfc9c1ccd44a66dac2396e7ee74940776792772819fc69f3a3731 \ + --hash=sha256:21dbdc5343f53b2e25404673c4f00a3335aef25521bd5fa8c707ec3833934892 \ + --hash=sha256:52c416f8eb0daae39dabe71415cb531f95dce2d81e1f61a74537a50c63b28ab3 \ + --hash=sha256:55106b04e2c83dfb73dc8732e9abad69d83a436b5b82b773481d95d17b9685e1 \ + --hash=sha256:67462d0669f8f5459529de179f7771bd697389fcb3faab54d63bf788599a48ea \ + --hash=sha256:6c4b760082626120031d7902cd983d8c1f424cdba2809f1067511ef283629d4b \ + --hash=sha256:731839240c59ba219d4cb3bc5880d438248533366f102402cfa0621b71796b62 \ + --hash=sha256:754c0c2e373d13d6309f408325343b642eb0f40f1a6ad21779cfa9502209e150 \ + --hash=sha256:75bd7b3d93fbeee40860fd70cdc88df4464e06b70a5ad9ce1446f5f32eb84007 \ + --hash=sha256:77b19a40788e3e362b54af4dcf9e6fde59ca016db2c61360aa30b47c7b7cef00 \ + --hash=sha256:7b7e8f78072a2edec87c9836f25f34203fd492a4475709a18b417a33cfb21fa9 \ + --hash=sha256:8ec9df3dd2018904c4cc06331951e274f3f3fd091e6d6cc350aaa90fa9b42a76 \ + --hash=sha256:a76cae64080210389a571c7d13c94a1a6cf8cb75153044fd1f822a962c97aeab \ + --hash=sha256:aa6ae84a14103e8dc19266ef4c3e5d7c00b68f21d07f2966f0ca7bdb6c2761fb \ + --hash=sha256:bbd732a08187a9e2a6ecf9e8af713f1d68256ee0f7c8b652a32795670fb481ba \ + --hash=sha256:c072655ad1d5fe9ef462445d3e77a8166cbfa5e599045f8aa3c19b75315f10e5 \ + --hash=sha256:d9c944d364688f827dc889cf83f1fca311caf4fa50b19f009d1f2b525edd33a3 \ + --hash=sha256:ef4e2f338fc763f50a8113890f455e1a70acd42a4d083370ceb80c463d803972 \ + --hash=sha256:f3736fe21da2b7d8a13fe8fe415f1272d2a1ccdeff4849c1421d2fb30fd533bc \ + --hash=sha256:f4e025e852754ca833401777c25888acb96889ee2c27e7e629a19aee288833f0 + # via + # keras + # tensorflow-cpu +idna==3.7 \ + --hash=sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc \ + --hash=sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0 + # via requests +keras==3.3.3 \ + --hash=sha256:260df9ef71c6b89eb6816ce1c60f139c38ccdddd16f24e7005d2be127cdef8e4 \ + --hash=sha256:f2fdffc8434fd77045cf8fb21816dbaa2308d5f76974ca924b2f60b40433b1a0 + # via tensorflow-cpu +libclang==18.1.1 \ + --hash=sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8 \ + --hash=sha256:4dd2d3b82fab35e2bf9ca717d7b63ac990a3519c7e312f19fa8e86dcc712f7fb \ + --hash=sha256:54dda940a4a0491a9d1532bf071ea3ef26e6dbaf03b5000ed94dd7174e8f9592 \ + --hash=sha256:69f8eb8f65c279e765ffd28aaa7e9e364c776c17618af8bff22a8df58677ff4f \ + --hash=sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5 \ + --hash=sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8 \ + --hash=sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250 \ + --hash=sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b \ + --hash=sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe + # via tensorflow-cpu +markdown==3.6 \ + --hash=sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f \ + --hash=sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224 + # via tensorboard +markdown-it-py==3.0.0 \ + --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ + --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb + # via rich +markupsafe==2.1.5 \ + --hash=sha256:00e046b6dd71aa03a41079792f8473dc494d564611a8f89bbbd7cb93295ebdcf \ + --hash=sha256:075202fa5b72c86ad32dc7d0b56024ebdbcf2048c0ba09f1cde31bfdd57bcfff \ + --hash=sha256:0e397ac966fdf721b2c528cf028494e86172b4feba51d65f81ffd65c63798f3f \ + --hash=sha256:17b950fccb810b3293638215058e432159d2b71005c74371d784862b7e4683f3 \ + --hash=sha256:1f3fbcb7ef1f16e48246f704ab79d79da8a46891e2da03f8783a5b6fa41a9532 \ + --hash=sha256:2174c595a0d73a3080ca3257b40096db99799265e1c27cc5a610743acd86d62f \ + --hash=sha256:2b7c57a4dfc4f16f7142221afe5ba4e093e09e728ca65c51f5620c9aaeb9a617 \ + --hash=sha256:2d2d793e36e230fd32babe143b04cec8a8b3eb8a3122d2aceb4a371e6b09b8df \ + --hash=sha256:30b600cf0a7ac9234b2638fbc0fb6158ba5bdcdf46aeb631ead21248b9affbc4 \ + --hash=sha256:397081c1a0bfb5124355710fe79478cdbeb39626492b15d399526ae53422b906 \ + --hash=sha256:3a57fdd7ce31c7ff06cdfbf31dafa96cc533c21e443d57f5b1ecc6cdc668ec7f \ + --hash=sha256:3c6b973f22eb18a789b1460b4b91bf04ae3f0c4234a0a6aa6b0a92f6f7b951d4 \ + --hash=sha256:3e53af139f8579a6d5f7b76549125f0d94d7e630761a2111bc431fd820e163b8 \ + --hash=sha256:4096e9de5c6fdf43fb4f04c26fb114f61ef0bf2e5604b6ee3019d51b69e8c371 \ + --hash=sha256:4275d846e41ecefa46e2015117a9f491e57a71ddd59bbead77e904dc02b1bed2 \ + --hash=sha256:4c31f53cdae6ecfa91a77820e8b151dba54ab528ba65dfd235c80b086d68a465 \ + --hash=sha256:4f11aa001c540f62c6166c7726f71f7573b52c68c31f014c25cc7901deea0b52 \ + --hash=sha256:5049256f536511ee3f7e1b3f87d1d1209d327e818e6ae1365e8653d7e3abb6a6 \ + --hash=sha256:58c98fee265677f63a4385256a6d7683ab1832f3ddd1e66fe948d5880c21a169 \ + --hash=sha256:598e3276b64aff0e7b3451b72e94fa3c238d452e7ddcd893c3ab324717456bad \ + --hash=sha256:5b7b716f97b52c5a14bffdf688f971b2d5ef4029127f1ad7a513973cfd818df2 \ + --hash=sha256:5dedb4db619ba5a2787a94d877bc8ffc0566f92a01c0ef214865e54ecc9ee5e0 \ + --hash=sha256:619bc166c4f2de5caa5a633b8b7326fbe98e0ccbfacabd87268a2b15ff73a029 \ + --hash=sha256:629ddd2ca402ae6dbedfceeba9c46d5f7b2a61d9749597d4307f943ef198fc1f \ + --hash=sha256:656f7526c69fac7f600bd1f400991cc282b417d17539a1b228617081106feb4a \ + --hash=sha256:6ec585f69cec0aa07d945b20805be741395e28ac1627333b1c5b0105962ffced \ + --hash=sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5 \ + --hash=sha256:7502934a33b54030eaf1194c21c692a534196063db72176b0c4028e140f8f32c \ + --hash=sha256:7a68b554d356a91cce1236aa7682dc01df0edba8d043fd1ce607c49dd3c1edcf \ + --hash=sha256:7b2e5a267c855eea6b4283940daa6e88a285f5f2a67f2220203786dfa59b37e9 \ + --hash=sha256:823b65d8706e32ad2df51ed89496147a42a2a6e01c13cfb6ffb8b1e92bc910bb \ + --hash=sha256:8590b4ae07a35970728874632fed7bd57b26b0102df2d2b233b6d9d82f6c62ad \ + --hash=sha256:8dd717634f5a044f860435c1d8c16a270ddf0ef8588d4887037c5028b859b0c3 \ + --hash=sha256:8dec4936e9c3100156f8a2dc89c4b88d5c435175ff03413b443469c7c8c5f4d1 \ + --hash=sha256:97cafb1f3cbcd3fd2b6fbfb99ae11cdb14deea0736fc2b0952ee177f2b813a46 \ + --hash=sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc \ + --hash=sha256:a549b9c31bec33820e885335b451286e2969a2d9e24879f83fe904a5ce59d70a \ + --hash=sha256:ac07bad82163452a6884fe8fa0963fb98c2346ba78d779ec06bd7a6262132aee \ + --hash=sha256:ae2ad8ae6ebee9d2d94b17fb62763125f3f374c25618198f40cbb8b525411900 \ + --hash=sha256:b91c037585eba9095565a3556f611e3cbfaa42ca1e865f7b8015fe5c7336d5a5 \ + --hash=sha256:bc1667f8b83f48511b94671e0e441401371dfd0f0a795c7daa4a3cd1dde55bea \ + --hash=sha256:bec0a414d016ac1a18862a519e54b2fd0fc8bbfd6890376898a6c0891dd82e9f \ + --hash=sha256:bf50cd79a75d181c9181df03572cdce0fbb75cc353bc350712073108cba98de5 \ + --hash=sha256:bff1b4290a66b490a2f4719358c0cdcd9bafb6b8f061e45c7a2460866bf50c2e \ + --hash=sha256:c061bb86a71b42465156a3ee7bd58c8c2ceacdbeb95d05a99893e08b8467359a \ + --hash=sha256:c8b29db45f8fe46ad280a7294f5c3ec36dbac9491f2d1c17345be8e69cc5928f \ + --hash=sha256:ce409136744f6521e39fd8e2a24c53fa18ad67aa5bc7c2cf83645cce5b5c4e50 \ + --hash=sha256:d050b3361367a06d752db6ead6e7edeb0009be66bc3bae0ee9d97fb326badc2a \ + --hash=sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b \ + --hash=sha256:d9fad5155d72433c921b782e58892377c44bd6252b5af2f67f16b194987338a4 \ + --hash=sha256:daa4ee5a243f0f20d528d939d06670a298dd39b1ad5f8a72a4275124a7819eff \ + --hash=sha256:db0b55e0f3cc0be60c1f19efdde9a637c32740486004f20d1cff53c3c0ece4d2 \ + --hash=sha256:e61659ba32cf2cf1481e575d0462554625196a1f2fc06a1c777d3f48e8865d46 \ + --hash=sha256:ea3d8a3d18833cf4304cd2fc9cbb1efe188ca9b5efef2bdac7adc20594a0e46b \ + --hash=sha256:ec6a563cff360b50eed26f13adc43e61bc0c04d94b8be985e6fb24b81f6dcfdf \ + --hash=sha256:f5dfb42c4604dddc8e4305050aa6deb084540643ed5804d7455b5df8fe16f5e5 \ + --hash=sha256:fa173ec60341d6bb97a89f5ea19c85c5643c1e7dedebc22f5181eb73573142c5 \ + --hash=sha256:fa9db3f79de01457b03d4f01b34cf91bc0048eb2c3846ff26f66687c2f6d16ab \ + --hash=sha256:fce659a462a1be54d2ffcacea5e3ba2d74daa74f30f5f143fe0c58636e355fdd \ + --hash=sha256:ffee1f21e5ef0d712f9033568f8344d5da8cc2869dbd08d87c84656e6a2d2f68 + # via werkzeug +mdurl==0.1.2 \ + --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ + --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba + # via markdown-it-py +ml-dtypes==0.3.2 \ + --hash=sha256:2c34f2ba9660b21fe1034b608308a01be82bbef2a92fb8199f24dc6bad0d5226 \ + --hash=sha256:3a17ef2322e60858d93584e9c52a5be7dd6236b056b7fa1ec57f1bb6ba043e33 \ + --hash=sha256:533059bc5f1764fac071ef54598db358c167c51a718f68f5bb55e3dee79d2967 \ + --hash=sha256:6604877d567a29bfe7cc02969ae0f2425260e5335505cf5e7fefc3e5465f5655 \ + --hash=sha256:6b35c4e8ca957c877ac35c79ffa77724ecc3702a1e4b18b08306c03feae597bb \ + --hash=sha256:763697ab8a88d47443997a7cdf3aac7340049aed45f7521f6b0ec8a0594821fe \ + --hash=sha256:7a4c3fcbf86fa52d0204f07cfd23947ef05b4ad743a1a988e163caa34a201e5e \ + --hash=sha256:7afde548890a92b41c0fed3a6c525f1200a5727205f73dc21181a2726571bb53 \ + --hash=sha256:7ba8e1fafc7fff3e643f453bffa7d082df1678a73286ce8187d3e825e776eb94 \ + --hash=sha256:91f8783fd1f2c23fd3b9ee5ad66b785dafa58ba3cdb050c4458021fa4d1eb226 \ + --hash=sha256:93b78f53431c93953f7850bb1b925a17f0ab5d97527e38a7e865b5b4bc5cfc18 \ + --hash=sha256:961134ea44c7b8ca63eda902a44b58cd8bd670e21d62e255c81fba0a8e70d9b7 \ + --hash=sha256:b89b194e9501a92d289c1ffd411380baf5daafb9818109a4f49b0a1b6dce4462 \ + --hash=sha256:c7b3fb3d4f6b39bcd4f6c4b98f406291f0d681a895490ee29a0f95bab850d53c \ + --hash=sha256:d1a746fe5fb9cd974a91070174258f0be129c592b93f9ce7df6cc336416c3fbd \ + --hash=sha256:e8505946df1665db01332d885c2020b4cb9e84a8b1241eb4ba69d59591f65855 \ + --hash=sha256:f47619d978ab1ae7dfdc4052ea97c636c6263e1f19bd1be0e42c346b98d15ff4 + # via + # keras + # tensorflow-cpu +mypy-extensions==1.0.0 \ + --hash=sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d \ + --hash=sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782 + # via black +namex==0.0.8 \ + --hash=sha256:32a50f6c565c0bb10aa76298c959507abdc0e850efe085dc38f3440fcb3aa90b \ + --hash=sha256:7ddb6c2bb0e753a311b7590f84f6da659dd0c05e65cb89d519d54c0a250c0487 + # via keras +numpy==1.26.4 \ + --hash=sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b \ + --hash=sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818 \ + --hash=sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20 \ + --hash=sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0 \ + --hash=sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010 \ + --hash=sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a \ + --hash=sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea \ + --hash=sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c \ + --hash=sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71 \ + --hash=sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110 \ + --hash=sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be \ + --hash=sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a \ + --hash=sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a \ + --hash=sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5 \ + --hash=sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed \ + --hash=sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd \ + --hash=sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c \ + --hash=sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e \ + --hash=sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0 \ + --hash=sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c \ + --hash=sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a \ + --hash=sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b \ + --hash=sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0 \ + --hash=sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6 \ + --hash=sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2 \ + --hash=sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a \ + --hash=sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30 \ + --hash=sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218 \ + --hash=sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5 \ + --hash=sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07 \ + --hash=sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2 \ + --hash=sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4 \ + --hash=sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764 \ + --hash=sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef \ + --hash=sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3 \ + --hash=sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f + # via + # h5py + # keras + # ml-dtypes + # opt-einsum + # tensorboard + # tensorflow-cpu +opt-einsum==3.3.0 \ + --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ + --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 + # via tensorflow-cpu +optree==0.11.0 \ + --hash=sha256:00a63f10d4a476e8e9aa2988daba9b2e88cb369c5aacc12545957d7d00bcd1a7 \ + --hash=sha256:0db6968394096223881053dffdcaf2b8e220fd85db904f14aa931e4dc422c046 \ + --hash=sha256:0df9a3923725aabb112ec7f10c74fa96b6c640da1cd30e7bc62fd4b03ef02875 \ + --hash=sha256:162ed3ff2eb3f1c358e131e72c025f2b93d69b906e9057a811d014032ec71dc8 \ + --hash=sha256:228b97e8c991739b10c8548c118747ba32ee765f88236342e492bf9648afc0bc \ + --hash=sha256:234a4f8f97a1217f13390df7ac416771689749d9a1c8eda31bf8622cd333219e \ + --hash=sha256:26b1230f9b75b579923a4f837c7c13db8b8d815cf68ce5af31dda5d818a877b2 \ + --hash=sha256:2b3bb59324d635f2015bb3e237fd772b1fd548eee6cc80e008fbe0f092e9228d \ + --hash=sha256:2bc08fb9691f43afc3a01119dead6b823ce3d7239e42fc3e47d4028eed50a6a2 \ + --hash=sha256:31d444684ebd8c9f09a3d806fb3277843138ef9952b7a2954908e440e3b22519 \ + --hash=sha256:39bed744a61e2f795e172d2853779ac59b8dea236982dc160ea22063afc99ca3 \ + --hash=sha256:3cdc9fac9888d9eff11128ccfc4d4c10309163e372f312f7942ecee8df3d7824 \ + --hash=sha256:4144126dd3c2ece2d2dd1d5e0b39fb91adf1c46f660c2c5a2df7f80666989d5d \ + --hash=sha256:418850ceff364f51a6d81f32a1efd06a4e2d8df79a162e892685bc20c0aedd72 \ + --hash=sha256:5e250144eacdd5813dec0b18d91df0229197e3be402db42fd8e254ec90ea343d \ + --hash=sha256:5e5df0e8aaca124cc1ffca311786cc909810f3c046de090729cdafbf910082f8 \ + --hash=sha256:63e020a34b7168b5d0701a265c7c95b07984ff699d4894b20fa601282be88f20 \ + --hash=sha256:64c2e00fe508f50a42c50838df0d1f5be0dce5b4bef2373db8ad72b860211015 \ + --hash=sha256:6a406eee5acd3fd4875fa44c3972d29ae6d4329e7296e9219986fe6ff8e92ea0 \ + --hash=sha256:6cdd625dab2dff5374ff9c6792e8702fced8f0ea713ce959fc8f95499b5ecb2f \ + --hash=sha256:6e8c3757088cd7fce666f2a5e031b65d7898e210452380d2657c0fc0a7ec9932 \ + --hash=sha256:738e8bf4158e9c11cd051d89c2e453aeacf80ff8719ebc3251069015646554d0 \ + --hash=sha256:8e6a46e95c3ea8546055087d6fe52a1dcd56de5182365f1469106cc72cdf3307 \ + --hash=sha256:979ffc2b96f16595c219fb7a89597dd2fa00ac47a3b411fdcf8ae6821da52290 \ + --hash=sha256:9bf322ad14f907ad4660ca286e731e750546d54934a94cc5ba7efe8860c60ab4 \ + --hash=sha256:9d9d644e5448db9f32e2497487aca3bb2d3f92cbb50429a411ccda3f1f0968f3 \ + --hash=sha256:a5f37bcfe4e363e3bb8d36c5698fb829546956b2fe88951994387162a1859625 \ + --hash=sha256:a64df43fce2d8eeafd7db6e27447c56b3fa64842df847819684b3b1cc254c016 \ + --hash=sha256:a91840f9d45e7c01f151ba1815ae32b4c3c21e4290298772ee4b13314f729856 \ + --hash=sha256:b201a9405e250cf5770955863af2a236e382bdf5e4e086897ff03c41418c39da \ + --hash=sha256:b26ac807d8993b7e43081b4b7bbb0378b4e5f3e6525daf923c470bc176cc3327 \ + --hash=sha256:b8126d81ecb2c9e3554420834014ba343251f564c905ee3bef09d205b924b0c0 \ + --hash=sha256:b9d236bc1491a5e366921b95fecc05aa6ff55989a81f2242cd11121b82c24503 \ + --hash=sha256:bc17f9d085cd75a2de4f299a9c5e3c3520138eac7596061e581230b03862b44d \ + --hash=sha256:d666099a78f7bf31bf3a520d6871ddcae65484bcff095fc4271a391553b09c75 \ + --hash=sha256:e2d47bd28eff690eb2f7432e490265a291b04d6d346cf7b586491b2e2337bf97 \ + --hash=sha256:ee208f0bec6436085a9fa3ae98af54bfcb8822086894fc1ade283e80a6f11fd7 \ + --hash=sha256:f53951bfb640417558568284a8949d67bcdbf21fa0113107e20bd9403aa20b2b \ + --hash=sha256:fa9ed745d4cbac5e15df70339b30867ba033542b87f7b734f4cacae5ec73ba00 + # via keras +packaging==24.0 \ + --hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \ + --hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9 + # via + # black + # tensorflow-cpu +pathspec==0.12.1 \ + --hash=sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08 \ + --hash=sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712 + # via black +platformdirs==4.2.1 \ + --hash=sha256:031cd18d4ec63ec53e82dceaac0417d218a6863f7745dfcc9efe7793b7039bdf \ + --hash=sha256:17d5a1161b3fd67b390023cb2d3b026bbd40abde6fdb052dfbd3a29c3ba22ee1 + # via black +protobuf==4.25.3 \ + --hash=sha256:19b270aeaa0099f16d3ca02628546b8baefe2955bbe23224aaf856134eccf1e4 \ + --hash=sha256:209ba4cc916bab46f64e56b85b090607a676f66b473e6b762e6f1d9d591eb2e8 \ + --hash=sha256:25b5d0b42fd000320bd7830b349e3b696435f3b329810427a6bcce6a5492cc5c \ + --hash=sha256:7c8daa26095f82482307bc717364e7c13f4f1c99659be82890dcfc215194554d \ + --hash=sha256:c053062984e61144385022e53678fbded7aea14ebb3e0305ae3592fb219ccfa4 \ + --hash=sha256:d4198877797a83cbfe9bffa3803602bbe1625dc30d8a097365dbc762e5790faa \ + --hash=sha256:e3c97a1555fd6388f857770ff8b9703083de6bf1f9274a002a332d65fbb56c8c \ + --hash=sha256:e7cb0ae90dd83727f0c0718634ed56837bfeeee29a5f82a7514c03ee1364c019 \ + --hash=sha256:f0700d54bcf45424477e46a9f0944155b46fb0639d69728739c0e47bab83f2b9 \ + --hash=sha256:f1279ab38ecbfae7e456a108c5c0681e4956d5b1090027c1de0f934dfdb4b35c \ + --hash=sha256:f4f118245c4a087776e0a8408be33cf09f6c547442c00395fbfb116fac2f8ac2 + # via + # tensorboard + # tensorflow-cpu +pygments==2.17.2 \ + --hash=sha256:b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c \ + --hash=sha256:da46cec9fd2de5be3a8a784f434e4c4ab670b4ff54d605c4c2717e9d49c4c367 + # via rich +requests==2.31.0 \ + --hash=sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f \ + --hash=sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1 + # via tensorflow-cpu +rich==13.7.1 \ + --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ + --hash=sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432 + # via keras +six==1.16.0 \ + --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ + --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 + # via + # astunparse + # google-pasta + # tensorboard + # tensorflow-cpu +tensorboard==2.16.2 \ + --hash=sha256:9f2b4e7dad86667615c0e5cd072f1ea8403fc032a299f0072d6f74855775cc45 + # via tensorflow-cpu +tensorboard-data-server==0.7.2 \ + --hash=sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb \ + --hash=sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60 \ + --hash=sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530 + # via tensorboard +tensorflow-cpu==2.16.1 \ + --hash=sha256:050f550a8a1aa77959826fd642024d527699a817cbf3e16c59773981c1fae0a2 \ + --hash=sha256:10daa2bda40c85f7b0ed8d036e6c0394fe24e1806bec0835b5331f8a451e182d \ + --hash=sha256:1ca39bd2f4e28c78f86f744e2cd751d317e42b3b2f8454a9bef1e21aa15e7775 \ + --hash=sha256:282503444a5a61d330fb0a522a56d4c79a241941eb0a074916dfa37a10285e69 \ + --hash=sha256:32190a26ef5a4cc259926e5ed5e3c8c94cf47b7b04bdb18b3e54ec7769673ebd \ + --hash=sha256:3c79d15e51aab9d9cbeb1b4dc13f2c83a80e63540162afc1ee8c66b89ceb123a \ + --hash=sha256:45b31b5258726fbd2c9f2422415beb6fe737d8fe63f1c461e3648dad5c088348 \ + --hash=sha256:67bb51840057ba3a2f46ca6d5cee738974c80b44c9e94df39a55c558d392fc46 \ + --hash=sha256:8fde4a1a1515f3099119c2d71e1653aa6e7ae81ec58b7cb045cab5e5bb147b8b \ + --hash=sha256:98843927dcdc8ab1e4cc20ca0998c69d8623e6c5c779f4de4c82be613de37abd \ + --hash=sha256:aea6520308b1f15511e69bd40b52ba9478143e6e1e8e49d57cd36410321b7b6f \ + --hash=sha256:f7136781cfd6818b2fe74ffea4c585b020c9140652507cbce1558169f2058b58 + # via -r requirements.in +termcolor==2.4.0 \ + --hash=sha256:9297c0df9c99445c2412e832e882a7884038a25617c60cea2ad69488d4040d63 \ + --hash=sha256:aab9e56047c8ac41ed798fa36d892a37aca6b3e9159f3e0c24bc64a9b3ac7b7a + # via tensorflow-cpu +typing-extensions==4.11.0 \ + --hash=sha256:83f085bd5ca59c80295fc2a82ab5dac679cbe02b9f33f7d83af68e241bea51b0 \ + --hash=sha256:c1f94d72897edaf4ce775bb7558d5b79d8126906a14ea5ed1635921406c0387a + # via + # optree + # tensorflow-cpu +urllib3==2.2.1 \ + --hash=sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d \ + --hash=sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19 + # via requests +werkzeug==3.0.2 \ + --hash=sha256:3aac3f5da756f93030740bc235d3e09449efcf65f2f55e3602e1d851b8f48795 \ + --hash=sha256:e39b645a6ac92822588e7b39a692e7828724ceae0b0d702ef96701f90e70128d + # via tensorboard +wheel==0.43.0 \ + --hash=sha256:465ef92c69fa5c5da2d1cf8ac40559a8c940886afcef87dcf14b9470862f1d85 \ + --hash=sha256:55c570405f142630c6b9f72fe09d9b67cf1477fcf543ae5b8dcb1f5b7377da81 + # via astunparse +wrapt==1.16.0 \ + --hash=sha256:0d2691979e93d06a95a26257adb7bfd0c93818e89b1406f5a28f36e0d8c1e1fc \ + --hash=sha256:14d7dc606219cdd7405133c713f2c218d4252f2a469003f8c46bb92d5d095d81 \ + --hash=sha256:1a5db485fe2de4403f13fafdc231b0dbae5eca4359232d2efc79025527375b09 \ + --hash=sha256:1acd723ee2a8826f3d53910255643e33673e1d11db84ce5880675954183ec47e \ + --hash=sha256:1ca9b6085e4f866bd584fb135a041bfc32cab916e69f714a7d1d397f8c4891ca \ + --hash=sha256:1dd50a2696ff89f57bd8847647a1c363b687d3d796dc30d4dd4a9d1689a706f0 \ + --hash=sha256:2076fad65c6736184e77d7d4729b63a6d1ae0b70da4868adeec40989858eb3fb \ + --hash=sha256:2a88e6010048489cda82b1326889ec075a8c856c2e6a256072b28eaee3ccf487 \ + --hash=sha256:3ebf019be5c09d400cf7b024aa52b1f3aeebeff51550d007e92c3c1c4afc2a40 \ + --hash=sha256:418abb18146475c310d7a6dc71143d6f7adec5b004ac9ce08dc7a34e2babdc5c \ + --hash=sha256:43aa59eadec7890d9958748db829df269f0368521ba6dc68cc172d5d03ed8060 \ + --hash=sha256:44a2754372e32ab315734c6c73b24351d06e77ffff6ae27d2ecf14cf3d229202 \ + --hash=sha256:490b0ee15c1a55be9c1bd8609b8cecd60e325f0575fc98f50058eae366e01f41 \ + --hash=sha256:49aac49dc4782cb04f58986e81ea0b4768e4ff197b57324dcbd7699c5dfb40b9 \ + --hash=sha256:5eb404d89131ec9b4f748fa5cfb5346802e5ee8836f57d516576e61f304f3b7b \ + --hash=sha256:5f15814a33e42b04e3de432e573aa557f9f0f56458745c2074952f564c50e664 \ + --hash=sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d \ + --hash=sha256:66027d667efe95cc4fa945af59f92c5a02c6f5bb6012bff9e60542c74c75c362 \ + --hash=sha256:66dfbaa7cfa3eb707bbfcd46dab2bc6207b005cbc9caa2199bcbc81d95071a00 \ + --hash=sha256:685f568fa5e627e93f3b52fda002c7ed2fa1800b50ce51f6ed1d572d8ab3e7fc \ + --hash=sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1 \ + --hash=sha256:6a42cd0cfa8ffc1915aef79cb4284f6383d8a3e9dcca70c445dcfdd639d51267 \ + --hash=sha256:6dcfcffe73710be01d90cae08c3e548d90932d37b39ef83969ae135d36ef3956 \ + --hash=sha256:6f6eac2360f2d543cc875a0e5efd413b6cbd483cb3ad7ebf888884a6e0d2e966 \ + --hash=sha256:72554a23c78a8e7aa02abbd699d129eead8b147a23c56e08d08dfc29cfdddca1 \ + --hash=sha256:73870c364c11f03ed072dda68ff7aea6d2a3a5c3fe250d917a429c7432e15228 \ + --hash=sha256:73aa7d98215d39b8455f103de64391cb79dfcad601701a3aa0dddacf74911d72 \ + --hash=sha256:75ea7d0ee2a15733684badb16de6794894ed9c55aa5e9903260922f0482e687d \ + --hash=sha256:7bd2d7ff69a2cac767fbf7a2b206add2e9a210e57947dd7ce03e25d03d2de292 \ + --hash=sha256:807cc8543a477ab7422f1120a217054f958a66ef7314f76dd9e77d3f02cdccd0 \ + --hash=sha256:8e9723528b9f787dc59168369e42ae1c3b0d3fadb2f1a71de14531d321ee05b0 \ + --hash=sha256:9090c9e676d5236a6948330e83cb89969f433b1943a558968f659ead07cb3b36 \ + --hash=sha256:9153ed35fc5e4fa3b2fe97bddaa7cbec0ed22412b85bcdaf54aeba92ea37428c \ + --hash=sha256:9159485323798c8dc530a224bd3ffcf76659319ccc7bbd52e01e73bd0241a0c5 \ + --hash=sha256:941988b89b4fd6b41c3f0bfb20e92bd23746579736b7343283297c4c8cbae68f \ + --hash=sha256:94265b00870aa407bd0cbcfd536f17ecde43b94fb8d228560a1e9d3041462d73 \ + --hash=sha256:98b5e1f498a8ca1858a1cdbffb023bfd954da4e3fa2c0cb5853d40014557248b \ + --hash=sha256:9b201ae332c3637a42f02d1045e1d0cccfdc41f1f2f801dafbaa7e9b4797bfc2 \ + --hash=sha256:a0ea261ce52b5952bf669684a251a66df239ec6d441ccb59ec7afa882265d593 \ + --hash=sha256:a33a747400b94b6d6b8a165e4480264a64a78c8a4c734b62136062e9a248dd39 \ + --hash=sha256:a452f9ca3e3267cd4d0fcf2edd0d035b1934ac2bd7e0e57ac91ad6b95c0c6389 \ + --hash=sha256:a86373cf37cd7764f2201b76496aba58a52e76dedfaa698ef9e9688bfd9e41cf \ + --hash=sha256:ac83a914ebaf589b69f7d0a1277602ff494e21f4c2f743313414378f8f50a4cf \ + --hash=sha256:aefbc4cb0a54f91af643660a0a150ce2c090d3652cf4052a5397fb2de549cd89 \ + --hash=sha256:b3646eefa23daeba62643a58aac816945cadc0afaf21800a1421eeba5f6cfb9c \ + --hash=sha256:b47cfad9e9bbbed2339081f4e346c93ecd7ab504299403320bf85f7f85c7d46c \ + --hash=sha256:b935ae30c6e7400022b50f8d359c03ed233d45b725cfdd299462f41ee5ffba6f \ + --hash=sha256:bb2dee3874a500de01c93d5c71415fcaef1d858370d405824783e7a8ef5db440 \ + --hash=sha256:bc57efac2da352a51cc4658878a68d2b1b67dbe9d33c36cb826ca449d80a8465 \ + --hash=sha256:bf5703fdeb350e36885f2875d853ce13172ae281c56e509f4e6eca049bdfb136 \ + --hash=sha256:c31f72b1b6624c9d863fc095da460802f43a7c6868c5dda140f51da24fd47d7b \ + --hash=sha256:c5cd603b575ebceca7da5a3a251e69561bec509e0b46e4993e1cac402b7247b8 \ + --hash=sha256:d2efee35b4b0a347e0d99d28e884dfd82797852d62fcd7ebdeee26f3ceb72cf3 \ + --hash=sha256:d462f28826f4657968ae51d2181a074dfe03c200d6131690b7d65d55b0f360f8 \ + --hash=sha256:d5e49454f19ef621089e204f862388d29e6e8d8b162efce05208913dde5b9ad6 \ + --hash=sha256:da4813f751142436b075ed7aa012a8778aa43a99f7b36afe9b742d3ed8bdc95e \ + --hash=sha256:db2e408d983b0e61e238cf579c09ef7020560441906ca990fe8412153e3b291f \ + --hash=sha256:db98ad84a55eb09b3c32a96c576476777e87c520a34e2519d3e59c44710c002c \ + --hash=sha256:dbed418ba5c3dce92619656802cc5355cb679e58d0d89b50f116e4a9d5a9603e \ + --hash=sha256:dcdba5c86e368442528f7060039eda390cc4091bfd1dca41e8046af7c910dda8 \ + --hash=sha256:decbfa2f618fa8ed81c95ee18a387ff973143c656ef800c9f24fb7e9c16054e2 \ + --hash=sha256:e4fdb9275308292e880dcbeb12546df7f3e0f96c6b41197e0cf37d2826359020 \ + --hash=sha256:eb1b046be06b0fce7249f1d025cd359b4b80fc1c3e24ad9eca33e0dcdb2e4a35 \ + --hash=sha256:eb6e651000a19c96f452c85132811d25e9264d836951022d6e81df2fff38337d \ + --hash=sha256:ed867c42c268f876097248e05b6117a65bcd1e63b779e916fe2e33cd6fd0d3c3 \ + --hash=sha256:edfad1d29c73f9b863ebe7082ae9321374ccb10879eeabc84ba3b69f2579d537 \ + --hash=sha256:f2058f813d4f2b5e3a9eb2eb3faf8f1d99b81c3e51aeda4b168406443e8ba809 \ + --hash=sha256:f6b2d0c6703c988d334f297aa5df18c45e97b0af3679bb75059e0e0bd8b1069d \ + --hash=sha256:f8212564d49c50eb4565e502814f694e240c55551a5f1bc841d4fcaabb0a9b8a \ + --hash=sha256:ffa565331890b90056c01db69c0fe634a776f8019c143a5ae265f9c6bc4bd6d4 + # via tensorflow-cpu + +# The following packages are considered to be unsafe in a requirements file: +setuptools==69.5.1 \ + --hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \ + --hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32 + # via + # tensorboard + # tensorflow-cpu From 418c823bb15542a54437a6abd00741f4a8f14e05 Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Thu, 2 May 2024 22:30:55 +0000 Subject: [PATCH 04/22] Op kernels are stateless and support shape inference. --- tf_shell/cc/kernels/mul_kernels.cc | 64 +++---- tf_shell/cc/kernels/polynomial_kernels.cc | 20 +-- tf_shell/cc/kernels/shape_kernels.cc | 17 +- tf_shell/cc/ops/shell_ops.cc | 205 +++++++++++++++++----- tf_shell/python/shell_tensor.py | 2 + 5 files changed, 215 insertions(+), 93 deletions(-) diff --git a/tf_shell/cc/kernels/mul_kernels.cc b/tf_shell/cc/kernels/mul_kernels.cc index b70db59..f3b110b 100644 --- a/tf_shell/cc/kernels/mul_kernels.cc +++ b/tf_shell/cc/kernels/mul_kernels.cc @@ -633,55 +633,55 @@ REGISTER_KERNEL_BUILDER(Name("MulCtPt64").Device(DEVICE_CPU), // Multiply plaintext or ciphertext by plaintext scalar. REGISTER_KERNEL_BUILDER( - Name("MulPtTfScalar64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MulPtTfScalar64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MulShellTfScalarOp>); REGISTER_KERNEL_BUILDER( - Name("MulPtTfScalar64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MulPtTfScalar64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MulShellTfScalarOp>); REGISTER_KERNEL_BUILDER( - Name("MulCtTfScalar64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MulCtTfScalar64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MulShellTfScalarOp>); REGISTER_KERNEL_BUILDER( - Name("MulCtTfScalar64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MulCtTfScalar64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MulShellTfScalarOp>); REGISTER_KERNEL_BUILDER( - Name("MulPtTfScalar64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MulPtTfScalar64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MulShellTfScalarOp>); REGISTER_KERNEL_BUILDER( - Name("MulPtTfScalar64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MulPtTfScalar64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MulShellTfScalarOp>); REGISTER_KERNEL_BUILDER( - Name("MulCtTfScalar64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MulCtTfScalar64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MulShellTfScalarOp>); REGISTER_KERNEL_BUILDER( - Name("MulCtTfScalar64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MulCtTfScalar64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MulShellTfScalarOp>); REGISTER_KERNEL_BUILDER( - Name("MulPtTfScalar64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MulPtTfScalar64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MulShellTfScalarOp>); REGISTER_KERNEL_BUILDER( - Name("MulPtTfScalar64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MulPtTfScalar64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MulShellTfScalarOp>); REGISTER_KERNEL_BUILDER( - Name("MulCtTfScalar64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MulCtTfScalar64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MulShellTfScalarOp>); REGISTER_KERNEL_BUILDER( - Name("MulCtTfScalar64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MulCtTfScalar64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MulShellTfScalarOp>); REGISTER_KERNEL_BUILDER( - Name("MulPtTfScalar64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MulPtTfScalar64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MulShellTfScalarOp>); REGISTER_KERNEL_BUILDER( - Name("MulPtTfScalar64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MulPtTfScalar64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MulShellTfScalarOp>); REGISTER_KERNEL_BUILDER( - Name("MulCtTfScalar64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MulCtTfScalar64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MulShellTfScalarOp>); REGISTER_KERNEL_BUILDER( - Name("MulCtTfScalar64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MulCtTfScalar64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MulShellTfScalarOp>); // Multiply plaintext by plaintext. @@ -690,58 +690,58 @@ REGISTER_KERNEL_BUILDER(Name("MulPtPt64").Device(DEVICE_CPU), // Matrix multiply ciphertext and plaintext. REGISTER_KERNEL_BUILDER( - Name("MatMulCtPt64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MatMulCtPt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MatMulCtPtOp); REGISTER_KERNEL_BUILDER( - Name("MatMulCtPt64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MatMulCtPt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MatMulCtPtOp); REGISTER_KERNEL_BUILDER( - Name("MatMulCtPt64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MatMulCtPt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MatMulCtPtOp); REGISTER_KERNEL_BUILDER( - Name("MatMulCtPt64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MatMulCtPt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MatMulCtPtOp); REGISTER_KERNEL_BUILDER( - Name("MatMulCtPt64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MatMulCtPt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MatMulCtPtOp); REGISTER_KERNEL_BUILDER( - Name("MatMulCtPt64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MatMulCtPt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MatMulCtPtOp); REGISTER_KERNEL_BUILDER( - Name("MatMulCtPt64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MatMulCtPt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MatMulCtPtOp); REGISTER_KERNEL_BUILDER( - Name("MatMulCtPt64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MatMulCtPt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MatMulCtPtOp); // Matrix multiply plaintext and ciphertext. REGISTER_KERNEL_BUILDER( - Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MatMulPtCtOp); REGISTER_KERNEL_BUILDER( - Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MatMulPtCtOp); REGISTER_KERNEL_BUILDER( - Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MatMulPtCtOp); REGISTER_KERNEL_BUILDER( - Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MatMulPtCtOp); REGISTER_KERNEL_BUILDER( - Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MatMulPtCtOp); REGISTER_KERNEL_BUILDER( - Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MatMulPtCtOp); REGISTER_KERNEL_BUILDER( - Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MatMulPtCtOp); REGISTER_KERNEL_BUILDER( - Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("MatMulPtCt64").Device(DEVICE_CPU).TypeConstraint("Dtype"), MatMulPtCtOp); \ No newline at end of file diff --git a/tf_shell/cc/kernels/polynomial_kernels.cc b/tf_shell/cc/kernels/polynomial_kernels.cc index ed9f014..e90e371 100644 --- a/tf_shell/cc/kernels/polynomial_kernels.cc +++ b/tf_shell/cc/kernels/polynomial_kernels.cc @@ -243,46 +243,46 @@ class PolynomialExportOp : public OpKernel { // Import ops. REGISTER_KERNEL_BUILDER(Name("PolynomialImport64") .Device(DEVICE_CPU) - .TypeConstraint("dtype"), + .TypeConstraint("Dtype"), PolynomialImportOp); REGISTER_KERNEL_BUILDER( - Name("PolynomialImport64").Device(DEVICE_CPU).TypeConstraint("dtype"), + Name("PolynomialImport64").Device(DEVICE_CPU).TypeConstraint("Dtype"), PolynomialImportOp); REGISTER_KERNEL_BUILDER(Name("PolynomialImport64") .Device(DEVICE_CPU) - .TypeConstraint("dtype"), + .TypeConstraint("Dtype"), PolynomialImportOp); REGISTER_KERNEL_BUILDER(Name("PolynomialImport64") .Device(DEVICE_CPU) - .TypeConstraint("dtype"), + .TypeConstraint("Dtype"), PolynomialImportOp); REGISTER_KERNEL_BUILDER(Name("PolynomialImport64") .Device(DEVICE_CPU) - .TypeConstraint("dtype"), + .TypeConstraint("Dtype"), PolynomialImportOp); REGISTER_KERNEL_BUILDER(Name("PolynomialImport64") .Device(DEVICE_CPU) - .TypeConstraint("dtype"), + .TypeConstraint("Dtype"), PolynomialImportOp); REGISTER_KERNEL_BUILDER(Name("PolynomialImport64") .Device(DEVICE_CPU) - .TypeConstraint("dtype"), + .TypeConstraint("Dtype"), PolynomialImportOp); REGISTER_KERNEL_BUILDER(Name("PolynomialImport64") .Device(DEVICE_CPU) - .TypeConstraint("dtype"), + .TypeConstraint("Dtype"), PolynomialImportOp); REGISTER_KERNEL_BUILDER(Name("PolynomialImport64") .Device(DEVICE_CPU) - .TypeConstraint("dtype"), + .TypeConstraint("Dtype"), PolynomialImportOp); REGISTER_KERNEL_BUILDER(Name("PolynomialImport64") .Device(DEVICE_CPU) - .TypeConstraint("dtype"), + .TypeConstraint("Dtype"), PolynomialImportOp); // Import ops. diff --git a/tf_shell/cc/kernels/shape_kernels.cc b/tf_shell/cc/kernels/shape_kernels.cc index 934e92b..4093e5c 100644 --- a/tf_shell/cc/kernels/shape_kernels.cc +++ b/tf_shell/cc/kernels/shape_kernels.cc @@ -36,21 +36,22 @@ using tensorflow::errors::InvalidArgument; // This class is exactly like TensorFlow's ExpandDimsOp, but allows operating // on a tensor with a variant dtype. class ExpandDimsVariantOp : public OpKernel { + private: + int dim; + public: explicit ExpandDimsVariantOp(OpKernelConstruction* op_ctx) - : OpKernel(op_ctx) {} - - void Compute(OpKernelContext* ctx) override { - // OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT, - // InvalidArgument("ExpandDims on Variant not supported")); - - int32 dim = ctx->input(1).flat()(0); + : OpKernel(op_ctx) { + // Get the dimension to expand from the op attributes. + OP_REQUIRES_OK(op_ctx, op_ctx->GetAttr("axis", &dim)); // Recall first dimension of a shell variant tensor is the packing // dimension. We don't allow expanding this dimension. - OP_REQUIRES(ctx, dim != 0, InvalidArgument("Invalid dimension index.")); + OP_REQUIRES(op_ctx, dim != 0, InvalidArgument("Invalid dimension index.")); dim += dim > 0 ? -1 : 0; + } + void Compute(OpKernelContext* ctx) override { OP_REQUIRES( ctx, (dim >= -1 - ctx->input(0).dims() && dim <= ctx->input(0).dims()), InvalidArgument("Tried to expand dim index ", dim, " for tensor with ", diff --git a/tf_shell/cc/ops/shell_ops.cc b/tf_shell/cc/ops/shell_ops.cc index 45bbf3a..ec13826 100644 --- a/tf_shell/cc/ops/shell_ops.cc +++ b/tf_shell/cc/ops/shell_ops.cc @@ -33,17 +33,15 @@ REGISTER_OP("ContextImport64") .Input("noise_variance: uint64") .Input("seed: string") .Output("shell_context: variant") - .SetIsStateful() .SetShapeFn(ScalarShape); REGISTER_OP("PolynomialImport64") .Attr( - "dtype: {uint8, int8, int16, uint16, int32, uint32, int64, uint64, " + "Dtype: {uint8, int8, int16, uint16, int32, uint32, int64, uint64, " "float, double}") .Input("shell_context: variant") - .Input("in: dtype") + .Input("in: Dtype") .Output("val: variant") - .SetIsStateful() .SetShapeFn([](InferenceContext* c) { ShapeHandle output; @@ -58,15 +56,26 @@ REGISTER_OP("PolynomialImport64") // so no SetShapeFn() for this Op. REGISTER_OP("PolynomialExport64") .Attr("dtype: {uint8, int8, uint16, int16, uint32, int32, uint64, int64}") + .Attr("batching_dim: int") .Input("shell_context: variant") .Input("in: variant") .Output("val: dtype") - .SetIsStateful(); + .SetShapeFn([](InferenceContext* c) { + tsl::int32 batching_dim; + TF_RETURN_IF_ERROR(c->GetAttr("batching_dim", &batching_dim)); + ShapeHandle batching_dim_shape = c->MakeShape({batching_dim}); + + ShapeHandle output; + TF_RETURN_IF_ERROR( + c->Concatenate(c->input(1), batching_dim_shape, &output)); + + c->set_output(0, output); + return OkStatus(); + }); REGISTER_OP("KeyGen64") .Input("context: variant") .Output("key: variant") - .SetIsStateful() .SetShapeFn(ScalarShape); REGISTER_OP("Encrypt64") @@ -74,177 +83,287 @@ REGISTER_OP("Encrypt64") .Input("key: variant") .Input("val: variant") .Output("out: variant") - .SetIsStateful() .SetShapeFn([](InferenceContext* c) { c->set_output(0, c->input(2)); return OkStatus(); }); -// Output shape depends on content of shell_context -// so no SetShapeFn() for this Op. REGISTER_OP("Decrypt64") .Attr("dtype: {uint8, int8, uint16, int16, uint32, int32, uint64, int64}") + .Attr("batching_dim: int") .Input("context: variant") .Input("key: variant") .Input("val: variant") .Output("out: dtype") - .SetIsStateful(); + .SetShapeFn([](InferenceContext* c) { + tsl::int32 batching_dim; + TF_RETURN_IF_ERROR(c->GetAttr("batching_dim", &batching_dim)); + ShapeHandle batching_dim_shape = c->MakeShape({batching_dim}); + + ShapeHandle output; + TF_RETURN_IF_ERROR( + c->Concatenate(c->input(1), batching_dim_shape, &output)); + + c->set_output(0, output); + return OkStatus(); + }); // Add and subtract. REGISTER_OP("AddCtCt64") .Input("a: variant") .Input("b: variant") .Output("c: variant") - .SetIsStateful(); + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(0)); + return OkStatus(); + }); REGISTER_OP("AddCtPt64") .Input("a: variant") .Input("b: variant") .Output("c: variant") - .SetIsStateful(); + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(0)); + return OkStatus(); + }); REGISTER_OP("AddPtPt64") .Input("context: variant") .Input("a: variant") .Input("b: variant") .Output("c: variant") - .SetIsStateful(); + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(0)); + return OkStatus(); + }); REGISTER_OP("SubCtCt64") .Input("a: variant") .Input("b: variant") .Output("c: variant") - .SetIsStateful(); + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(0)); + return OkStatus(); + }); REGISTER_OP("SubCtPt64") .Input("a: variant") .Input("b: variant") .Output("c: variant") - .SetIsStateful(); + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(0)); + return OkStatus(); + }); REGISTER_OP("SubPtPt64") .Input("context: variant") .Input("a: variant") .Input("b: variant") .Output("c: variant") - .SetIsStateful(); + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(1)); + return OkStatus(); + }); REGISTER_OP("NegCt64") .Input("value: variant") .Output("negated_value: variant") - .SetIsStateful(); + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(0)); + return OkStatus(); + }); REGISTER_OP("NegPt64") .Input("context: variant") .Input("value: variant") .Output("negated_value: variant") - .SetIsStateful(); + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(0)); + return OkStatus(); + }); // Multiply. REGISTER_OP("MulCtCt64") .Input("a: variant") .Input("b: variant") .Output("c: variant") - .SetIsStateful(); + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(0)); + return OkStatus(); + }); REGISTER_OP("MulCtPt64") .Input("a: variant") .Input("b: variant") .Output("c: variant") - .SetIsStateful(); + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(0)); + return OkStatus(); + }); REGISTER_OP("MulCtTfScalar64") - .Attr("dtype: {uint8, int8, uint16, int16, uint32, int32, uint64, int64}") + .Attr("Dtype: {uint8, int8, uint16, int16, uint32, int32, uint64, int64}") .Input("context: variant") .Input("a: variant") - .Input("b: dtype") + .Input("b: Dtype") .Output("c: variant") - .SetIsStateful(); + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(1)); + return OkStatus(); + }); REGISTER_OP("MulPtTfScalar64") - .Attr("dtype: {uint8, int8, uint16, int16, uint32, int32, uint64, int64}") + .Attr("Dtype: {uint8, int8, uint16, int16, uint32, int32, uint64, int64}") .Input("context: variant") .Input("a: variant") - .Input("b: dtype") + .Input("b: Dtype") .Output("c: variant") - .SetIsStateful(); + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(1)); + return OkStatus(); + }); REGISTER_OP("MulPtPt64") .Input("context: variant") .Input("a: variant") .Input("b: variant") .Output("c: variant") - .SetIsStateful(); + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(0)); + return OkStatus(); + }); REGISTER_OP("MatMulCtPt64") - .Attr("dtype: {uint8, int8, uint16, int16, uint32, int32, uint64, int64}") + .Attr("Dtype: {uint8, int8, uint16, int16, uint32, int32, uint64, int64}") .Input("context: variant") .Input("a: variant") - .Input("b: dtype") + .Input("b: Dtype") .Output("c: variant") - .SetIsStateful(); + .SetShapeFn([](InferenceContext* c) { + // Output has the same shape as the plaintext b outer dim. + ShapeHandle output; + TF_RETURN_IF_ERROR(c->Subshape(c->input(2), 1, &output)); + c->set_output(0, output); + return OkStatus(); + }); REGISTER_OP("MatMulPtCt64") - .Attr("dtype: {uint8, int8, uint16, int16, uint32, int32, uint64, int64}") + .Attr("Dtype: {uint8, int8, uint16, int16, uint32, int32, uint64, int64}") .Input("context: variant") .Input("rotation_key: variant") - .Input("a: dtype") + .Input("a: Dtype") .Input("b: variant") .Output("c: variant") - .SetIsStateful(); + .SetShapeFn([](InferenceContext* c) { + // Output has the same shape as the plaintext b outer dim. + tsl::int32 a_rank = c->Rank(c->input(2)); + ShapeHandle a_shape_prefix; + TF_RETURN_IF_ERROR( + c->Subshape(c->input(2), 0, a_rank - 2, &a_shape_prefix)); + + ShapeHandle output; + TF_RETURN_IF_ERROR(c->Concatenate(a_shape_prefix, c->input(3), &output)); + + c->set_output(0, output); + return OkStatus(); + }); // Rotate. REGISTER_OP("RotationKeyGen64") .Input("context: variant") .Input("key: variant") .Output("rotation_key: variant") - .SetIsStateful(); + .SetShapeFn(ScalarShape); REGISTER_OP("Roll64") .Input("rotation_key: variant") .Input("value: variant") .Input("shift: int64") .Output("rotated_value: variant") - .SetIsStateful(); + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(1)); + return OkStatus(); + }); REGISTER_OP("ReduceSumByRotation64") .Input("value: variant") .Input("rotation_key: variant") .Output("repeated_reduce_sum: variant") - .SetIsStateful(); + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(0)); + return OkStatus(); + }); REGISTER_OP("ReduceSum64") .Input("value: variant") .Input("axis: int64") .Output("repeated_reduce_sum: variant") - .SetIsStateful(); + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(0)); + return OkStatus(); + }); // Modulus switching. REGISTER_OP("ModulusReduceContext64") .Input("context: variant") .Output("reduced_context: variant") - .SetIsStateful(); + .SetShapeFn(ScalarShape); REGISTER_OP("ModulusReduceKey64") .Input("key: variant") .Output("reduced_key: variant") - .SetIsStateful(); + .SetShapeFn(ScalarShape); REGISTER_OP("ModulusReduceCt64") .Input("context: variant") .Input("value: variant") .Output("reduced_value: variant") - .SetIsStateful(); + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(1)); + return OkStatus(); + }); REGISTER_OP("ModulusReducePt64") .Input("context: variant") .Input("value: variant") .Output("reduced_value: variant") - .SetIsStateful(); + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(1)); + return OkStatus(); + }); // Shape kernels. REGISTER_OP("ExpandDimsVariant") .Input("value: variant") - .Input("axis: int32") + .Attr("axis: int") .Output("expanded_value: variant") - .SetIsStateful(); + .SetShapeFn([](InferenceContext* c) { + tsl::int32 rank = c->Rank(c->input(0)); + + tsl::int32 axis; + TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis)); + + //Check that axis is in the correct range. + if (axis < -rank || axis > rank) { + return tensorflow::errors::InvalidArgument( + "axis must be in the range [-rank, rank], got ", axis); + } + + if (axis < 0) { + axis += rank; + } + + ShapeHandle prefix; + TF_RETURN_IF_ERROR(c->Subshape(c->input(0), 0, axis, &prefix)); + + ShapeHandle postfix; + TF_RETURN_IF_ERROR(c->Subshape(c->input(0), axis, rank - 1, &postfix)); + + ShapeHandle output; + ShapeHandle axis_dim = c->MakeShape({1}); + TF_RETURN_IF_ERROR(c->Concatenate(prefix, axis_dim, &output)); + TF_RETURN_IF_ERROR(c->Concatenate(output, postfix, &output)); + + c->set_output(0, output); + return OkStatus(); + }); diff --git a/tf_shell/python/shell_tensor.py b/tf_shell/python/shell_tensor.py index 11102ef..0f71d9c 100644 --- a/tf_shell/python/shell_tensor.py +++ b/tf_shell/python/shell_tensor.py @@ -584,6 +584,7 @@ def to_tensorflow(s_tensor, key=None): key._raw_key, s_tensor._raw, dtype=shell_dtype, + batching_dim=s_tensor._context.num_slots, ) else: @@ -593,6 +594,7 @@ def to_tensorflow(s_tensor, key=None): s_tensor._context._raw_context, s_tensor._raw, dtype=shell_dtype, + batching_dim=s_tensor._context.num_slots, ) # Shell tensor represents floats as integers * scaling_factor. From aa15e167434be8f63c3c59524b6627dc77dc0477 Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Mon, 6 May 2024 14:50:29 +0000 Subject: [PATCH 05/22] Shell types extend tf ExtensionType. --- tf_shell/__init__.py | 3 + tf_shell/python/shell_context.py | 118 ++++----- tf_shell/python/shell_key.py | 57 ++--- tf_shell/python/shell_tensor.py | 398 +++++++++++++++---------------- tf_shell/test/context_test.py | 12 +- tf_shell/test/rotation_test.py | 2 +- 6 files changed, 270 insertions(+), 320 deletions(-) diff --git a/tf_shell/__init__.py b/tf_shell/__init__.py index 7572034..f564bdb 100644 --- a/tf_shell/__init__.py +++ b/tf_shell/__init__.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from tf_shell.python.shell_tensor import ShellTensor64 +from tf_shell.python.shell_tensor import mod_reduce_tensor64 from tf_shell.python.shell_tensor import to_shell_plaintext from tf_shell.python.shell_tensor import to_encrypted from tf_shell.python.shell_tensor import to_tensorflow @@ -27,9 +28,11 @@ from tf_shell.python.shell_context import ShellContext64 from tf_shell.python.shell_context import create_context64 +from tf_shell.python.shell_context import mod_reduce_context64 from tf_shell.python.shell_key import ShellKey64 from tf_shell.python.shell_key import create_key64 +from tf_shell.python.shell_key import mod_reduce_key64 from tf_shell.python.shell_key import ShellRotationKey64 from tf_shell.python.shell_key import create_rotation_key64 diff --git a/tf_shell/python/shell_context.py b/tf_shell/python/shell_context.py index 8231da7..d2ca990 100644 --- a/tf_shell/python/shell_context.py +++ b/tf_shell/python/shell_context.py @@ -14,106 +14,76 @@ # See the License for the specific language governing permissions and # limitations under the License. import tf_shell.python.ops.shell_ops as shell_ops -import math -import random +import tensorflow as tf +import typing + + +class ShellContext64(tf.experimental.ExtensionType): + _raw_context: tf.Tensor + log_n: int + num_slots: int + two_n: int + main_moduli: tf.Tensor + level: int + aux_moduli: tf.Tensor + plaintext_modulus: int + noise_variance: int + noise_bits: int + scaling_factor: int + mul_depth_supported: int + seed: str - -class ShellContext64(object): def __init__( self, - shell_context, + _raw_context, log_n, main_moduli, aux_moduli, plaintext_modulus, noise_variance, - scaling_factor, # The field version of fixed point fractional bits. + scaling_factor, mul_depth_supported, seed, ): - self._raw_context = shell_context + self._raw_context = _raw_context self.log_n = log_n self.num_slots = 2**log_n - self.two_n = 2 ** (self.log_n + 1) + self.two_n = 2 ** (log_n + 1) self.main_moduli = main_moduli + self.level = len(main_moduli) self.aux_moduli = aux_moduli self.plaintext_modulus = plaintext_modulus self.noise_variance = noise_variance - if noise_variance % 2 == 0: - self.noise_bits = noise_variance.bit_length() + if self.noise_variance % 2 == 0: + self.noise_bits = self.noise_variance.bit_length() else: - self.noise_bits = noise_variance.bit_length() + 1 + self.noise_bits = self.noise_variance.bit_length() + 1 self.scaling_factor = scaling_factor self.mul_depth_supported = mul_depth_supported self.seed = seed - @property - def level(self): - return len(self.main_moduli) - - @property - def Q(self): - if not hasattr(self, "_Q"): - self._Q = 1 - for x in self.main_moduli: - self._Q *= x - return self._Q - - def __lt__(self, other): - return self.level < other.level - - def __le__(self, other): - return self.level <= other.level - - def __gt__(self, other): - return self.level > other.level - def __ge__(self, other): - return self.level >= other.level +def mod_reduce_context64(context): + if not isinstance(context, ShellContext64): + raise ValueError("context must be a ShellContext64.") - def __eq__(self, other): - return ( - self.log_n == other.log_n - and self.main_moduli == other.main_moduli - and self.aux_moduli == other.aux_moduli - and self.plaintext_modulus == other.plaintext_modulus - and self.noise_variance == other.noise_variance - and self.seed == other.seed - ) + assert context.mul_depth_supported > 0, "Not enough multiplication primes." - def __ne__(self, other): - return not self.__eq__(other) + smaller_context = shell_ops.modulus_reduce_context64(context._raw_context) - def __hash__(self): - return ( - hash(tuple(self.main_moduli)) - ^ hash(tuple(self.aux_moduli)) - ^ hash(self.plaintext_modulus) - ^ hash(self.noise_variance) - ^ hash(self.seed) - ) - - def get_mod_reduced(self): - assert self.mul_depth_supported > 0, "Not enough multiplication primes." - - if hasattr(self, "_mod_reduced"): - return self._mod_reduced - - smaller_context = shell_ops.modulus_reduce_context64(self._raw_context) - - self._mod_reduced = ShellContext64( - shell_context=smaller_context, - log_n=self.log_n, - main_moduli=self.main_moduli[:-1], - aux_moduli=self.aux_moduli, - plaintext_modulus=self.plaintext_modulus, - noise_variance=self.noise_variance, - scaling_factor=self.scaling_factor, - mul_depth_supported=self.mul_depth_supported - 1, - seed=self.seed, - ) + mod_reduced = ShellContext64( + _raw_context=smaller_context, + log_n=context.log_n, + main_moduli=context.main_moduli[:-1], + aux_moduli=context.aux_moduli, + plaintext_modulus=context.plaintext_modulus, + noise_variance=context.noise_variance, + scaling_factor=context.scaling_factor, + mul_depth_supported=context.mul_depth_supported - 1, + seed=context.seed, + ) - return self._mod_reduced + return mod_reduced def create_context64( @@ -140,7 +110,7 @@ def create_context64( ) return ShellContext64( - shell_context=shell_context, + _raw_context=shell_context, log_n=log_n, main_moduli=main_moduli, aux_moduli=aux_moduli, diff --git a/tf_shell/python/shell_key.py b/tf_shell/python/shell_key.py index 03dadb6..bf5d6fe 100644 --- a/tf_shell/python/shell_key.py +++ b/tf_shell/python/shell_key.py @@ -15,42 +15,37 @@ # limitations under the License. import tf_shell.python.ops.shell_ops as shell_ops from tf_shell.python.shell_context import ShellContext64 +from tf_shell.python.shell_context import mod_reduce_context64 +import tensorflow as tf +import typing -class ShellKey64(object): - def __init__( - self, - raw_key, - level, - ): - self._raw_key = raw_key - self.level = level - - def get_mod_reduced(self): - if hasattr(self, "_mod_reduced"): - return self._mod_reduced - - smaller_key = shell_ops.modulus_reduce_key64(self._raw_key) - self._mod_reduced = ShellKey64(smaller_key, self.level - 1) - return self._mod_reduced +class ShellKey64(tf.experimental.ExtensionType): + _raw_key: tf.Tensor + level: int def create_key64(context): if not isinstance(context, ShellContext64): - raise ValueError("Context must be a ShellContext64") + raise ValueError("context must be a ShellContext64") + + return ShellKey64( + _raw_key=shell_ops.key_gen64(context._raw_context), level=context.level + ) + + +def mod_reduce_key64(key): + if not isinstance(key, ShellKey64): + raise ValueError("key must be a ShellKey64") - raw_key = shell_ops.key_gen64(context._raw_context) - return ShellKey64(raw_key, context.level) + smaller_raw_key = shell_ops.modulus_reduce_key64(key._raw_key) + mod_reduced = ShellKey64(smaller_raw_key, key.level - 1) + return mod_reduced -class ShellRotationKey64(object): - def __init__( - self, - raw_rot_keys_at_level, - context, - ): - self._raw_rot_keys_at_level = raw_rot_keys_at_level - self._context = context +class ShellRotationKey64(tf.experimental.ExtensionType): + _raw_rot_keys_at_level: typing.Mapping[int, tf.Tensor] + context: ShellContext64 def _get_key_at_level(self, level): if level not in self._raw_rot_keys_at_level: @@ -65,7 +60,7 @@ def create_rotation_key64(context, key, skip_at_mul_depth=[]): generating keys at levels (particular number of moduli) at which no rotations are required.""" if not isinstance(context, ShellContext64): - raise ValueError("Context must be a ShellContext64.") + raise ValueError("context must be a ShellContext64.") if context.level != key.level: raise ValueError("Context and key levels must match.") @@ -79,7 +74,7 @@ def create_rotation_key64(context, key, skip_at_mul_depth=[]): if context.mul_depth_supported == 0 or context.level == 1: break - context = context.get_mod_reduced() - key = key.get_mod_reduced() + context = mod_reduce_context64(context) + key = mod_reduce_key64(key) - return ShellRotationKey64(raw_rot_keys_at_level, context) + return ShellRotationKey64(_raw_rot_keys_at_level=raw_rot_keys_at_level, context=context) diff --git a/tf_shell/python/shell_tensor.py b/tf_shell/python/shell_tensor.py index 0f71d9c..845bdf4 100644 --- a/tf_shell/python/shell_tensor.py +++ b/tf_shell/python/shell_tensor.py @@ -17,51 +17,31 @@ import tensorflow as tf import tf_shell.python.ops.shell_ops as shell_ops from tf_shell.python.shell_context import ShellContext64 +from tf_shell.python.shell_context import mod_reduce_context64 from tf_shell.python.shell_key import ShellKey64 +from tf_shell.python.shell_key import mod_reduce_key64 from tf_shell.python.shell_key import ShellRotationKey64 -class ShellTensor64(object): - is_tensor_like = True # needed to pass tf.is_tensor, new as of TF 2.2+ - - def __init__( - self, - value, - context, - underlying_dtype, - scaling_factor, - is_enc, - noise_bit_count, - ): - assert isinstance( - value, tf.Tensor - ), f"Should be variant tensor, instead got {type(value)}" - - assert ( - value.dtype is tf.variant - ), f"Should be variant tensor, instead got {value.dtype}" - - assert isinstance(context, ShellContext64), f"Should be ShellContext64" - - self._raw = value - self._context = context - self._underlying_dtype = underlying_dtype - self._scaling_factor = scaling_factor - self._is_enc = is_enc - - self._noise_bit_count = noise_bit_count +class ShellTensor64(tf.experimental.ExtensionType): + _raw_tensor: tf.Tensor + _context: ShellContext64 + _underlying_dtype: tf.DType + _scaling_factor: int + _is_enc: bool + _noise_bit_count: tf.Tensor @property def shape(self): - return [self._context.num_slots] + self._raw.shape + return [self._context.num_slots] + self._raw_tensor.shape @property def name(self): - return self._raw.name + return self._raw_tensor.name @property def dtype(self): - return self._raw.name + return self._raw_tensor.name @property def plaintext_dtype(self): @@ -90,12 +70,12 @@ def __getitem__(self, slice): f"ShellTensor does not support intra-slot slicing. Use `:` on the first dimension. Got {slice}" ) return ShellTensor64( - value=self._raw[slice[1:]], - context=self._context, - underlying_dtype=self._underlying_dtype, - scaling_factor=self._scaling_factor, - is_enc=self.is_encrypted, - noise_bit_count=self._noise_bit_count, + _raw_tensor=self._raw_tensor[slice[1:]], + _context=self._context, + _underlying_dtype=self._underlying_dtype, + _scaling_factor=self._scaling_factor, + _is_enc=self.is_encrypted, + _noise_bit_count=self._noise_bit_count, ) def __add__(self, other): @@ -103,33 +83,33 @@ def __add__(self, other): matched_self, matched_other = _match_moduli_and_scaling(self, other) if self.is_encrypted and other.is_encrypted: - result_raw = shell_ops.add_ct_ct64( - matched_self._raw, matched_other._raw + result_raw_tensor = shell_ops.add_ct_ct64( + matched_self._raw_tensor, matched_other._raw_tensor ) elif self.is_encrypted and not other.is_encrypted: - result_raw = shell_ops.add_ct_pt64( - matched_self._raw, matched_other._raw + result_raw_tensor = shell_ops.add_ct_pt64( + matched_self._raw_tensor, matched_other._raw_tensor ) elif not self.is_encrypted and other.is_encrypted: - result_raw = shell_ops.add_ct_pt64( - matched_other._raw, matched_self._raw + result_raw_tensor = shell_ops.add_ct_pt64( + matched_other._raw_tensor, matched_self._raw_tensor ) elif not self.is_encrypted and not other.is_encrypted: - result_raw = shell_ops.add_pt_pt64( + result_raw_tensor = shell_ops.add_pt_pt64( matched_self._context._raw_context, - matched_self._raw, - matched_other._raw, + matched_self._raw_tensor, + matched_other._raw_tensor, ) else: raise ValueError("Invalid operands") return ShellTensor64( - value=result_raw, - context=matched_self._context, - underlying_dtype=self._underlying_dtype, - scaling_factor=self._scaling_factor, - is_enc=self._is_enc or other._is_enc, - noise_bit_count=self._noise_bit_count + 1, + _raw_tensor=result_raw_tensor, + _context=matched_self._context, + _underlying_dtype=self._underlying_dtype, + _scaling_factor=self._scaling_factor, + _is_enc=self._is_enc or other._is_enc, + _noise_bit_count=self._noise_bit_count + 1, ) elif isinstance(other, tf.Tensor): @@ -161,34 +141,34 @@ def __sub__(self, other): matched_self, matched_other = _match_moduli_and_scaling(self, other) if self.is_encrypted and other.is_encrypted: - result_raw = shell_ops.sub_ct_ct64( - matched_self._raw, matched_other._raw + result_raw_tensor = shell_ops.sub_ct_ct64( + matched_self._raw_tensor, matched_other._raw_tensor ) elif self.is_encrypted and not other.is_encrypted: - result_raw = shell_ops.sub_ct_pt64( - matched_self._raw, matched_other._raw + result_raw_tensor = shell_ops.sub_ct_pt64( + matched_self._raw_tensor, matched_other._raw_tensor ) elif not self.is_encrypted and other.is_encrypted: negative_other = -matched_other - result_raw = shell_ops.add_ct_pt64( - negative_other._raw, matched_self._raw + result_raw_tensor = shell_ops.add_ct_pt64( + negative_other._raw_tensor, matched_self._raw_tensor ) elif not self.is_encrypted and not other.is_encrypted: - result_raw = shell_ops.sub_pt_pt64( + result_raw_tensor = shell_ops.sub_pt_pt64( matched_self._context._raw_context, - matched_self._raw, - matched_other._raw, + matched_self._raw_tensor, + matched_other._raw_tensor, ) else: raise ValueError("Invalid operands") return ShellTensor64( - value=result_raw, - context=matched_self._context, - underlying_dtype=self._underlying_dtype, - scaling_factor=self._scaling_factor, - is_enc=self._is_enc or other._is_enc, - noise_bit_count=self._noise_bit_count + 1, + _raw_tensor=result_raw_tensor, + _context=matched_self._context, + _underlying_dtype=self._underlying_dtype, + _scaling_factor=self._scaling_factor, + _is_enc=self._is_enc or other._is_enc, + _noise_bit_count=self._noise_bit_count + 1, ) elif isinstance(other, tf.Tensor): if other.shape == (1,) or other.shape == (): @@ -229,20 +209,20 @@ def __rsub__(self, other): if self_matched.is_encrypted: negative_self_matched = -self_matched raw_result = shell_ops.add_ct_pt64( - negative_self_matched._raw, other_matched._raw + negative_self_matched._raw_tensor, other_matched._raw_tensor ) else: raw_result = shell_ops.sub_pt_pt64( - self._context._raw_context, other_matched._raw, self_matched._raw + self._context._raw_context, other_matched._raw_tensor, self_matched._raw_tensor ) return ShellTensor64( - value=raw_result, - context=self._context, - underlying_dtype=self._underlying_dtype, - scaling_factor=self._scaling_factor, - is_enc=self._is_enc, - noise_bit_count=self._noise_bit_count + 1, + _raw_tensor=raw_result, + _context=self._context, + _underlying_dtype=self._underlying_dtype, + _scaling_factor=self._scaling_factor, + _is_enc=self._is_enc, + _noise_bit_count=self._noise_bit_count + 1, ) else: # Try to import the unknown operand to a TensorFlow tensor and @@ -257,17 +237,17 @@ def __rsub__(self, other): def __neg__(self): if self.is_encrypted: - raw_result = shell_ops.neg_ct64(self._raw) + raw_result = shell_ops.neg_ct64(self._raw_tensor) else: - raw_result = shell_ops.neg_pt64(self._context._raw_context, self._raw) + raw_result = shell_ops.neg_pt64(self._context._raw_context, self._raw_tensor) return ShellTensor64( - value=raw_result, - context=self._context, - underlying_dtype=self._underlying_dtype, - scaling_factor=self._scaling_factor, - is_enc=self._is_enc, - noise_bit_count=self._noise_bit_count + 1, + _raw_tensor=raw_result, + _context=self._context, + _underlying_dtype=self._underlying_dtype, + _scaling_factor=self._scaling_factor, + _is_enc=self._is_enc, + _noise_bit_count=self._noise_bit_count + 1, ) def __mul__(self, other): @@ -276,32 +256,32 @@ def __mul__(self, other): if self.is_encrypted and other.is_encrypted: raw_result = shell_ops.mul_ct_ct64( - matched_self._raw, matched_other._raw + matched_self._raw_tensor, matched_other._raw_tensor ) elif self.is_encrypted and not other.is_encrypted: raw_result = shell_ops.mul_ct_pt64( - matched_self._raw, matched_other._raw + matched_self._raw_tensor, matched_other._raw_tensor ) elif not self.is_encrypted and other.is_encrypted: raw_result = shell_ops.mul_ct_pt64( - matched_other._raw, matched_self._raw + matched_other._raw_tensor, matched_self._raw_tensor ) elif not self.is_encrypted and not other.is_encrypted: raw_result = shell_ops.mul_pt_pt64( matched_self._context._raw_context, - matched_self._raw, - matched_other._raw, + matched_self._raw_tensor, + matched_other._raw_tensor, ) else: raise ValueError("Invalid operands") return ShellTensor64( - value=raw_result, - context=matched_self._context, - underlying_dtype=self._underlying_dtype, - scaling_factor=matched_self._scaling_factor**2, - is_enc=self._is_enc or other._is_enc, - noise_bit_count=matched_self._noise_bit_count + _raw_tensor=raw_result, + _context=matched_self._context, + _underlying_dtype=self._underlying_dtype, + _scaling_factor=matched_self._scaling_factor**2, + _is_enc=self._is_enc or other._is_enc, + _noise_bit_count=matched_self._noise_bit_count + matched_other._noise_bit_count, ) elif isinstance(other, tf.Tensor): @@ -315,20 +295,20 @@ def __mul__(self, other): if self.is_encrypted: raw_result = shell_ops.mul_ct_tf_scalar64( - self._context._raw_context, self._raw, other + self._context._raw_context, self._raw_tensor, other ) else: raw_result = shell_ops.mul_pt_tf_scalar64( - self._context._raw_context, self._raw, other + self._context._raw_context, self._raw_tensor, other ) return ShellTensor64( - value=raw_result, - context=self._context, - underlying_dtype=self._underlying_dtype, - scaling_factor=self._scaling_factor**2, - is_enc=self._is_enc, - noise_bit_count=self.noise_bits + self._context.noise_bits, + _raw_tensor=raw_result, + _context=self._context, + _underlying_dtype=self._underlying_dtype, + _scaling_factor=self._scaling_factor**2, + _is_enc=self._is_enc, + _noise_bit_count=self.noise_bits + self._context.noise_bits, ) else: @@ -352,49 +332,51 @@ def __mul__(self, other): def __rmul__(self, other): return self * other - def get_mod_reduced(self): - """Switches the ShellTensor to a new context with different moduli. If - preserve_plaintext is True (default), the plaintext value will be - maintained through the modulus switch. If preserve_plaintext is False, - the plaintext will be divided by the ratio of the new and old moduli.""" - if hasattr(self, "_mod_reduced"): - return self._mod_reduced +def mod_reduce_tensor64(shell_tensor): + """Switches the ShellTensor to a new context with different moduli. If + preserve_plaintext is True (default), the plaintext value will be + maintained through the modulus switch. If preserve_plaintext is False, + the plaintext will be divided by the ratio of the new and old moduli.""" - # Switch to the new context and moduli. - if self.is_encrypted: - op = shell_ops.modulus_reduce_ct64 - else: - op = shell_ops.modulus_reduce_pt64 + assert isinstance( + shell_tensor, ShellTensor64 + ), f"shell_tensor must be a ShellTensor64, instead got {type(shell_tensor)}" - raw_result = op( - self._context._raw_context, - self._raw, - ) + # Switch to the new context and moduli. + if shell_tensor.is_encrypted: + op = shell_ops.modulus_reduce_ct64 + else: + op = shell_ops.modulus_reduce_pt64 - reduced_self = ShellTensor64( - value=raw_result, - context=self._context.get_mod_reduced(), - underlying_dtype=self._underlying_dtype, - scaling_factor=self._scaling_factor, - is_enc=self._is_enc, - noise_bit_count=self.noise_bits - - self._context.main_moduli[-1].bit_length() - + 1, - ) + raw_result = op( + shell_tensor._context._raw_context, + shell_tensor._raw_tensor, + ) - # Cache the result. - self._mod_reduced = reduced_self + reduced_noise = tf.cast(tf.math.ceil( + tf.math.log(tf.cast(shell_tensor._context.main_moduli[-1], tf.float32)) + / tf.math.log(tf.cast(2, tf.float32)) + ), tf.int32) + + reduced_self = ShellTensor64( + _raw_tensor=raw_result, + _context=mod_reduce_context64(shell_tensor._context), + _underlying_dtype=shell_tensor._underlying_dtype, + _scaling_factor=shell_tensor._scaling_factor, + _is_enc=shell_tensor._is_enc, + _noise_bit_count=shell_tensor.noise_bits - reduced_noise, + ) - return reduced_self + return reduced_self def _match_moduli_and_scaling(x, y): # Mod switch to the smaller modulus of the two. while x._context.level > y._context.level: - x = x.get_mod_reduced() + x = mod_reduce_tensor64(x) while x._context.level < y._context.level: - y = y.get_mod_reduced() + y = mod_reduce_tensor64(y) # Match the scaling factors. # First make sure the scaling factors are compatible. @@ -416,23 +398,23 @@ def _match_moduli_and_scaling(x, y): def _match_shape(x, y): # Match the shape of x and y via broadcasting. - if tf.size(x._raw) > tf.size(y._raw): + if tf.size(x._raw_tensor) > tf.size(y._raw_tensor): y = ShellTensor64( - value=tf.broadcast_to(y._raw, tf.shape(x._raw)), - context=y._context, - underlying_dtype=y._underlying_dtype, - scaling_factor=y._scaling_factor, - is_enc=y._is_enc, - noise_bit_count=y._noise_bit_count, + _raw_tensor=tf.broadcast_to(y._raw_tensor, tf.shape(x._raw_tensor)), + _context=y._context, + _underlying_dtype=y._underlying_dtype, + _scaling_factor=y._scaling_factor, + _is_enc=y._is_enc, + _noise_bit_count=y._noise_bit_count, ) - elif tf.size(x._raw) < tf.size(y._raw): + elif tf.size(x._raw_tensor) < tf.size(y._raw_tensor): x = ShellTensor64( - value=tf.broadcast_to(x._raw, tf.shape(y._raw)), - context=y._context, - underlying_dtype=y._underlying_dtype, - scaling_factor=y._scaling_factor, - is_enc=y._is_enc, - noise_bit_count=y._noise_bit_count, + _raw_tensor=tf.broadcast_to(x._raw_tensor, tf.shape(y._raw_tensor)), + _context=y._context, + _underlying_dtype=y._underlying_dtype, + _scaling_factor=y._scaling_factor, + _is_enc=y._is_enc, + _noise_bit_count=y._noise_bit_count, ) return x, y @@ -508,12 +490,12 @@ def to_shell_plaintext(tensor, context): scaled_tensor = tf.pad(scaled_tensor, padding) return ShellTensor64( - value=shell_ops.polynomial_import64(context._raw_context, scaled_tensor), - context=context, - underlying_dtype=tensor.dtype, - scaling_factor=context.scaling_factor, - is_enc=False, - noise_bit_count=context.noise_bits, + _raw_tensor=shell_ops.polynomial_import64(context._raw_context, scaled_tensor), + _context=context, + _underlying_dtype=tensor.dtype, + _scaling_factor=context.scaling_factor, + _is_enc=False, + _noise_bit_count=context.noise_bits, ) else: try: @@ -535,16 +517,16 @@ def to_encrypted(x, key, context=None): return x # Do nothing, already encrypted. else: return ShellTensor64( - value=shell_ops.encrypt64( + _raw_tensor=shell_ops.encrypt64( x._context._raw_context, key._raw_key, - x._raw, + x._raw_tensor, ), - context=x._context, - underlying_dtype=x._underlying_dtype, - scaling_factor=x._scaling_factor, - is_enc=True, - noise_bit_count=x._noise_bit_count, + _context=x._context, + _underlying_dtype=x._underlying_dtype, + _scaling_factor=x._scaling_factor, + _is_enc=True, + _noise_bit_count=x._noise_bit_count, ) else: if not isinstance(context, ShellContext64): @@ -576,13 +558,13 @@ def to_tensorflow(s_tensor, key=None): # Mod reduce the key to match the level of the ciphertext. while key.level > s_tensor._context.level: - key = key.get_mod_reduced() + key = mod_reduce_key64(key) # Decrypt op returns a tf Tensor. tf_tensor = shell_ops.decrypt64( s_tensor._context._raw_context, key._raw_key, - s_tensor._raw, + s_tensor._raw_tensor, dtype=shell_dtype, batching_dim=s_tensor._context.num_slots, ) @@ -592,7 +574,7 @@ def to_tensorflow(s_tensor, key=None): # Always convert to int64, then handle the fixed point as appropriate. tf_tensor = shell_ops.polynomial_export64( s_tensor._context._raw_context, - s_tensor._raw, + s_tensor._raw_tensor, dtype=shell_dtype, batching_dim=s_tensor._context.num_slots, ) @@ -621,12 +603,12 @@ def roll(x, shift, rotation_key): shift = tf.cast(shift, tf.int64) return ShellTensor64( - value=shell_ops.roll64(raw_rotation_key, x._raw, shift), - context=x._context, - underlying_dtype=x._underlying_dtype, - scaling_factor=x._scaling_factor, - is_enc=True, - noise_bit_count=x._noise_bit_count + 6, # TODO correct? + _raw_tensor=shell_ops.roll64(raw_rotation_key, x._raw_tensor, shift), + _context=x._context, + _underlying_dtype=x._underlying_dtype, + _scaling_factor=x._scaling_factor, + _is_enc=True, + _noise_bit_count=x._noise_bit_count + 6, # TODO correct? ) elif isinstance(x, tf.Tensor): return tf.roll(x, shift) @@ -659,12 +641,12 @@ def reduce_sum(x, axis, rotation_key=None): ) return ShellTensor64( - value=shell_ops.reduce_sum_by_rotation64(x._raw, raw_rotation_key), - context=x._context, - underlying_dtype=x._underlying_dtype, - scaling_factor=x._scaling_factor, - is_enc=True, - noise_bit_count=result_noise_bits, + _raw_tensor=shell_ops.reduce_sum_by_rotation64(x._raw_tensor, raw_rotation_key), + _context=x._context, + _underlying_dtype=x._underlying_dtype, + _scaling_factor=x._scaling_factor, + _is_enc=True, + _noise_bit_count=result_noise_bits, ) else: if axis >= len(x.shape): @@ -673,12 +655,12 @@ def reduce_sum(x, axis, rotation_key=None): result_noise_bits = x._noise_bit_count + x.shape[axis].bit_length() + 1 return ShellTensor64( - value=shell_ops.reduce_sum64(x._raw, axis), - context=x._context, - underlying_dtype=x._underlying_dtype, - scaling_factor=x._scaling_factor, - is_enc=True, - noise_bit_count=result_noise_bits, + _raw_tensor=shell_ops.reduce_sum64(x._raw_tensor, axis), + _context=x._context, + _underlying_dtype=x._underlying_dtype, + _scaling_factor=x._scaling_factor, + _is_enc=True, + _noise_bit_count=result_noise_bits, ) elif isinstance(x, tf.Tensor): return tf.reduce_sum(x, axis) @@ -715,16 +697,16 @@ def matmul(x, y, rotation_key=None): reduce_sum_noise = multiplication_noise + x.shape[1].bit_length() return ShellTensor64( - value=shell_ops.mat_mul_ct_pt64( + _raw_tensor=shell_ops.mat_mul_ct_pt64( x._context._raw_context, - x._raw, + x._raw_tensor, scaled_y, ), - context=x._context, - underlying_dtype=x._underlying_dtype, - scaling_factor=x._scaling_factor**2, - is_enc=True, - noise_bit_count=reduce_sum_noise, + _context=x._context, + _underlying_dtype=x._underlying_dtype, + _scaling_factor=x._scaling_factor**2, + _is_enc=True, + _noise_bit_count=reduce_sum_noise, ) elif isinstance(x, tf.Tensor) and isinstance(y, ShellTensor64): @@ -757,17 +739,17 @@ def matmul(x, y, rotation_key=None): reduce_sum_noise = rotation_noise + y._context.num_slots.bit_length() return ShellTensor64( - value=shell_ops.mat_mul_pt_ct64( + _raw_tensor=shell_ops.mat_mul_pt_ct64( y._context._raw_context, raw_rotation_key, scaled_x, - y._raw, + y._raw_tensor, ), - context=y._context, - underlying_dtype=y._underlying_dtype, - scaling_factor=y._scaling_factor**2, - is_enc=True, - noise_bit_count=reduce_sum_noise, + _context=y._context, + _underlying_dtype=y._underlying_dtype, + _scaling_factor=y._scaling_factor**2, + _is_enc=True, + _noise_bit_count=reduce_sum_noise, ) elif isinstance(x, ShellTensor64) and isinstance(y, ShellTensor64): @@ -790,12 +772,12 @@ def expand_dims(x, axis=-1): "Cannot expand dims at axis 0 for ShellTensor64, this is the batching dimension." ) return ShellTensor64( - value=shell_ops.expand_dims_variant(x._raw, axis), - context=x._context, - underlying_dtype=x._underlying_dtype, - scaling_factor=x._scaling_factor, - is_enc=x._is_enc, - noise_bit_count=x._noise_bit_count, + _raw_tensor=shell_ops.expand_dims_variant(x._raw_tensor, axis), + _context=x._context, + _underlying_dtype=x._underlying_dtype, + _scaling_factor=x._scaling_factor, + _is_enc=x._is_enc, + _noise_bit_count=x._noise_bit_count, ) elif isinstance(x, tf.Tensor): return tf.expand_dims(x, axis) @@ -811,12 +793,12 @@ def reshape(x, shape): "Cannot reshape axis 0 for ShellTensor64, this is the batching dimension." ) return ShellTensor64( - value=tf.reshape(x._raw, shape[1:]), - context=x._context, - underlying_dtype=x._underlying_dtype, - scaling_factor=x._scaling_factor, - is_enc=x._is_enc, - noise_bit_count=x._noise_bit_count, + _raw_tensor=tf.reshape(x._raw_tensor, shape[1:]), + _context=x._context, + _underlying_dtype=x._underlying_dtype, + _scaling_factor=x._scaling_factor, + _is_enc=x._is_enc, + _noise_bit_count=x._noise_bit_count, ) elif isinstance(x, tf.Tensor): return tf.reshape(x, shape) diff --git a/tf_shell/test/context_test.py b/tf_shell/test/context_test.py index def4ca1..60498ca 100644 --- a/tf_shell/test/context_test.py +++ b/tf_shell/test/context_test.py @@ -34,13 +34,13 @@ def test_create_context(self): # The ratio between the smaller and the larger context should be # the last modulus in the chain. - smaller_context = context.get_mod_reduced() - self.assertAllClose(context.Q / smaller_context.Q, ql) + smaller_context = tf_shell.mod_reduce_context64(context) + self.assertAllClose(context.main_moduli[:-1], smaller_context.main_moduli) # The ratio between the smaller and the larger context should be the # scaling factor. - even_smaller_context = smaller_context.get_mod_reduced() - self.assertAllClose(context.Q / even_smaller_context.Q, ql * ql2) + even_smaller_context = tf_shell.mod_reduce_context64(smaller_context) + self.assertAllClose(context.main_moduli[:-2], even_smaller_context.main_moduli) def test_mod_reduce_context(self): # Num plaintext bits: 48, noise bits: 65 @@ -59,10 +59,10 @@ def test_mod_reduce_context(self): ea = tf_shell.to_encrypted(sa, key) # Mod reducing should not affect the plaintext value. - smaller_sa = sa.get_mod_reduced() + smaller_sa = tf_shell.mod_reduce_tensor64(sa) self.assertAllClose(a, tf_shell.to_tensorflow(smaller_sa)) - smaller_ea = ea.get_mod_reduced() + smaller_ea = tf_shell.mod_reduce_tensor64(ea) self.assertAllClose(a, tf_shell.to_tensorflow(smaller_ea, key)) # Check the arguments were not modified diff --git a/tf_shell/test/rotation_test.py b/tf_shell/test/rotation_test.py index 8bae3b5..c5fff55 100644 --- a/tf_shell/test/rotation_test.py +++ b/tf_shell/test/rotation_test.py @@ -154,7 +154,7 @@ def _test_roll_mod_reduced(self, test_context, roll_num): enc = tf_shell.to_encrypted(s, test_context.key) # Test roll on a mod reduced ciphertext. - enc_reduced = enc.get_mod_reduced() + enc_reduced = tf_shell.mod_reduce_tensor64(enc) rolled_enc_reduced = tf_shell.roll( enc_reduced, roll_num, test_context.rotation_key ) From 9446794a78267d457baf9a0fdd682cb0b8dac3c3 Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Wed, 8 May 2024 03:58:47 +0000 Subject: [PATCH 06/22] Tests and examples use autograph for shell ops. --- examples/label_dp_sgd.ipynb | 132 +++--- examples/label_dp_sgd_post_scale.ipynb | 448 ++---------------- tf_shell/python/shell_tensor.py | 82 ++-- tf_shell_ml/loss.py | 3 +- tf_shell_ml/test/mnist_enc_backprop_test.py | 3 +- tf_shell_ml/test/mnist_noenc_backprop_test.py | 63 +-- tf_shell_ml/test/mnist_post_scale_test.py | 36 +- 7 files changed, 251 insertions(+), 516 deletions(-) diff --git a/examples/label_dp_sgd.ipynb b/examples/label_dp_sgd.ipynb index 3cd6950..5dd5caa 100644 --- a/examples/label_dp_sgd.ipynb +++ b/examples/label_dp_sgd.ipynb @@ -38,8 +38,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-04-26 15:59:42.151758: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", - "2024-04-26 15:59:42.173513: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "2024-05-08 01:19:38.518051: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2024-05-08 01:19:38.539912: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" ] } @@ -156,6 +156,7 @@ "metadata": {}, "outputs": [], "source": [ + "@tf.function\n", "def train_step(x, enc_y):\n", " # Forward pass always in plaintext\n", " y_1 = hidden_layer(x)\n", @@ -170,8 +171,8 @@ " # 10 classes. tf-shell instead back propagates in two mini-batches per batch\n", " # resulting in two gradients of shape [10]. Furthermore, the gradients are\n", " # in an \"expanded\" form where the gradient is repeated by the size of the\n", - " # mini-batch. Said another way, if real_grad_top/bottom is the \"real\"\n", - " # gradient of shape [10] from the top/bottom halves of the batch:\n", + " # batch. Said another way, if real_grad_top/bottom is the \"real\" gradient of\n", + " # shape [10] from the top/bottom halves of the batch:\n", " #\n", " # dJ_dw = tf.concat([\n", " # tf.repeat(\n", @@ -192,6 +193,7 @@ " return dJ_dw1[0], dJ_dw0[0]\n", "\n", "\n", + "@tf.function\n", "def train_step_wrapper(x_batch, y_batch):\n", " x_batch = tf.cast(x_batch, tf.float32)\n", " y_batch = tf.cast(y_batch, tf.float32)\n", @@ -253,212 +255,230 @@ "To start tensorboard, run: tensorboard --logdir /tmp/tflogs\n", "\n", "Start of epoch 0\n", - "Epoch: 0, Batch: 0 / 15, Time Stamp: 0.07017874717712402\n", - "\taccuracy: 0.11117256432771683\n" + "Epoch: 0, Batch: 0 / 15, Time Stamp: 0.06978559494018555\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-04-26 16:05:57.677839: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-05-08 01:19:54.275369: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:104] Profiler session initializing.\n", + "2024-05-08 01:19:54.275391: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:119] Profiler session started.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 1 / 15, Time Stamp: 360.4916572570801\n", - "\taccuracy: 0.1150442510843277\n" + "WARNING:tensorflow:Error while stopping profiler: Cannot export profiling results. No profiler is running.\n", + "\taccuracy: 0.15154866874217987\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-04-26 16:11:59.988257: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-05-08 01:26:14.168977: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:70] Profiler session collecting data.\n", + "2024-05-08 01:26:14.181516: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:131] Profiler session tear down.\n", + "2024-05-08 01:26:14.182392: I external/local_tsl/tsl/profiler/rpc/client/save_profile.cc:144] Collecting XSpace to repository: /tmp/tflogs/pt-20240508-011954/plugins/profile/2024_05_08_01_26_14/e81647a0f462.xplane.pb\n", + "2024-05-08 01:26:14.216694: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 2 / 15, Time Stamp: 722.7873704433441\n", - "\taccuracy: 0.11117256432771683\n" + "Epoch: 0, Batch: 1 / 15, Time Stamp: 380.09435200691223\n", + "\taccuracy: 0.16648229956626892\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-04-26 16:18:09.918304: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-05-08 01:32:34.396483: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 3 / 15, Time Stamp: 1092.714866399765\n", - "\taccuracy: 0.10951327532529831\n" + "Epoch: 0, Batch: 2 / 15, Time Stamp: 760.1941771507263\n", + "\taccuracy: 0.14712388813495636\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-04-26 16:24:15.157070: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-05-08 01:38:54.856343: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 4 / 15, Time Stamp: 1457.9539773464203\n", - "\taccuracy: 0.11946902424097061\n" + "Epoch: 0, Batch: 3 / 15, Time Stamp: 1140.6538624763489\n", + "\taccuracy: 0.16592919826507568\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-04-26 16:30:21.150770: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-05-08 01:45:17.714885: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 5 / 15, Time Stamp: 1823.9476954936981\n", - "\taccuracy: 0.12555310130119324\n" + "Epoch: 0, Batch: 4 / 15, Time Stamp: 1523.5124411582947\n", + "\taccuracy: 0.1692477911710739\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-04-26 16:36:25.598224: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-05-08 01:51:34.837468: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 6 / 15, Time Stamp: 2188.3950967788696\n", - "\taccuracy: 0.13993363082408905\n" + "Epoch: 0, Batch: 5 / 15, Time Stamp: 1900.635261774063\n", + "\taccuracy: 0.17865043878555298\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-04-26 16:42:28.388670: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-05-08 01:57:54.261685: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 7 / 15, Time Stamp: 2551.185366868973\n", - "\taccuracy: 0.15873894095420837\n" + "Epoch: 0, Batch: 6 / 15, Time Stamp: 2280.059217453003\n", + "\taccuracy: 0.18860618770122528\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-04-26 16:48:32.459850: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-05-08 02:04:14.045979: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 8 / 15, Time Stamp: 2915.256863594055\n", - "\taccuracy: 0.1692477911710739\n" + "Epoch: 0, Batch: 7 / 15, Time Stamp: 2659.84348654747\n", + "\taccuracy: 0.1946902722120285\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-04-26 16:54:34.126222: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-05-08 02:10:38.937992: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 9 / 15, Time Stamp: 3276.922953605652\n", - "\taccuracy: 0.17865043878555298\n" + "Epoch: 0, Batch: 8 / 15, Time Stamp: 3044.7355086803436\n", + "\taccuracy: 0.20243363082408905\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-04-26 17:00:35.599672: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-05-08 02:16:59.938804: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 10 / 15, Time Stamp: 3638.3964817523956\n", - "\taccuracy: 0.1875\n" + "Epoch: 0, Batch: 9 / 15, Time Stamp: 3425.7369046211243\n", + "\taccuracy: 0.22676990926265717\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-04-26 17:06:38.453861: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-05-08 02:23:25.627886: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 11 / 15, Time Stamp: 4001.250802755356\n", - "\taccuracy: 0.20630531013011932\n" + "Epoch: 0, Batch: 10 / 15, Time Stamp: 3811.425350189209\n", + "\taccuracy: 0.24834071099758148\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-04-26 17:12:41.857820: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-05-08 02:29:44.326151: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 12 / 15, Time Stamp: 4364.654639005661\n", - "\taccuracy: 0.21902655065059662\n" + "Epoch: 0, Batch: 11 / 15, Time Stamp: 4190.123619794846\n", + "\taccuracy: 0.26216813921928406\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-04-26 17:18:46.174162: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-05-08 02:35:59.882177: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 13 / 15, Time Stamp: 4728.971588611603\n", - "\taccuracy: 0.24170354008674622\n", - "Epoch: 0, Batch: 14 / 15, Time Stamp: 5091.019182920456\n", - "Total plaintext training time: 5091.019740104675 seconds\n" + "Epoch: 0, Batch: 12 / 15, Time Stamp: 4565.6800808906555\n", + "\taccuracy: 0.2815265357494354\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-04-26 17:24:48.222336: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-05-08 02:42:16.677880: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 0, Batch: 13 / 15, Time Stamp: 4942.475761651993\n", + "\taccuracy: 0.28595131635665894\n", + "Epoch: 0, Batch: 14 / 15, Time Stamp: 5321.51356959343\n", + "Total training time: 5321.514094829559 seconds\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-05-08 02:48:35.716055: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] } ], @@ -486,25 +506,29 @@ " break\n", "\n", " # If using deferred execution, one can trace and profile the training.\n", - " # tf.summary.trace_on(graph=True, profiler=True, profiler_outdir=logdir)\n", + " if step == 0:\n", + " tf.summary.trace_on(graph=True, profiler=True, profiler_outdir=logdir)\n", "\n", " train_step_wrapper(x_batch, y_batch)\n", "\n", - " # with writer.as_default():\n", - " # tf.summary.trace_export(\n", - " # name=\"tf_shell_example_label_dp_sgd\", step=(epoch + 1) * step\n", - " # )\n", + " if step == 0:\n", + " with writer.as_default():\n", + " tf.summary.trace_export(\n", + " name=\"label_dp_sgd\", step=(epoch + 1) * step\n", + " )\n", "\n", " # Check the accuracy.\n", " average_loss = 0\n", " average_accuracy = 0\n", " for x, y in val_dataset:\n", " y_pred = output_layer(hidden_layer(x))\n", + " loss = tf.reduce_mean(loss_fn(y, y_pred))\n", " accuracy = tf.reduce_mean(\n", " tf.cast(\n", " tf.equal(tf.argmax(y, axis=1), tf.argmax(y_pred, axis=1)), tf.float32\n", " )\n", " )\n", + " average_loss += loss\n", " average_accuracy += accuracy\n", " average_loss /= len(val_dataset)\n", " average_accuracy /= len(val_dataset)\n", @@ -517,7 +541,7 @@ " )\n", "\n", "\n", - "print(f\"Total plaintext training time: {time.time() - start_time} seconds\")" + "print(f\"Total training time: {time.time() - start_time} seconds\")" ] } ], diff --git a/examples/label_dp_sgd_post_scale.ipynb b/examples/label_dp_sgd_post_scale.ipynb index a724fb0..4ad4148 100644 --- a/examples/label_dp_sgd_post_scale.ipynb +++ b/examples/label_dp_sgd_post_scale.ipynb @@ -40,8 +40,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-04-26 15:20:18.850048: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", - "2024-04-26 15:20:18.873633: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "2024-05-08 03:16:15.442044: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2024-05-08 03:16:15.593868: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" ] } @@ -151,6 +151,7 @@ "metadata": {}, "outputs": [], "source": [ + "@tf.function\n", "def train_step(x, y):\n", " \"\"\"One step of training with using the \"post scale\" approach.\n", "\n", @@ -188,9 +189,14 @@ " y_pred = model(x, training=False)\n", "\n", " # Compute y_pred - y (where y is encrypted).\n", - " scalars = y_pred - y # dJ/dy_pred\n", + " scalars = y.__rsub__(y_pred) # dJ/dy_pred\n", " # ^ batch_size x num output classes.\n", "\n", + " # Expand the last dim so that the subsequent multiplication is\n", + " # broadcasted.\n", + " scalars = tf_shell.expand_dims(scalars, axis=-1)\n", + " # ^ batch_size x num output classes x 1\n", + "\n", " # Scale each gradient. Since 'scalars' may be a vector of ciphertexts, this\n", " # requires multiplying plaintext gradient for the specific layer (2d) by the\n", " # ciphertext (scalar). To do so efficiently under encryption requires\n", @@ -206,14 +212,9 @@ " packable_grad = tf.reshape(layer_grad_full, [batch_sz, num_output_classes, -1])\n", " # ^ batch_size x num output classes x flattened weights\n", "\n", - " # Expand the last dim so that the subsequent multiplication is\n", - " # broadcasted.\n", - " expanded_scalars = tf_shell.expand_dims(scalars, axis=-1)\n", - " # ^ batch_size x num output classes x 1\n", - "\n", " # Scale the gradient precursors.\n", - " scaled_grad = packable_grad * expanded_scalars\n", - " # ^ dy_pred/dW * dJ/dy_pred = dJ/dW\n", + " scaled_grad = scalars * packable_grad\n", + " # ^ dJ/dW = dJ/dy_pred * dy_pred/dW \n", "\n", " # Sum over the output classes.\n", " scaled_grad = tf_shell.reduce_sum(scaled_grad, axis=1)\n", @@ -236,6 +237,7 @@ " return ps_grads\n", "\n", "\n", + "@tf.function\n", "def train_step_wrapper(x_batch, y_batch):\n", " # Encrypt\n", " enc_y_batch = tf_shell.to_encrypted(y_batch, secret_key, context)\n", @@ -280,439 +282,81 @@ "To start tensorboard, run: tensorboard --logdir /tmp/tflogs\n", "\n", "Start of epoch 0\n", - "Epoch: 0, Batch: 0 / 30, Time Stamp: 0.06940650939941406\n", - "\taccuracy: 0.06139380484819412\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:21:43.872918: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 1 / 30, Time Stamp: 82.75300288200378\n", - "\taccuracy: 0.08683628588914871\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:22:52.390485: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 2 / 30, Time Stamp: 151.25495791435242\n", - "\taccuracy: 0.12721239030361176\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:24:00.741010: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 3 / 30, Time Stamp: 219.60518217086792\n", - "\taccuracy: 0.1548672616481781\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:25:14.001617: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 4 / 30, Time Stamp: 292.86561703681946\n", - "WARNING:tensorflow:5 out of the last 5 calls to .f at 0x7fbeeb495a20> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", - "\taccuracy: 0.17643804848194122\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:26:22.228981: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 5 / 30, Time Stamp: 361.09285974502563\n", - "WARNING:tensorflow:6 out of the last 6 calls to .f at 0x7fbeeb495a20> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", - "\taccuracy: 0.19081857800483704\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:27:31.338013: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 6 / 30, Time Stamp: 430.2023301124573\n", - "\taccuracy: 0.2101769894361496\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:28:42.044969: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 7 / 30, Time Stamp: 500.9089617729187\n", - "\taccuracy: 0.21902655065059662\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:29:54.669820: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 8 / 30, Time Stamp: 573.5338339805603\n", - "\taccuracy: 0.22400441765785217\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:31:03.402643: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 9 / 30, Time Stamp: 642.2663412094116\n", - "\taccuracy: 0.2317477911710739\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:32:11.294701: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 10 / 30, Time Stamp: 710.1584296226501\n", - "\taccuracy: 0.24668142199516296\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:33:19.957185: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 11 / 30, Time Stamp: 778.8208358287811\n", - "\taccuracy: 0.26493361592292786\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:34:32.724597: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 12 / 30, Time Stamp: 851.5886828899384\n", - "\taccuracy: 0.2887168228626251\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:35:40.807055: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 13 / 30, Time Stamp: 919.6708896160126\n", - "\taccuracy: 0.3163716793060303\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:36:49.307414: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 14 / 30, Time Stamp: 988.1711373329163\n", - "\taccuracy: 0.35011062026023865\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:37:56.881641: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 15 / 30, Time Stamp: 1055.7455956935883\n", - "\taccuracy: 0.3794247806072235\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:39:10.054974: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 16 / 30, Time Stamp: 1128.918863773346\n", - "\taccuracy: 0.41261062026023865\n" + "Epoch: 0, Batch: 0 / 30, Time Stamp: 0.07824254035949707\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-04-26 15:40:17.709488: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-05-08 03:16:18.337431: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:104] Profiler session initializing.\n", + "2024-05-08 03:16:18.337456: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:119] Profiler session started.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 17 / 30, Time Stamp: 1196.5733196735382\n", - "\taccuracy: 0.451880544424057\n" + "WARNING:tensorflow:Error while stopping profiler: Cannot export profiling results. No profiler is running.\n", + "\taccuracy: 0.13440264761447906\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-04-26 15:41:25.546967: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-05-08 03:17:43.002417: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:70] Profiler session collecting data.\n", + "2024-05-08 03:17:43.013809: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:131] Profiler session tear down.\n", + "2024-05-08 03:17:43.015068: I external/local_tsl/tsl/profiler/rpc/client/save_profile.cc:144] Collecting XSpace to repository: /tmp/tflogs/pt-20240508-031618/plugins/profile/2024_05_08_03_17_43/e81647a0f462.xplane.pb\n", + "2024-05-08 03:17:43.097148: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 18 / 30, Time Stamp: 1264.4107003211975\n", - "\taccuracy: 0.4933628439903259\n" + "Epoch: 0, Batch: 1 / 30, Time Stamp: 84.8675389289856\n", + "\taccuracy: 0.14435841143131256\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-04-26 15:42:33.810712: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-05-08 03:18:51.448626: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 19 / 30, Time Stamp: 1332.6749150753021\n", - "\taccuracy: 0.5221238732337952\n" + "Epoch: 0, Batch: 2 / 30, Time Stamp: 153.19263339042664\n", + "\taccuracy: 0.15597344934940338\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-04-26 15:43:46.661953: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-05-08 03:19:58.580131: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 20 / 30, Time Stamp: 1405.525713443756\n", - "\taccuracy: 0.5365044474601746\n" + "Epoch: 0, Batch: 3 / 30, Time Stamp: 220.32476949691772\n" ] }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:44:54.080379: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 21 / 30, Time Stamp: 1472.944087266922\n", - "\taccuracy: 0.5553097128868103\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:46:01.651246: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 22 / 30, Time Stamp: 1540.5149257183075\n", - "\taccuracy: 0.5636062026023865\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:47:10.968213: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 23 / 30, Time Stamp: 1609.8318963050842\n", - "\taccuracy: 0.571349561214447\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:48:23.901182: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 24 / 30, Time Stamp: 1682.7653839588165\n", - "\taccuracy: 0.5818583965301514\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:49:33.029080: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 25 / 30, Time Stamp: 1751.8927001953125\n", - "\taccuracy: 0.5923672318458557\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:50:41.049063: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 26 / 30, Time Stamp: 1819.9127042293549\n", - "\taccuracy: 0.6100663542747498\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:51:49.092682: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 27 / 30, Time Stamp: 1887.9565889835358\n", - "\taccuracy: 0.6299778819084167\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:53:02.865500: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 28 / 30, Time Stamp: 1961.7292737960815\n", - "\taccuracy: 0.6493362784385681\n", - "Epoch: 0, Batch: 29 / 30, Time Stamp: 2030.1340026855469\n", - "Total plaintext training time: 2030.1345376968384 seconds\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-26 15:54:11.269970: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n", + "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n", + "\u001b[1;31mClick here for more info. \n", + "\u001b[1;31mView Jupyter log for further details." ] } ], @@ -740,38 +384,42 @@ " break\n", "\n", " # If using deferred execution, one can trace and profile the training.\n", - " # tf.summary.trace_on(graph=True, profiler=True, profiler_outdir=logdir)\n", + " if step == 0:\n", + " tf.summary.trace_on(graph=True, profiler=True, profiler_outdir=logdir)\n", "\n", " train_step_wrapper(x_batch, y_batch)\n", "\n", - " # with writer.as_default():\n", - " # tf.summary.trace_export(\n", - " # name=\"tf_shell_example_label_dp_sgd\", step=(epoch + 1) * step\n", - " # )\n", + " if step == 0:\n", + " with writer.as_default():\n", + " tf.summary.trace_export(\n", + " name=\"label_dp_sgd_post_scale\", step=(epoch + 1) * step\n", + " )\n", "\n", " # Check the accuracy.\n", " average_loss = 0\n", " average_accuracy = 0\n", " for x, y in val_dataset:\n", " y_pred = model(x, training=False)\n", + " loss = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(y, y_pred))\n", " accuracy = tf.reduce_mean(\n", " tf.cast(\n", " tf.equal(tf.argmax(y, axis=1), tf.argmax(y_pred, axis=1)), tf.float32\n", " )\n", " )\n", " average_accuracy += accuracy\n", + " average_loss += loss\n", " average_loss /= len(val_dataset)\n", " average_accuracy /= len(val_dataset)\n", " tf.print(f\"\\taccuracy: {accuracy}\")\n", "\n", " with writer.as_default():\n", - " tf.summary.scalar(\"loss\", average_loss, step=(epoch + 1) * batch_size - 1)\n", + " tf.summary.scalar(\"loss\", average_loss, step=(epoch + 1) * step)\n", " tf.summary.scalar(\n", - " \"accuracy\", average_accuracy, step=(epoch + 1) * batch_size - 1\n", + " \"accuracy\", average_accuracy, step=(epoch + 1) * step\n", " )\n", "\n", "\n", - "print(f\"Total plaintext training time: {time.time() - start_time} seconds\")" + "print(f\"Total training time: {time.time() - start_time} seconds\")" ] } ], diff --git a/tf_shell/python/shell_tensor.py b/tf_shell/python/shell_tensor.py index 845bdf4..ec3842e 100644 --- a/tf_shell/python/shell_tensor.py +++ b/tf_shell/python/shell_tensor.py @@ -35,6 +35,10 @@ class ShellTensor64(tf.experimental.ExtensionType): def shape(self): return [self._context.num_slots] + self._raw_tensor.shape + @property + def dtype(self): + return tf.variant + @property def name(self): return self._raw_tensor.name @@ -213,7 +217,9 @@ def __rsub__(self, other): ) else: raw_result = shell_ops.sub_pt_pt64( - self._context._raw_context, other_matched._raw_tensor, self_matched._raw_tensor + self._context._raw_context, + other_matched._raw_tensor, + self_matched._raw_tensor, ) return ShellTensor64( @@ -239,7 +245,9 @@ def __neg__(self): if self.is_encrypted: raw_result = shell_ops.neg_ct64(self._raw_tensor) else: - raw_result = shell_ops.neg_pt64(self._context._raw_context, self._raw_tensor) + raw_result = shell_ops.neg_pt64( + self._context._raw_context, self._raw_tensor + ) return ShellTensor64( _raw_tensor=raw_result, @@ -354,10 +362,13 @@ def mod_reduce_tensor64(shell_tensor): shell_tensor._raw_tensor, ) - reduced_noise = tf.cast(tf.math.ceil( - tf.math.log(tf.cast(shell_tensor._context.main_moduli[-1], tf.float32)) - / tf.math.log(tf.cast(2, tf.float32)) - ), tf.int32) + reduced_noise = tf.cast( + tf.math.ceil( + tf.math.log(tf.cast(shell_tensor._context.main_moduli[-1], tf.float32)) + / tf.math.log(tf.cast(2, tf.float32)) + ), + tf.int32, + ) reduced_self = ShellTensor64( _raw_tensor=raw_result, @@ -397,25 +408,35 @@ def _match_moduli_and_scaling(x, y): def _match_shape(x, y): - # Match the shape of x and y via broadcasting. - if tf.size(x._raw_tensor) > tf.size(y._raw_tensor): - y = ShellTensor64( - _raw_tensor=tf.broadcast_to(y._raw_tensor, tf.shape(x._raw_tensor)), - _context=y._context, - _underlying_dtype=y._underlying_dtype, - _scaling_factor=y._scaling_factor, - _is_enc=y._is_enc, - _noise_bit_count=y._noise_bit_count, - ) - elif tf.size(x._raw_tensor) < tf.size(y._raw_tensor): - x = ShellTensor64( + # Match the shape of x and y via broadcasting. Note, this copies the data, + # fully materializing the new tensors shape. In the future, shell_ops + # should suuport broadcasting directly, avoiding the copy. + + x = tf.cond( + tf.size(x._raw_tensor) < tf.size(y._raw_tensor), + lambda: ShellTensor64( _raw_tensor=tf.broadcast_to(x._raw_tensor, tf.shape(y._raw_tensor)), + _context=x._context, + _underlying_dtype=x._underlying_dtype, + _scaling_factor=x._scaling_factor, + _is_enc=x._is_enc, + _noise_bit_count=x._noise_bit_count, + ), + lambda: x, + ) + + y = tf.cond( + tf.size(x._raw_tensor) > tf.size(y._raw_tensor), + lambda: ShellTensor64( + _raw_tensor=tf.broadcast_to(y._raw_tensor, tf.shape(x._raw_tensor)), _context=y._context, _underlying_dtype=y._underlying_dtype, _scaling_factor=y._scaling_factor, _is_enc=y._is_enc, _noise_bit_count=y._noise_bit_count, - ) + ), + lambda: y, + ) return x, y @@ -490,7 +511,9 @@ def to_shell_plaintext(tensor, context): scaled_tensor = tf.pad(scaled_tensor, padding) return ShellTensor64( - _raw_tensor=shell_ops.polynomial_import64(context._raw_context, scaled_tensor), + _raw_tensor=shell_ops.polynomial_import64( + context._raw_context, scaled_tensor + ), _context=context, _underlying_dtype=tensor.dtype, _scaling_factor=context.scaling_factor, @@ -636,12 +659,12 @@ def reduce_sum(x, axis, rotation_key=None): # reduce sum does log2(num_slots) rotations and additions. # TODO: add noise from rotations? - result_noise_bits = ( - x._noise_bit_count + x._context.num_slots.bit_length() + 1, - ) + result_noise_bits = x._noise_bit_count + x._context.noise_bits return ShellTensor64( - _raw_tensor=shell_ops.reduce_sum_by_rotation64(x._raw_tensor, raw_rotation_key), + _raw_tensor=shell_ops.reduce_sum_by_rotation64( + x._raw_tensor, raw_rotation_key + ), _context=x._context, _underlying_dtype=x._underlying_dtype, _scaling_factor=x._scaling_factor, @@ -649,10 +672,13 @@ def reduce_sum(x, axis, rotation_key=None): _noise_bit_count=result_noise_bits, ) else: - if axis >= len(x.shape): - raise ValueError("Axis greater than number of dimensions") - - result_noise_bits = x._noise_bit_count + x.shape[axis].bit_length() + 1 + result_noise_bits = x._noise_bit_count + tf.cast( + tf.math.ceil( + tf.math.log(tf.cast(tf.shape(x._raw_tensor)[axis - 1], tf.float32)) + / tf.math.log(tf.constant(2, dtype=tf.float32)) + ), + tf.int32, + ) return ShellTensor64( _raw_tensor=shell_ops.reduce_sum64(x._raw_tensor, axis), diff --git a/tf_shell_ml/loss.py b/tf_shell_ml/loss.py index ad12132..6520254 100644 --- a/tf_shell_ml/loss.py +++ b/tf_shell_ml/loss.py @@ -15,6 +15,7 @@ # limitations under the License. from tensorflow.nn import softmax from tensorflow.math import log +import tf_shell class CategoricalCrossentropy: @@ -28,7 +29,7 @@ def __call__(self, y_true, y_pred): batch_size = y_true.shape.as_list()[0] batch_size_inv = 1 / batch_size out = -y_true * log(y_pred) - cce = out.reduce_sum() * batch_size_inv + cce = tf_shell.reduce_sum(out, axis=0) * batch_size_inv return cce def grad(self, y_true, y_pred): diff --git a/tf_shell_ml/test/mnist_enc_backprop_test.py b/tf_shell_ml/test/mnist_enc_backprop_test.py index 423211e..424fa82 100644 --- a/tf_shell_ml/test/mnist_enc_backprop_test.py +++ b/tf_shell_ml/test/mnist_enc_backprop_test.py @@ -88,6 +88,7 @@ loss_fn = tf_shell_ml.CategoricalCrossentropy() +@tf.function def train_step(x, y): # Forward pass. y_1 = hidden_layer(x) @@ -110,7 +111,7 @@ def train_step(x, y): class TestMNISTBackprop(tf.test.TestCase): - def test_mnist_plaintext_backprop(self): + def test_mnist_enc_backprop(self): (x_batch, y_batch) = next(iter(train_dataset)) # Plaintext backprop splitting the batch in half vertically. diff --git a/tf_shell_ml/test/mnist_noenc_backprop_test.py b/tf_shell_ml/test/mnist_noenc_backprop_test.py index d3ffb7d..b08eed0 100644 --- a/tf_shell_ml/test/mnist_noenc_backprop_test.py +++ b/tf_shell_ml/test/mnist_noenc_backprop_test.py @@ -64,6 +64,7 @@ optimizer = tf.keras.optimizers.Adam(0.01) +@tf.function def train_step(x, y): # Forward pass. y_1 = hidden_layer(x) @@ -84,38 +85,44 @@ class TestMNISTBackprop(tf.test.TestCase): # Test plaintext training using tf_shell_ml primitives. def test_mnist_plaintext_backprop(self): - for epoch in range(epochs): - for step, (x_batch, y_batch) in enumerate(train_dataset.take(batch_size)): - # Plaintext backprop splitting the batch in half vertically. - output_layer_grad, hidden_layer_grad = train_step(x_batch, y_batch) - - # To directly apply the weights, use the following: - # output_layer.weights[0] = output_layer.weights[0] - 0.01 * output_layer_grad[0] - # hidden_layer.weights[0] = hidden_layer.weights[0] - 0.01 * hidden_layer_grad[0] - - optimizer.apply_gradients( - zip( - output_layer_grad + hidden_layer_grad, - output_layer.weights + hidden_layer.weights, + + # Test both eager and graph mode. + for is_eager in [True, False]: + tf.config.run_functions_eagerly(is_eager) + + # Train the model. + for epoch in range(epochs): + for step, (x_batch, y_batch) in enumerate(train_dataset.take(batch_size)): + # Plaintext backprop splitting the batch in half vertically. + output_layer_grad, hidden_layer_grad = train_step(x_batch, y_batch) + + # To directly apply the weights, use the following: + # output_layer.weights[0] = output_layer.weights[0] - 0.01 * output_layer_grad[0] + # hidden_layer.weights[0] = hidden_layer.weights[0] - 0.01 * hidden_layer_grad[0] + + optimizer.apply_gradients( + zip( + output_layer_grad + hidden_layer_grad, + output_layer.weights + hidden_layer.weights, + ) ) - ) - - average_accuracy = 0.0 - for x, y in val_dataset: - y_pred = output_layer(hidden_layer(x)) - accuracy = tf.reduce_mean( - tf.cast( - tf.equal(tf.argmax(y, axis=1), tf.argmax(y_pred, axis=1)), - tf.float32, + + average_accuracy = 0.0 + for x, y in val_dataset: + y_pred = output_layer(hidden_layer(x)) + accuracy = tf.reduce_mean( + tf.cast( + tf.equal(tf.argmax(y, axis=1), tf.argmax(y_pred, axis=1)), + tf.float32, + ) ) - ) - average_accuracy += accuracy - average_accuracy /= len(val_dataset) + average_accuracy += accuracy + average_accuracy /= len(val_dataset) - print(f"Accuracy: {average_accuracy}") + print(f"Accuracy: {average_accuracy}") - # Ensure the model is learning. - self.assertAllGreater(average_accuracy, 0.9) + # Ensure the model is learning. + self.assertAllGreater(average_accuracy, 0.9) if __name__ == "__main__": diff --git a/tf_shell_ml/test/mnist_post_scale_test.py b/tf_shell_ml/test/mnist_post_scale_test.py index f5969f5..1e6abd1 100644 --- a/tf_shell_ml/test/mnist_post_scale_test.py +++ b/tf_shell_ml/test/mnist_post_scale_test.py @@ -64,6 +64,7 @@ ) +@tf.function def train_step(x, y): """One step of training with using the "post scale" approach. @@ -100,7 +101,8 @@ def train_step(x, y): y_pred = model(x, training=False) # Compute y_pred - y (where y may be encrypted). - scalars = y_pred - y # dJ/dy_pred + # scalars = y_pred - y # dJ/dy_pred + scalars = y.__rsub__(y_pred) # dJ/dy_pred # ^ batch_size x num output classes. # Expand the last dim so that the subsequent multiplications are @@ -124,8 +126,8 @@ def train_step(x, y): # ^ batch_size x num output classes x flattened weights # Scale the gradient precursors. - scaled_grad = packable_grad * scalars - # ^ dy_pred/dW * dJ/dy_pred = dJ/dW + scaled_grad = scalars * packable_grad + # ^ dJ/dW = dJ/dy_pred * dy_pred/dW # Sum over the output classes. scaled_grad = tf_shell.reduce_sum(scaled_grad, axis=1) @@ -154,7 +156,9 @@ def train_step(x, y): class TestPlaintextPostScale(tf.test.TestCase): - def test_mnist_post_scale(self): + def test_mnist_post_scale_eager(self): + tf.config.run_functions_eagerly(True) + (x_batch, y_batch) = next(iter(train_dataset)) # Plaintext @@ -171,6 +175,30 @@ def test_mnist_post_scale(self): atol=1 / context.scaling_factor * context.num_slots, ) + def test_mnist_post_scale_autograph(self): + tf.config.run_functions_eagerly(False) + + (x_batch, y_batch) = next(iter(train_dataset)) + + # Plaintext + ps_grads = train_step(x_batch, y_batch) + + # With autograph on (eagerly off), the tf.function trace cannot be + # reused between plaintext and encrypted calls. Reset the graph + # between plaintext and encrypted train_step() calls. + tf.keras.backend.clear_session() + + # Encrypted + enc_y_batch = tf_shell.to_encrypted(y_batch, key, context) + shell_ps_grads = train_step(x_batch, enc_y_batch) + + # Compare the gradients. + self.assertAllClose( + ps_grads, + shell_ps_grads, + atol=1 / context.scaling_factor * context.num_slots, + ) + if __name__ == "__main__": unittest.main() From 7cc3ad624cfca26a35d8059b27e4e27d4e7c9eff Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Wed, 8 May 2024 06:02:02 +0000 Subject: [PATCH 07/22] Faster rotation test via profiling and removing redundant tests. --- tf_shell/cc/kernels/polynomial_kernels.cc | 4 ++-- tf_shell/cc/kernels/rotation_kernels.cc | 15 ++++++++++++--- tf_shell/cc/kernels/symmetric_kernels.cc | 4 ++-- tf_shell/test/rotation_test.py | 11 +++++++++++ 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/tf_shell/cc/kernels/polynomial_kernels.cc b/tf_shell/cc/kernels/polynomial_kernels.cc index e90e371..93295ba 100644 --- a/tf_shell/cc/kernels/polynomial_kernels.cc +++ b/tf_shell/cc/kernels/polynomial_kernels.cc @@ -139,7 +139,7 @@ class PolynomialImportOp : public OpKernel { auto thread_pool = op_ctx->device()->tensorflow_cpu_worker_threads()->workers; int const cost_per_import = - 0.08f * num_slots; // ns, measured on log_n = 11 + 70 * num_slots; // ns, measured on log_n = 11 thread_pool->ParallelFor(flat_input.dimension(1), cost_per_import, import_in_range); } @@ -234,7 +234,7 @@ class PolynomialExportOp : public OpKernel { auto thread_pool = op_ctx->device()->tensorflow_cpu_worker_threads()->workers; int const cost_per_export = - 0.08f * num_slots; // ns, measured on log_n = 11 + 70 * num_slots; // ns, measured on log_n = 11 thread_pool->ParallelFor(flat_output.dimension(1), cost_per_export, export_in_range); } diff --git a/tf_shell/cc/kernels/rotation_kernels.cc b/tf_shell/cc/kernels/rotation_kernels.cc index 0eddc20..15a84e0 100644 --- a/tf_shell/cc/kernels/rotation_kernels.cc +++ b/tf_shell/cc/kernels/rotation_kernels.cc @@ -246,7 +246,7 @@ class RollOp : public OpKernel { auto thread_pool = op_ctx->device()->tensorflow_cpu_worker_threads()->workers; int const cost_per_rot = - 1000000 * num_components; // ns, measured on log_n = 11 + 500 * num_slots * num_components; // ns, measured on log_n = 11 thread_pool->ParallelFor(flat_output.dimension(0), cost_per_rot, roll_in_range); } @@ -271,6 +271,16 @@ class ReduceSumByRotationOp : public OpKernel { auto flat_value = value.flat(); + // Recover num_slots from first ciphertext to validate shift argument. + SymmetricCtVariant const* ct_var = + std::move(flat_value(0).get>()); + OP_REQUIRES( + op_ctx, ct_var != nullptr, + InvalidArgument("SymmetricCtVariant a did not unwrap successfully.")); + SymmetricCt const& ct = ct_var->ct; + int num_slots = 1 << ct.LogN(); + int num_components = ct.NumModuli(); + // Recover the input rotation keys. OP_REQUIRES_VALUE(RotationKeyVariant const* rotation_key_var, op_ctx, GetVariant>(op_ctx, 1)); @@ -296,7 +306,6 @@ class ReduceSumByRotationOp : public OpKernel { InvalidArgument( "SymmetricCtVariant a did not unwrap successfully.")); SymmetricCt sum = ct_var->ct; // deep copy to start the sum. - int const num_slots = 1 << sum.LogN(); // Add the rotations to the sum. // Note the ciphertext rotations operate on each half of the @@ -323,7 +332,7 @@ class ReduceSumByRotationOp : public OpKernel { auto thread_pool = op_ctx->device()->tensorflow_cpu_worker_threads()->workers; - int const cost_per_reduce = 38306686; // ns measured on log_n = 11 + int const cost_per_reduce = 18000 * num_slots; // ns measured on log_n = 11 thread_pool->ParallelFor(flat_output.dimension(0), cost_per_reduce, reduce_in_range); } diff --git a/tf_shell/cc/kernels/symmetric_kernels.cc b/tf_shell/cc/kernels/symmetric_kernels.cc index 389bbda..1fc9656 100644 --- a/tf_shell/cc/kernels/symmetric_kernels.cc +++ b/tf_shell/cc/kernels/symmetric_kernels.cc @@ -133,7 +133,7 @@ class EncryptOp : public OpKernel { auto thread_pool = op_ctx->device()->tensorflow_cpu_worker_threads()->workers; - int const cost_per_enc = 2200 * num_slots; // ns, measured on log_n = 11 + int const cost_per_enc = 6000 * num_slots; // ns, measured on log_n = 11 thread_pool->ParallelFor(flat_output.dimension(0), cost_per_enc, enc_in_range); } @@ -215,7 +215,7 @@ class DecryptOp : public OpKernel { auto thread_pool = op_ctx->device()->tensorflow_cpu_worker_threads()->workers; - int const cost_per_dec = 0.12f * num_slots; // ns, measured on log_n = 11 + int const cost_per_dec = 75 * num_slots; // ns, measured on log_n = 11 thread_pool->ParallelFor(flat_output.dimension(1), cost_per_dec, dec_in_range); } diff --git a/tf_shell/test/rotation_test.py b/tf_shell/test/rotation_test.py index c5fff55..499c470 100644 --- a/tf_shell/test/rotation_test.py +++ b/tf_shell/test/rotation_test.py @@ -114,7 +114,18 @@ def _test_roll(self, test_context, roll_num): self.assertAllClose(rolled_tftensor, rolled_result, atol=1e-3) def test_roll(self): + # Testing all contexts for all possible rotations is slow. Instead, + # test a subset of rotations for each context, and one context tests + # all rotations. for test_context in self.test_contexts: + rotation_range = test_context.shell_context.num_slots // 2 - 1 + for roll_num in [-rotation_range, rotation_range, -1, 0, 1]: + with self.subTest( + f"roll with context {test_context}, rotating by {roll_num}" + ): + self._test_roll(test_context, roll_num) + + for test_context in [self.test_contexts[0]]: rotation_range = test_context.shell_context.num_slots // 2 - 1 for roll_num in range(-rotation_range, rotation_range, 1): with self.subTest( From e8762686126c181dc50320268f22495e0f37953e Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Mon, 13 May 2024 09:16:05 +0000 Subject: [PATCH 08/22] Shell ops perform broadcasting without relying on tf.broadcast_to(). --- tf_shell/cc/kernels/add_kernels.cc | 72 +++++---- tf_shell/cc/kernels/mul_kernels.cc | 69 ++++++--- tf_shell/cc/kernels/rotation_kernels.cc | 67 +++++--- tf_shell/cc/kernels/shape_kernels.cc | 30 ++-- tf_shell/cc/kernels/utils.h | 55 +++++++ tf_shell/cc/ops/shell_ops.cc | 179 ++++++++++++++++++++-- tf_shell/python/shell_tensor.py | 54 +------ tf_shell/test/shape_test.py | 4 +- tf_shell_ml/test/mnist_post_scale_test.py | 5 +- 9 files changed, 382 insertions(+), 153 deletions(-) diff --git a/tf_shell/cc/kernels/add_kernels.cc b/tf_shell/cc/kernels/add_kernels.cc index 928ddb4..f8772de 100644 --- a/tf_shell/cc/kernels/add_kernels.cc +++ b/tf_shell/cc/kernels/add_kernels.cc @@ -99,18 +99,24 @@ class AddCtCtOp : public OpKernel { Tensor const& a = op_ctx->input(0); Tensor const& b = op_ctx->input(1); - // Check the inputs have the same shape. This Op does not support - // broadcasting. - OP_REQUIRES(op_ctx, a.shape() == b.shape(), - InvalidArgument("Inputs must have the same shape.")); + BCast bcast(BCast::FromShape(a.shape()), BCast::FromShape(b.shape()), + /*fewer_dims_optimization=*/true); + OP_REQUIRES( + op_ctx, bcast.IsValid(), + InvalidArgument("Invalid broadcast between ", a.shape().DebugString(), + " and ", b.shape().DebugString())); + auto flat_a = MyBFlat(op_ctx, a, bcast.x_reshape(), bcast.x_bcast()); + auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); + + // Check the inputs have the same shape. + OP_REQUIRES( + op_ctx, flat_a.size() == flat_b.size(), + InvalidArgument("Broadcasted inputs must have the same shape.")); // Allocate the output tensor which is the same size as one of the inputs. Tensor* output; - OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, a.shape(), &output)); - - // Set up flat views of the inputs and output tensors. - auto flat_a = a.flat(); - auto flat_b = b.flat(); + TensorShape output_shape = BCast::ToShape(bcast.output_shape()); + OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, output_shape, &output)); auto flat_output = output->flat(); for (int i = 0; i < flat_output.dimension(0); ++i) { @@ -152,18 +158,24 @@ class AddCtPtOp : public OpKernel { Tensor const& a = op_ctx->input(0); Tensor const& b = op_ctx->input(1); - // Check the inputs have the same shape. This Op does not support - // broadcasting. - OP_REQUIRES(op_ctx, a.shape() == b.shape(), - InvalidArgument("Inputs must have the same shape.")); + BCast bcast(BCast::FromShape(a.shape()), BCast::FromShape(b.shape()), + /*fewer_dims_optimization=*/true); + OP_REQUIRES( + op_ctx, bcast.IsValid(), + InvalidArgument("Invalid broadcast between ", a.shape().DebugString(), + " and ", b.shape().DebugString())); + auto flat_a = MyBFlat(op_ctx, a, bcast.x_reshape(), bcast.x_bcast()); + auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); + + // Check the inputs have the same shape. + OP_REQUIRES( + op_ctx, flat_a.size() == flat_b.size(), + InvalidArgument("Broadcasted inputs must have the same shape.")); // Allocate the output tensor which is the same size as one of the inputs. Tensor* output; - OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, a.shape(), &output)); - - // Set up flat views of the inputs and output tensors. - auto flat_a = a.flat(); - auto flat_b = b.flat(); + TensorShape output_shape = BCast::ToShape(bcast.output_shape()); + OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, output_shape, &output)); auto flat_output = output->flat(); for (int i = 0; i < flat_output.dimension(0); ++i) { @@ -209,18 +221,24 @@ class AddPtPtOp : public OpKernel { Tensor const& a = op_ctx->input(1); Tensor const& b = op_ctx->input(2); - // Check the inputs have the same shape. This Op does not support - // broadcasting. - OP_REQUIRES(op_ctx, a.shape() == b.shape(), - InvalidArgument("Inputs must have the same shape.")); + BCast bcast(BCast::FromShape(a.shape()), BCast::FromShape(b.shape()), + /*fewer_dims_optimization=*/true); + OP_REQUIRES( + op_ctx, bcast.IsValid(), + InvalidArgument("Invalid broadcast between ", a.shape().DebugString(), + " and ", b.shape().DebugString())); + auto flat_a = MyBFlat(op_ctx, a, bcast.x_reshape(), bcast.x_bcast()); + auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); + + // Check the inputs have the same shape. + OP_REQUIRES( + op_ctx, flat_a.size() == flat_b.size(), + InvalidArgument("Broadcasted inputs must have the same shape.")); // Allocate the output tensor which is the same size as one of the inputs. Tensor* output; - OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, a.shape(), &output)); - - // Set up flat views of the inputs and output tensors. - auto flat_a = a.flat(); - auto flat_b = b.flat(); + TensorShape output_shape = BCast::ToShape(bcast.output_shape()); + OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, output_shape, &output)); auto flat_output = output->flat(); for (int i = 0; i < flat_output.dimension(0); ++i) { diff --git a/tf_shell/cc/kernels/mul_kernels.cc b/tf_shell/cc/kernels/mul_kernels.cc index f3b110b..33de13e 100644 --- a/tf_shell/cc/kernels/mul_kernels.cc +++ b/tf_shell/cc/kernels/mul_kernels.cc @@ -57,17 +57,24 @@ class MulCtCtOp : public OpKernel { Tensor const& a = op_ctx->input(0); Tensor const& b = op_ctx->input(1); - // Check that the inputs have the same shape. - OP_REQUIRES(op_ctx, a.shape() == b.shape(), - InvalidArgument("Inputs must have the same shape.")); + BCast bcast(BCast::FromShape(a.shape()), BCast::FromShape(b.shape()), + /*fewer_dims_optimization=*/true); + OP_REQUIRES( + op_ctx, bcast.IsValid(), + InvalidArgument("Invalid broadcast between ", a.shape().DebugString(), + " and ", b.shape().DebugString())); + auto flat_a = MyBFlat(op_ctx, a, bcast.x_reshape(), bcast.x_bcast()); + auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); + + // Check the inputs have the same shape. + OP_REQUIRES( + op_ctx, flat_a.size() == flat_b.size(), + InvalidArgument("Broadcasted inputs must have the same shape.")); // Allocate the output tensor which is the same shape as each of the inputs. Tensor* output; - OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, a.shape(), &output)); - - // Set up the flat views of the input and output tensors. - auto flat_a = a.flat(); - auto flat_b = b.flat(); + TensorShape output_shape = BCast::ToShape(bcast.output_shape()); + OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, output_shape, &output)); auto flat_output = output->flat(); // Multiply each pair of ciphertexts and store the result in the output. @@ -109,17 +116,24 @@ class MulCtPtOp : public OpKernel { Tensor const& a = op_ctx->input(0); Tensor const& b = op_ctx->input(1); - // Check that the inputs have the same shape. - OP_REQUIRES(op_ctx, a.shape() == b.shape(), - InvalidArgument("Inputs must have the same shape.")); + BCast bcast(BCast::FromShape(a.shape()), BCast::FromShape(b.shape()), + /*fewer_dims_optimization=*/true); + OP_REQUIRES( + op_ctx, bcast.IsValid(), + InvalidArgument("Invalid broadcast between ", a.shape().DebugString(), + " and ", b.shape().DebugString())); + auto flat_a = MyBFlat(op_ctx, a, bcast.x_reshape(), bcast.x_bcast()); + auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); + + // Check the inputs have the same shape. + OP_REQUIRES( + op_ctx, flat_a.size() == flat_b.size(), + InvalidArgument("Broadcasted inputs must have the same shape.")); // Allocate the output tensor which is the same shape as each of the inputs. Tensor* output; - OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, a.shape(), &output)); - - // Set up the flat views of the inputs and output tensors. - auto flat_a = a.flat(); - auto flat_b = b.flat(); + TensorShape output_shape = BCast::ToShape(bcast.output_shape()); + OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, output_shape, &output)); auto flat_output = output->flat(); for (int i = 0; i < flat_output.dimension(0); ++i) { @@ -259,17 +273,24 @@ class MulPtPtOp : public OpKernel { Tensor const& a = op_ctx->input(1); Tensor const& b = op_ctx->input(2); - // Check that the inputs have the same shape. - OP_REQUIRES(op_ctx, a.shape() == b.shape(), - InvalidArgument("Inputs must have the same shape.")); + BCast bcast(BCast::FromShape(a.shape()), BCast::FromShape(b.shape()), + /*fewer_dims_optimization=*/true); + OP_REQUIRES( + op_ctx, bcast.IsValid(), + InvalidArgument("Invalid broadcast between ", a.shape().DebugString(), + " and ", b.shape().DebugString())); + auto flat_a = MyBFlat(op_ctx, a, bcast.x_reshape(), bcast.x_bcast()); + auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); + + // Check the inputs have the same shape. + OP_REQUIRES( + op_ctx, flat_a.size() == flat_b.size(), + InvalidArgument("Broadcasted inputs must have the same shape.")); // Allocate the output tensor which is the same shape as each of the inputs. Tensor* output; - OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, a.shape(), &output)); - - // Set up the flat views of the inputs and output tensors. - auto flat_a = a.flat(); - auto flat_b = b.flat(); + TensorShape output_shape = BCast::ToShape(bcast.output_shape()); + OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, output_shape, &output)); auto flat_output = output->flat(); for (int i = 0; i < flat_output.dimension(0); ++i) { diff --git a/tf_shell/cc/kernels/rotation_kernels.cc b/tf_shell/cc/kernels/rotation_kernels.cc index 15a84e0..e16e1af 100644 --- a/tf_shell/cc/kernels/rotation_kernels.cc +++ b/tf_shell/cc/kernels/rotation_kernels.cc @@ -214,8 +214,10 @@ class RollOp : public OpKernel { RotationKey const* key; if (shift != 0) { - OP_REQUIRES(op_ctx, shift - 1 < keys.size(), // Skip key at zero. - InvalidArgument("No key for shift of '", shift, "'")); + OP_REQUIRES( + op_ctx, + shift - 1 < static_cast(keys.size()), // Skip key at zero. + InvalidArgument("No key for shift of '", shift, "'")); key = &keys[shift - 1]; // Skip key at zero. } @@ -312,8 +314,10 @@ class ReduceSumByRotationOp : public OpKernel { // ciphertext separately. So the max rotation is by half the number // of slots. for (int shift = 1; shift < num_slots / 2; shift <<= 1) { - OP_REQUIRES(op_ctx, shift - 1 < keys.size(), // Skip key at zero. - InvalidArgument("No key for shift of '", shift, "'")); + OP_REQUIRES( + op_ctx, + shift - 1 < static_cast(keys.size()), // Skip key at zero. + InvalidArgument("No key for shift of '", shift, "'")); RotationKey const* key = &keys[shift - 1]; // Skip key at zero. // Rotate by the shift. @@ -332,7 +336,8 @@ class ReduceSumByRotationOp : public OpKernel { auto thread_pool = op_ctx->device()->tensorflow_cpu_worker_threads()->workers; - int const cost_per_reduce = 18000 * num_slots; // ns measured on log_n = 11 + int const cost_per_reduce = + 9000 * num_slots * num_components; // ns measured on log_n = 11 thread_pool->ParallelFor(flat_output.dimension(0), cost_per_reduce, reduce_in_range); } @@ -345,8 +350,20 @@ class ReduceSumOp : public OpKernel { using RotationKey = rlwe::RnsGaloisKey; using SymmetricCt = rlwe::RnsBgvCiphertext; + int dim_to_reduce; + public: - explicit ReduceSumOp(OpKernelConstruction* op_ctx) : OpKernel(op_ctx) {} + explicit ReduceSumOp(OpKernelConstruction* op_ctx) : OpKernel(op_ctx) { + // Get the dimension to reduce over from the op attributes. + OP_REQUIRES_OK(op_ctx, op_ctx->GetAttr("axis", &dim_to_reduce)); + + // Recall first dimension of a shell variant tensor is the packing + // dimension. We don't allow expanding this dimension. + OP_REQUIRES( + op_ctx, dim_to_reduce != 0, + InvalidArgument("ReduceSumOp cannot reduce over packing axis (zero'th " + "dimension). See ReduceSumByRotationOp.")); + } void Compute(OpKernelContext* op_ctx) override { // Recover the input tensor. @@ -354,32 +371,34 @@ class ReduceSumOp : public OpKernel { OP_REQUIRES(op_ctx, value.dim_size(0) > 0, InvalidArgument("Cannot reduce_sum an empty ciphertext.")); - // Recover the axis to reduce over. - Tensor const& axis_tensor = op_ctx->input(1); - OP_REQUIRES(op_ctx, axis_tensor.NumElements() == 1, - InvalidArgument("axis must be scalar, saw shape: ", - axis_tensor.shape().DebugString())); - OP_REQUIRES_VALUE(int64 axis, op_ctx, GetScalar(op_ctx, 1)); - - // The axis to reduce over. - int dim_to_reduce = axis - 1; + OP_REQUIRES( + op_ctx, dim_to_reduce != 0, + InvalidArgument("ReduceSumOp cannot reduce over packing axis (zero'th " + "dimension). See ReduceSumByRotationOp.")); + + // We emulate numpy's interpretation of the dim axis when + // -input.dims() >= dim <= input.dims(). + int clamped_dim = dim_to_reduce; + if (clamped_dim < 0) { + clamped_dim += value.dims() + 1; // + 1 for packing dim. + } else if (clamped_dim > 0) { + clamped_dim -= 1; // -1 for packing dimension. + } // Check axis is within dim size. - OP_REQUIRES(op_ctx, dim_to_reduce < value.dims(), - InvalidArgument("Cannot reduce_sum over polynomial_axis '", - dim_to_reduce, "' (axis '", axis, - "') for input with shape ", - value.shape().DebugString())); + OP_REQUIRES( + op_ctx, clamped_dim >= 0 && clamped_dim < value.dims(), + InvalidArgument("Cannot reduce_sum over polynomial_axis '", clamped_dim, + "for input with shape ", value.shape().DebugString())); - uint8_t dim_sz_to_reduce = value.dim_size(dim_to_reduce); + uint8_t dim_sz_to_reduce = value.dim_size(clamped_dim); - // Since the first dimension is the batching dimension, subtract 1. - auto flat_value = value.flat_inner_outer_dims(dim_to_reduce - 1); + auto flat_value = value.flat_inner_outer_dims(clamped_dim - 1); // Setup the output. Tensor* output; auto output_shape = value.shape(); - OP_REQUIRES_OK(op_ctx, output_shape.RemoveDimWithStatus(dim_to_reduce)); + OP_REQUIRES_OK(op_ctx, output_shape.RemoveDimWithStatus(clamped_dim)); OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, output_shape, &output)); // Setup a shape to access the output Tensor as a flat Tensor, with the // same indexing as the input Tensor excluding the dimension to reduce. diff --git a/tf_shell/cc/kernels/shape_kernels.cc b/tf_shell/cc/kernels/shape_kernels.cc index 4093e5c..5443fc8 100644 --- a/tf_shell/cc/kernels/shape_kernels.cc +++ b/tf_shell/cc/kernels/shape_kernels.cc @@ -48,14 +48,24 @@ class ExpandDimsVariantOp : public OpKernel { // Recall first dimension of a shell variant tensor is the packing // dimension. We don't allow expanding this dimension. OP_REQUIRES(op_ctx, dim != 0, InvalidArgument("Invalid dimension index.")); - dim += dim > 0 ? -1 : 0; } void Compute(OpKernelContext* ctx) override { - OP_REQUIRES( - ctx, (dim >= -1 - ctx->input(0).dims() && dim <= ctx->input(0).dims()), - InvalidArgument("Tried to expand dim index ", dim, " for tensor with ", - ctx->input(0).dims(), " dimensions.")); + OP_REQUIRES(ctx, dim != 0, InvalidArgument("Invalid dimension index.")); + + // We emulate numpy's interpretation of the dim axis when + // -input.dims() >= dim <= input.dims(). + int clamped_dim = dim; + if (clamped_dim < 0) { + clamped_dim += ctx->input(0).dims() + 1; // + 1 for packing dim. + } else if (clamped_dim > 0) { + clamped_dim -= 1; // -1 for packing dimension. + } + + OP_REQUIRES(ctx, clamped_dim >= 0 && clamped_dim <= ctx->input(0).dims(), + InvalidArgument("Tried to expand dim index ", clamped_dim, + " for tensor with ", ctx->input(0).dims(), + " dimensions.")); auto existing_dims = ctx->input(0).shape().dim_sizes(); // Safe - # elements in tensor dims bounded. @@ -65,15 +75,9 @@ class ExpandDimsVariantOp : public OpKernel { new_shape[i] = existing_dims[i]; } - // We emulate numpy's interpretation of the dim axis when - // -input.dims() >= dim <= input.dims(). - if (dim < 0) { - dim += existing_dims.size() + 1; - } - // Clamp to the end if needed. - dim = std::min(dim, existing_dims_size); - new_shape.emplace(new_shape.begin() + dim, 1); + clamped_dim = std::min(clamped_dim, existing_dims_size); + new_shape.emplace(new_shape.begin() + clamped_dim, 1); TensorShape const output_shape(new_shape); Tensor* output = nullptr; diff --git a/tf_shell/cc/kernels/utils.h b/tf_shell/cc/kernels/utils.h index 53875b2..becea8f 100644 --- a/tf_shell/cc/kernels/utils.h +++ b/tf_shell/cc/kernels/utils.h @@ -20,14 +20,20 @@ #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/util/bcast.h" +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +using tensorflow::BCast; using tensorflow::OkStatus; using tensorflow::OpKernelContext; using tensorflow::StatusOr; using tensorflow::Tensor; +using tensorflow::TensorShape; using tensorflow::TensorShapeUtils; +using tensorflow::TTypes; using tensorflow::Variant; using tensorflow::errors::InvalidArgument; +using tensorflow::errors::Unimplemented; using std::vector; @@ -86,6 +92,55 @@ StatusOr GetVariant(OpKernelContext* ctx, int index) { return t; } +template +inline Eigen::Tensor BFlat( + OpKernelContext* op_ctx, Tensor const& t, BCast::Vec const& x_reshape, + BCast::Vec const& x_bcast) { + // A TensorFlow is a TTypes::Tensor (aka Eigen::TensorMap). + // Eigen::TensorMap is a view into an Eigen::Tensor. When performing a + // reshape, broadcast, or even eval on an Eigen::TensorMap, it cannot be + // assigned to another Eigen::TensorMap. This is why the following code + // assigns the result of the reshape to an Eigen::Tensor. + // + // For a demo, see https://godbolt.org/z/41xvWvb63 + typedef Eigen::Tensor + ETensor; + + ETensor reshaped_t = t.template shaped(x_reshape); + + ETensor broadcasted_t = + reshaped_t.broadcast(BCast::ToIndexArray(x_bcast)); + + return std::move( + broadcasted_t.reshape(BCast::ToIndexArray<1>({broadcasted_t.size()}))); +} + +inline Eigen::Tensor MyBFlat( + OpKernelContext* op_ctx, Tensor const& t, BCast::Vec const& x_reshape, + BCast::Vec const& x_bcast) { + // Uses the switch statement approach as in: + // `tensorflow/tensorflow/core/kernels/broadcast_to_op.h` + int const ndims = x_reshape.size(); + switch (ndims) { + case 1: + return std::move(BFlat<1>(op_ctx, t, x_reshape, x_bcast)); + case 2: + return std::move(BFlat<2>(op_ctx, t, x_reshape, x_bcast)); + case 3: + return std::move(BFlat<3>(op_ctx, t, x_reshape, x_bcast)); + case 4: + return std::move(BFlat<4>(op_ctx, t, x_reshape, x_bcast)); + case 5: + return std::move(BFlat<5>(op_ctx, t, x_reshape, x_bcast)); + case 6: + return std::move(BFlat<6>(op_ctx, t, x_reshape, x_bcast)); + default: + op_ctx->SetStatus(Unimplemented("Broadcast ", t.DebugString(), + " is not supported yet.")); + return std::move(BFlat<1>(op_ctx, t, x_reshape, x_bcast)); + } +} + // Status macros from // https://github.com/abseil/abseil-cpp/issues/976#issuecomment-1664601671 // diff --git a/tf_shell/cc/ops/shell_ops.cc b/tf_shell/cc/ops/shell_ops.cc index ec13826..98a323d 100644 --- a/tf_shell/cc/ops/shell_ops.cc +++ b/tf_shell/cc/ops/shell_ops.cc @@ -17,6 +17,8 @@ #include "tensorflow/core/framework/shape_inference.h" using tensorflow::OkStatus; +using tensorflow::errors::InvalidArgument; +using tensorflow::shape_inference::DimensionHandle; using tensorflow::shape_inference::InferenceContext; using tensorflow::shape_inference::ScalarShape; using tensorflow::shape_inference::ShapeHandle; @@ -114,7 +116,18 @@ REGISTER_OP("AddCtCt64") .Input("b: variant") .Output("c: variant") .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->input(0)); + auto a_sz = c->NumElements(c->input(0)); + auto b_sz = c->NumElements(c->input(1)); + DimensionHandle out_sz; + TF_RETURN_IF_ERROR(c->Max(a_sz, b_sz, &out_sz)); + + if (c->Value(out_sz) == c->Value(a_sz)) { + c->set_output(0, c->input(0)); + } else if (c->Value(out_sz) == c->Value(b_sz)) { + c->set_output(0, c->input(1)); + } else { + c->set_output(0, c->UnknownShape()); + } return OkStatus(); }); @@ -123,7 +136,18 @@ REGISTER_OP("AddCtPt64") .Input("b: variant") .Output("c: variant") .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->input(0)); + auto a_sz = c->NumElements(c->input(0)); + auto b_sz = c->NumElements(c->input(1)); + DimensionHandle out_sz; + TF_RETURN_IF_ERROR(c->Max(a_sz, b_sz, &out_sz)); + + if (c->Value(out_sz) == c->Value(a_sz)) { + c->set_output(0, c->input(0)); + } else if (c->Value(out_sz) == c->Value(b_sz)) { + c->set_output(0, c->input(1)); + } else { + c->set_output(0, c->UnknownShape()); + } return OkStatus(); }); @@ -133,7 +157,18 @@ REGISTER_OP("AddPtPt64") .Input("b: variant") .Output("c: variant") .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->input(0)); + auto a_sz = c->NumElements(c->input(0)); + auto b_sz = c->NumElements(c->input(1)); + DimensionHandle out_sz; + TF_RETURN_IF_ERROR(c->Max(a_sz, b_sz, &out_sz)); + + if (c->Value(out_sz) == c->Value(a_sz)) { + c->set_output(0, c->input(0)); + } else if (c->Value(out_sz) == c->Value(b_sz)) { + c->set_output(0, c->input(1)); + } else { + c->set_output(0, c->UnknownShape()); + } return OkStatus(); }); @@ -142,7 +177,18 @@ REGISTER_OP("SubCtCt64") .Input("b: variant") .Output("c: variant") .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->input(0)); + auto a_sz = c->NumElements(c->input(0)); + auto b_sz = c->NumElements(c->input(1)); + DimensionHandle out_sz; + TF_RETURN_IF_ERROR(c->Max(a_sz, b_sz, &out_sz)); + + if (c->Value(out_sz) == c->Value(a_sz)) { + c->set_output(0, c->input(0)); + } else if (c->Value(out_sz) == c->Value(b_sz)) { + c->set_output(0, c->input(1)); + } else { + c->set_output(0, c->UnknownShape()); + } return OkStatus(); }); @@ -151,7 +197,18 @@ REGISTER_OP("SubCtPt64") .Input("b: variant") .Output("c: variant") .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->input(0)); + auto a_sz = c->NumElements(c->input(0)); + auto b_sz = c->NumElements(c->input(1)); + DimensionHandle out_sz; + TF_RETURN_IF_ERROR(c->Max(a_sz, b_sz, &out_sz)); + + if (c->Value(out_sz) == c->Value(a_sz)) { + c->set_output(0, c->input(0)); + } else if (c->Value(out_sz) == c->Value(b_sz)) { + c->set_output(0, c->input(1)); + } else { + c->set_output(0, c->UnknownShape()); + } return OkStatus(); }); @@ -161,7 +218,18 @@ REGISTER_OP("SubPtPt64") .Input("b: variant") .Output("c: variant") .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->input(1)); + auto a_sz = c->NumElements(c->input(0)); + auto b_sz = c->NumElements(c->input(1)); + DimensionHandle out_sz; + TF_RETURN_IF_ERROR(c->Max(a_sz, b_sz, &out_sz)); + + if (c->Value(out_sz) == c->Value(a_sz)) { + c->set_output(0, c->input(0)); + } else if (c->Value(out_sz) == c->Value(b_sz)) { + c->set_output(0, c->input(1)); + } else { + c->set_output(0, c->UnknownShape()); + } return OkStatus(); }); @@ -188,7 +256,18 @@ REGISTER_OP("MulCtCt64") .Input("b: variant") .Output("c: variant") .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->input(0)); + auto a_sz = c->NumElements(c->input(0)); + auto b_sz = c->NumElements(c->input(1)); + DimensionHandle out_sz; + TF_RETURN_IF_ERROR(c->Max(a_sz, b_sz, &out_sz)); + + if (c->Value(out_sz) == c->Value(a_sz)) { + c->set_output(0, c->input(0)); + } else if (c->Value(out_sz) == c->Value(b_sz)) { + c->set_output(0, c->input(1)); + } else { + c->set_output(0, c->UnknownShape()); + } return OkStatus(); }); @@ -197,7 +276,18 @@ REGISTER_OP("MulCtPt64") .Input("b: variant") .Output("c: variant") .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->input(0)); + auto a_sz = c->NumElements(c->input(0)); + auto b_sz = c->NumElements(c->input(1)); + DimensionHandle out_sz; + TF_RETURN_IF_ERROR(c->Max(a_sz, b_sz, &out_sz)); + + if (c->Value(out_sz) == c->Value(a_sz)) { + c->set_output(0, c->input(0)); + } else if (c->Value(out_sz) == c->Value(b_sz)) { + c->set_output(0, c->input(1)); + } else { + c->set_output(0, c->UnknownShape()); + } return OkStatus(); }); @@ -229,7 +319,18 @@ REGISTER_OP("MulPtPt64") .Input("b: variant") .Output("c: variant") .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->input(0)); + auto a_sz = c->NumElements(c->input(0)); + auto b_sz = c->NumElements(c->input(1)); + DimensionHandle out_sz; + TF_RETURN_IF_ERROR(c->Max(a_sz, b_sz, &out_sz)); + + if (c->Value(out_sz) == c->Value(a_sz)) { + c->set_output(0, c->input(0)); + } else if (c->Value(out_sz) == c->Value(b_sz)) { + c->set_output(0, c->input(1)); + } else { + c->set_output(0, c->UnknownShape()); + } return OkStatus(); }); @@ -290,16 +391,66 @@ REGISTER_OP("ReduceSumByRotation64") .Input("rotation_key: variant") .Output("repeated_reduce_sum: variant") .SetShapeFn([](InferenceContext* c) { + // ReduceSum over the packing dimension does not change the shape. c->set_output(0, c->input(0)); return OkStatus(); }); REGISTER_OP("ReduceSum64") .Input("value: variant") - .Input("axis: int64") + .Attr("axis: int") .Output("repeated_reduce_sum: variant") .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->input(0)); + tsl::int32 rank = c->Rank(c->input(0)); + + tsl::int32 axis; + TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis)); + + // Check that axis is in the correct range. + if (axis == 0) { + return InvalidArgument( + "axis may not be zero. See ReduceSumByRotation()"); + } + + // Recall first dimension of a shell variant tensor is the packing + // dimension. + int clamped_axis = axis; + if (clamped_axis < 0) { + clamped_axis += rank + 1; + } else if (clamped_axis > 0) { + clamped_axis -= 1; + } + + if (clamped_axis < 0 || clamped_axis > rank) { + return InvalidArgument("axis must be in the range [0, rank], got ", + clamped_axis); + } + + ShapeHandle output; + + // If this op ever supports keepdim=True, use the following shape. + // DimensionHandle reduced_dim = c->MakeDim({1}); + // TF_RETURN_IF_ERROR(c->ReplaceDim(c->input(0), clamped_axis, + // reduced_dim, &output)); + + // This op currently only supports keepdim=False whose shape is computed + // via the following. + ShapeHandle prefix; + TF_RETURN_IF_ERROR(c->Subshape(c->input(0), 0, clamped_axis, &prefix)); + + ShapeHandle postfix; + TF_RETURN_IF_ERROR( + c->Subshape(c->input(0), clamped_axis + 1, rank, &postfix)); + + if (clamped_axis == 0) { + output = postfix; + } else if (clamped_axis == rank - 1) { + output = prefix; + } else { + TF_RETURN_IF_ERROR(c->Concatenate(prefix, postfix, &output)); + } + + c->set_output(0, output); return OkStatus(); }); @@ -343,10 +494,10 @@ REGISTER_OP("ExpandDimsVariant") tsl::int32 axis; TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis)); - //Check that axis is in the correct range. + // Check that axis is in the correct range. if (axis < -rank || axis > rank) { - return tensorflow::errors::InvalidArgument( - "axis must be in the range [-rank, rank], got ", axis); + return InvalidArgument("axis must be in the range [-rank, rank], got ", + axis); } if (axis < 0) { diff --git a/tf_shell/python/shell_tensor.py b/tf_shell/python/shell_tensor.py index ec3842e..49f8a81 100644 --- a/tf_shell/python/shell_tensor.py +++ b/tf_shell/python/shell_tensor.py @@ -206,20 +206,16 @@ def __rsub__(self, other): # zeros out to the number of slots. shell_other = to_shell_plaintext(other, self._context) - # Match the shapes via broadcasting. This is after importing to - # save NTTs. - self_matched, other_matched = _match_shape(self, shell_other) - - if self_matched.is_encrypted: - negative_self_matched = -self_matched + if self.is_encrypted: + negative_self= -self raw_result = shell_ops.add_ct_pt64( - negative_self_matched._raw_tensor, other_matched._raw_tensor + negative_self._raw_tensor, shell_other._raw_tensor ) else: raw_result = shell_ops.sub_pt_pt64( self._context._raw_context, - other_matched._raw_tensor, - self_matched._raw_tensor, + shell_other._raw_tensor, + self._raw_tensor, ) return ShellTensor64( @@ -402,42 +398,6 @@ def _match_moduli_and_scaling(x, y): while x._scaling_factor < y._scaling_factor: x = x * y._scaling_factor - x, y = _match_shape(x, y) - - return x, y - - -def _match_shape(x, y): - # Match the shape of x and y via broadcasting. Note, this copies the data, - # fully materializing the new tensors shape. In the future, shell_ops - # should suuport broadcasting directly, avoiding the copy. - - x = tf.cond( - tf.size(x._raw_tensor) < tf.size(y._raw_tensor), - lambda: ShellTensor64( - _raw_tensor=tf.broadcast_to(x._raw_tensor, tf.shape(y._raw_tensor)), - _context=x._context, - _underlying_dtype=x._underlying_dtype, - _scaling_factor=x._scaling_factor, - _is_enc=x._is_enc, - _noise_bit_count=x._noise_bit_count, - ), - lambda: x, - ) - - y = tf.cond( - tf.size(x._raw_tensor) > tf.size(y._raw_tensor), - lambda: ShellTensor64( - _raw_tensor=tf.broadcast_to(y._raw_tensor, tf.shape(x._raw_tensor)), - _context=y._context, - _underlying_dtype=y._underlying_dtype, - _scaling_factor=y._scaling_factor, - _is_enc=y._is_enc, - _noise_bit_count=y._noise_bit_count, - ), - lambda: y, - ) - return x, y @@ -681,7 +641,7 @@ def reduce_sum(x, axis, rotation_key=None): ) return ShellTensor64( - _raw_tensor=shell_ops.reduce_sum64(x._raw_tensor, axis), + _raw_tensor=shell_ops.reduce_sum64(x._raw_tensor, axis=axis), _context=x._context, _underlying_dtype=x._underlying_dtype, _scaling_factor=x._scaling_factor, @@ -798,7 +758,7 @@ def expand_dims(x, axis=-1): "Cannot expand dims at axis 0 for ShellTensor64, this is the batching dimension." ) return ShellTensor64( - _raw_tensor=shell_ops.expand_dims_variant(x._raw_tensor, axis), + _raw_tensor=shell_ops.expand_dims_variant(x._raw_tensor, axis=axis), _context=x._context, _underlying_dtype=x._underlying_dtype, _scaling_factor=x._scaling_factor, diff --git a/tf_shell/test/shape_test.py b/tf_shell/test/shape_test.py index 595db49..fc75f60 100644 --- a/tf_shell/test/shape_test.py +++ b/tf_shell/test/shape_test.py @@ -101,12 +101,12 @@ def _test_expand_dims(self, test_context, dim): def test_expand_dims(self): for test_context in self.test_contexts: - for dim in range(1, len(test_context.outer_shape) + 1, 1): + for dim in range(1, len(test_context.outer_shape) + 2, 1): with self.subTest( f"expand_dims on dimension {dim} with context `{test_context}`." ): self._test_expand_dims(test_context, dim) - for dim in range(-len(test_context.outer_shape) + 1, -1, 1): + for dim in range(-len(test_context.outer_shape) - 1, 0, 1): with self.subTest( f"expand_dims on dimension {dim} with context `{test_context}`." ): diff --git a/tf_shell_ml/test/mnist_post_scale_test.py b/tf_shell_ml/test/mnist_post_scale_test.py index 1e6abd1..4e1cdd4 100644 --- a/tf_shell_ml/test/mnist_post_scale_test.py +++ b/tf_shell_ml/test/mnist_post_scale_test.py @@ -131,12 +131,12 @@ def train_step(x, y): # Sum over the output classes. scaled_grad = tf_shell.reduce_sum(scaled_grad, axis=1) - # ^ batch_size x 1 x flattened weights + # ^ batch_size x flattened weights # In the real world, this approach would also likely require clipping # the gradient, and adding DP noise. - # Reshape to remove the '1' dimension in the middle. + # Reshape to unflatten the weights. scaled_grad = tf_shell.reshape(scaled_grad, [batch_sz] + grad_shape) # ^ batch_size x weights @@ -175,6 +175,7 @@ def test_mnist_post_scale_eager(self): atol=1 / context.scaling_factor * context.num_slots, ) +class TestPlaintextPostScale(tf.test.TestCase): def test_mnist_post_scale_autograph(self): tf.config.run_functions_eagerly(False) From f5aff3e59a74cd10e3af31edd8c13879e700bbc4 Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Mon, 13 May 2024 09:17:10 +0000 Subject: [PATCH 09/22] Ring formatters. --- tf_shell/cc/kernels/polynomial_kernels.cc | 6 ++---- tf_shell/python/shell_key.py | 4 +++- tf_shell/python/shell_tensor.py | 2 +- tf_shell_ml/test/mnist_noenc_backprop_test.py | 4 +++- tf_shell_ml/test/mnist_post_scale_test.py | 3 ++- 5 files changed, 11 insertions(+), 8 deletions(-) diff --git a/tf_shell/cc/kernels/polynomial_kernels.cc b/tf_shell/cc/kernels/polynomial_kernels.cc index 93295ba..677c00d 100644 --- a/tf_shell/cc/kernels/polynomial_kernels.cc +++ b/tf_shell/cc/kernels/polynomial_kernels.cc @@ -138,8 +138,7 @@ class PolynomialImportOp : public OpKernel { auto thread_pool = op_ctx->device()->tensorflow_cpu_worker_threads()->workers; - int const cost_per_import = - 70 * num_slots; // ns, measured on log_n = 11 + int const cost_per_import = 70 * num_slots; // ns, measured on log_n = 11 thread_pool->ParallelFor(flat_input.dimension(1), cost_per_import, import_in_range); } @@ -233,8 +232,7 @@ class PolynomialExportOp : public OpKernel { auto thread_pool = op_ctx->device()->tensorflow_cpu_worker_threads()->workers; - int const cost_per_export = - 70 * num_slots; // ns, measured on log_n = 11 + int const cost_per_export = 70 * num_slots; // ns, measured on log_n = 11 thread_pool->ParallelFor(flat_output.dimension(1), cost_per_export, export_in_range); } diff --git a/tf_shell/python/shell_key.py b/tf_shell/python/shell_key.py index bf5d6fe..2695850 100644 --- a/tf_shell/python/shell_key.py +++ b/tf_shell/python/shell_key.py @@ -77,4 +77,6 @@ def create_rotation_key64(context, key, skip_at_mul_depth=[]): context = mod_reduce_context64(context) key = mod_reduce_key64(key) - return ShellRotationKey64(_raw_rot_keys_at_level=raw_rot_keys_at_level, context=context) + return ShellRotationKey64( + _raw_rot_keys_at_level=raw_rot_keys_at_level, context=context + ) diff --git a/tf_shell/python/shell_tensor.py b/tf_shell/python/shell_tensor.py index 49f8a81..b668e83 100644 --- a/tf_shell/python/shell_tensor.py +++ b/tf_shell/python/shell_tensor.py @@ -207,7 +207,7 @@ def __rsub__(self, other): shell_other = to_shell_plaintext(other, self._context) if self.is_encrypted: - negative_self= -self + negative_self = -self raw_result = shell_ops.add_ct_pt64( negative_self._raw_tensor, shell_other._raw_tensor ) diff --git a/tf_shell_ml/test/mnist_noenc_backprop_test.py b/tf_shell_ml/test/mnist_noenc_backprop_test.py index b08eed0..eaf4e30 100644 --- a/tf_shell_ml/test/mnist_noenc_backprop_test.py +++ b/tf_shell_ml/test/mnist_noenc_backprop_test.py @@ -92,7 +92,9 @@ def test_mnist_plaintext_backprop(self): # Train the model. for epoch in range(epochs): - for step, (x_batch, y_batch) in enumerate(train_dataset.take(batch_size)): + for step, (x_batch, y_batch) in enumerate( + train_dataset.take(batch_size) + ): # Plaintext backprop splitting the batch in half vertically. output_layer_grad, hidden_layer_grad = train_step(x_batch, y_batch) diff --git a/tf_shell_ml/test/mnist_post_scale_test.py b/tf_shell_ml/test/mnist_post_scale_test.py index 4e1cdd4..9f92e5c 100644 --- a/tf_shell_ml/test/mnist_post_scale_test.py +++ b/tf_shell_ml/test/mnist_post_scale_test.py @@ -175,12 +175,13 @@ def test_mnist_post_scale_eager(self): atol=1 / context.scaling_factor * context.num_slots, ) + class TestPlaintextPostScale(tf.test.TestCase): def test_mnist_post_scale_autograph(self): tf.config.run_functions_eagerly(False) (x_batch, y_batch) = next(iter(train_dataset)) - + # Plaintext ps_grads = train_step(x_batch, y_batch) From c4a52220ff331895f20e1c54d18acf5b79ad2919 Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Fri, 17 May 2024 19:16:46 +0000 Subject: [PATCH 10/22] HE-specific graph based optimizations. This optimization reorders ciphertext - plaintext operations which look like ((ct * pt) * pt) to (ct * (pt * pt). --- tf_shell/BUILD | 30 ++ tf_shell/cc/kernels/add_kernels.cc | 19 +- tf_shell/cc/kernels/mul_kernels.cc | 14 +- tf_shell/cc/ops/shell_ops.cc | 89 +++--- tf_shell/cc/optimizers/ct_pt.cc | 274 ++++++++++++++++++ tf_shell/cc/optimizers/ct_pt.h | 30 ++ tf_shell/python/optimizers/__init__.py | 15 + .../python/optimizers/shell_optimizers.py | 154 ++++++++++ tf_shell/python/shell_tensor.py | 48 ++- tf_shell/test/BUILD | 13 + tf_shell/test/ct_pt_optimizer_test.py | 175 +++++++++++ 11 files changed, 792 insertions(+), 69 deletions(-) create mode 100644 tf_shell/cc/optimizers/ct_pt.cc create mode 100644 tf_shell/cc/optimizers/ct_pt.h create mode 100644 tf_shell/python/optimizers/__init__.py create mode 100644 tf_shell/python/optimizers/shell_optimizers.py create mode 100644 tf_shell/test/ct_pt_optimizer_test.py diff --git a/tf_shell/BUILD b/tf_shell/BUILD index 7ce3bb5..6a6e4cd 100644 --- a/tf_shell/BUILD +++ b/tf_shell/BUILD @@ -48,12 +48,41 @@ py_library( srcs_version = "PY3", ) +cc_binary( + name = "python/optimizers/_ct_pt_optimizer.so", + srcs = [ + "cc/optimizers/ct_pt.cc", + "cc/optimizers/ct_pt.h", + ], + copts = [ + "-pthread", + "-fPIC", + ], + linkshared = 1, + deps = [ + "//third_party/tensorflow:hermetic_tf", + "@com_google_protobuf//:protobuf", + ], +) + +py_library( + name = "shell_optimizers_py", + srcs = [ + "python/optimizers/shell_optimizers.py", + ], + data = [ + ":python/optimizers/_ct_pt_optimizer.so", + ], + srcs_version = "PY3", +) + py_library( name = "tf_shell_lib", srcs = [ "__init__.py", "python/__init__.py", "python/ops/__init__.py", + "python/optimizers/__init__.py", "python/shell_context.py", "python/shell_key.py", "python/shell_tensor.py", @@ -62,6 +91,7 @@ py_library( visibility = ["//visibility:public"], deps = [ ":shell_ops_py", + ":shell_optimizers_py", ], ) diff --git a/tf_shell/cc/kernels/add_kernels.cc b/tf_shell/cc/kernels/add_kernels.cc index f8772de..1d92983 100644 --- a/tf_shell/cc/kernels/add_kernels.cc +++ b/tf_shell/cc/kernels/add_kernels.cc @@ -95,9 +95,10 @@ class AddCtCtOp : public OpKernel { explicit AddCtCtOp(OpKernelConstruction* op_ctx) : OpKernel(op_ctx) {} void Compute(OpKernelContext* op_ctx) override { - // Unpack the input arguments. - Tensor const& a = op_ctx->input(0); - Tensor const& b = op_ctx->input(1); + // Unpack the input arguments. The 0th argument is the context, which is not + // directly used in this op but required for graph optimization. + Tensor const& a = op_ctx->input(1); + Tensor const& b = op_ctx->input(2); BCast bcast(BCast::FromShape(a.shape()), BCast::FromShape(b.shape()), /*fewer_dims_optimization=*/true); @@ -154,9 +155,10 @@ class AddCtPtOp : public OpKernel { explicit AddCtPtOp(OpKernelConstruction* op_ctx) : OpKernel(op_ctx) {} void Compute(OpKernelContext* op_ctx) override { - // Unpack the input arguments. - Tensor const& a = op_ctx->input(0); - Tensor const& b = op_ctx->input(1); + // Unpack the input arguments. The 0th argument is the context, which is not + // directly used in this op but required for graph optimization. + Tensor const& a = op_ctx->input(1); + Tensor const& b = op_ctx->input(2); BCast bcast(BCast::FromShape(a.shape()), BCast::FromShape(b.shape()), /*fewer_dims_optimization=*/true); @@ -279,8 +281,9 @@ class NegCtOp : public OpKernel { explicit NegCtOp(OpKernelConstruction* op_ctx) : OpKernel(op_ctx) {} void Compute(OpKernelContext* op_ctx) override { - // Unpack the input argument. - Tensor const& a = op_ctx->input(0); + // Unpack the input arguments. The 0th argument is the context, which is not + // directly used in this op but required for graph optimization. + Tensor const& a = op_ctx->input(1); // Allocate the output tensor which is the same size as the input. Tensor* output; diff --git a/tf_shell/cc/kernels/mul_kernels.cc b/tf_shell/cc/kernels/mul_kernels.cc index 33de13e..d56871c 100644 --- a/tf_shell/cc/kernels/mul_kernels.cc +++ b/tf_shell/cc/kernels/mul_kernels.cc @@ -53,9 +53,10 @@ class MulCtCtOp : public OpKernel { explicit MulCtCtOp(OpKernelConstruction* op_ctx) : OpKernel(op_ctx) {} void Compute(OpKernelContext* op_ctx) override { - // Get the input tensors. - Tensor const& a = op_ctx->input(0); - Tensor const& b = op_ctx->input(1); + // Unpack the input arguments. The 0th argument is the context, which is not + // directly used in this op but required for graph optimization. + Tensor const& a = op_ctx->input(1); + Tensor const& b = op_ctx->input(2); BCast bcast(BCast::FromShape(a.shape()), BCast::FromShape(b.shape()), /*fewer_dims_optimization=*/true); @@ -112,9 +113,10 @@ class MulCtPtOp : public OpKernel { explicit MulCtPtOp(OpKernelConstruction* op_ctx) : OpKernel(op_ctx) {} void Compute(OpKernelContext* op_ctx) override { - // Get the input tensors. - Tensor const& a = op_ctx->input(0); - Tensor const& b = op_ctx->input(1); + // Unpack the input arguments. The 0th argument is the context, which is not + // directly used in this op but required for graph optimization. + Tensor const& a = op_ctx->input(1); + Tensor const& b = op_ctx->input(2); BCast bcast(BCast::FromShape(a.shape()), BCast::FromShape(b.shape()), /*fewer_dims_optimization=*/true); diff --git a/tf_shell/cc/ops/shell_ops.cc b/tf_shell/cc/ops/shell_ops.cc index 98a323d..dac495d 100644 --- a/tf_shell/cc/ops/shell_ops.cc +++ b/tf_shell/cc/ops/shell_ops.cc @@ -54,8 +54,6 @@ REGISTER_OP("PolynomialImport64") return OkStatus(); }); -// Output shape depends on content of context object -// so no SetShapeFn() for this Op. REGISTER_OP("PolynomialExport64") .Attr("dtype: {uint8, int8, uint16, int16, uint32, int32, uint64, int64}") .Attr("batching_dim: int") @@ -69,7 +67,7 @@ REGISTER_OP("PolynomialExport64") ShapeHandle output; TF_RETURN_IF_ERROR( - c->Concatenate(c->input(1), batching_dim_shape, &output)); + c->Concatenate(batching_dim_shape, c->input(1), &output)); c->set_output(0, output); return OkStatus(); @@ -104,7 +102,7 @@ REGISTER_OP("Decrypt64") ShapeHandle output; TF_RETURN_IF_ERROR( - c->Concatenate(c->input(1), batching_dim_shape, &output)); + c->Concatenate(batching_dim_shape, c->input(2), &output)); c->set_output(0, output); return OkStatus(); @@ -112,19 +110,20 @@ REGISTER_OP("Decrypt64") // Add and subtract. REGISTER_OP("AddCtCt64") + .Input("context: variant") .Input("a: variant") .Input("b: variant") .Output("c: variant") .SetShapeFn([](InferenceContext* c) { - auto a_sz = c->NumElements(c->input(0)); - auto b_sz = c->NumElements(c->input(1)); + auto a_sz = c->NumElements(c->input(1)); + auto b_sz = c->NumElements(c->input(2)); DimensionHandle out_sz; TF_RETURN_IF_ERROR(c->Max(a_sz, b_sz, &out_sz)); if (c->Value(out_sz) == c->Value(a_sz)) { - c->set_output(0, c->input(0)); - } else if (c->Value(out_sz) == c->Value(b_sz)) { c->set_output(0, c->input(1)); + } else if (c->Value(out_sz) == c->Value(b_sz)) { + c->set_output(0, c->input(2)); } else { c->set_output(0, c->UnknownShape()); } @@ -132,19 +131,20 @@ REGISTER_OP("AddCtCt64") }); REGISTER_OP("AddCtPt64") + .Input("context: variant") .Input("a: variant") .Input("b: variant") .Output("c: variant") .SetShapeFn([](InferenceContext* c) { - auto a_sz = c->NumElements(c->input(0)); - auto b_sz = c->NumElements(c->input(1)); + auto a_sz = c->NumElements(c->input(1)); + auto b_sz = c->NumElements(c->input(2)); DimensionHandle out_sz; TF_RETURN_IF_ERROR(c->Max(a_sz, b_sz, &out_sz)); if (c->Value(out_sz) == c->Value(a_sz)) { - c->set_output(0, c->input(0)); - } else if (c->Value(out_sz) == c->Value(b_sz)) { c->set_output(0, c->input(1)); + } else if (c->Value(out_sz) == c->Value(b_sz)) { + c->set_output(0, c->input(2)); } else { c->set_output(0, c->UnknownShape()); } @@ -157,15 +157,15 @@ REGISTER_OP("AddPtPt64") .Input("b: variant") .Output("c: variant") .SetShapeFn([](InferenceContext* c) { - auto a_sz = c->NumElements(c->input(0)); - auto b_sz = c->NumElements(c->input(1)); + auto a_sz = c->NumElements(c->input(1)); + auto b_sz = c->NumElements(c->input(2)); DimensionHandle out_sz; TF_RETURN_IF_ERROR(c->Max(a_sz, b_sz, &out_sz)); if (c->Value(out_sz) == c->Value(a_sz)) { - c->set_output(0, c->input(0)); - } else if (c->Value(out_sz) == c->Value(b_sz)) { c->set_output(0, c->input(1)); + } else if (c->Value(out_sz) == c->Value(b_sz)) { + c->set_output(0, c->input(2)); } else { c->set_output(0, c->UnknownShape()); } @@ -173,19 +173,20 @@ REGISTER_OP("AddPtPt64") }); REGISTER_OP("SubCtCt64") + .Input("context: variant") .Input("a: variant") .Input("b: variant") .Output("c: variant") .SetShapeFn([](InferenceContext* c) { - auto a_sz = c->NumElements(c->input(0)); - auto b_sz = c->NumElements(c->input(1)); + auto a_sz = c->NumElements(c->input(1)); + auto b_sz = c->NumElements(c->input(2)); DimensionHandle out_sz; TF_RETURN_IF_ERROR(c->Max(a_sz, b_sz, &out_sz)); if (c->Value(out_sz) == c->Value(a_sz)) { - c->set_output(0, c->input(0)); - } else if (c->Value(out_sz) == c->Value(b_sz)) { c->set_output(0, c->input(1)); + } else if (c->Value(out_sz) == c->Value(b_sz)) { + c->set_output(0, c->input(2)); } else { c->set_output(0, c->UnknownShape()); } @@ -193,19 +194,20 @@ REGISTER_OP("SubCtCt64") }); REGISTER_OP("SubCtPt64") + .Input("context: variant") .Input("a: variant") .Input("b: variant") .Output("c: variant") .SetShapeFn([](InferenceContext* c) { - auto a_sz = c->NumElements(c->input(0)); - auto b_sz = c->NumElements(c->input(1)); + auto a_sz = c->NumElements(c->input(1)); + auto b_sz = c->NumElements(c->input(2)); DimensionHandle out_sz; TF_RETURN_IF_ERROR(c->Max(a_sz, b_sz, &out_sz)); if (c->Value(out_sz) == c->Value(a_sz)) { - c->set_output(0, c->input(0)); - } else if (c->Value(out_sz) == c->Value(b_sz)) { c->set_output(0, c->input(1)); + } else if (c->Value(out_sz) == c->Value(b_sz)) { + c->set_output(0, c->input(2)); } else { c->set_output(0, c->UnknownShape()); } @@ -218,15 +220,15 @@ REGISTER_OP("SubPtPt64") .Input("b: variant") .Output("c: variant") .SetShapeFn([](InferenceContext* c) { - auto a_sz = c->NumElements(c->input(0)); - auto b_sz = c->NumElements(c->input(1)); + auto a_sz = c->NumElements(c->input(1)); + auto b_sz = c->NumElements(c->input(2)); DimensionHandle out_sz; TF_RETURN_IF_ERROR(c->Max(a_sz, b_sz, &out_sz)); if (c->Value(out_sz) == c->Value(a_sz)) { - c->set_output(0, c->input(0)); - } else if (c->Value(out_sz) == c->Value(b_sz)) { c->set_output(0, c->input(1)); + } else if (c->Value(out_sz) == c->Value(b_sz)) { + c->set_output(0, c->input(2)); } else { c->set_output(0, c->UnknownShape()); } @@ -234,10 +236,11 @@ REGISTER_OP("SubPtPt64") }); REGISTER_OP("NegCt64") + .Input("context: variant") .Input("value: variant") .Output("negated_value: variant") .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->input(0)); + c->set_output(0, c->input(1)); return OkStatus(); }); @@ -246,25 +249,26 @@ REGISTER_OP("NegPt64") .Input("value: variant") .Output("negated_value: variant") .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->input(0)); + c->set_output(0, c->input(1)); return OkStatus(); }); // Multiply. REGISTER_OP("MulCtCt64") + .Input("context: variant") .Input("a: variant") .Input("b: variant") .Output("c: variant") .SetShapeFn([](InferenceContext* c) { - auto a_sz = c->NumElements(c->input(0)); - auto b_sz = c->NumElements(c->input(1)); + auto a_sz = c->NumElements(c->input(1)); + auto b_sz = c->NumElements(c->input(2)); DimensionHandle out_sz; TF_RETURN_IF_ERROR(c->Max(a_sz, b_sz, &out_sz)); if (c->Value(out_sz) == c->Value(a_sz)) { - c->set_output(0, c->input(0)); - } else if (c->Value(out_sz) == c->Value(b_sz)) { c->set_output(0, c->input(1)); + } else if (c->Value(out_sz) == c->Value(b_sz)) { + c->set_output(0, c->input(2)); } else { c->set_output(0, c->UnknownShape()); } @@ -272,19 +276,20 @@ REGISTER_OP("MulCtCt64") }); REGISTER_OP("MulCtPt64") + .Input("context: variant") .Input("a: variant") .Input("b: variant") .Output("c: variant") .SetShapeFn([](InferenceContext* c) { - auto a_sz = c->NumElements(c->input(0)); - auto b_sz = c->NumElements(c->input(1)); + auto a_sz = c->NumElements(c->input(1)); + auto b_sz = c->NumElements(c->input(2)); DimensionHandle out_sz; TF_RETURN_IF_ERROR(c->Max(a_sz, b_sz, &out_sz)); if (c->Value(out_sz) == c->Value(a_sz)) { - c->set_output(0, c->input(0)); - } else if (c->Value(out_sz) == c->Value(b_sz)) { c->set_output(0, c->input(1)); + } else if (c->Value(out_sz) == c->Value(b_sz)) { + c->set_output(0, c->input(2)); } else { c->set_output(0, c->UnknownShape()); } @@ -319,15 +324,15 @@ REGISTER_OP("MulPtPt64") .Input("b: variant") .Output("c: variant") .SetShapeFn([](InferenceContext* c) { - auto a_sz = c->NumElements(c->input(0)); - auto b_sz = c->NumElements(c->input(1)); + auto a_sz = c->NumElements(c->input(1)); + auto b_sz = c->NumElements(c->input(2)); DimensionHandle out_sz; TF_RETURN_IF_ERROR(c->Max(a_sz, b_sz, &out_sz)); if (c->Value(out_sz) == c->Value(a_sz)) { - c->set_output(0, c->input(0)); - } else if (c->Value(out_sz) == c->Value(b_sz)) { c->set_output(0, c->input(1)); + } else if (c->Value(out_sz) == c->Value(b_sz)) { + c->set_output(0, c->input(2)); } else { c->set_output(0, c->UnknownShape()); } diff --git a/tf_shell/cc/optimizers/ct_pt.cc b/tf_shell/cc/optimizers/ct_pt.cc new file mode 100644 index 0000000..065030d --- /dev/null +++ b/tf_shell/cc/optimizers/ct_pt.cc @@ -0,0 +1,274 @@ +#include "ct_pt.h" + +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/functions.h" +#include "tensorflow/core/grappler/utils/graph_view.h" +#include "tensorflow/core/grappler/utils/topological_sort.h" + +namespace tensorflow { +namespace grappler { + +namespace { + +constexpr bool const debug = false; + +struct RemapperContext { + explicit RemapperContext(GrapplerItem* item, Status* status) + : nodes_to_preserve(item->NodesToPreserve()), + graph_view(&item->graph, status), + graph_properties(*item) {} + + std::unordered_set nodes_to_preserve; + utils::MutableGraphView graph_view; + GraphProperties graph_properties; +}; + +constexpr char kAddCtPt[] = "AddCtPt64"; +constexpr char kSubCtPt[] = "SubCtPt64"; +constexpr char kMulCtPt[] = "MulCtPt64"; + +constexpr char kAddPtPt[] = "AddPtPt64"; +constexpr char kSubPtPt[] = "SubPtPt64"; +constexpr char kMulPtPt[] = "MulPtPt64"; + +bool IsAddCtPt(NodeDef const& node) { return node.op() == kAddCtPt; } +bool IsSubCtPt(NodeDef const& node) { return node.op() == kSubCtPt; } +bool IsMulCtPt(NodeDef const& node) { return node.op() == kMulCtPt; } + +char const* GetOpFromCtPt(NodeDef const& node, bool is_ct_pt) { + if (IsAddCtPt(node)) { + return is_ct_pt ? kAddCtPt : kAddPtPt; + } else if (IsSubCtPt(node)) { + return is_ct_pt ? kSubCtPt : kAddPtPt; // PtPt replaces sub with add. + } else if (IsMulCtPt(node)) { + return is_ct_pt ? kMulCtPt : kMulPtPt; + } + + return nullptr; +} + +struct ReorderArith { + int shell_context_node_index; + int outer_node_index; + int inner_node_index; + int outer_pt_node_index; + int inner_ct_node_index; + int inner_pt_node_index; +}; + +void PrintReorderArith(RemapperContext& ctx, ReorderArith const& reorder) { + auto const* outer_node = + ctx.graph_view.GetNode(reorder.outer_node_index)->node(); + auto const* inner_node = + ctx.graph_view.GetNode(reorder.inner_node_index)->node(); + auto const* outer_pt_node = + ctx.graph_view.GetNode(reorder.outer_pt_node_index)->node(); + auto const* inner_ct_node = + ctx.graph_view.GetNode(reorder.inner_ct_node_index)->node(); + auto const* inner_pt_node = + ctx.graph_view.GetNode(reorder.inner_pt_node_index)->node(); + + std::cout << outer_node->name() << " ( " << inner_node->name() << " ( " + << inner_ct_node->name() << " , " << inner_pt_node->name() << " ), " + << outer_pt_node->name() << " ) " << std::endl; +} + +// Returns true if the node_index points to the outermost add of the pattern +// outer_op(inner_op(ct, pt), pt) and fills the ReorderArith struct accordingly. +// If the outer_op is add or sub, the inner_op must be add or sub. +// If instead the outer_op is mul, the inner_op must be mul. +bool FindAddOrSub(RemapperContext& ctx, int node_index, ReorderArith* reorder) { + // Check given node is op(ct, pt). + auto const* outer_node_view = ctx.graph_view.GetNode(node_index); + auto const* outer_node_def = outer_node_view->node(); + + if (!IsAddCtPt(*outer_node_def) && !IsSubCtPt(*outer_node_def) && + !IsMulCtPt(*outer_node_def)) { + return false; + } + + // Next, check the feed node ct at input 0 is the output of another + // CtPt op. + auto const& outer_fanin_0 = outer_node_view->GetRegularFanin(0); + auto const* context_node_view = outer_fanin_0.node_view(); + + auto const& outer_fanin_1 = outer_node_view->GetRegularFanin(1); + auto const* inner_node_view = outer_fanin_1.node_view(); + auto const* inner_node_def = inner_node_view->node(); + + auto const& outer_fanin_2 = outer_node_view->GetRegularFanin(2); + auto const* outer_pt_node_view = outer_fanin_2.node_view(); + + // If the outer op is add or sub, the inner op must be add or sub as well. + if (!IsMulCtPt(*outer_node_def) && !IsAddCtPt(*inner_node_def) && + !IsSubCtPt(*inner_node_def)) { + return false; + } + + // If the outer op is mul, the inner op must be mul as well. + if (IsMulCtPt(*outer_node_def) && !IsMulCtPt(*inner_node_def)) { + return false; + } + + auto const& inner_fanin_0 = inner_node_view->GetRegularFanin(0); + auto const* inner_context_node_view = inner_fanin_0.node_view(); + + // If the contexts do not match, the pattern should not be matched.. + if (context_node_view->node_index() != inner_context_node_view->node_index()) + return false; + + auto const& inner_fanin_1 = inner_node_view->GetRegularFanin(1); + auto const* inner_ct_node_view = inner_fanin_1.node_view(); + + auto const& inner_fanin_2 = inner_node_view->GetRegularFanin(2); + auto const* inner_pt_node_view = inner_fanin_2.node_view(); + + ReorderArith new_reorder{ + .shell_context_node_index = context_node_view->node_index(), + .outer_node_index = node_index, + .inner_node_index = inner_node_view->node_index(), + .outer_pt_node_index = outer_pt_node_view->node_index(), + .inner_ct_node_index = inner_ct_node_view->node_index(), + .inner_pt_node_index = inner_pt_node_view->node_index()}; + + if constexpr (debug) { + std::cout << "Found pattern:"; + PrintReorderArith(ctx, new_reorder); + } + + *reorder = new_reorder; + + return true; +} + +// This function replaces the pattern outer_op(inner_op(ct, pt), pt) with +// outer_op(ct, inner_op(pt, pt)). +Status ApplyReorderArith(RemapperContext* ctx, ReorderArith const& reorder, + std::vector* nodes_to_delete) { + GraphDef const* graph = ctx->graph_view.graph(); + utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); + Status status; + + // First replace the inner node with a pt pt node. + NodeDef const& inner_node_def = graph->node(reorder.inner_node_index); + NodeDef new_pt_pt_inner; + + auto const new_inner_op = GetOpFromCtPt(inner_node_def, /*is_ct_pt=*/false); + if (new_inner_op == nullptr) { + return errors::Internal("Inner node is not supported type."); + } + new_pt_pt_inner.set_op(new_inner_op); + // Name of the new node needs to be the same as the old, even though the + // op is different, so downstream nodes can still find it. + new_pt_pt_inner.set_name(inner_node_def.name()); + new_pt_pt_inner.set_device(inner_node_def.device()); + + NodeDef const& shell_context_node = + graph->node(reorder.shell_context_node_index); + new_pt_pt_inner.add_input(shell_context_node.name()); + NodeDef const& inner_pt = graph->node(reorder.inner_pt_node_index); + new_pt_pt_inner.add_input(inner_pt.name()); + NodeDef const& outer_pt = graph->node(reorder.outer_pt_node_index); + new_pt_pt_inner.add_input(outer_pt.name()); + + // Replace the outer node with a ct pt op, where pt comes from + // new_pt_pt_inner created above. + NodeDef const& outer_node_def = graph->node(reorder.outer_node_index); + NodeDef new_outer; + + auto const new_outer_op = GetOpFromCtPt(outer_node_def, /*is_ct_pt=*/true); + if (new_outer_op == nullptr) { + return errors::Internal("Inner node is not supported type."); + } + new_outer.set_op(new_outer_op); + // Name of the new node needs to be the same as the old, even though the + // op is different, so downstream nodes can still find it. + new_outer.set_name(outer_node_def.name()); + new_outer.set_device(outer_node_def.device()); + + NodeDef const& inner_ct = graph->node(reorder.inner_ct_node_index); + new_outer.add_input(shell_context_node.name()); + new_outer.add_input(inner_ct.name()); + new_outer.add_input(new_pt_pt_inner.name()); + + if constexpr (debug) { + std::cout << "New outer node: " << new_outer.DebugString() << std::endl; + std::cout << "New inner node: " << new_pt_pt_inner.DebugString() + << std::endl; + } + + // Add the new nodes to the graph. + mutation->AddNode(std::move(new_outer), &status); + TF_RETURN_IF_ERROR(status); + mutation->AddNode(std::move(new_pt_pt_inner), &status); + TF_RETURN_IF_ERROR(status); + + (*nodes_to_delete)[reorder.outer_node_index] = true; + (*nodes_to_delete)[reorder.inner_node_index] = true; + + return OkStatus(); +} + +} // namespace + +CtPtOptimizer::CtPtOptimizer() {} + +Status CtPtOptimizer::Init( + tensorflow::RewriterConfig_CustomGraphOptimizer const* config) { + return OkStatus(); +} + +Status CtPtOptimizer::Optimize(Cluster* cluster, GrapplerItem const& item, + GraphDef* optimized_graph) { + GrapplerItem mutable_item = item; + Status status; + RemapperContext ctx(&mutable_item, &status); + TF_RETURN_IF_ERROR(status); + + // Topological sort and processing the nodes in reverse requires only + // one pass on all the nodes. + TF_RETURN_IF_ERROR( + ctx.graph_view.SortTopologically(/*ignore_cycles=*/false, {})); + + bool finished = false; + while (!finished) { + int const num_nodes = mutable_item.graph.node_size(); + std::vector nodes_to_delete(num_nodes); + finished = true; + + for (int i = num_nodes - 1; i >= 0; --i) { + if (nodes_to_delete[i]) { + continue; + } + + // Remap op( op(ct, pt), pt) to op(ct, op(pt, pt)). + ReorderArith reorder; + if (FindAddOrSub(ctx, i, &reorder)) { + TF_RETURN_IF_ERROR(ApplyReorderArith(&ctx, reorder, &nodes_to_delete)); + finished = false; + } + } + + // Remove nodes. + utils::Mutation* mutation = ctx.graph_view.GetMutationBuilder(); + for (int i = 0; i < num_nodes; ++i) { + if (nodes_to_delete[i]) { + mutation->RemoveNode(ctx.graph_view.GetNode(i)); + } + } + TF_RETURN_IF_ERROR(mutation->Apply()); + } + + *optimized_graph = std::move(mutable_item.graph); + + return OkStatus(); +} + +REGISTER_GRAPH_OPTIMIZER(CtPtOptimizer); + +} // namespace grappler +} // namespace tensorflow \ No newline at end of file diff --git a/tf_shell/cc/optimizers/ct_pt.h b/tf_shell/cc/optimizers/ct_pt.h new file mode 100644 index 0000000..bf2a540 --- /dev/null +++ b/tf_shell/cc/optimizers/ct_pt.h @@ -0,0 +1,30 @@ +#pragma once + +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/utils/functions.h" + +namespace tensorflow { +namespace grappler { + +class CtPtOptimizer : public CustomGraphOptimizer { + public: + CtPtOptimizer(); + + Status Init( + tensorflow::RewriterConfig_CustomGraphOptimizer const* config) override; + + string name() const override { return name_; } + + bool UsesFunctionLibrary() const override { return false; } + + Status Optimize(Cluster* cluster, GrapplerItem const& item, + GraphDef* optimized_graph) override; + + private: + string const name_ = "CtPtOptimizer"; +}; + +} // namespace grappler +} // namespace tensorflow \ No newline at end of file diff --git a/tf_shell/python/optimizers/__init__.py b/tf_shell/python/optimizers/__init__.py new file mode 100644 index 0000000..4855b8c --- /dev/null +++ b/tf_shell/python/optimizers/__init__.py @@ -0,0 +1,15 @@ +#!/usr/bin/python +# +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tf_shell/python/optimizers/shell_optimizers.py b/tf_shell/python/optimizers/shell_optimizers.py new file mode 100644 index 0000000..29146d9 --- /dev/null +++ b/tf_shell/python/optimizers/shell_optimizers.py @@ -0,0 +1,154 @@ +#!/usr/bin/python +# +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import load_library +from tensorflow.python.platform import resource_loader + +shell_ops = load_library.load_op_library( + resource_loader.get_path_to_datafile("_ct_pt_optimizer.so") +) + +# Based on https://github.com/openvinotoolkit/openvino_tensorflow/blob/d9dcb9d4c5932d0a8e9a3633d4134ae5841af6c1/python/openvino_tensorflow/__init__.in.py + +from tensorflow.python.framework import ops +from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.grappler import tf_optimizer +from tensorflow.python.framework import convert_to_constants +from tensorflow.python.training import saver +from tensorflow.python.util import nest +from tensorflow.python.eager import context +from tensorflow.python.eager import wrap_function +from tensorflow.core.function.polymorphism import function_type as function_type_lib + + +rewriter_config = rewriter_config_pb2.RewriterConfig() +rewriter_config.meta_optimizer_iterations = rewriter_config_pb2.RewriterConfig.ONE +shell_optimizer = rewriter_config.custom_optimizers.add() +shell_optimizer.name = "CtPtOptimizer" + + +def optimize_shell_graph(func): + # Converting var2consts for larger models might take a long time + frozen_func = convert_to_constants.convert_variables_to_constants_v2( + func, lower_control_flow=False, aggressive_inlining=True + ) + + meta_graph_def = saver.export_meta_graph( + graph_def=frozen_func.graph.as_graph_def(add_shapes=True), + graph=frozen_func.graph, + ) + + # print("orig graph def", meta_graph_def) + + fetch_collection = meta_graph_pb2.CollectionDef() + for array in frozen_func.outputs: + fetch_collection.node_list.value.append(array.name) + + # Grappler determines fetch ops from collection 'train_op'. + meta_graph_def.collection_def[ops.GraphKeys.TRAIN_OP].CopyFrom(fetch_collection) + + grappler_session_config = config_pb2.ConfigProto() + grappler_session_config.graph_options.rewrite_options.CopyFrom(rewriter_config) + optimized_graph_def = tf_optimizer.OptimizeGraph( + grappler_session_config, meta_graph_def, graph_id=b"tf_graph" + ) + + # print("opt graph def", optimized_graph_def) + + # Swap original function with optimized function in TF's context + for f in optimized_graph_def.library.function: + while context.context().has_function(f.signature.name): + context.context().remove_function(f.signature.name) + + try: + optimized_func = wrap_function.function_from_graph_def( + optimized_graph_def, + [tensor.name for tensor in frozen_func.inputs], + [tensor.name for tensor in frozen_func.outputs], + ) + except Exception as e: + raise ValueError( + "Could not wrap the optimized graph. Did the shell optimizer remove" + " some of the inputs or outputs? Original error: " + str(e) + ) + + optimized_func.graph.structured_outputs = nest.pack_sequence_as( + func.graph.structured_outputs, + optimized_func.graph.structured_outputs, + expand_composites=True, # required for extension types e.g. ShellTensor + ) + + optimized_func.graph.structured_input_signature = func.structured_input_signature + + # `optimized_func` is a WrappedFunction holding an AtomicFunction which + # derives a `function_type` from the structured_input_signature and + # structured_outputs. This is used to flatten the arguments when calling the + # function, which usually isn't important when the arguments are just + # Tensors but when they are composite (e.g. ShellTensor) the flattening + # becomes important. + # + # There are two bugs in Tensorflow that make this tricky. + # 1) First, we need to force update the `function_type` to reflect the new + # structured input and output signatures. Ideally, when + # function_from_graph_def() calls prune to create the new WrappedFunction, + # it would correctly set the pruned_graph.structured_input_signature instead + # of None, and it would set the structured_outputs to the "structured" + # typespec of the outputs, instead of the flattened version. Once + # b/129646028 is fixed (prune supports composite tensors), this can + # hopefully be removed. For more info, see line 390 of + # tensorflow/python/eager/wrap_function.py in tensorflow 2.16.1. + # 2) Second, after calling the optimized_func, the output args are never + # unflattened. This is a bug in TensorFlow and a fix PR is submitted at + # https://github.com/tensorflow/tensorflow/pull/67612. + # For now, we require calling code to run something like: + # my_func_output = optimized_func.function_type.pack_output(my_func_output) + + updated_fn_type = function_type_lib.from_structured_signature( + optimized_func.graph.structured_input_signature, + optimized_func.graph.structured_outputs, + optimized_func.graph.function_captures.capture_types, + ) + optimized_func._function_type = updated_fn_type + + return optimized_func + + +# Here is a method to enable custom optimizers described by +# https://github.com/tensorflow/tensorflow/issues/55451#issuecomment-1147065792 +def enable_tf_shell_optimizer(optimizers): + from tensorflow.core.protobuf import config_pb2 + from tensorflow.python.framework import ops + from tensorflow.python.grappler import tf_optimizer + from tensorflow.python.framework import meta_graph + from tensorflow.core.protobuf import rewriter_config_pb2 + from tensorflow.core.framework import graph_pb2 + from tensorflow.python.eager import context + + rewriter_config = rewriter_config_pb2.RewriterConfig() + rewriter_config.meta_optimizer_iterations = rewriter_config_pb2.RewriterConfig.ONE + for optimizer in optimizers: + custom_optimizer = rewriter_config.custom_optimizers.add() + custom_optimizer.name = optimizer + grappler_session_config = context.context().config + grappler_session_config.graph_options.rewrite_options.CopyFrom(rewriter_config) + + grappler_options = context.FunctionCallOptions(config_proto=grappler_session_config) + context.context().function_call_options = grappler_options diff --git a/tf_shell/python/shell_tensor.py b/tf_shell/python/shell_tensor.py index b668e83..125bdf2 100644 --- a/tf_shell/python/shell_tensor.py +++ b/tf_shell/python/shell_tensor.py @@ -88,15 +88,21 @@ def __add__(self, other): if self.is_encrypted and other.is_encrypted: result_raw_tensor = shell_ops.add_ct_ct64( - matched_self._raw_tensor, matched_other._raw_tensor + matched_self._context._raw_context, + matched_self._raw_tensor, + matched_other._raw_tensor, ) elif self.is_encrypted and not other.is_encrypted: result_raw_tensor = shell_ops.add_ct_pt64( - matched_self._raw_tensor, matched_other._raw_tensor + matched_self._context._raw_context, + matched_self._raw_tensor, + matched_other._raw_tensor, ) elif not self.is_encrypted and other.is_encrypted: result_raw_tensor = shell_ops.add_ct_pt64( - matched_other._raw_tensor, matched_self._raw_tensor + matched_self._context._raw_context, + matched_other._raw_tensor, + matched_self._raw_tensor, ) elif not self.is_encrypted and not other.is_encrypted: result_raw_tensor = shell_ops.add_pt_pt64( @@ -146,16 +152,22 @@ def __sub__(self, other): if self.is_encrypted and other.is_encrypted: result_raw_tensor = shell_ops.sub_ct_ct64( - matched_self._raw_tensor, matched_other._raw_tensor + matched_self._context._raw_context, + matched_self._raw_tensor, + matched_other._raw_tensor, ) elif self.is_encrypted and not other.is_encrypted: result_raw_tensor = shell_ops.sub_ct_pt64( - matched_self._raw_tensor, matched_other._raw_tensor + matched_self._context._raw_context, + matched_self._raw_tensor, + matched_other._raw_tensor, ) elif not self.is_encrypted and other.is_encrypted: negative_other = -matched_other result_raw_tensor = shell_ops.add_ct_pt64( - negative_other._raw_tensor, matched_self._raw_tensor + matched_self._context._raw_context, + negative_other._raw_tensor, + matched_self._raw_tensor, ) elif not self.is_encrypted and not other.is_encrypted: result_raw_tensor = shell_ops.sub_pt_pt64( @@ -209,7 +221,9 @@ def __rsub__(self, other): if self.is_encrypted: negative_self = -self raw_result = shell_ops.add_ct_pt64( - negative_self._raw_tensor, shell_other._raw_tensor + self._context._raw_context, + negative_self._raw_tensor, + shell_other._raw_tensor, ) else: raw_result = shell_ops.sub_pt_pt64( @@ -239,7 +253,9 @@ def __rsub__(self, other): def __neg__(self): if self.is_encrypted: - raw_result = shell_ops.neg_ct64(self._raw_tensor) + raw_result = shell_ops.neg_ct64( + self._context._raw_context, self._raw_tensor + ) else: raw_result = shell_ops.neg_pt64( self._context._raw_context, self._raw_tensor @@ -260,15 +276,21 @@ def __mul__(self, other): if self.is_encrypted and other.is_encrypted: raw_result = shell_ops.mul_ct_ct64( - matched_self._raw_tensor, matched_other._raw_tensor + matched_self._context._raw_context, + matched_self._raw_tensor, + matched_other._raw_tensor, ) elif self.is_encrypted and not other.is_encrypted: raw_result = shell_ops.mul_ct_pt64( - matched_self._raw_tensor, matched_other._raw_tensor + matched_self._context._raw_context, + matched_self._raw_tensor, + matched_other._raw_tensor, ) elif not self.is_encrypted and other.is_encrypted: raw_result = shell_ops.mul_ct_pt64( - matched_other._raw_tensor, matched_self._raw_tensor + matched_self._context._raw_context, + matched_other._raw_tensor, + matched_self._raw_tensor, ) elif not self.is_encrypted and not other.is_encrypted: raw_result = shell_ops.mul_pt_pt64( @@ -394,9 +416,9 @@ def _match_moduli_and_scaling(x, y): ) while x._scaling_factor > y._scaling_factor: - y = y * x._scaling_factor + y = y.__mul__(x._scaling_factor) while x._scaling_factor < y._scaling_factor: - x = x * y._scaling_factor + x = x.__mul__(y._scaling_factor) return x, y diff --git a/tf_shell/test/BUILD b/tf_shell/test/BUILD index 734d7fa..102e2c3 100644 --- a/tf_shell/test/BUILD +++ b/tf_shell/test/BUILD @@ -107,3 +107,16 @@ py_test( requirement("tensorflow-cpu"), ], ) + +py_test( + name = "ct_pt_optimizer_test", + size = "medium", + srcs = [ + "ct_pt_optimizer_test.py", + ], + imports = ["./"], + deps = [ + "//tf_shell:tf_shell_lib", + requirement("tensorflow-cpu"), + ], +) diff --git a/tf_shell/test/ct_pt_optimizer_test.py b/tf_shell/test/ct_pt_optimizer_test.py new file mode 100644 index 0000000..bbf8cd8 --- /dev/null +++ b/tf_shell/test/ct_pt_optimizer_test.py @@ -0,0 +1,175 @@ +import tensorflow as tf +import tf_shell +import tf_shell.python.optimizers.shell_optimizers as shell_optimizers +import test_utils + + +# These test cases are for the CtPtOptimizer, which optimizes the tf graph by +# reordering ciphertext - plaintext operations which look like +# ((ct + pt) + pt) to (ct + (pt + pt). + +@tf.function +def ct_pt_pt_add(ct, pt): + return ((((((ct + pt) + pt) + pt) + pt) + pt) + pt) + pt + +@tf.function +def ct_pt_pt_sub(ct, pt): + return ((((((ct - pt) - pt) - pt) - pt) - pt) - pt) - pt + +@tf.function +def ct_pt_pt_add_sub(ct, pt): + return ((((((ct + pt) - pt) + pt) - pt) + pt) - pt) + pt + +@tf.function +def ct_pt_pt_mul(ct, pt): + return ((ct * pt) * pt) + +@tf.function +def ct_pt_pt_add_mul_no_opt(ct, pt): + # This should not be optimized, mul and add are not commutative. + return ((ct + pt) * pt) + +@tf.function +def ct_pt_pt_mul_add_no_opt(ct, pt): + # This should not be optimized, mul and add are not commutative. + return ((ct * pt) + pt) + + +def count_ct_pt_ops(graph, op_name): + num_ct_pt_ops = 0 + for node in graph.as_graph_def().node: + if node.op == op_name: + num_ct_pt_ops += 1 + return num_ct_pt_ops + + +class TestCtPtOptimizer(tf.test.TestCase): + test_contexts = None + + @classmethod + def setUpClass(cls): + int_dtypes = [ + tf.uint8, + tf.int8, + tf.uint16, + tf.int16, + tf.uint32, + tf.int32, + tf.uint64, + tf.int64, + ] + cls.test_contexts = [] + + cls.test_contexts.append( + test_utils.TestContext( + outer_shape=[1], + plaintext_dtype=tf.float32, + log_n=11, + main_moduli=[8556589057, 8388812801], + aux_moduli=[], + plaintext_modulus=40961, + scaling_factor=1, + mul_depth_supported=0, + ) + ) + + def _test_func(self, test_context, tf_func, num_pts, num_opt_pts, op_name): + a = test_utils.uniform_for_n_muls(test_context, num_pts + 1) + b = test_utils.uniform_for_n_muls(test_context, num_pts + 1) + + ct_a = tf_shell.to_encrypted(a, test_context.key, test_context.shell_context) + pt_b = tf_shell.to_shell_plaintext(b, test_context.shell_context) + + # Sanity check the plain TensorFlow function correctly computes the + # correct value. + enc_c = tf_func(ct_a, pt_b) + self.assertAllClose( + tf_shell.to_tensorflow(enc_c, test_context.key), + tf_func(a, b), + atol=1 / test_context.shell_context.scaling_factor * num_pts, + ) + + func = tf_func.get_concrete_function(ct_a, pt_b) + orig_num_ops = count_ct_pt_ops(func.graph, op_name) + self.assertEqual(orig_num_ops, num_pts) + + # print("\noriginal graph:") + # for node in func.graph.as_graph_def().node: + # print(f'{node.name} {node.op}({node.input})') + + # Optimize the graph using tf_shells HE-specific optimizers. + optimized_func = shell_optimizers.optimize_shell_graph(func) + # Call the optimized function. + enc_c = optimized_func(ct_a, pt_b) + # Can remove pack_output above if + # https://github.com/tensorflow/tensorflow/pull/67612 is merged. + enc_c = optimized_func.function_type.pack_output(enc_c) + opt_num_ops = count_ct_pt_ops(optimized_func.graph, op_name) + + # print("\noptimized graph:") + # for node in optimized_func.graph.as_graph_def().node: + # print(f'{node.name} {node.op}({node.input})') + + self.assertEqual(opt_num_ops, 1) + + # Check the optimized graph still computes the correct value. + self.assertAllClose( + tf_shell.to_tensorflow(enc_c, test_context.key), + tf_func(a, b), + atol=1 / test_context.shell_context.scaling_factor * num_pts, + ) + + def test_func(self): + for test_context in self.test_contexts: + with self.subTest(f"Optimizer for func ct_pt_pt_add."): + self._test_func(test_context, ct_pt_pt_add, 7, 1, "AddCtPt64") + with self.subTest(f"Optimizer for func ct_pt_pt_sub."): + self._test_func(test_context, ct_pt_pt_sub, 7, 1, "SubCtPt64") + with self.subTest(f"Optimizer for func ct_pt_pt_mul."): + self._test_func(test_context, ct_pt_pt_mul, 2, 1, "MulCtPt64") + with self.subTest(f"Optimizer for func ct_pt_pt_add_mul_no_opt."): + self._test_func(test_context, ct_pt_pt_add_mul_no_opt, 1, 1, "MulCtPt64") + self._test_func(test_context, ct_pt_pt_add_mul_no_opt, 1, 1, "AddCtPt64") + with self.subTest(f"Optimizer for func ct_pt_pt_mul_add_no_opt."): + self._test_func(test_context, ct_pt_pt_mul_add_no_opt, 1, 1, "MulCtPt64") + self._test_func(test_context, ct_pt_pt_mul_add_no_opt, 1, 1, "AddCtPt64") + + +class TestCtPtAutoEnableOptimizer(tf.test.TestCase): + def test_auto_optimize(self): + from timeit import timeit + + context = tf_shell.create_context64( + log_n=10, + main_moduli=[8556589057, 8388812801], + plaintext_modulus=40961, + scaling_factor=3, + mul_depth_supported=3, + seed="test_seed", + ) + + secret_key = tf_shell.create_key64(context) + + a = tf.random.uniform([context.num_slots, 40000], dtype=tf.float32, maxval=10) + b = tf.random.uniform([context.num_slots, 40000], dtype=tf.float32, maxval=10) + + ct_a = tf_shell.to_encrypted(a, secret_key, context) + pt_b = tf_shell.to_shell_plaintext(b, context) + + # Call the function as usual. + unopt_time = timeit(lambda: ct_pt_pt_add(ct_a, pt_b), number=1) + + # Turn on automatic optimization. Note there is no way to get the + # optimized graph from the tf.function so we need to rely on timing info + # to make sure it's turned on. + shell_optimizers.enable_tf_shell_optimizer(["CtPtOptimizer"]) + + opt_time = timeit(lambda: ct_pt_pt_add(ct_a, pt_b), number=1) + + # Optimized time should be twice as fast due to the two ciphertext + # components, but give it some slack and check if it is 1.7x faster. + self.assertLess(opt_time, unopt_time / 1.7) + + +if __name__ == "__main__": + tf.test.main() From 32bfc009a5344b362e87497516d9d3992187c373 Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Sun, 19 May 2024 06:26:54 +0000 Subject: [PATCH 11/22] Remove duplicate dtype. --- tf_shell/python/shell_tensor.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tf_shell/python/shell_tensor.py b/tf_shell/python/shell_tensor.py index 125bdf2..ec7862f 100644 --- a/tf_shell/python/shell_tensor.py +++ b/tf_shell/python/shell_tensor.py @@ -43,10 +43,6 @@ def dtype(self): def name(self): return self._raw_tensor.name - @property - def dtype(self): - return self._raw_tensor.name - @property def plaintext_dtype(self): return self._underlying_dtype From 7000d431649c35607726165af032d0f336b9fef1 Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Tue, 21 May 2024 07:50:55 +0000 Subject: [PATCH 12/22] Broadcast to flat tensors templated based on tensor dtype. --- tf_shell/cc/kernels/add_kernels.cc | 12 ++++++------ tf_shell/cc/kernels/mul_kernels.cc | 12 ++++++------ tf_shell/cc/kernels/utils.h | 25 +++++++++++++------------ 3 files changed, 25 insertions(+), 24 deletions(-) diff --git a/tf_shell/cc/kernels/add_kernels.cc b/tf_shell/cc/kernels/add_kernels.cc index 1d92983..0518e66 100644 --- a/tf_shell/cc/kernels/add_kernels.cc +++ b/tf_shell/cc/kernels/add_kernels.cc @@ -106,8 +106,8 @@ class AddCtCtOp : public OpKernel { op_ctx, bcast.IsValid(), InvalidArgument("Invalid broadcast between ", a.shape().DebugString(), " and ", b.shape().DebugString())); - auto flat_a = MyBFlat(op_ctx, a, bcast.x_reshape(), bcast.x_bcast()); - auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); + auto flat_a = MyBFlat(op_ctx, a, bcast.x_reshape(), bcast.x_bcast()); + auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); // Check the inputs have the same shape. OP_REQUIRES( @@ -166,8 +166,8 @@ class AddCtPtOp : public OpKernel { op_ctx, bcast.IsValid(), InvalidArgument("Invalid broadcast between ", a.shape().DebugString(), " and ", b.shape().DebugString())); - auto flat_a = MyBFlat(op_ctx, a, bcast.x_reshape(), bcast.x_bcast()); - auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); + auto flat_a = MyBFlat(op_ctx, a, bcast.x_reshape(), bcast.x_bcast()); + auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); // Check the inputs have the same shape. OP_REQUIRES( @@ -229,8 +229,8 @@ class AddPtPtOp : public OpKernel { op_ctx, bcast.IsValid(), InvalidArgument("Invalid broadcast between ", a.shape().DebugString(), " and ", b.shape().DebugString())); - auto flat_a = MyBFlat(op_ctx, a, bcast.x_reshape(), bcast.x_bcast()); - auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); + auto flat_a = MyBFlat(op_ctx, a, bcast.x_reshape(), bcast.x_bcast()); + auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); // Check the inputs have the same shape. OP_REQUIRES( diff --git a/tf_shell/cc/kernels/mul_kernels.cc b/tf_shell/cc/kernels/mul_kernels.cc index d56871c..c199e86 100644 --- a/tf_shell/cc/kernels/mul_kernels.cc +++ b/tf_shell/cc/kernels/mul_kernels.cc @@ -64,8 +64,8 @@ class MulCtCtOp : public OpKernel { op_ctx, bcast.IsValid(), InvalidArgument("Invalid broadcast between ", a.shape().DebugString(), " and ", b.shape().DebugString())); - auto flat_a = MyBFlat(op_ctx, a, bcast.x_reshape(), bcast.x_bcast()); - auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); + auto flat_a = MyBFlat(op_ctx, a, bcast.x_reshape(), bcast.x_bcast()); + auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); // Check the inputs have the same shape. OP_REQUIRES( @@ -124,8 +124,8 @@ class MulCtPtOp : public OpKernel { op_ctx, bcast.IsValid(), InvalidArgument("Invalid broadcast between ", a.shape().DebugString(), " and ", b.shape().DebugString())); - auto flat_a = MyBFlat(op_ctx, a, bcast.x_reshape(), bcast.x_bcast()); - auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); + auto flat_a = MyBFlat(op_ctx, a, bcast.x_reshape(), bcast.x_bcast()); + auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); // Check the inputs have the same shape. OP_REQUIRES( @@ -281,8 +281,8 @@ class MulPtPtOp : public OpKernel { op_ctx, bcast.IsValid(), InvalidArgument("Invalid broadcast between ", a.shape().DebugString(), " and ", b.shape().DebugString())); - auto flat_a = MyBFlat(op_ctx, a, bcast.x_reshape(), bcast.x_bcast()); - auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); + auto flat_a = MyBFlat(op_ctx, a, bcast.x_reshape(), bcast.x_bcast()); + auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); // Check the inputs have the same shape. OP_REQUIRES( diff --git a/tf_shell/cc/kernels/utils.h b/tf_shell/cc/kernels/utils.h index becea8f..367975c 100644 --- a/tf_shell/cc/kernels/utils.h +++ b/tf_shell/cc/kernels/utils.h @@ -92,8 +92,8 @@ StatusOr GetVariant(OpKernelContext* ctx, int index) { return t; } -template -inline Eigen::Tensor BFlat( +template +inline Eigen::Tensor BFlat( OpKernelContext* op_ctx, Tensor const& t, BCast::Vec const& x_reshape, BCast::Vec const& x_bcast) { // A TensorFlow is a TTypes::Tensor (aka Eigen::TensorMap). @@ -103,10 +103,10 @@ inline Eigen::Tensor BFlat( // assigns the result of the reshape to an Eigen::Tensor. // // For a demo, see https://godbolt.org/z/41xvWvb63 - typedef Eigen::Tensor + typedef Eigen::Tensor ETensor; - ETensor reshaped_t = t.template shaped(x_reshape); + ETensor reshaped_t = t.template shaped(x_reshape); ETensor broadcasted_t = reshaped_t.broadcast(BCast::ToIndexArray(x_bcast)); @@ -115,7 +115,8 @@ inline Eigen::Tensor BFlat( broadcasted_t.reshape(BCast::ToIndexArray<1>({broadcasted_t.size()}))); } -inline Eigen::Tensor MyBFlat( +template +inline Eigen::Tensor MyBFlat( OpKernelContext* op_ctx, Tensor const& t, BCast::Vec const& x_reshape, BCast::Vec const& x_bcast) { // Uses the switch statement approach as in: @@ -123,21 +124,21 @@ inline Eigen::Tensor MyBFlat( int const ndims = x_reshape.size(); switch (ndims) { case 1: - return std::move(BFlat<1>(op_ctx, t, x_reshape, x_bcast)); + return std::move(BFlat(op_ctx, t, x_reshape, x_bcast)); case 2: - return std::move(BFlat<2>(op_ctx, t, x_reshape, x_bcast)); + return std::move(BFlat(op_ctx, t, x_reshape, x_bcast)); case 3: - return std::move(BFlat<3>(op_ctx, t, x_reshape, x_bcast)); + return std::move(BFlat(op_ctx, t, x_reshape, x_bcast)); case 4: - return std::move(BFlat<4>(op_ctx, t, x_reshape, x_bcast)); + return std::move(BFlat(op_ctx, t, x_reshape, x_bcast)); case 5: - return std::move(BFlat<5>(op_ctx, t, x_reshape, x_bcast)); + return std::move(BFlat(op_ctx, t, x_reshape, x_bcast)); case 6: - return std::move(BFlat<6>(op_ctx, t, x_reshape, x_bcast)); + return std::move(BFlat(op_ctx, t, x_reshape, x_bcast)); default: op_ctx->SetStatus(Unimplemented("Broadcast ", t.DebugString(), " is not supported yet.")); - return std::move(BFlat<1>(op_ctx, t, x_reshape, x_bcast)); + return std::move(BFlat(op_ctx, t, x_reshape, x_bcast)); } } From 5ada2a64a3617b13ab7e354d57bff329f5f2939b Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Tue, 21 May 2024 07:52:50 +0000 Subject: [PATCH 13/22] Broadcast over dimension 0 and mul uses faster fused op. --- tf_shell/cc/kernels/mul_kernels.cc | 78 +++++++++++++++++------------- tf_shell/python/shell_tensor.py | 16 +++++- tf_shell/test/add_test.py | 44 ++++++++++++++++- tf_shell/test/mul_test.py | 35 ++++++++++++++ 4 files changed, 138 insertions(+), 35 deletions(-) diff --git a/tf_shell/cc/kernels/mul_kernels.cc b/tf_shell/cc/kernels/mul_kernels.cc index c199e86..4d36122 100644 --- a/tf_shell/cc/kernels/mul_kernels.cc +++ b/tf_shell/cc/kernels/mul_kernels.cc @@ -187,46 +187,32 @@ class MulShellTfScalarOp : public OpKernel { Tensor const& a = op_ctx->input(1); Tensor const& b = op_ctx->input(2); - OP_REQUIRES(op_ctx, b.dims() == 0 && b.NumElements() == 1, - InvalidArgument("Plaintext must be scalar. Instead got shape:", - b.shape().DebugString())); + BCast bcast(BCast::FromShape(a.shape()), BCast::FromShape(b.shape()), + /*fewer_dims_optimization=*/false); + OP_REQUIRES( + op_ctx, bcast.IsValid(), + InvalidArgument("Invalid broadcast between ", a.shape().DebugString(), + " and ", b.shape().DebugString())); + auto flat_a = a.flat(); // a is not broadcasted, just b. + auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); + + // Check the inputs have the same shape. + OP_REQUIRES( + op_ctx, flat_a.size() == flat_b.size(), + InvalidArgument("Broadcasted inputs must have the same shape.")); // Allocate the output tensor which is the same shape as the first input. Tensor* output; OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, a.shape(), &output)); - - // Set up the flat views of the inputs and output tensors. - auto flat_a = a.flat(); - auto flat_b = b.flat(); auto flat_output = output->flat(); - // First, encode the scalar b. - T wrapped_b; - if constexpr (std::is_signed::value) { - // SHELL is built on the assumption that the plaintext type (in this - // case `PtT`) will always fit into the ciphertext underlying type - // (in this case `T`). E.g. the plaintext modulus is stored as the - // ciphertext type. This is true even in the RNS code paths. This means - // that this function can convert `PtT` to a signed version of `T`, - // then modulus switch into plaintext field t and type `T` without - // overflow. - using SignedInteger = std::make_signed_t; - - SignedInteger signed_b = static_cast(flat_b(0)); - - // Map signed integers into the plaintext modulus field. - OP_REQUIRES_VALUE( - std::vector wrapped_b_vector, op_ctx, - (encoder->template WrapSigned({signed_b}))); - - wrapped_b = wrapped_b_vector[0]; - } else { - // Since From and To are both unsigned, just cast and copy. - wrapped_b = static_cast(flat_b(0)); - } - - // Now multiply every polynomial in a by the same b. + // Now multiply. for (int i = 0; i < flat_output.dimension(0); ++i) { + // First encode the scalar b + // TDOO(jchoncholas): encode all scalars at once beforehand. + T wrapped_b; + EncodeScalar(op_ctx, flat_b(i), encoder, &wrapped_b); + CtOrPolyVariant const* ct_or_pt_var = std::move(flat_a(i).get()); OP_REQUIRES(op_ctx, ct_or_pt_var != nullptr, @@ -254,6 +240,32 @@ class MulShellTfScalarOp : public OpKernel { } } } + +private: + void EncodeScalar(OpKernelContext* op_ctx, PtT const& val, Encoder const* encoder, T* wrapped_val) { + if constexpr (std::is_signed::value) { + // SHELL is built on the assumption that the plaintext type (in this + // case `PtT`) will always fit into the ciphertext underlying type + // (in this case `T`). E.g. the plaintext modulus is stored as the + // ciphertext type. This is true even in the RNS code paths. This means + // that this function can convert `PtT` to a signed version of `T`, + // then modulus switch into plaintext field t and type `T` without + // overflow. + using SignedInteger = std::make_signed_t; + + SignedInteger signed_val = static_cast(val); + + // Map signed integers into the plaintext modulus field. + OP_REQUIRES_VALUE( + std::vector wrapped_val_vector, op_ctx, + (encoder->template WrapSigned({signed_val}))); + + *wrapped_val = wrapped_val_vector[0]; + } else { + // Since From and To are both unsigned, just cast and copy. + *wrapped_val = static_cast(val); + } + } }; template diff --git a/tf_shell/python/shell_tensor.py b/tf_shell/python/shell_tensor.py index ec7862f..029ea79 100644 --- a/tf_shell/python/shell_tensor.py +++ b/tf_shell/python/shell_tensor.py @@ -124,6 +124,9 @@ def __add__(self, other): # with zeros replicate the scalar across all slots. other = tf.broadcast_to(other, (self._context.num_slots, 1)) + elif other.shape[0] == 1 and len(other.shape) == len(self.shape): + other = tf.broadcast_to(other, [self._context.num_slots] + other.shape[1:]) + # Lift tensorflow tensor to shell tensor with the same scaling # factor as self and attempt the addition again. so = to_shell_plaintext(other, self._context) @@ -188,6 +191,9 @@ def __sub__(self, other): # with zeros replicate the scalar across all slots. other = tf.broadcast_to(other, (self._context.num_slots, 1)) + elif other.shape[0] == 1 and len(other.shape) == len(self.shape): + other = tf.broadcast_to(other, [self._context.num_slots] + other.shape[1:]) + # Lift tensorflow tensor to shell tensor with the same scaling # factor as self and attempt the subtraction again. shell_other = to_shell_plaintext(other, self._context) @@ -210,6 +216,10 @@ def __rsub__(self, other): # with zeros replicate the scalar across all slots. other = tf.broadcast_to(other, (self._context.num_slots, 1)) + elif other.shape[0] == 1 and len(other.shape) == len(self.shape): + other = tf.broadcast_to(other, [self._context.num_slots] + other.shape[1:]) + + # Import to a shell plaintext, which pads the first dimension with # zeros out to the number of slots. shell_other = to_shell_plaintext(other, self._context) @@ -310,7 +320,11 @@ def __mul__(self, other): # Multiplying by a scalar uses a special op which is more efficient # than the caller creating creating a ShellTensor the same # dimensions as self and multiplying. - if other.shape == (1,) or other.shape == (): + if ( + other.shape == (1,) + or other.shape == () + or (other.shape[0] == 1 and len(other.shape) == len(self.shape)) + ): # Encode the other scalar tensor to the same scaling factor as # self. other = _encode_scaling(other, self._scaling_factor) diff --git a/tf_shell/test/add_test.py b/tf_shell/test/add_test.py index 21924bb..ccca1a8 100644 --- a/tf_shell/test/add_test.py +++ b/tf_shell/test/add_test.py @@ -553,7 +553,7 @@ def _test_ct_scalar_add(self, test_context): # This test performs one addition. _, max_val = test_utils.get_bounds_for_n_adds(test_context, 1) a = test_utils.uniform_for_n_adds(test_context, 1) - b = test_utils.uniform_for_n_adds(test_context, 1, shape=[1]) + b = test_utils.uniform_for_n_adds(test_context, 1) except Exception as e: print( f"Note: Skipping test ct_scalar_add with context `{test_context}`. Not enough precision to support this test." @@ -561,6 +561,10 @@ def _test_ct_scalar_add(self, test_context): print(e) return + # Resize b so the size of the first dimension is 1. This is the + # ciphertext packing dimension. + b = tf.expand_dims(b[0], axis=0) + sa = tf_shell.to_shell_plaintext(a, test_context.shell_context) sc = sa + b @@ -586,6 +590,44 @@ def test_ct_scalar_add(self): with self.subTest(f"ct_scalar_add with context `{test_context}`."): self._test_ct_scalar_add(test_context) + def _test_ct_single_scalar_add(self, test_context): + try: + # This test performs one addition. + _, max_val = test_utils.get_bounds_for_n_adds(test_context, 1) + a = test_utils.uniform_for_n_adds(test_context, 1) + b = test_utils.uniform_for_n_adds(test_context, 1, shape=[1]) + except Exception as e: + print( + f"Note: Skipping test ct_scalar_add with context `{test_context}`. Not enough precision to support this test." + ) + print(e) + return + + sa = tf_shell.to_shell_plaintext(a, test_context.shell_context) + + sc = sa + b + self.assertAllClose(a + b, tf_shell.to_tensorflow(sc), atol=1e-3) + + if test_context.plaintext_dtype.is_unsigned: + # To test subtraction, ensure that a > b to avoid underflow. + # a + max_val is safe, because max_val is the total range / 2 and + # a is less than max_val. + max_val = int(max_val) + saa = tf_shell.to_shell_plaintext(a + max_val, test_context.shell_context) + ee = saa - b + self.assertAllClose(a + max_val - b, tf_shell.to_tensorflow(ee), atol=1e-3) + else: + sd = sa - b + self.assertAllClose(a - b, tf_shell.to_tensorflow(sd), atol=1e-3) + + # Ensure initial arguments are not modified. + self.assertAllClose(a, tf_shell.to_tensorflow(sa)) + + def test_ct_single_scalar_add(self): + for test_context in self.test_contexts: + with self.subTest(f"ct_scalar_add with context `{test_context}`."): + self._test_ct_scalar_add(test_context) + def _test_pt_pt_add(self, test_context): try: # This test performs one addition. diff --git a/tf_shell/test/mul_test.py b/tf_shell/test/mul_test.py index b003c80..0ace622 100644 --- a/tf_shell/test/mul_test.py +++ b/tf_shell/test/mul_test.py @@ -217,6 +217,12 @@ def _test_ct_tf_scalar_mul(self, test_context): ) print(e) return + + # Resize b so the size of the first dimension is 1. This is the + # ciphertext packing dimension and tests the code path in tf-shell + # where this operation is performed using a special scalar op which + # does not require computing the NTT of the plaintext. + b = tf.expand_dims(b[0], axis=0) sa = tf_shell.to_shell_plaintext(a, test_context.shell_context) ea = tf_shell.to_encrypted(sa, test_context.key) @@ -235,6 +241,35 @@ def test_ct_tf_scalar_mul(self): with self.subTest(f"ct_tf_scalar_mul with context `{test_context}`."): self._test_ct_tf_scalar_mul(test_context) + def _test_ct_tf_single_scalar_mul(self, test_context): + try: + # This test performs one multiplication. + a = test_utils.uniform_for_n_muls(test_context, 1) + b = test_utils.uniform_for_n_muls(test_context, 1, shape=[1]) + except Exception as e: + print( + f"Note: Skipping test ct_tf_scalar_mul with context {test_context}. Not enough precision to support this test." + ) + print(e) + return + + sa = tf_shell.to_shell_plaintext(a, test_context.shell_context) + ea = tf_shell.to_encrypted(sa, test_context.key) + + ec = ea * b + self.assertAllClose(a * b, tf_shell.to_tensorflow(ec, test_context.key)) + + ed = b * ea + self.assertAllClose(a * b, tf_shell.to_tensorflow(ed, test_context.key)) + + # Check the arguments were not modified. + self.assertAllClose(a, tf_shell.to_tensorflow(ea, test_context.key)) + + def test_ct_tf_single_scalar_mul(self): + for test_context in self.test_contexts: + with self.subTest(f"ct_tf_scalar_mul with context `{test_context}`."): + self._test_ct_tf_scalar_mul(test_context) + def _test_ct_tf_mul(self, test_context): try: # This test performs one multiplication. From c90b208b298a603624f59bc4ba52630824a08ca1 Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Tue, 21 May 2024 08:05:32 +0000 Subject: [PATCH 14/22] Dropout layer. --- tf_shell_ml/BUILD | 1 + tf_shell_ml/__init__.py | 1 + tf_shell_ml/dense.py | 1 - tf_shell_ml/dropout.py | 72 ++++++++++++++++++++++ tf_shell_ml/test/BUILD | 10 +++ tf_shell_ml/test/dropout_test.py | 101 +++++++++++++++++++++++++++++++ 6 files changed, 185 insertions(+), 1 deletion(-) create mode 100644 tf_shell_ml/dropout.py create mode 100644 tf_shell_ml/test/dropout_test.py diff --git a/tf_shell_ml/BUILD b/tf_shell_ml/BUILD index f898996..f03e966 100644 --- a/tf_shell_ml/BUILD +++ b/tf_shell_ml/BUILD @@ -6,6 +6,7 @@ py_library( "__init__.py", "activation.py", "dense.py", + "dropout.py", "loss.py", ], visibility = ["//visibility:public"], diff --git a/tf_shell_ml/__init__.py b/tf_shell_ml/__init__.py index 923f02f..654142e 100644 --- a/tf_shell_ml/__init__.py +++ b/tf_shell_ml/__init__.py @@ -16,5 +16,6 @@ from __future__ import absolute_import from tf_shell_ml.dense import ShellDense +from tf_shell_ml.dropout import ShellDropout from tf_shell_ml.activation import relu, relu_deriv, sigmoid, sigmoid_deriv from tf_shell_ml.loss import CategoricalCrossentropy diff --git a/tf_shell_ml/dense.py b/tf_shell_ml/dense.py index 2572a23..51b7511 100644 --- a/tf_shell_ml/dense.py +++ b/tf_shell_ml/dense.py @@ -59,7 +59,6 @@ def build(self, input_shape): def __call__(self, inputs): if not self.built: self.build(inputs.shape) - self.built = True self._layer_input = inputs diff --git a/tf_shell_ml/dropout.py b/tf_shell_ml/dropout.py new file mode 100644 index 0000000..42fc8df --- /dev/null +++ b/tf_shell_ml/dropout.py @@ -0,0 +1,72 @@ +#!/usr/bin/python +# +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tensorflow as tf +from tensorflow.python.keras import initializers +import tf_shell + + +class ShellDropout: + def __init__( + self, + rate, + noise_shape=None, + seed=None, + per_batch=False, + ): + self.rate = float(rate) + if (self.rate < 0.0) or (self.rate >= 1.0): + raise ValueError( + "rate must be a float in the range [0, 1), got {}".format(self.rate) + ) + self.noise_shape = noise_shape + self.seed = seed + self.per_batch = per_batch + if self.per_batch and self.noise_shape is not None: + raise ValueError("noise_shape must be None when per_batch is True") + + self.built = False + + def build(self, input_shape): + self.units_in = int(input_shape[1]) + + self.built = True + + def __call__(self, inputs, training=False): + if not self.built: + self.build(inputs.shape) + + if not training or self.rate == 0.0: + return inputs + + if self.per_batch: + dummy_input = tf.ones([1] + inputs.shape[1:]) + else: + dummy_input = tf.ones(inputs.shape) + + dropout_mask = tf.nn.dropout( + dummy_input, + self.rate, + noise_shape=self.noise_shape, + seed=self.seed, + ) + + self._layer_intermediate = dropout_mask + self.outputs = inputs * dropout_mask + return self.outputs + + def backward(self, dy): + d_x = dy * self._layer_intermediate + return d_x diff --git a/tf_shell_ml/test/BUILD b/tf_shell_ml/test/BUILD index 68b9987..d3ac2fb 100644 --- a/tf_shell_ml/test/BUILD +++ b/tf_shell_ml/test/BUILD @@ -29,3 +29,13 @@ py_test( requirement("tensorflow-cpu"), ], ) + +py_test( + name = "dropout_test", + size = "medium", + srcs = ["dropout_test.py"], + deps = [ + "//tf_shell_ml", + requirement("tensorflow-cpu"), + ], +) \ No newline at end of file diff --git a/tf_shell_ml/test/dropout_test.py b/tf_shell_ml/test/dropout_test.py new file mode 100644 index 0000000..2586483 --- /dev/null +++ b/tf_shell_ml/test/dropout_test.py @@ -0,0 +1,101 @@ +#!/usr/bin/python +# +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +import tensorflow as tf +import keras +import numpy as np +import tf_shell +import tf_shell_ml + +# # Num plaintext bits: 32, noise bits: 84 +# # Max representable value: 654624 +# context = tf_shell.create_context64( +# log_n=11, +# main_moduli=[288230376151748609, 144115188076060673], +# plaintext_modulus=4294991873, +# scaling_factor=3, +# mul_depth_supported=3, +# seed="test_seed", +# ) +# 61 bits of security according to lattice estimator primal_bdd. +# Runtime 170 seconds (83ms/example). + +# Num plaintext bits: 32, noise bits: 84 +# Max representable value: 654624 +context = tf_shell.create_context64( + log_n=12, + main_moduli=[288230376151760897, 288230376152137729], + plaintext_modulus=4294991873, + scaling_factor=3, + mul_depth_supported=3, + seed="test_seed", +) +# 120 bits of security according to lattice estimator primal_bdd. +# Runtime 388 seconds (95ms/example). + +key = tf_shell.create_key64(context) +rotation_key = tf_shell.create_rotation_key64(context, key) + + +class TestDropout(tf.test.TestCase): + def _test_dropout_forward(self, per_batch): + # First check plaintext forward pass. + x = tf.random.uniform((context.num_slots, 100)) + 1 + + dropout_layer = tf_shell_ml.ShellDropout(0.2, per_batch=per_batch) + + notrain_y = dropout_layer(x, training=False) + self.assertAllEqual(notrain_y, x) + + train_y = dropout_layer(x, training=True) + self.assertLess(tf.math.count_nonzero(train_y), tf.size(train_y, out_type=tf.int64)) + + enc_x = tf_shell.to_encrypted(x, key, context) + dropout_layer = tf_shell_ml.ShellDropout(0.2, per_batch=per_batch) + + notrain_enc_y = dropout_layer(enc_x, training=False) + self.assertAllClose(tf_shell.to_tensorflow(notrain_enc_y, key), x, atol=1/context.scaling_factor) + + enc_train_y = dropout_layer(enc_x, training=True) + dec_train_y = tf_shell.to_tensorflow(enc_train_y, key) + self.assertLess(tf.math.count_nonzero(dec_train_y), tf.size(dec_train_y, out_type=tf.int64)) + + def _test_dropout_back(self, per_batch): + x = tf.random.uniform((context.num_slots, 100)) + 1 + + dropout_layer = tf_shell_ml.ShellDropout(0.2, per_batch=per_batch) + + notrain_y = dropout_layer(x, training=True) + dy = tf.ones_like(notrain_y) + + dx = dropout_layer.backward(dy) + + enc_dy = tf_shell.to_encrypted(dy, key, context) + enc_dx = dropout_layer.backward(enc_dy) + dec_dx = tf_shell.to_tensorflow(enc_dx, key) + self.assertAllClose(dx, dec_dx, atol=1/context.scaling_factor) + + def test_dropout(self): + self._test_dropout_forward(False) + self._test_dropout_back(False) + + def test_dropout_per_batch(self): + self._test_dropout_forward(True) + self._test_dropout_back(True) + + +if __name__ == "__main__": + unittest.main() From 9aae9c06c91fe43652d5ab8633cebd4f240106a5 Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Thu, 23 May 2024 07:05:50 +0000 Subject: [PATCH 15/22] Embedding layer supporting encrypted backprop. --- tf_shell_ml/BUILD | 1 + tf_shell_ml/__init__.py | 3 +- tf_shell_ml/embedding.py | 232 +++++++++++++++++++++++++++++ tf_shell_ml/test/BUILD | 10 ++ tf_shell_ml/test/embedding_test.py | 118 +++++++++++++++ 5 files changed, 363 insertions(+), 1 deletion(-) create mode 100644 tf_shell_ml/embedding.py create mode 100644 tf_shell_ml/test/embedding_test.py diff --git a/tf_shell_ml/BUILD b/tf_shell_ml/BUILD index f03e966..25c9f41 100644 --- a/tf_shell_ml/BUILD +++ b/tf_shell_ml/BUILD @@ -7,6 +7,7 @@ py_library( "activation.py", "dense.py", "dropout.py", + "embedding.py", "loss.py", ], visibility = ["//visibility:public"], diff --git a/tf_shell_ml/__init__.py b/tf_shell_ml/__init__.py index 654142e..ed0a1cc 100644 --- a/tf_shell_ml/__init__.py +++ b/tf_shell_ml/__init__.py @@ -15,7 +15,8 @@ # limitations under the License. from __future__ import absolute_import +from tf_shell_ml.activation import relu, relu_deriv, sigmoid, sigmoid_deriv from tf_shell_ml.dense import ShellDense from tf_shell_ml.dropout import ShellDropout -from tf_shell_ml.activation import relu, relu_deriv, sigmoid, sigmoid_deriv +from tf_shell_ml.embedding import ShellEmbedding from tf_shell_ml.loss import CategoricalCrossentropy diff --git a/tf_shell_ml/embedding.py b/tf_shell_ml/embedding.py new file mode 100644 index 0000000..5a07fd6 --- /dev/null +++ b/tf_shell_ml/embedding.py @@ -0,0 +1,232 @@ +#!/usr/bin/python +# +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tensorflow as tf +from tensorflow.python.keras import initializers +import tf_shell + + +class ShellEmbedding: + def __init__( + self, + input_dim, + output_dim, + embeddings_initializer="uniform", + ): + self.input_dim = int(input_dim) + self.output_dim = int(output_dim) + self.embeddings_initializer = initializers.get(embeddings_initializer) + + self.weights = [] + self.grad_map = None + self.build() + + def build(self): + self.embeddings = tf.Variable( + self.embeddings_initializer([self.input_dim, self.output_dim]) + ) + self.weights.append(self.embeddings) + self.built = True + + def _reset_grad_map(self): + # The gradient map stores temerary gradients between microbatches. + # + # Ideally the MutableHashTable would store values of type ShellTensor64 + # but it does not support ExtensionTypes. Instead, store the raw variant + # tensor held inside the ShellTensor64 class. + # + # Even worse, the MutableHashTable does not support variant tensors + # which are non-scalar (more than one dimension). + # see: https://github.com/tensorflow/tensorflow/blob/7b077cd4a83bcf3e296796bc7a32ff9e68a10fa2/tensorflow/core/kernels/lookup_table_op.cc#L1188 + # So instead, use a MutableDenseHashTable. + + default_row = tf_shell.to_shell_plaintext( + tf.zeros([self._shell_context.num_slots, self.output_dim]), + self._shell_context, + )._raw_tensor + + self.grad_map = tf.lookup.experimental.DenseHashTable( + key_dtype=tf.int64, + value_dtype=tf.variant, + default_value=default_row, + empty_key=-1, + deleted_key=-2, + ) + self.grad_map_count = tf.lookup.experimental.MutableHashTable( + key_dtype=tf.int64, value_dtype=tf.int64, default_value=0 + ) + + def __call__(self, inputs): + if inputs.dtype != tf.int64: + raise ValueError( + f"Embedding layer expects int64 input. Got {inputs.dtype}." + ) + if inputs.ndim != 2: + raise ValueError( + f"Embedding layer expects rank 2 input. Got {inputs}." + ) + + self._layer_input = inputs + outputs = tf.experimental.numpy.take(self.embeddings, inputs, axis=0) + return outputs + + def _add_or_insert_grad_row(self, row_index, shell_tensor): + """Add a gradient row to the hash map if it does not exist, otherwise + add the gradient to the existing row. + """ + c = self.grad_map_count.lookup(row_index) + + if c == tf.constant(0, dtype=tf.int64): + self.grad_map_count.insert(row_index, 1) + self.grad_map.insert(row_index, shell_tensor._raw_tensor) + else: + self.grad_map_count.insert(row_index, c + 1) + + raw_row = self.grad_map.lookup(row_index) + # raw_row shape: (batch_size, output_dim) + + # First create a ShellTensor from the variant raw_tensor + # stored in the hash map. + shell_row = tf_shell.ShellTensor64( + _raw_tensor=raw_row, + _context=self._shell_context, + _underlying_dtype=self._underlying_dtype, + _scaling_factor=self._scaling_factor, + _is_enc=self._is_enc, + _noise_bit_count=self._orig_noise_bit_count + tf.cast(c, tf.int32), + ) + row_sum = shell_row + shell_tensor + self.grad_map.insert(row_index, row_sum._raw_tensor) + + def backward_accum(self, dy, rotation_key): + """Accumulate the gradients for the embedding layer. Unlike the + other layers backward() methods, this does not return gradients. + + The tricky party about backpropagation of encrypted gradients through + the embedding layer is that each sample in dy must be applied to a + row of the embedding matrix. This requires splitting the batch-axis + packing of dy. + + The strategy is to copy dy `batch_size` times, and rotate each so b'th + copy has sample b in packing position zero. Next, the samples are + grouped by which row of the embedding matrix they are applied to + (which is based on self._layer_input). + """ + + self._shell_context = dy._context + self._underlying_dtype = dy._underlying_dtype + self._scaling_factor = dy._scaling_factor + self._is_enc = dy._is_enc + self._orig_noise_bit_count = dy._noise_bit_count + # dy = tf_shell.to_shell_plaintext(dy, self.shell_context) + + if self.grad_map is None: + self._reset_grad_map() + + # tfshell uses batch axis packing (two batches per ciphertext). + batch_size = self._layer_input.shape[0] // 2 + + # if len(dy.shape) == tf.constant(2): + # skip_word_dim = True + # sentence_len = 1 + # else: + # skip_word_dim = False + # sentence_len = dy.shape[1] + + if dy.ndim != self._layer_input.ndim + 1: + raise ValueError( + f"Embedding layer dy ndims exptected {self._layer_input.ndim + 1}. Got {dy}." + ) + sentence_len = dy.shape[1] + + for word in tf.range(sentence_len): + # For every sample in the batch, rotate dy so that the sample's grad is + # in the first position, and sum based on layer_input. Store the running + # total in grad_map. + for b in tf.range(batch_size): + # if skip_word_dim: + # dy_b = tf_shell.roll(dy, -b, rotation_key) + # else: + dy_b = tf_shell.roll(dy[:, word], -b, rotation_key) + + # Examine the row of `_layer input` to determine which row of + # the embedding matrix to apply the gradient to. Note this is a + # scalar of type integer and there is one for each batch (and + # two batches per ciphertext). + # if skip_word_dim: + # row_index_bottom = self._layer_input[b] + # row_index_top = self._layer_input[b + batch_size] + # else: + row_index_bottom = self._layer_input[b, word] + row_index_top = self._layer_input[b + batch_size, word] + + self._add_or_insert_grad_row(row_index_bottom, dy_b) + self._add_or_insert_grad_row(row_index_top, dy_b) + + def _lookup_shell_tensor(self, row_index, count, secret_key): + grad_row = self.grad_map.lookup(row_index) + + # Turn the grad variant tensor back into a ShellTensor64. + grad_row = tf_shell.ShellTensor64( + _raw_tensor=grad_row, + _context=self._shell_context, + _underlying_dtype=self._underlying_dtype, + _scaling_factor=self._scaling_factor, + _is_enc=self._is_enc, + _noise_bit_count=self._orig_noise_bit_count + tf.cast(count, tf.int32), + ) + + # Decrypt and expand dims to [1, output_dim] + grad_row = tf_shell.to_tensorflow(grad_row, secret_key) + grad_row = tf.expand_dims(grad_row[0, :], axis=0) + return grad_row + + def decrypt_grad(self, secret_key): + """Get the accumulated gradients in matrix form which can then be + applied to the weights. This method returns two gradients. + """ + + if self.grad_map is None: + return None + + # Start building the gradient tensor. + c = self.grad_map_count.lookup(0) + if c == tf.constant(0, dtype=tf.int64): + # Initialize if just begining and no grad update for the + # first row exists. + grads = tf.zeros([1, self.output_dim], dtype=self.embeddings.dtype) + else: + grads = self._lookup_shell_tensor(0, c, secret_key) + + # Build up the gradient tensor row by row where each row corresponds to + # a sample in the batch. + for i in tf.range(1, self.input_dim, dtype=tf.int64): + tf.autograph.experimental.set_loop_options( + maximum_iterations=self.input_dim, + shape_invariants=[(grads, tf.TensorShape([None, self.output_dim]))] + ) + + c = self.grad_map_count.lookup(i) + + # If no gradient was accumulated for this row, add a row of zeros. + if c == tf.constant(0, dtype=tf.int64): + grads = tf.concat([grads, tf.zeros([1, self.output_dim])], axis=0) + else: + grad_row = self._lookup_shell_tensor(i, c, secret_key) + grads = tf.concat([grads, grad_row], axis=0) + + # Reset the gradient accumulator and return. + self.grad_map = None + return grads diff --git a/tf_shell_ml/test/BUILD b/tf_shell_ml/test/BUILD index d3ac2fb..ec53c59 100644 --- a/tf_shell_ml/test/BUILD +++ b/tf_shell_ml/test/BUILD @@ -38,4 +38,14 @@ py_test( "//tf_shell_ml", requirement("tensorflow-cpu"), ], +) + +py_test( + name = "embedding_test", + size = "medium", + srcs = ["embedding_test.py"], + deps = [ + "//tf_shell_ml", + requirement("tensorflow-cpu"), + ], ) \ No newline at end of file diff --git a/tf_shell_ml/test/embedding_test.py b/tf_shell_ml/test/embedding_test.py new file mode 100644 index 0000000..ce54712 --- /dev/null +++ b/tf_shell_ml/test/embedding_test.py @@ -0,0 +1,118 @@ +#!/usr/bin/python +# +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +import tensorflow as tf +import keras +import numpy as np +import tf_shell +import tf_shell_ml + +# # Num plaintext bits: 32, noise bits: 84 +# # Max representable value: 654624 +# context = tf_shell.create_context64( +# log_n=11, +# main_moduli=[288230376151748609, 144115188076060673], +# plaintext_modulus=4294991873, +# scaling_factor=3, +# mul_depth_supported=3, +# seed="test_seed", +# ) +# 61 bits of security according to lattice estimator primal_bdd. +# Runtime 170 seconds (83ms/example). + +# Num plaintext bits: 32, noise bits: 84 +# Max representable value: 654624 +context = tf_shell.create_context64( + log_n=12, + main_moduli=[288230376151760897, 288230376152137729], + plaintext_modulus=4294991873, + scaling_factor=3, + mul_depth_supported=3, + seed="test_seed", +) +# 120 bits of security according to lattice estimator primal_bdd. +# Runtime 388 seconds (95ms/example). + +key = tf_shell.create_key64(context) +rotation_key = tf_shell.create_rotation_key64(context, key) + + +class TestEmbedding(tf.test.TestCase): + def test_embedding_forward(self): + input_dim = 100 + output_dim = 10 + embedding_layer = tf_shell_ml.ShellEmbedding(input_dim, output_dim) + + # First check plaintext forward pass. + x = tf.zeros((context.num_slots, 5), dtype=tf.int64) + y = embedding_layer(x) + + # Check that the output is the same for the same inputs. + for i in range(1, context.num_slots): + self.assertAllEqual(y[0], y[i]) + + # Next check encrypted forward pass throws an error. + enc_x = tf_shell.to_encrypted(x, key, context) + try: + should_fail = embedding_layer(enc_x) + except: + pass + else: + raise ValueError( + "Embedding layer forward with encrypted value should fail." + ) + + def _test_embedding(self): + input_dim = 100 + output_dim = 10 + embedding_layer = tf_shell_ml.ShellEmbedding(input_dim, output_dim) + + special_index = 2 + x = tf.ones((context.num_slots, 1), dtype=tf.int64) * special_index + + @tf.function + def forward_backward(x): + y = embedding_layer(x) + + dy = tf.ones_like(y) + enc_dy = tf_shell.to_encrypted(dy, key, context) + + embedding_layer.backward_accum(enc_dy, rotation_key) + dx = embedding_layer.decrypt_grad(key) + return dx + + dx = forward_backward(x) + + self.assertAllEqual(dx[special_index,:], tf.constant(context.num_slots, shape=(output_dim,))) + + # Make sure the rest of the gradient elements are 0 + for i in range(0, input_dim): + if i == special_index: + self.assertAllEqual(dx[special_index,:], tf.constant(context.num_slots, shape=(output_dim,))) + else: + self.assertAllEqual(dx[i,:], tf.constant(0, shape=(output_dim,))) + + def test_embedding_eager(self): + tf.config.run_functions_eagerly(True) + self._test_embedding() + + def test_embedding_defer(self): + tf.config.run_functions_eagerly(False) + self._test_embedding() + + +if __name__ == "__main__": + unittest.main() From 43da433839e871cf86a6da9b78a7fd894824f5e5 Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Wed, 29 May 2024 06:24:22 +0000 Subject: [PATCH 16/22] Simplify ml tests and examples. --- .devcontainer/devcontainer.json | 15 ++---- examples/label_dp_sgd.ipynb | 36 +++++--------- examples/label_dp_sgd_post_scale.ipynb | 58 +++++------------------ tf_shell_ml/test/mnist_post_scale_test.py | 37 +++++---------- 4 files changed, 40 insertions(+), 106 deletions(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 3ce30e0..9e0cd89 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -8,11 +8,8 @@ } }, - // Configure tool-specific properties. "customizations": { - // Configure properties specific to VS Code. "vscode": { - // Add the IDs of extensions you want installed when the container is created. "extensions": [ "ms-python.python", "ms-vscode.cpptools-extension-pack", @@ -20,19 +17,13 @@ "minherz.copyright-inserter", "DavidAnson.vscode-markdownlint", "yzhang.markdown-all-in-one", - "ms-python.black-formatter" + "ms-python.black-formatter", + "ms-toolsai.jupyter", + "ms-toolsai.tensorboard" ] } }, "mounts": [ "source=/tmp,target=/tmp,type=bind,consistency=cached" ], - - // Use 'forwardPorts' to make a list of ports inside the container available locally. - // "forwardPorts": [], - // Use 'postCreateCommand' to run commands after the container is created. - //"postCreateCommand": "echo hi", - // Uncomment when using a ptrace-based debugger like C++, Go, and Rust - // "runArgs": [ "--cap-add=SYS_PTRACE", "--security-opt", "seccomp=unconfined" ], - // Set `remoteUser` to `root` to connect as root instead. More info: https://aka.ms/vscode-remote/containers/non-root. "remoteUser": "vscode" } diff --git a/examples/label_dp_sgd.ipynb b/examples/label_dp_sgd.ipynb index 5dd5caa..341eeda 100644 --- a/examples/label_dp_sgd.ipynb +++ b/examples/label_dp_sgd.ipynb @@ -190,7 +190,7 @@ " #\n", " # Only return the weight gradients at [0], not the bias gradients at [1].\n", " # The bias is not used in this test.\n", - " return dJ_dw1[0], dJ_dw0[0]\n", + " return [dJ_dw1[0], dJ_dw0[0]]\n", "\n", "\n", "@tf.function\n", @@ -204,32 +204,22 @@ " # Run the training step. The top and bottom halves of the batch are\n", " # treated as two separate mini-batches run in parallel to maximize\n", " # efficiency.\n", - " enc_output_layer_grad, enc_hidden_layer_grad = train_step(x_batch, enc_y_batch)\n", + " enc_grads = train_step(x_batch, enc_y_batch)\n", "\n", " # Decrypt the weight gradients. In practice, the gradients should be\n", " # noised before decrypting.\n", - " repeated_output_layer_grad = tf_shell.to_tensorflow(\n", - " enc_output_layer_grad, secret_key\n", - " )\n", - " repeated_hidden_layer_grad = tf_shell.to_tensorflow(\n", - " enc_hidden_layer_grad, secret_key\n", - " )\n", + " repeated_grads = [tf_shell.to_tensorflow(g, secret_key) for g in enc_grads]\n", "\n", - " # Apply the gradients to the model. We choose the first dimension at\n", - " # index 0 arbitrarily. The weight gradients are repeated across the\n", - " # first dimension. See note in train_step for more information.\n", - " optimizer.apply_gradients(\n", - " zip(\n", - " [repeated_output_layer_grad[0], repeated_hidden_layer_grad[0]],\n", - " output_layer.weights + hidden_layer.weights,\n", - " )\n", - " )\n", - " optimizer.apply_gradients(\n", - " zip(\n", - " [repeated_output_layer_grad[batch_size // 2], repeated_hidden_layer_grad[batch_size // 2]],\n", - " output_layer.weights + hidden_layer.weights,\n", - " )\n", - " )" + " # Pull out grads from the top and bottom batches.\n", + " top_grad = [g[0] for g in repeated_grads]\n", + " bottom_grad = [g[batch_size // 2] for g in repeated_grads]\n", + "\n", + " # Decrypt the weight gradients. In practice, the gradients should be\n", + " # noised before decrypting.\n", + " weights = output_layer.weights + hidden_layer.weights\n", + "\n", + " optimizer.apply_gradients(zip(top_grad, weights))\n", + " optimizer.apply_gradients(zip(bottom_grad, weights))" ] }, { diff --git a/examples/label_dp_sgd_post_scale.ipynb b/examples/label_dp_sgd_post_scale.ipynb index 4ad4148..0971889 100644 --- a/examples/label_dp_sgd_post_scale.ipynb +++ b/examples/label_dp_sgd_post_scale.ipynb @@ -40,8 +40,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-05-08 03:16:15.442044: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", - "2024-05-08 03:16:15.593868: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "2024-05-28 19:37:40.084959: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2024-05-28 19:37:40.228766: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" ] } @@ -246,10 +246,8 @@ " ps_grads = train_step(x_batch, enc_y_batch)\n", "\n", " # Decrypt\n", - " grads = []\n", - " for enc_g in ps_grads:\n", - " grads.append(tf_shell.to_tensorflow(enc_g, secret_key)[0])\n", - " # ^ take the first element because the grad sum is repeated over the batching dim.\n", + " grads = [tf_shell.to_tensorflow(enc_g, secret_key)[0] for enc_g in ps_grads]\n", + " # ^ take the first element with [0] because the grad sum is repeated over the batching dim.\n", "\n", " model.optimizer.apply_gradients(\n", " zip(\n", @@ -282,15 +280,15 @@ "To start tensorboard, run: tensorboard --logdir /tmp/tflogs\n", "\n", "Start of epoch 0\n", - "Epoch: 0, Batch: 0 / 30, Time Stamp: 0.07824254035949707\n" + "Epoch: 0, Batch: 0 / 30, Time Stamp: 0.06798219680786133\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-05-08 03:16:18.337431: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:104] Profiler session initializing.\n", - "2024-05-08 03:16:18.337456: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:119] Profiler session started.\n" + "2024-05-28 19:37:42.830181: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:104] Profiler session initializing.\n", + "2024-05-28 19:37:42.830206: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:119] Profiler session started.\n" ] }, { @@ -298,54 +296,24 @@ "output_type": "stream", "text": [ "WARNING:tensorflow:Error while stopping profiler: Cannot export profiling results. No profiler is running.\n", - "\taccuracy: 0.13440264761447906\n" + "\taccuracy: 0.1305309683084488\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-05-08 03:17:43.002417: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:70] Profiler session collecting data.\n", - "2024-05-08 03:17:43.013809: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:131] Profiler session tear down.\n", - "2024-05-08 03:17:43.015068: I external/local_tsl/tsl/profiler/rpc/client/save_profile.cc:144] Collecting XSpace to repository: /tmp/tflogs/pt-20240508-031618/plugins/profile/2024_05_08_03_17_43/e81647a0f462.xplane.pb\n", - "2024-05-08 03:17:43.097148: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-05-28 19:39:03.997126: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:70] Profiler session collecting data.\n", + "2024-05-28 19:39:04.007926: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:131] Profiler session tear down.\n", + "2024-05-28 19:39:04.008962: I external/local_tsl/tsl/profiler/rpc/client/save_profile.cc:144] Collecting XSpace to repository: /tmp/tflogs/pt-20240528-193742/plugins/profile/2024_05_28_19_39_04/e81647a0f462.xplane.pb\n", + "2024-05-28 19:39:04.061259: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 1 / 30, Time Stamp: 84.8675389289856\n", - "\taccuracy: 0.14435841143131256\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-05-08 03:18:51.448626: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 2 / 30, Time Stamp: 153.19263339042664\n", - "\taccuracy: 0.15597344934940338\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-05-08 03:19:58.580131: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 0, Batch: 3 / 30, Time Stamp: 220.32476949691772\n" + "Epoch: 0, Batch: 1 / 30, Time Stamp: 81.33036208152771\n" ] }, { diff --git a/tf_shell_ml/test/mnist_post_scale_test.py b/tf_shell_ml/test/mnist_post_scale_test.py index 9f92e5c..f3cb8d5 100644 --- a/tf_shell_ml/test/mnist_post_scale_test.py +++ b/tf_shell_ml/test/mnist_post_scale_test.py @@ -156,14 +156,20 @@ def train_step(x, y): class TestPlaintextPostScale(tf.test.TestCase): - def test_mnist_post_scale_eager(self): - tf.config.run_functions_eagerly(True) + def _test_mnist_post_scale(self, eager_mode): + tf.config.run_functions_eagerly(eager_mode) (x_batch, y_batch) = next(iter(train_dataset)) # Plaintext ps_grads = train_step(x_batch, y_batch) + if not eager_mode: + # With autograph on (eagerly off), the tf.function trace cannot be + # reused between plaintext and encrypted calls. Reset the graph + # between plaintext and encrypted train_step() calls. + tf.keras.backend.clear_session() + # Encrypted enc_y_batch = tf_shell.to_encrypted(y_batch, key, context) shell_ps_grads = train_step(x_batch, enc_y_batch) @@ -175,31 +181,10 @@ def test_mnist_post_scale_eager(self): atol=1 / context.scaling_factor * context.num_slots, ) - -class TestPlaintextPostScale(tf.test.TestCase): - def test_mnist_post_scale_autograph(self): - tf.config.run_functions_eagerly(False) - - (x_batch, y_batch) = next(iter(train_dataset)) - - # Plaintext - ps_grads = train_step(x_batch, y_batch) - - # With autograph on (eagerly off), the tf.function trace cannot be - # reused between plaintext and encrypted calls. Reset the graph - # between plaintext and encrypted train_step() calls. + def test_mnist_post_scale(self): + self._test_mnist_post_scale(eager_mode=False) tf.keras.backend.clear_session() - - # Encrypted - enc_y_batch = tf_shell.to_encrypted(y_batch, key, context) - shell_ps_grads = train_step(x_batch, enc_y_batch) - - # Compare the gradients. - self.assertAllClose( - ps_grads, - shell_ps_grads, - atol=1 / context.scaling_factor * context.num_slots, - ) + self._test_mnist_post_scale(eager_mode=True) if __name__ == "__main__": From f533b16c36aa2a01d32e048861710df221b734f0 Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Mon, 10 Jun 2024 20:41:26 +0000 Subject: [PATCH 17/22] tf_shell supports broadcast_to. --- tf_shell/__init__.py | 1 + tf_shell/python/shell_tensor.py | 31 +++++++++++++++++++++++++++++-- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/tf_shell/__init__.py b/tf_shell/__init__.py index f564bdb..7c5a075 100644 --- a/tf_shell/__init__.py +++ b/tf_shell/__init__.py @@ -25,6 +25,7 @@ from tf_shell.python.shell_tensor import matmul from tf_shell.python.shell_tensor import expand_dims from tf_shell.python.shell_tensor import reshape +from tf_shell.python.shell_tensor import broadcast_to from tf_shell.python.shell_context import ShellContext64 from tf_shell.python.shell_context import create_context64 diff --git a/tf_shell/python/shell_tensor.py b/tf_shell/python/shell_tensor.py index 029ea79..1d81c40 100644 --- a/tf_shell/python/shell_tensor.py +++ b/tf_shell/python/shell_tensor.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math import tensorflow as tf import tf_shell.python.ops.shell_ops as shell_ops from tf_shell.python.shell_context import ShellContext64 @@ -35,6 +34,10 @@ class ShellTensor64(tf.experimental.ExtensionType): def shape(self): return [self._context.num_slots] + self._raw_tensor.shape + @property + def ndim(self): + return self._raw_tensor.ndim + 1 + @property def dtype(self): return tf.variant @@ -690,12 +693,16 @@ def matmul(x, y, rotation_key=None): """Matrix multiplication is specialized to whether the operands are plaintext or ciphertext. - matmul(ciphertext, plaintext) works as in Tensorflow. + matmul(ciphertext, plaintext) works the same way as Tensorflow. matmul(plaintext, ciphertext) in tf-shell has slightly different semantics than plaintext / Tensorflow. tf-shell affects top and bottom halves independently, as well as the first dimension repeating the sum of either the halves.""" + + if len(x.shape) < 2 or len(y.shape) < 2: + raise ValueError(f"matmul not supported for tensors with rank < 2. Got {x.shape} and {y.shape}.") + if isinstance(x, ShellTensor64) and isinstance(y, tf.Tensor): if x._underlying_dtype != y.dtype: raise ValueError( @@ -822,3 +829,23 @@ def reshape(x, shape): return tf.reshape(x, shape) else: raise ValueError("Unsupported type for expand_dims") + +def broadcast_to(x, shape): + if isinstance(x, ShellTensor64): + if shape[0] != x._context.num_slots: + raise ValueError( + "Cannot broadcast_to over axis 0 for ShellTensor64, this is the batching dimension." + ) + + return ShellTensor64( + _raw_tensor=tf.broadcast_to(x._raw_tensor, shape[1:]), + _context=x._context, + _underlying_dtype=x._underlying_dtype, + _scaling_factor=x._scaling_factor, + _is_enc=x._is_enc, + _noise_bit_count=x._noise_bit_count, + ) + elif isinstance(x, tf.Tensor): + return tf.broadcast_to(x, shape) + else: + raise ValueError("Unsupported type for expand_dims") From 829d8cf6c8a485700055c2636a61846cbc78114c Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Mon, 10 Jun 2024 20:41:47 +0000 Subject: [PATCH 18/22] BinaryCrossentropy loss function. --- tf_shell_ml/__init__.py | 1 + tf_shell_ml/loss.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/tf_shell_ml/__init__.py b/tf_shell_ml/__init__.py index ed0a1cc..adb132b 100644 --- a/tf_shell_ml/__init__.py +++ b/tf_shell_ml/__init__.py @@ -20,3 +20,4 @@ from tf_shell_ml.dropout import ShellDropout from tf_shell_ml.embedding import ShellEmbedding from tf_shell_ml.loss import CategoricalCrossentropy +from tf_shell_ml.loss import BinaryCrossentropy diff --git a/tf_shell_ml/loss.py b/tf_shell_ml/loss.py index 6520254..ef1cb92 100644 --- a/tf_shell_ml/loss.py +++ b/tf_shell_ml/loss.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from tensorflow.nn import softmax +from tensorflow.nn import sigmoid from tensorflow.math import log import tf_shell @@ -48,3 +49,37 @@ def grad(self, y_true, y_pred): # grad = grad * batch_size_inv raise NotImplementedError("Multiply by scalar op not implemented.") return grad + + +class BinaryCrossentropy: + + def __init__(self, from_logits=False, lazy_normalization=True): + self.from_logits = from_logits + self.lazy_normalization = lazy_normalization + + def __call__(self, y_true, y_pred): + if self.from_logits: + y_pred = sigmoid(y_pred) + + batch_size = y_true.shape.as_list()[0] + batch_size_inv = 1 / batch_size + out = -(y_true * log(y_pred) + (1 - y_true) * log(1 - y_pred)) + bce = tf_shell.reduce_sum(out, axis=0) * batch_size_inv + return bce + + def grad(self, y_true, y_pred): + if self.from_logits: + y_pred = sigmoid(y_pred) + + # When using deferred execution, we need to use the __rsub__ method + # otherwise it tries to go through the y tensors __sub__ method which + # fails when y_true is encrypted (a ShellTensor64). + # grad = y_pred - y_true + grad = y_true.__rsub__(y_pred) + + if not self.lazy_normalization: + # batch_size = y_true.shape.as_list()[0] + # batch_size_inv = 1 / batch_size + # grad = grad * batch_size_inv + raise NotImplementedError("Multiply by scalar op not implemented.") + return grad From c7e13a28499580f8ed8c31edeaec74bc975884c0 Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Mon, 10 Jun 2024 20:42:30 +0000 Subject: [PATCH 19/22] Op level broadcasting does not require copy. --- tf_shell/cc/kernels/add_kernels.cc | 45 +++++------ tf_shell/cc/kernels/mul_kernels.cc | 116 +++++++++++++++-------------- tf_shell/cc/kernels/utils.h | 111 +++++++++++++++------------ tf_shell/test/mul_test.py | 2 +- 4 files changed, 145 insertions(+), 129 deletions(-) diff --git a/tf_shell/cc/kernels/add_kernels.cc b/tf_shell/cc/kernels/add_kernels.cc index 0518e66..e7c924a 100644 --- a/tf_shell/cc/kernels/add_kernels.cc +++ b/tf_shell/cc/kernels/add_kernels.cc @@ -106,13 +106,10 @@ class AddCtCtOp : public OpKernel { op_ctx, bcast.IsValid(), InvalidArgument("Invalid broadcast between ", a.shape().DebugString(), " and ", b.shape().DebugString())); - auto flat_a = MyBFlat(op_ctx, a, bcast.x_reshape(), bcast.x_bcast()); - auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); - - // Check the inputs have the same shape. - OP_REQUIRES( - op_ctx, flat_a.size() == flat_b.size(), - InvalidArgument("Broadcasted inputs must have the same shape.")); + auto flat_a = a.flat(); + auto flat_b = b.flat(); + IndexConverterFunctor a_bcaster(bcast.output_shape(), a.shape()); + IndexConverterFunctor b_bcaster(bcast.output_shape(), b.shape()); // Allocate the output tensor which is the same size as one of the inputs. Tensor* output; @@ -122,14 +119,14 @@ class AddCtCtOp : public OpKernel { for (int i = 0; i < flat_output.dimension(0); ++i) { SymmetricCtVariant const* ct_a_var = - std::move(flat_a(i).get>()); + std::move(flat_a(a_bcaster(i)).get>()); OP_REQUIRES(op_ctx, ct_a_var != nullptr, InvalidArgument("SymmetricCtVariant at flat index: ", i, " for input a did not unwrap successfully.")); SymmetricCt const& ct_a = ct_a_var->ct; SymmetricCtVariant const* ct_b_var = - std::move(flat_b(i).get>()); + std::move(flat_b(b_bcaster(i)).get>()); OP_REQUIRES(op_ctx, ct_b_var != nullptr, InvalidArgument("SymmetricCtVariant at flat index: ", i, " for input b did not unwrap successfully.")); @@ -166,13 +163,10 @@ class AddCtPtOp : public OpKernel { op_ctx, bcast.IsValid(), InvalidArgument("Invalid broadcast between ", a.shape().DebugString(), " and ", b.shape().DebugString())); - auto flat_a = MyBFlat(op_ctx, a, bcast.x_reshape(), bcast.x_bcast()); - auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); - - // Check the inputs have the same shape. - OP_REQUIRES( - op_ctx, flat_a.size() == flat_b.size(), - InvalidArgument("Broadcasted inputs must have the same shape.")); + auto flat_a = a.flat(); + auto flat_b = b.flat(); + IndexConverterFunctor a_bcaster(bcast.output_shape(), a.shape()); + IndexConverterFunctor b_bcaster(bcast.output_shape(), b.shape()); // Allocate the output tensor which is the same size as one of the inputs. Tensor* output; @@ -182,14 +176,14 @@ class AddCtPtOp : public OpKernel { for (int i = 0; i < flat_output.dimension(0); ++i) { SymmetricCtVariant const* ct_a_var = - std::move(flat_a(i).get>()); + std::move(flat_a(a_bcaster(i)).get>()); OP_REQUIRES(op_ctx, ct_a_var != nullptr, InvalidArgument("SymmetricCtVariant at flat index: ", i, " for input a did not unwrap successfully.")); SymmetricCt const& ct_a = ct_a_var->ct; PolynomialVariant const* pv_b_var = - std::move(flat_b(i).get>()); + std::move(flat_b(b_bcaster(i)).get>()); OP_REQUIRES(op_ctx, pv_b_var != nullptr, InvalidArgument("PolynomialVariant at flat index: ", i, " for input b did not unwrap successfully.")); @@ -229,13 +223,10 @@ class AddPtPtOp : public OpKernel { op_ctx, bcast.IsValid(), InvalidArgument("Invalid broadcast between ", a.shape().DebugString(), " and ", b.shape().DebugString())); - auto flat_a = MyBFlat(op_ctx, a, bcast.x_reshape(), bcast.x_bcast()); - auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); - - // Check the inputs have the same shape. - OP_REQUIRES( - op_ctx, flat_a.size() == flat_b.size(), - InvalidArgument("Broadcasted inputs must have the same shape.")); + auto flat_a = a.flat(); + auto flat_b = b.flat(); + IndexConverterFunctor a_bcaster(bcast.output_shape(), a.shape()); + IndexConverterFunctor b_bcaster(bcast.output_shape(), b.shape()); // Allocate the output tensor which is the same size as one of the inputs. Tensor* output; @@ -245,14 +236,14 @@ class AddPtPtOp : public OpKernel { for (int i = 0; i < flat_output.dimension(0); ++i) { PolynomialVariant const* pv_a_var = - std::move(flat_a(i).get>()); + std::move(flat_a(a_bcaster(i)).get>()); OP_REQUIRES(op_ctx, pv_a_var != nullptr, InvalidArgument("PolynomialVariant at flat index: ", i, " for input a did not unwrap successfully.")); RnsPolynomial const& pt_a = pv_a_var->poly; PolynomialVariant const* pv_b_var = - std::move(flat_b(i).get>()); + std::move(flat_b(b_bcaster(i)).get>()); OP_REQUIRES(op_ctx, pv_b_var != nullptr, InvalidArgument("PolynomialVariant at flat index: ", i, " for input b did not unwrap successfully.")); diff --git a/tf_shell/cc/kernels/mul_kernels.cc b/tf_shell/cc/kernels/mul_kernels.cc index 4d36122..8c0e1f6 100644 --- a/tf_shell/cc/kernels/mul_kernels.cc +++ b/tf_shell/cc/kernels/mul_kernels.cc @@ -64,13 +64,10 @@ class MulCtCtOp : public OpKernel { op_ctx, bcast.IsValid(), InvalidArgument("Invalid broadcast between ", a.shape().DebugString(), " and ", b.shape().DebugString())); - auto flat_a = MyBFlat(op_ctx, a, bcast.x_reshape(), bcast.x_bcast()); - auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); - - // Check the inputs have the same shape. - OP_REQUIRES( - op_ctx, flat_a.size() == flat_b.size(), - InvalidArgument("Broadcasted inputs must have the same shape.")); + auto flat_a = a.flat(); + auto flat_b = b.flat(); + IndexConverterFunctor a_bcaster(bcast.output_shape(), a.shape()); + IndexConverterFunctor b_bcaster(bcast.output_shape(), b.shape()); // Allocate the output tensor which is the same shape as each of the inputs. Tensor* output; @@ -81,14 +78,14 @@ class MulCtCtOp : public OpKernel { // Multiply each pair of ciphertexts and store the result in the output. for (int i = 0; i < flat_output.dimension(0); ++i) { SymmetricCtVariant const* ct_a_var = - std::move(flat_a(i).get>()); + std::move(flat_a(a_bcaster(i)).get>()); OP_REQUIRES(op_ctx, ct_a_var != nullptr, InvalidArgument("SymmetricCtVariant at flat index:", i, " for input a did not unwrap successfully.")); SymmetricCt const& ct_a = ct_a_var->ct; SymmetricCtVariant const* ct_b_var = - std::move(flat_b(i).get>()); + std::move(flat_b(b_bcaster(i)).get>()); OP_REQUIRES(op_ctx, ct_b_var != nullptr, InvalidArgument("SymmetricCtVariant at flat index:", i, " for input b did not unwrap successfully.")); @@ -124,13 +121,10 @@ class MulCtPtOp : public OpKernel { op_ctx, bcast.IsValid(), InvalidArgument("Invalid broadcast between ", a.shape().DebugString(), " and ", b.shape().DebugString())); - auto flat_a = MyBFlat(op_ctx, a, bcast.x_reshape(), bcast.x_bcast()); - auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); - - // Check the inputs have the same shape. - OP_REQUIRES( - op_ctx, flat_a.size() == flat_b.size(), - InvalidArgument("Broadcasted inputs must have the same shape.")); + auto flat_a = a.flat(); + auto flat_b = b.flat(); + IndexConverterFunctor a_bcaster(bcast.output_shape(), a.shape()); + IndexConverterFunctor b_bcaster(bcast.output_shape(), b.shape()); // Allocate the output tensor which is the same shape as each of the inputs. Tensor* output; @@ -138,32 +132,52 @@ class MulCtPtOp : public OpKernel { OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, output_shape, &output)); auto flat_output = output->flat(); - for (int i = 0; i < flat_output.dimension(0); ++i) { - SymmetricCtVariant const* ct_a_var = - std::move(flat_a(i).get>()); - OP_REQUIRES(op_ctx, ct_a_var != nullptr, - InvalidArgument("SymmetricCtVariant at flat index:", i, - " for input a did not unwrap successfully.")); - SymmetricCt const& ct_a = ct_a_var->ct; + // Recover num_slots from first ciphertext. + SymmetricCtVariant const* ct_var = + std::move(flat_a(0).get>()); + OP_REQUIRES( + op_ctx, ct_var != nullptr, + InvalidArgument("SymmetricCtVariant a did not unwrap successfully.")); + SymmetricCt const& ct = ct_var->ct; + int num_slots = 1 << ct.LogN(); + int num_components = ct.NumModuli(); - PolynomialVariant const* pv_b_var = - std::move(flat_b(i).get>()); - OP_REQUIRES(op_ctx, pv_b_var != nullptr, - InvalidArgument("PolynomialVariant at flat index:", i, - " for input b did not unwrap successfully.")); - RnsPolynomial const& pt_b = pv_b_var->poly; + auto mul_in_range = [&](int start, int end) { + for (int i = start; i < end; ++i) { + SymmetricCtVariant const* ct_a_var = + std::move(flat_a(a_bcaster(i)).get>()); + OP_REQUIRES( + op_ctx, ct_a_var != nullptr, + InvalidArgument("SymmetricCtVariant at flat index:", i, + " for input a did not unwrap successfully.")); + SymmetricCt const& ct_a = ct_a_var->ct; - OP_REQUIRES_VALUE(SymmetricCt ct_c, op_ctx, - ct_a * pt_b); // shell absorb operation + PolynomialVariant const* pv_b_var = + std::move(flat_b(b_bcaster(i)).get>()); + OP_REQUIRES( + op_ctx, pv_b_var != nullptr, + InvalidArgument("PolynomialVariant at flat index:", i, + " for input b did not unwrap successfully.")); + RnsPolynomial const& pt_b = pv_b_var->poly; - SymmetricCtVariant ct_c_var(std::move(ct_c)); - flat_output(i) = std::move(ct_c_var); - } + OP_REQUIRES_VALUE(SymmetricCt ct_c, op_ctx, + ct_a * pt_b); // shell absorb operation + + SymmetricCtVariant ct_c_var(std::move(ct_c)); + flat_output(i) = std::move(ct_c_var); + } + }; + + auto thread_pool = + op_ctx->device()->tensorflow_cpu_worker_threads()->workers; + int const cost_per_mul = 30 * num_slots * num_components; + thread_pool->ParallelFor(flat_output.dimension(0), cost_per_mul, + mul_in_range); } }; -// This Op can multiply either a shell ciphertext or a plaintext polynomial by a -// plaintext scalar, depending on the class template. +// This Op can multiply either a shell ciphertext or a plaintext polynomial by +// a plaintext scalar, depending on the class template. template class MulShellTfScalarOp : public OpKernel { private: @@ -194,12 +208,8 @@ class MulShellTfScalarOp : public OpKernel { InvalidArgument("Invalid broadcast between ", a.shape().DebugString(), " and ", b.shape().DebugString())); auto flat_a = a.flat(); // a is not broadcasted, just b. - auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); - - // Check the inputs have the same shape. - OP_REQUIRES( - op_ctx, flat_a.size() == flat_b.size(), - InvalidArgument("Broadcasted inputs must have the same shape.")); + auto flat_b = b.flat(); + IndexConverterFunctor b_bcaster(bcast.output_shape(), b.shape()); // Allocate the output tensor which is the same shape as the first input. Tensor* output; @@ -211,7 +221,7 @@ class MulShellTfScalarOp : public OpKernel { // First encode the scalar b // TDOO(jchoncholas): encode all scalars at once beforehand. T wrapped_b; - EncodeScalar(op_ctx, flat_b(i), encoder, &wrapped_b); + EncodeScalar(op_ctx, flat_b(b_bcaster(i)), encoder, &wrapped_b); CtOrPolyVariant const* ct_or_pt_var = std::move(flat_a(i).get()); @@ -241,7 +251,7 @@ class MulShellTfScalarOp : public OpKernel { } } -private: + private: void EncodeScalar(OpKernelContext* op_ctx, PtT const& val, Encoder const* encoder, T* wrapped_val) { if constexpr (std::is_signed::value) { // SHELL is built on the assumption that the plaintext type (in this @@ -293,15 +303,13 @@ class MulPtPtOp : public OpKernel { op_ctx, bcast.IsValid(), InvalidArgument("Invalid broadcast between ", a.shape().DebugString(), " and ", b.shape().DebugString())); - auto flat_a = MyBFlat(op_ctx, a, bcast.x_reshape(), bcast.x_bcast()); - auto flat_b = MyBFlat(op_ctx, b, bcast.y_reshape(), bcast.y_bcast()); - - // Check the inputs have the same shape. - OP_REQUIRES( - op_ctx, flat_a.size() == flat_b.size(), - InvalidArgument("Broadcasted inputs must have the same shape.")); + auto flat_a = a.flat(); + auto flat_b = b.flat(); + IndexConverterFunctor a_bcaster(bcast.output_shape(), a.shape()); + IndexConverterFunctor b_bcaster(bcast.output_shape(), b.shape()); - // Allocate the output tensor which is the same shape as each of the inputs. + // Allocate the output tensor which is the same shape as each of the + // inputs. Tensor* output; TensorShape output_shape = BCast::ToShape(bcast.output_shape()); OP_REQUIRES_OK(op_ctx, op_ctx->allocate_output(0, output_shape, &output)); @@ -309,14 +317,14 @@ class MulPtPtOp : public OpKernel { for (int i = 0; i < flat_output.dimension(0); ++i) { PolynomialVariant const* pv_a_var = - std::move(flat_a(i).get>()); + std::move(flat_a(a_bcaster(i)).get>()); OP_REQUIRES(op_ctx, pv_a_var != nullptr, InvalidArgument("PolynomialVariant at flat index:", i, " for input a did not unwrap successfully.")); RnsPolynomial const& pt_a = pv_a_var->poly; PolynomialVariant const* pv_b_var = - std::move(flat_b(i).get>()); + std::move(flat_b(b_bcaster(i)).get>()); OP_REQUIRES(op_ctx, pv_b_var != nullptr, InvalidArgument("PolynomialVariant at flat index:", i, " for input b did not unwrap successfully.")); diff --git a/tf_shell/cc/kernels/utils.h b/tf_shell/cc/kernels/utils.h index 367975c..1d06aca 100644 --- a/tf_shell/cc/kernels/utils.h +++ b/tf_shell/cc/kernels/utils.h @@ -92,55 +92,72 @@ StatusOr GetVariant(OpKernelContext* ctx, int index) { return t; } -template -inline Eigen::Tensor BFlat( - OpKernelContext* op_ctx, Tensor const& t, BCast::Vec const& x_reshape, - BCast::Vec const& x_bcast) { - // A TensorFlow is a TTypes::Tensor (aka Eigen::TensorMap). - // Eigen::TensorMap is a view into an Eigen::Tensor. When performing a - // reshape, broadcast, or even eval on an Eigen::TensorMap, it cannot be - // assigned to another Eigen::TensorMap. This is why the following code - // assigns the result of the reshape to an Eigen::Tensor. - // - // For a demo, see https://godbolt.org/z/41xvWvb63 - typedef Eigen::Tensor - ETensor; - - ETensor reshaped_t = t.template shaped(x_reshape); - - ETensor broadcasted_t = - reshaped_t.broadcast(BCast::ToIndexArray(x_bcast)); - - return std::move( - broadcasted_t.reshape(BCast::ToIndexArray<1>({broadcasted_t.size()}))); -} +// A class to help switching indexing schemes from a broadcasted tensor to the +// underlying tensor which resides in memory. +// +// Background: +// A TensorFlow Tensor is of type TTypes::Tensor (aka Eigen::TensorMap). +// Eigen::TensorMap is like a view into an Eigen::Tensor. When performing a +// reshape, broadcast, or even an eval operation on an Eigen::TensorMap, it +// cannot be assigned to another Eigen::TensorMap. This means that broadcasting +// a tensor requires fully materializing it (i.e. a deep copy). This is slow an +// unnecessary. Instead, this class will switch the indexing scheme from the +// a broadcasted tensor to the underlying tensor. +// +// For a demo of why TensorMaps cannot avoid materializing after a broadcast +// operation, see https://godbolt.org/z/41xvWvb63. +// +// Say a tensor should be broadcasted from `underlying_shape` to `bc_shape`. +// When this class is called via the () operator, it returns the index into the +// underlying tensor computed from the flat broadcasted index `bc_flat_index`. +class IndexConverterFunctor { + public: + IndexConverterFunctor(BCast::Vec const& bc_shape, + tensorflow::TensorShape const& underlying_shape) + : bc_shape_(bc_shape), underlying_shape_(underlying_shape) { + if (BCast::ToShape(bc_shape) == underlying_shape) { + functor_ = &IndexConverterFunctor::identity; + } else { + functor_ = &IndexConverterFunctor::broadcastToUnderlyingIndex; + }; + } -template -inline Eigen::Tensor MyBFlat( - OpKernelContext* op_ctx, Tensor const& t, BCast::Vec const& x_reshape, - BCast::Vec const& x_bcast) { - // Uses the switch statement approach as in: - // `tensorflow/tensorflow/core/kernels/broadcast_to_op.h` - int const ndims = x_reshape.size(); - switch (ndims) { - case 1: - return std::move(BFlat(op_ctx, t, x_reshape, x_bcast)); - case 2: - return std::move(BFlat(op_ctx, t, x_reshape, x_bcast)); - case 3: - return std::move(BFlat(op_ctx, t, x_reshape, x_bcast)); - case 4: - return std::move(BFlat(op_ctx, t, x_reshape, x_bcast)); - case 5: - return std::move(BFlat(op_ctx, t, x_reshape, x_bcast)); - case 6: - return std::move(BFlat(op_ctx, t, x_reshape, x_bcast)); - default: - op_ctx->SetStatus(Unimplemented("Broadcast ", t.DebugString(), - " is not supported yet.")); - return std::move(BFlat(op_ctx, t, x_reshape, x_bcast)); + inline int operator()(int bc_flat_index) { + return std::invoke(functor_, this, bc_flat_index); } -} + + private: + int (IndexConverterFunctor::*functor_)(int); + BCast::Vec const& bc_shape_; + tensorflow::TensorShape const& underlying_shape_; + + int identity(int bc_flat_index) { return bc_flat_index; } + + int broadcastToUnderlyingIndex(int bc_flat_index) { + // First convert the flat indexing scheme to the output_shape. + auto const bc_ndims = bc_shape_.size(); + std::vector bc_full_index(bc_ndims); + for (int i = bc_ndims - 1; i >= 0; --i) { + bc_full_index[i] = bc_flat_index % bc_shape_[i]; + bc_flat_index /= bc_shape_[i]; + } + + // Undo the broadcasting. + assert(bcast.size() == bc_ndims); + for (size_t i = 0; i < bc_ndims; ++i) { + bc_full_index[i] %= underlying_shape_.dim_size(i); + } + + // Compute the flat index of the underlying shape. + int underlying_index = 0; + for (int i = 0; i < underlying_shape_.dims(); ++i) { + underlying_index *= underlying_shape_.dim_size(i); + underlying_index += bc_full_index[i]; + } + + return underlying_index; + } +}; // Status macros from // https://github.com/abseil/abseil-cpp/issues/976#issuecomment-1664601671 diff --git a/tf_shell/test/mul_test.py b/tf_shell/test/mul_test.py index 0ace622..1d831b4 100644 --- a/tf_shell/test/mul_test.py +++ b/tf_shell/test/mul_test.py @@ -199,7 +199,7 @@ def _test_ct_pt_mul_with_broadcast(self, test_context): self.assertAllClose(b, tf_shell.to_tensorflow(sb)) self.assertAllClose(a, tf_shell.to_tensorflow(ea, test_context.key)) - def test_ct_pt_mul(self): + def test_ct_pt_mul_with_broadcast(self): for test_context in self.test_contexts: with self.subTest( f"ct_pt_mul_with_broadcast with context `{test_context}`." From f6f9e508bc69149dc6f68cf9c067e33730875f07 Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Tue, 11 Jun 2024 00:22:44 +0000 Subject: [PATCH 20/22] Rerun examples with new op broadcasting. --- examples/.gitignore | 1 + examples/benchmark.ipynb | 309 ++++++++++++++ examples/intro.ipynb | 33 +- examples/label_dp_sgd.ipynb | 193 +++++---- examples/label_dp_sgd_post_scale.ipynb | 556 +++++++++++++++++++++---- examples/label_dp_sgd_sentiment.ipynb | 461 ++++++++++++++++++++ 6 files changed, 1370 insertions(+), 183 deletions(-) create mode 100644 examples/.gitignore create mode 100644 examples/benchmark.ipynb create mode 100644 examples/label_dp_sgd_sentiment.ipynb diff --git a/examples/.gitignore b/examples/.gitignore new file mode 100644 index 0000000..12a43fe --- /dev/null +++ b/examples/.gitignore @@ -0,0 +1 @@ +tflogs \ No newline at end of file diff --git a/examples/benchmark.ipynb b/examples/benchmark.ipynb new file mode 100644 index 0000000..924dc1b --- /dev/null +++ b/examples/benchmark.ipynb @@ -0,0 +1,309 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Introduction to tf-shell\n", + "\n", + "To get started, `pip install tf-shell`. tf-shell has a few modules, the one used\n", + "in this notebook is `tf_shell`." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 21:30:04.256734: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2024-06-10 21:30:04.257664: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2024-06-10 21:30:04.291533: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2024-06-10 21:30:04.428601: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "2024-06-10 21:30:05.195988: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" + ] + } + ], + "source": [ + "import tf_shell\n", + "import tensorflow as tf\n", + "import timeit\n", + "\n", + "context = tf_shell.create_context64(\n", + " log_n=10,\n", + " main_moduli=[8556589057, 8388812801],\n", + " plaintext_modulus=40961,\n", + " scaling_factor=3,\n", + " mul_depth_supported=3,\n", + " seed=\"test_seed\",\n", + ")\n", + "\n", + "secret_key = tf_shell.create_key64(context)\n", + "rotation_key = tf_shell.create_rotation_key64(context, secret_key)\n", + "\n", + "a = tf.random.uniform([context.num_slots, 55555], dtype=tf.float32, maxval=10)\n", + "b = tf.random.uniform([55555, 333], dtype=tf.float32, maxval=10)\n", + "c = tf.random.uniform([2, context.num_slots], dtype=tf.float32, maxval=10)\n", + "d = tf.random.uniform([context.num_slots, 4444], dtype=tf.float32, maxval=10)\n", + "\n", + "enc_a = tf_shell.to_encrypted(a, secret_key, context)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.4906675929996709\n" + ] + } + ], + "source": [ + "def to_pt():\n", + " return tf_shell.to_shell_plaintext(a, context)\n", + "\n", + "time = min(timeit.Timer(to_pt).repeat(repeat=3, number=1))\n", + "print(time)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5.263423050000711\n" + ] + } + ], + "source": [ + "def enc():\n", + " return tf_shell.to_encrypted(d, secret_key, context)\n", + "\n", + "time = min(timeit.Timer(enc).repeat(repeat=3, number=1))\n", + "print(time)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.5277276859997073\n" + ] + } + ], + "source": [ + "def dec():\n", + " return tf_shell.to_tensorflow(enc_a, secret_key)\n", + "\n", + "time = min(timeit.Timer(dec).repeat(repeat=3, number=1))\n", + "print(time)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.4192462440005329\n" + ] + } + ], + "source": [ + "def ct_ct_add():\n", + " return enc_a + enc_a\n", + "\n", + "time = min(timeit.Timer(ct_ct_add).repeat(repeat=3, number=1))\n", + "print(time)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.4219015720009338\n" + ] + } + ], + "source": [ + "def ct_ct_sub():\n", + " return enc_a - enc_a\n", + "\n", + "time = min(timeit.Timer(ct_ct_sub).repeat(repeat=3, number=1))\n", + "print(time)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.8668678089998139\n" + ] + } + ], + "source": [ + "def ct_ct_mul():\n", + " return enc_a * enc_a\n", + "\n", + "time = min(timeit.Timer(ct_ct_mul).repeat(repeat=3, number=1))\n", + "print(time)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.7579904609992809\n" + ] + } + ], + "source": [ + "def ct_pt_add():\n", + " return enc_a + a\n", + "\n", + "time = min(timeit.Timer(ct_pt_add).repeat(repeat=3, number=1))\n", + "print(time)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.6268679120003071\n" + ] + } + ], + "source": [ + "def ct_pt_mul():\n", + " return enc_a * a\n", + "\n", + "time = min(timeit.Timer(ct_pt_mul).repeat(repeat=3, number=1))\n", + "print(time)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "25.57404864599812\n" + ] + } + ], + "source": [ + "def ct_pt_matmul():\n", + " return tf_shell.matmul(enc_a, b)\n", + "\n", + "time = min(timeit.Timer(ct_pt_matmul).repeat(repeat=3, number=1))\n", + "print(time)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "361.1888753159983\n" + ] + } + ], + "source": [ + "def pt_ct_matmul():\n", + " return tf_shell.matmul(c, enc_a, rotation_key)\n", + "\n", + "time = min(timeit.Timer(pt_ct_matmul).repeat(repeat=3, number=1))\n", + "print(time)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4.650902364999638\n" + ] + } + ], + "source": [ + "def ct_roll():\n", + " return tf_shell.roll(enc_a, 2, rotation_key)\n", + "\n", + "time = min(timeit.Timer(ct_roll).repeat(repeat=3, number=1))\n", + "print(time)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/intro.ipynb b/examples/intro.ipynb index 775460b..b07cbed 100644 --- a/examples/intro.ipynb +++ b/examples/intro.ipynb @@ -12,9 +12,22 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 21:54:19.630781: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2024-06-10 21:54:19.631217: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2024-06-10 21:54:19.633550: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2024-06-10 21:54:19.663933: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "2024-06-10 21:54:20.301503: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" + ] + } + ], "source": [ "import tf_shell" ] @@ -38,7 +51,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -69,14 +82,14 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "The first 3 elements of the data are [0.62601924 4.461747 5.8008575 ]\n" + "The first 3 elements of the data are [0.05918741 3.8001454 5.9336624 ]\n" ] } ], @@ -99,7 +112,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -132,16 +145,16 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "enc: [0.6666667 4.3333335 5.6666665]\n", - "enc + enc: [ 1.3333334 8.666667 11.333333 ]\n", - "enc * enc: [ 0.44444445 18.777779 32.11111 ]\n" + "enc: [0. 3.6666667 6. ]\n", + "enc + enc: [ 0. 7.3333335 12. ]\n", + "enc * enc: [ 0. 13.444445 36. ]\n" ] } ], diff --git a/examples/label_dp_sgd.ipynb b/examples/label_dp_sgd.ipynb index 341eeda..1a75ba8 100644 --- a/examples/label_dp_sgd.ipynb +++ b/examples/label_dp_sgd.ipynb @@ -38,14 +38,18 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-05-08 01:19:38.518051: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", - "2024-05-08 01:19:38.539912: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", - "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" + "2024-06-10 22:24:25.445864: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2024-06-10 22:24:25.446300: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2024-06-10 22:24:25.448282: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2024-06-10 22:24:25.475776: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "2024-06-10 22:24:26.095565: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" ] } ], "source": [ "import time\n", + "import os\n", "from datetime import datetime\n", "import tensorflow as tf\n", "import keras\n", @@ -99,11 +103,16 @@ "x_train, x_test = x_train / np.float32(255.0), x_test / np.float32(255.0)\n", "y_train, y_test = tf.one_hot(y_train, 10), tf.one_hot(y_test, 10)\n", "\n", + "epochs = 1\n", "train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n", - "train_dataset = train_dataset.shuffle(buffer_size=2048).batch(batch_size)\n", + "train_dataset = (\n", + " train_dataset.shuffle(buffer_size=2048)\n", + " .batch(batch_size, drop_remainder=True)\n", + " .repeat(count=epochs)\n", + ")\n", "\n", "val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))\n", - "val_dataset = val_dataset.batch(batch_size)" + "val_dataset = val_dataset.batch(batch_size, drop_remainder=True)" ] }, { @@ -243,17 +252,16 @@ "output_type": "stream", "text": [ "To start tensorboard, run: tensorboard --logdir /tmp/tflogs\n", - "\n", - "Start of epoch 0\n", - "Epoch: 0, Batch: 0 / 15, Time Stamp: 0.06978559494018555\n" + "\ttensorboard profiling requires: pip install tensorboard_plugin_profile\n", + "Batch: 0 / 14, Time Stamp: 0.06959199905395508\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-05-08 01:19:54.275369: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:104] Profiler session initializing.\n", - "2024-05-08 01:19:54.275391: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:119] Profiler session started.\n" + "2024-06-10 22:24:40.385175: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:104] Profiler session initializing.\n", + "2024-06-10 22:24:40.385203: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:119] Profiler session started.\n" ] }, { @@ -261,274 +269,261 @@ "output_type": "stream", "text": [ "WARNING:tensorflow:Error while stopping profiler: Cannot export profiling results. No profiler is running.\n", - "\taccuracy: 0.15154866874217987\n" + "\taccuracy: 0.119384765625\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-05-08 01:26:14.168977: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:70] Profiler session collecting data.\n", - "2024-05-08 01:26:14.181516: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:131] Profiler session tear down.\n", - "2024-05-08 01:26:14.182392: I external/local_tsl/tsl/profiler/rpc/client/save_profile.cc:144] Collecting XSpace to repository: /tmp/tflogs/pt-20240508-011954/plugins/profile/2024_05_08_01_26_14/e81647a0f462.xplane.pb\n", - "2024-05-08 01:26:14.216694: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-06-10 22:30:47.746997: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:70] Profiler session collecting data.\n", + "2024-06-10 22:30:47.758943: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:131] Profiler session tear down.\n", + "2024-06-10 22:30:47.759670: I external/local_tsl/tsl/profiler/rpc/client/save_profile.cc:144] Collecting XSpace to repository: /workspaces/tf-shell/examples/tflogs/dp-sgd-20240610-222440/plugins/profile/2024_06_10_22_30_47/e64b0b6b3843.xplane.pb\n", + "2024-06-10 22:30:47.796887: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 1 / 15, Time Stamp: 380.09435200691223\n", - "\taccuracy: 0.16648229956626892\n" + "Batch: 1 / 14, Time Stamp: 367.5231137275696\n", + "\taccuracy: 0.130615234375\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-05-08 01:32:34.396483: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-06-10 22:36:52.484897: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 2 / 15, Time Stamp: 760.1941771507263\n", - "\taccuracy: 0.14712388813495636\n" + "Batch: 2 / 14, Time Stamp: 732.1724491119385\n", + "\taccuracy: 0.12939453125\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-05-08 01:38:54.856343: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-06-10 22:42:57.920091: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 3 / 15, Time Stamp: 1140.6538624763489\n", - "\taccuracy: 0.16592919826507568\n" + "Batch: 3 / 14, Time Stamp: 1097.6079428195953\n", + "\taccuracy: 0.1337890625\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-05-08 01:45:17.714885: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-06-10 22:48:57.431460: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 4 / 15, Time Stamp: 1523.5124411582947\n", - "\taccuracy: 0.1692477911710739\n" + "Batch: 4 / 14, Time Stamp: 1457.1190173625946\n", + "\taccuracy: 0.144287109375\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-05-08 01:51:34.837468: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-06-10 22:54:55.687244: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 5 / 15, Time Stamp: 1900.635261774063\n", - "\taccuracy: 0.17865043878555298\n" + "Batch: 5 / 14, Time Stamp: 1815.3750908374786\n", + "\taccuracy: 0.161376953125\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-05-08 01:57:54.261685: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-06-10 23:00:51.842868: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 6 / 15, Time Stamp: 2280.059217453003\n", - "\taccuracy: 0.18860618770122528\n" + "Batch: 6 / 14, Time Stamp: 2171.5303070545197\n", + "\taccuracy: 0.1787109375\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-05-08 02:04:14.045979: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-06-10 23:06:46.069293: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 7 / 15, Time Stamp: 2659.84348654747\n", - "\taccuracy: 0.1946902722120285\n" + "Batch: 7 / 14, Time Stamp: 2525.7567131519318\n", + "\taccuracy: 0.189453125\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-05-08 02:10:38.937992: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-06-10 23:12:38.180800: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 8 / 15, Time Stamp: 3044.7355086803436\n", - "\taccuracy: 0.20243363082408905\n" + "Batch: 8 / 14, Time Stamp: 2877.8683331012726\n", + "\taccuracy: 0.207275390625\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-05-08 02:16:59.938804: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-06-10 23:18:34.076211: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 9 / 15, Time Stamp: 3425.7369046211243\n", - "\taccuracy: 0.22676990926265717\n" + "Batch: 9 / 14, Time Stamp: 3233.7637860774994\n", + "\taccuracy: 0.2255859375\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-05-08 02:23:25.627886: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-06-10 23:24:32.819610: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 10 / 15, Time Stamp: 3811.425350189209\n", - "\taccuracy: 0.24834071099758148\n" + "Batch: 10 / 14, Time Stamp: 3592.507264852524\n", + "\taccuracy: 0.2412109375\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-05-08 02:29:44.326151: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-06-10 23:30:24.891742: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 11 / 15, Time Stamp: 4190.123619794846\n", - "\taccuracy: 0.26216813921928406\n" + "Batch: 11 / 14, Time Stamp: 3944.5792849063873\n", + "\taccuracy: 0.255859375\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-05-08 02:35:59.882177: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-06-10 23:36:22.160473: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 12 / 15, Time Stamp: 4565.6800808906555\n", - "\taccuracy: 0.2815265357494354\n" + "Batch: 12 / 14, Time Stamp: 4301.848348855972\n", + "\taccuracy: 0.2705078125\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-05-08 02:42:16.677880: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-06-10 23:42:21.741076: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 13 / 15, Time Stamp: 4942.475761651993\n", - "\taccuracy: 0.28595131635665894\n", - "Epoch: 0, Batch: 14 / 15, Time Stamp: 5321.51356959343\n", - "Total training time: 5321.514094829559 seconds\n" + "Batch: 13 / 14, Time Stamp: 4661.428592205048\n", + "\taccuracy: 0.28564453125\n", + "Total training time: 5035.84995174408 seconds\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-05-08 02:48:35.716055: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-06-10 23:48:36.161868: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n", + "2024-06-10 23:48:36.164909: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] } ], "source": [ - "epochs = 1\n", "start_time = time.time()\n", "\n", "# Set up tensorboard logging.\n", "stamp = datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n", - "logdir = \"/tmp/tflogs/pt-%s\" % stamp\n", + "logdir = os.path.abspath(\"\") + \"/tflogs/dp-sgd-%s\" % stamp\n", "print(f\"To start tensorboard, run: tensorboard --logdir /tmp/tflogs\")\n", + "print(f\"\\ttensorboard profiling requires: pip install tensorboard_plugin_profile\")\n", "writer = tf.summary.create_file_writer(logdir)\n", "\n", - "for epoch in range(epochs):\n", - " print(\"\\nStart of epoch %d\" % (epoch,))\n", - "\n", - " # Iterate over the batches of the dataset.\n", - " for step, (x_batch, y_batch) in enumerate(train_dataset.take(batch_size)):\n", - " print(\n", - " f\"Epoch: {epoch}, Batch: {step} / {len(train_dataset)}, Time Stamp: {time.time() - start_time}\"\n", - " )\n", - "\n", - " # Skip the last batch if it is not full for performance.\n", - " if x_batch.shape[0] != batch_size:\n", - " break\n", + "for step, (x_batch, y_batch) in enumerate(train_dataset.take(batch_size)):\n", + " print(\n", + " f\"Batch: {step} / {len(train_dataset)}, Time Stamp: {time.time() - start_time}\"\n", + " )\n", "\n", - " # If using deferred execution, one can trace and profile the training.\n", - " if step == 0:\n", - " tf.summary.trace_on(graph=True, profiler=True, profiler_outdir=logdir)\n", + " if step == 0:\n", + " tf.summary.trace_on(graph=True, profiler=True, profiler_outdir=logdir)\n", "\n", - " train_step_wrapper(x_batch, y_batch)\n", + " train_step_wrapper(x_batch, y_batch)\n", "\n", - " if step == 0:\n", - " with writer.as_default():\n", - " tf.summary.trace_export(\n", - " name=\"label_dp_sgd\", step=(epoch + 1) * step\n", - " )\n", + " if step == 0:\n", + " with writer.as_default():\n", + " tf.summary.trace_export(name=\"label_dp_sgd\", step=step)\n", "\n", - " # Check the accuracy.\n", - " average_loss = 0\n", - " average_accuracy = 0\n", - " for x, y in val_dataset:\n", - " y_pred = output_layer(hidden_layer(x))\n", - " loss = tf.reduce_mean(loss_fn(y, y_pred))\n", - " accuracy = tf.reduce_mean(\n", - " tf.cast(\n", - " tf.equal(tf.argmax(y, axis=1), tf.argmax(y_pred, axis=1)), tf.float32\n", - " )\n", + " # Check the accuracy.\n", + " average_loss = 0\n", + " average_accuracy = 0\n", + " for x, y in val_dataset:\n", + " y_pred = output_layer(hidden_layer(x))\n", + " loss = tf.reduce_mean(loss_fn(y, y_pred))\n", + " accuracy = tf.reduce_mean(\n", + " tf.cast(\n", + " tf.equal(tf.argmax(y, axis=1), tf.argmax(y_pred, axis=1)), tf.float32\n", " )\n", - " average_loss += loss\n", - " average_accuracy += accuracy\n", - " average_loss /= len(val_dataset)\n", - " average_accuracy /= len(val_dataset)\n", - " tf.print(f\"\\taccuracy: {accuracy}\")\n", + " )\n", + " average_loss += loss\n", + " average_accuracy += accuracy\n", + " average_loss /= len(val_dataset)\n", + " average_accuracy /= len(val_dataset)\n", + " tf.print(f\"\\taccuracy: {accuracy}\")\n", "\n", - " with writer.as_default():\n", - " tf.summary.scalar(\"loss\", average_loss, step=(epoch + 1) * batch_size - 1)\n", - " tf.summary.scalar(\n", - " \"accuracy\", average_accuracy, step=(epoch + 1) * batch_size - 1\n", - " )\n", + " with writer.as_default():\n", + " tf.summary.scalar(\"loss\", average_loss, step=step)\n", + " tf.summary.scalar(\"accuracy\", average_accuracy, step=step)\n", "\n", "\n", "print(f\"Total training time: {time.time() - start_time} seconds\")" diff --git a/examples/label_dp_sgd_post_scale.ipynb b/examples/label_dp_sgd_post_scale.ipynb index 0971889..4aaca9d 100644 --- a/examples/label_dp_sgd_post_scale.ipynb +++ b/examples/label_dp_sgd_post_scale.ipynb @@ -40,9 +40,12 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-05-28 19:37:40.084959: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", - "2024-05-28 19:37:40.228766: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", - "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" + "2024-06-10 21:55:22.590363: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2024-06-10 21:55:22.590673: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2024-06-10 21:55:22.592479: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2024-06-10 21:55:22.618681: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "2024-06-10 21:55:23.196862: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" ] } ], @@ -53,7 +56,8 @@ "import keras\n", "import numpy as np\n", "import tf_shell\n", - "import tf_shell_ml" + "import tf_shell_ml\n", + "import os" ] }, { @@ -102,11 +106,16 @@ "x_train, x_test = x_train / np.float32(255.0), x_test / np.float32(255.0)\n", "y_train, y_test = tf.one_hot(y_train, 10), tf.one_hot(y_test, 10)\n", "\n", + "epochs = 1\n", "train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n", - "train_dataset = train_dataset.shuffle(buffer_size=2048).batch(batch_size)\n", + "train_dataset = (\n", + " train_dataset.shuffle(buffer_size=2048)\n", + " .batch(batch_size, drop_remainder=True)\n", + " .repeat(count=epochs)\n", + ")\n", "\n", "val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))\n", - "val_dataset = val_dataset.batch(batch_size)" + "val_dataset = val_dataset.batch(batch_size, drop_remainder=True)" ] }, { @@ -246,12 +255,20 @@ " ps_grads = train_step(x_batch, enc_y_batch)\n", "\n", " # Decrypt\n", - " grads = [tf_shell.to_tensorflow(enc_g, secret_key)[0] for enc_g in ps_grads]\n", - " # ^ take the first element with [0] because the grad sum is repeated over the batching dim.\n", + " batch_sz = context.num_slots\n", + " top_grads = [tf_shell.to_tensorflow(enc_g, secret_key)[0] for enc_g in ps_grads]\n", + " bottom_grads = [tf_shell.to_tensorflow(enc_g, secret_key)[batch_sz // 2] for enc_g in ps_grads]\n", + " # ^ take the first element of each batch because the grad sum is repeated over the batching dim.\n", "\n", " model.optimizer.apply_gradients(\n", " zip(\n", - " grads,\n", + " top_grads,\n", + " model.trainable_weights\n", + " )\n", + " )\n", + " model.optimizer.apply_gradients(\n", + " zip(\n", + " bottom_grads,\n", " model.trainable_weights\n", " )\n", " )" @@ -277,18 +294,17 @@ "name": "stdout", "output_type": "stream", "text": [ - "To start tensorboard, run: tensorboard --logdir /tmp/tflogs\n", - "\n", - "Start of epoch 0\n", - "Epoch: 0, Batch: 0 / 30, Time Stamp: 0.06798219680786133\n" + "To start tensorboard, run: tensorboard --logdir /tmp/tflogs --host 0.0.0.0\n", + "\ttensorboard profiling requires: pip install tensorboard_plugin_profile\n", + "Batch: 0 / 29, Time Stamp: 0.07114076614379883\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-05-28 19:37:42.830181: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:104] Profiler session initializing.\n", - "2024-05-28 19:37:42.830206: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:119] Profiler session started.\n" + "2024-06-10 21:55:24.765112: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:104] Profiler session initializing.\n", + "2024-06-10 21:55:24.765142: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:119] Profiler session started.\n" ] }, { @@ -296,95 +312,487 @@ "output_type": "stream", "text": [ "WARNING:tensorflow:Error while stopping profiler: Cannot export profiling results. No profiler is running.\n", - "\taccuracy: 0.1305309683084488\n" + "\taccuracy: 0.07861328125\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 21:56:43.385829: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:70] Profiler session collecting data.\n", + "2024-06-10 21:56:43.396207: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:131] Profiler session tear down.\n", + "2024-06-10 21:56:43.397087: I external/local_tsl/tsl/profiler/rpc/client/save_profile.cc:144] Collecting XSpace to repository: /workspaces/tf-shell/examples/tflogs/post-scale-20240610-215524/plugins/profile/2024_06_10_21_56_43/e64b0b6b3843.xplane.pb\n", + "2024-06-10 21:56:43.457550: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 1 / 29, Time Stamp: 78.8082582950592\n", + "\taccuracy: 0.0966796875\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "2024-05-28 19:39:03.997126: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:70] Profiler session collecting data.\n", - "2024-05-28 19:39:04.007926: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:131] Profiler session tear down.\n", - "2024-05-28 19:39:04.008962: I external/local_tsl/tsl/profiler/rpc/client/save_profile.cc:144] Collecting XSpace to repository: /tmp/tflogs/pt-20240528-193742/plugins/profile/2024_05_28_19_39_04/e81647a0f462.xplane.pb\n", - "2024-05-28 19:39:04.061259: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + "2024-06-10 21:57:44.839372: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0, Batch: 1 / 30, Time Stamp: 81.33036208152771\n" + "Batch: 2 / 29, Time Stamp: 140.14883470535278\n", + "\taccuracy: 0.11572265625\n" ] }, { - "ename": "", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n", - "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n", - "\u001b[1;31mClick here for more info. \n", - "\u001b[1;31mView Jupyter log for further details." + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 21:58:47.916557: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 3 / 29, Time Stamp: 203.22569370269775\n", + "\taccuracy: 0.14111328125\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 21:59:46.291334: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 4 / 29, Time Stamp: 261.60073924064636\n", + "\taccuracy: 0.17041015625\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:00:43.119037: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 5 / 29, Time Stamp: 318.4287750720978\n", + "\taccuracy: 0.205078125\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:01:41.184261: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 6 / 29, Time Stamp: 376.4937229156494\n", + "\taccuracy: 0.236328125\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:02:38.663447: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 7 / 29, Time Stamp: 433.9729034900665\n", + "\taccuracy: 0.27001953125\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:03:36.325330: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 8 / 29, Time Stamp: 491.6346232891083\n", + "\taccuracy: 0.39208984375\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:04:34.804192: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 9 / 29, Time Stamp: 550.1134612560272\n", + "\taccuracy: 0.5087890625\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:05:31.525443: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 10 / 29, Time Stamp: 606.834728717804\n", + "\taccuracy: 0.55224609375\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:06:28.600805: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 11 / 29, Time Stamp: 663.9102034568787\n", + "\taccuracy: 0.55419921875\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:07:25.919843: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 12 / 29, Time Stamp: 721.229740858078\n", + "\taccuracy: 0.57177734375\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:08:23.280634: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 13 / 29, Time Stamp: 778.5899968147278\n", + "\taccuracy: 0.6083984375\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:09:20.475956: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 14 / 29, Time Stamp: 835.7853996753693\n", + "\taccuracy: 0.64404296875\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:10:17.289818: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 15 / 29, Time Stamp: 892.5991945266724\n", + "\taccuracy: 0.68896484375\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:11:14.235364: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 16 / 29, Time Stamp: 949.5446727275848\n", + "\taccuracy: 0.72509765625\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:12:11.477066: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 17 / 29, Time Stamp: 1006.7864482402802\n", + "\taccuracy: 0.74365234375\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:13:08.782675: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 18 / 29, Time Stamp: 1064.0921244621277\n", + "\taccuracy: 0.7548828125\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:14:06.054970: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 19 / 29, Time Stamp: 1121.3647689819336\n", + "\taccuracy: 0.763671875\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:15:04.266183: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 20 / 29, Time Stamp: 1179.5777134895325\n", + "\taccuracy: 0.77734375\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:16:01.601114: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 21 / 29, Time Stamp: 1236.9104754924774\n", + "\taccuracy: 0.78857421875\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:16:59.192321: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 22 / 29, Time Stamp: 1294.5017142295837\n", + "\taccuracy: 0.80078125\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:17:57.075967: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 23 / 29, Time Stamp: 1352.38556599617\n", + "\taccuracy: 0.8115234375\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:18:54.139450: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 24 / 29, Time Stamp: 1409.4490039348602\n", + "\taccuracy: 0.82421875\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:19:50.059527: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 25 / 29, Time Stamp: 1465.369089603424\n", + "\taccuracy: 0.83251953125\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:20:47.868621: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 26 / 29, Time Stamp: 1523.1779313087463\n", + "\taccuracy: 0.84033203125\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:21:44.155515: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 27 / 29, Time Stamp: 1579.4650423526764\n", + "\taccuracy: 0.84228515625\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:22:40.935333: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch: 28 / 29, Time Stamp: 1636.244945526123\n", + "\taccuracy: 0.84423828125\n", + "Total training time: 1693.0756685733795 seconds\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 22:23:37.765942: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n", + "2024-06-10 22:23:37.768871: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" ] } ], "source": [ - "epochs = 1\n", "start_time = time.time()\n", "\n", "# Set up tensorboard logging.\n", "stamp = datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n", - "logdir = \"/tmp/tflogs/pt-%s\" % stamp\n", - "print(f\"To start tensorboard, run: tensorboard --logdir /tmp/tflogs\")\n", + "logdir = os.path.abspath(\"\") + \"/tflogs/post-scale-%s\" % stamp\n", + "print(f\"To start tensorboard, run: tensorboard --logdir /tmp/tflogs --host 0.0.0.0\")\n", + "print(f\"\\ttensorboard profiling requires: pip install tensorboard_plugin_profile\")\n", "writer = tf.summary.create_file_writer(logdir)\n", "\n", - "for epoch in range(epochs):\n", - " print(\"\\nStart of epoch %d\" % (epoch,))\n", + "# Iterate over the batches of the dataset.\n", + "for step, (x_batch, y_batch) in enumerate(train_dataset.take(batch_size)):\n", + " print(\n", + " f\"Batch: {step} / {len(train_dataset)}, Time Stamp: {time.time() - start_time}\"\n", + " )\n", "\n", - " # Iterate over the batches of the dataset.\n", - " for step, (x_batch, y_batch) in enumerate(train_dataset.take(batch_size)):\n", - " print(\n", - " f\"Epoch: {epoch}, Batch: {step} / {len(train_dataset)}, Time Stamp: {time.time() - start_time}\"\n", - " )\n", + " if step == 0:\n", + " tf.summary.trace_on(graph=True, profiler=True, profiler_outdir=logdir)\n", "\n", - " # Skip the last batch if it is not full for performance.\n", - " if x_batch.shape[0] != batch_size:\n", - " break\n", - "\n", - " # If using deferred execution, one can trace and profile the training.\n", - " if step == 0:\n", - " tf.summary.trace_on(graph=True, profiler=True, profiler_outdir=logdir)\n", - "\n", - " train_step_wrapper(x_batch, y_batch)\n", - "\n", - " if step == 0:\n", - " with writer.as_default():\n", - " tf.summary.trace_export(\n", - " name=\"label_dp_sgd_post_scale\", step=(epoch + 1) * step\n", - " )\n", - "\n", - " # Check the accuracy.\n", - " average_loss = 0\n", - " average_accuracy = 0\n", - " for x, y in val_dataset:\n", - " y_pred = model(x, training=False)\n", - " loss = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(y, y_pred))\n", - " accuracy = tf.reduce_mean(\n", - " tf.cast(\n", - " tf.equal(tf.argmax(y, axis=1), tf.argmax(y_pred, axis=1)), tf.float32\n", - " )\n", - " )\n", - " average_accuracy += accuracy\n", - " average_loss += loss\n", - " average_loss /= len(val_dataset)\n", - " average_accuracy /= len(val_dataset)\n", - " tf.print(f\"\\taccuracy: {accuracy}\")\n", + " train_step_wrapper(x_batch, y_batch)\n", "\n", + " if step == 0:\n", " with writer.as_default():\n", - " tf.summary.scalar(\"loss\", average_loss, step=(epoch + 1) * step)\n", - " tf.summary.scalar(\n", - " \"accuracy\", average_accuracy, step=(epoch + 1) * step\n", + " tf.summary.trace_export(name=\"label_dp_sgd_post_scale\", step=step)\n", + "\n", + " # Check the accuracy.\n", + " average_loss = 0\n", + " average_accuracy = 0\n", + " for x, y in val_dataset:\n", + " y_pred = model(x, training=False)\n", + " loss = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(y, y_pred))\n", + " accuracy = tf.reduce_mean(\n", + " tf.cast(\n", + " tf.equal(tf.argmax(y, axis=1), tf.argmax(y_pred, axis=1)), tf.float32\n", " )\n", + " )\n", + " average_accuracy += accuracy\n", + " average_loss += loss\n", + " average_loss /= len(val_dataset)\n", + " average_accuracy /= len(val_dataset)\n", + " tf.print(f\"\\taccuracy: {accuracy}\")\n", + "\n", + " with writer.as_default():\n", + " tf.summary.scalar(\"loss\", average_loss, step=step)\n", + " tf.summary.scalar(\"accuracy\", average_accuracy, step=step)\n", "\n", "\n", "print(f\"Total training time: {time.time() - start_time} seconds\")" diff --git a/examples/label_dp_sgd_sentiment.ipynb b/examples/label_dp_sgd_sentiment.ipynb new file mode 100644 index 0000000..13254ec --- /dev/null +++ b/examples/label_dp_sgd_sentiment.ipynb @@ -0,0 +1,461 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Sentiment Analysis on IMDB dataset\n", + "\n", + "This notebook walks through how perform sentament analysis on the IMDB dataset.\n", + "In this setting, one party has the reviews and the other party has the labels.\n", + "The party with the labels is helping the party with the reviews train a model\n", + "without sharing the labels themselves.\n", + "\n", + "Before starting, install tf-shell and the dataset.\n", + "\n", + "```bash\n", + "pip install tf-shell\n", + "pip install tensorflow_hub tensorflow_datasets\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 23:49:11.867640: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2024-06-10 23:49:11.867982: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2024-06-10 23:49:11.869644: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2024-06-10 23:49:11.895495: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "2024-06-10 23:49:12.500017: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", + "/workspaces/tf-shell/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import time\n", + "from datetime import datetime\n", + "import tensorflow as tf\n", + "import tensorflow_hub as hub\n", + "import tensorflow_datasets as tfds\n", + "\n", + "import keras\n", + "import numpy as np\n", + "import tf_shell\n", + "import tf_shell_ml\n", + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Set up parameters for the SHELL encryption library.\n", + "context = tf_shell.create_context64(\n", + " log_n=12,\n", + " main_moduli=[288230376151760897, 288230376152137729],\n", + " plaintext_modulus=4294991873,\n", + " scaling_factor=3,\n", + " mul_depth_supported=3,\n", + " seed=\"test_seed\",\n", + ")\n", + "\n", + "# Create the secret key for encryption and a rotation key (rotation key is\n", + "# an auxilary key required for operations like roll or matmul).\n", + "secret_key = tf_shell.create_key64(context)\n", + "public_rotation_key = tf_shell.create_rotation_key64(context, secret_key)\n", + "\n", + "# The batch size is determined by the ciphertext parameters, specifically the\n", + "# schemes polynomial's ring degree because tf-shell uses batch axis packing.\n", + "# Furthermore, two micro-batches to run in parallel.\n", + "batch_size = context.num_slots" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Setup IMDB dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 23:49:58.773005: W external/local_tsl/tsl/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with \"NOT_FOUND: Could not locate the credentials file.\". Retrieving token from GCE failed with \"FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata.google.internal\".\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1mDownloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /home/vscode/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Dl Size...: 100%|██████████| 80/80 [00:13<00:00, 5.97 MiB/s]rl]\n", + "Dl Completed...: 100%|██████████| 1/1 [00:13<00:00, 13.39s/ url]\n", + "2024-06-10 23:51:28.949194: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n", + "2024-06-10 23:51:28.949578: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1mDataset imdb_reviews downloaded and prepared to /home/vscode/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.\u001b[0m\n", + "Review: This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it.\n", + "Label: 0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 23:51:30.315651: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Most used words: ['', '[UNK]', 'the', 'a', 'and', 'of', 'to', 'is', 'in', 'it']\n", + "Dictionary size: 50000\n" + ] + } + ], + "source": [ + "# Split the training set into 60% and 40% to end up with 15,000 examples\n", + "# for training, 10,000 examples for validation and 25,000 examples for testing.\n", + "train_data, val_data, test_data = tfds.load(\n", + " name=\"imdb_reviews\", \n", + " split=('train[:60%]', 'train[60%:]', 'test'),\n", + " as_supervised=True)\n", + "\n", + "# Print the first example.\n", + "for review, label in train_data.take(1):\n", + " print(\"Review:\", review.numpy().decode('utf-8'))\n", + " print(\"Label:\", label.numpy())\n", + "\n", + "epochs = 3\n", + "train_data = train_data.shuffle(buffer_size=2048).batch(batch_size, drop_remainder=True).repeat(count=epochs)\n", + "val_data = val_data.shuffle(buffer_size=2048).batch(batch_size, drop_remainder=True)\n", + "test_data = test_data.shuffle(buffer_size=2048).batch(batch_size, drop_remainder=True)\n", + "\n", + "vocab_size = 50000 # This dataset has 92061 unique words.\n", + "max_length = 200\n", + "embedding_dim = 50\n", + "\n", + "vectorize_layer = tf.keras.layers.TextVectorization(\n", + " max_tokens=vocab_size,\n", + " output_mode='int',\n", + " output_sequence_length=max_length)\n", + " # TODO use pad_to_max_tokens instead of output_sequence_length?\n", + "\n", + "vectorize_layer.adapt(train_data.map(lambda text, label: text))\n", + "\n", + "print(\"Most used words:\", vectorize_layer.get_vocabulary()[:10])\n", + "print(\"Dictionary size:\", len(vectorize_layer.get_vocabulary()))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Create the trainable layers.\n", + "embedding_layer = tf_shell_ml.ShellEmbedding(\n", + " vocab_size + 1, # +1 for OOV token.\n", + " embedding_dim,\n", + ")\n", + "# TODO dropout layer?\n", + "hidden_layer = tf_shell_ml.ShellDense(\n", + " 16,\n", + " activation=tf_shell_ml.relu,\n", + " activation_deriv=tf_shell_ml.relu_deriv,\n", + ")\n", + "# TODO dropout layer?\n", + "output_layer = tf_shell_ml.ShellDense(1,\n", + " activation=tf.nn.softmax,\n", + ")\n", + "\n", + "## Call the layers once to create the weights.\n", + "#y1 = hidden_layer(tf.zeros((batch_size, 784)))\n", + "#y2 = output_layer(y1)\n", + "\n", + "loss_fn = tf_shell_ml.BinaryCrossentropy()\n", + "optimizer = tf.keras.optimizers.Adam(0.1)\n", + "emb_optimizer = tf.keras.optimizers.Adam(0.1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, define the `train_step` function which will be called for each batch on an\n", + "encrypted batch of labels, y. The function first does a forward on the plaintext\n", + "image x to compute a predicted label, then does backpropagation using the\n", + "encrypted label y." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "def train_step(x, enc_y):\n", + " # Forward pass always in plaintext\n", + " # y //= max_length # Normalize for reduce_sum.\n", + " y = embedding_layer(x)\n", + " y = tf_shell.reshape(y, (batch_size, max_length * embedding_dim))\n", + " # Could also do reduce_sum and division to replicate GlobalAveragePooling1D layer.\n", + " # y = tf_shell.reduce_sum(y, axis=1)\n", + "\n", + " y = hidden_layer(y)\n", + " y_pred = output_layer(y)\n", + "\n", + " # Backward pass.\n", + " dJ_dy_pred = loss_fn.grad(enc_y, y_pred)\n", + " dJ_dw2, dJ_dx2 = output_layer.backward(dJ_dy_pred, public_rotation_key)\n", + " dJ_dw1, dJ_dx1 = hidden_layer.backward(dJ_dx2, public_rotation_key)\n", + "\n", + " dJ_dx1_reshaped = tf_shell.reshape(dJ_dx1, (batch_size, max_length, embedding_dim))\n", + " # Could also tile up to this shape to replicate GlobalAveragePooling1D layer.\n", + " # dJ_dx1_reshaped = tf_shell.broadcast_to(\n", + " # dJ_dx1, (batch_size, max_length, embedding_dim)\n", + " # )\n", + "\n", + " embedding_layer.backward_accum(dJ_dx1_reshaped, public_rotation_key)\n", + "\n", + " # dJ_dw0, _ = embedding_layer.backward(dJ_dx1_reshaped, public_rotation_key)\n", + "\n", + " # dJ_dw0, the embedding layer gradient, would usually have outer shape [1]\n", + " # for the 1 output classes. tf-shell instead back propagates in two\n", + " # mini-batches per batch resulting in two gradients of shape [2].\n", + " # Furthermore, the gradients are in an \"expanded\" form where the gradient is\n", + " # repeated by the size of the batch. Said another way, if\n", + " # real_grad_top/bottom is the \"real\" gradient of shape [10] from the\n", + " # top/bottom halves of the batch:\n", + " #\n", + " # dJ_dw = tf.concat([\n", + " # tf.repeat(\n", + " # tf.expand_dims(real_grad_top, 0), repeats=[batch_sz // 2], axis=0\n", + " # ),\n", + " # tf.repeat(\n", + " # tf.expand_dims(real_grad_bottom, 0), repeats=[batch_sz // 2], axis=0\n", + " # )\n", + " # ])\n", + " #\n", + " # This repetition is result of the SHELL library using a packed\n", + " # representation of ciphertexts for efficiency. As such, if the ciphertexts\n", + " # need to be sent over the network, they may be masked and packed together\n", + " # before being transmitted to the party with the key.\n", + " #\n", + " # Only return the weight gradients at [0], not the bias gradients at [1].\n", + " # The bias is not used in this test.\n", + " # return [dJ_dw2[0], dJ_dw1[0], dJ_dw0[0]]\n", + " return [dJ_dw2[0], dJ_dw1[0]]\n", + "\n", + "\n", + "@tf.function\n", + "def train_step_wrapper(x_batch, y_batch):\n", + " # Encrypt the batch of secret labels y.\n", + " enc_y_batch = tf_shell.to_encrypted(y_batch, secret_key, context)\n", + "\n", + " # Run the training step. The top and bottom halves of the batch are\n", + " # treated as two separate mini-batches run in parallel to maximize\n", + " # efficiency.\n", + " enc_grads = train_step(x_batch, enc_y_batch)\n", + "\n", + " # Decrypt the weight gradients. In practice, the gradients should be\n", + " # noised before decrypting.\n", + " repeated_grads = [tf_shell.to_tensorflow(g, secret_key) for g in enc_grads]\n", + "\n", + " # Pull out grads from the top and bottom batches.\n", + " top_grad = [g[0] for g in repeated_grads]\n", + " bottom_grad = [g[batch_size // 2] for g in repeated_grads]\n", + "\n", + " # Apply the gradients to the model.\n", + " weights = output_layer.weights + hidden_layer.weights\n", + " optimizer.apply_gradients(zip(top_grad, weights))\n", + " optimizer.apply_gradients(zip(bottom_grad, weights))\n", + "\n", + " # Apply the embedding layer gradient (contains both batches).\n", + " # optimizer.apply_gradients(embedding_layer.decrypt_grad(secret_key), embedding_layer.weights)\n", + " emb_optimizer.apply_gradients(zip([embedding_layer.decrypt_grad(secret_key)], embedding_layer.weights))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here is the training loop. Each inner iteration runs two batches of size\n", + "$2^{12-1}$ simultaneously.\n", + "\n", + "Tensorboard can be used to visualize the training progress. See cell output for\n", + "command to start tensorboard." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "To start tensorboard, run: tensorboard --logdir /tmp/tflogs --host 0.0.0.0\n", + "\ttensorboard profiling requires: pip install tensorboard_plugin_profile\n", + "\ttrain loss: nan\taccuracy: 0.5001627802848816\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 23:51:31.724473: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\tvalidation loss: nan\taccuracy: 0.4974365234375\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-10 23:51:32.024516: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n", + "2024-06-10 23:51:32.059464: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:104] Profiler session initializing.\n", + "2024-06-10 23:51:32.059493: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:119] Profiler session started.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step: 0 / 9, Time Stamp: 1.4508495330810547\n" + ] + } + ], + "source": [ + "start_time = time.time()\n", + "tf.config.run_functions_eagerly(False)\n", + "\n", + "\n", + "def check_accuracy(dataset):\n", + " average_loss = 0\n", + " average_accuracy = 0\n", + " for x, y in dataset:\n", + " y = tf.cast(y, tf.float32)\n", + " y = tf.reshape(y, (batch_size, 1))\n", + "\n", + " y_pred = vectorize_layer(x)\n", + " # y_pred //= max_length # Normalize for reduce_sum.\n", + " y_pred = embedding_layer(y_pred)\n", + " y_pred = tf_shell.reshape(y_pred, (batch_size, max_length * embedding_dim))\n", + " # y_pred = tf_shell.reduce_sum(y_pred, axis=1)\n", + " y_pred = hidden_layer(y_pred)\n", + " y_pred = output_layer(y_pred)\n", + " \n", + " loss = tf.reduce_mean(loss_fn(y, y_pred))\n", + " accuracy = tf.reduce_mean(tf.cast(tf.equal(y, tf.round(y_pred)), tf.float32))\n", + " average_loss += loss\n", + " average_accuracy += accuracy\n", + " average_loss /= len(dataset)\n", + " average_accuracy /= len(dataset)\n", + "\n", + " return average_loss, average_accuracy\n", + "\n", + "\n", + "# Set up tensorboard logging.\n", + "stamp = datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n", + "logdir = os.path.abspath(\"\") + \"/tflogs/sentiment-%s\" % stamp\n", + "print(f\"To start tensorboard, run: tensorboard --logdir /tmp/tflogs --host 0.0.0.0\")\n", + "print(f\"\\ttensorboard profiling requires: pip install tensorboard_plugin_profile\")\n", + "writer = tf.summary.create_file_writer(logdir)\n", + "\n", + "# Initial accuracy\n", + "loss, accuracy = check_accuracy(train_data)\n", + "tf.print(f\"\\ttrain loss: {loss}\\taccuracy: {accuracy}\")\n", + "loss, accuracy = check_accuracy(val_data)\n", + "tf.print(f\"\\tvalidation loss: {loss}\\taccuracy: {accuracy}\")\n", + "\n", + "# Iterate over the batches of the dataset.\n", + "for step, (x_batch, y_batch) in enumerate(train_data.take(batch_size)):\n", + " print(\n", + " f\"Step: {step} / {len(train_data)}, Time Stamp: {time.time() - start_time}\"\n", + " )\n", + "\n", + " y_batch = tf.cast(y_batch, tf.float32)\n", + " y_batch = tf.reshape(y_batch, (batch_size, 1))\n", + "\n", + " if step == 0:\n", + " tf.summary.trace_on(graph=True, profiler=True, profiler_outdir=logdir)\n", + "\n", + " x_batch = vectorize_layer(x_batch) # No shape inference, do outside tf.function\n", + " train_step_wrapper(x_batch, y_batch)\n", + "\n", + " if step == 0:\n", + " with writer.as_default():\n", + " tf.summary.trace_export(name=\"sentiment\", step=step)\n", + "\n", + " loss, accuracy = check_accuracy(train_data)\n", + " tf.print(f\"\\ttrain loss: {loss}\\taccuracy: {accuracy}\")\n", + " loss, accuracy = check_accuracy(val_data)\n", + " tf.print(f\"\\tvalidation loss: {loss}\\taccuracy: {accuracy}\")\n", + "\n", + " with writer.as_default():\n", + " tf.summary.scalar(\"loss\", loss, step=step)\n", + " tf.summary.scalar(\"accuracy\", accuracy, step=step)\n", + "\n", + "\n", + "print(f\"Total training time: {time.time() - start_time} seconds\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} From cf0c776b8f0e0ff158d29fc01e760f2415573b60 Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Tue, 11 Jun 2024 00:24:17 +0000 Subject: [PATCH 21/22] Bump version to 0.1.3 --- MODULE.bazel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MODULE.bazel b/MODULE.bazel index 5d9a89c..2c5c913 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -1,6 +1,6 @@ module( name = "tf-shell", - version = "0.1.2", + version = "0.1.3", ) SUPPORTED_PYTHON_VERSIONS = [ From f225b2a2bbfa6b07ef03ad5c63f3730f3fb21a83 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 15 Jul 2024 18:48:23 +0000 Subject: [PATCH 22/22] Bump the pip group across 1 directory with 7 updates Bumps the pip group with 7 updates in the / directory: | Package | From | To | | --- | --- | --- | | [certifi](https://github.com/certifi/python-certifi) | `2023.7.22` | `2024.7.4` | | [idna](https://github.com/kjd/idna) | `3.4` | `3.7` | | [requests](https://github.com/psf/requests) | `2.31.0` | `2.32.2` | | [urllib3](https://github.com/urllib3/urllib3) | `1.26.16` | `1.26.19` | | [werkzeug](https://github.com/pallets/werkzeug) | `2.3.7` | `3.0.3` | | [setuptools](https://github.com/pypa/setuptools) | `68.2.2` | `70.0.0` | | [zipp](https://github.com/jaraco/zipp) | `3.17.0` | `3.19.1` | Updates `certifi` from 2023.7.22 to 2024.7.4 - [Commits](https://github.com/certifi/python-certifi/compare/2023.07.22...2024.07.04) Updates `idna` from 3.4 to 3.7 - [Release notes](https://github.com/kjd/idna/releases) - [Changelog](https://github.com/kjd/idna/blob/master/HISTORY.rst) - [Commits](https://github.com/kjd/idna/compare/v3.4...v3.7) Updates `requests` from 2.31.0 to 2.32.2 - [Release notes](https://github.com/psf/requests/releases) - [Changelog](https://github.com/psf/requests/blob/main/HISTORY.md) - [Commits](https://github.com/psf/requests/compare/v2.31.0...v2.32.2) Updates `urllib3` from 1.26.16 to 1.26.19 - [Release notes](https://github.com/urllib3/urllib3/releases) - [Changelog](https://github.com/urllib3/urllib3/blob/main/CHANGES.rst) - [Commits](https://github.com/urllib3/urllib3/compare/1.26.16...1.26.19) Updates `werkzeug` from 2.3.7 to 3.0.3 - [Release notes](https://github.com/pallets/werkzeug/releases) - [Changelog](https://github.com/pallets/werkzeug/blob/main/CHANGES.rst) - [Commits](https://github.com/pallets/werkzeug/compare/2.3.7...3.0.3) Updates `setuptools` from 68.2.2 to 70.0.0 - [Release notes](https://github.com/pypa/setuptools/releases) - [Changelog](https://github.com/pypa/setuptools/blob/main/NEWS.rst) - [Commits](https://github.com/pypa/setuptools/compare/v68.2.2...v70.0.0) Updates `zipp` from 3.17.0 to 3.19.1 - [Release notes](https://github.com/jaraco/zipp/releases) - [Changelog](https://github.com/jaraco/zipp/blob/main/NEWS.rst) - [Commits](https://github.com/jaraco/zipp/compare/v3.17.0...v3.19.1) --- updated-dependencies: - dependency-name: certifi dependency-type: direct:production dependency-group: pip - dependency-name: idna dependency-type: direct:production dependency-group: pip - dependency-name: requests dependency-type: direct:production dependency-group: pip - dependency-name: urllib3 dependency-type: direct:production dependency-group: pip - dependency-name: werkzeug dependency-type: direct:production dependency-group: pip - dependency-name: setuptools dependency-type: direct:production dependency-group: pip - dependency-name: zipp dependency-type: direct:production dependency-group: pip ... Signed-off-by: dependabot[bot] --- requirements_3_10.txt | 36 ++++++++++++++++++------------------ requirements_3_11.txt | 36 ++++++++++++++++++------------------ requirements_3_12.txt | 30 +++++++++++++++--------------- requirements_3_9.txt | 42 +++++++++++++++++++++--------------------- 4 files changed, 72 insertions(+), 72 deletions(-) diff --git a/requirements_3_10.txt b/requirements_3_10.txt index 071c3c8..91ae3df 100644 --- a/requirements_3_10.txt +++ b/requirements_3_10.txt @@ -39,9 +39,9 @@ black==24.4.0 \ --hash=sha256:f07b69fda20578367eaebbd670ff8fc653ab181e1ff95d84497f9fa20e7d0641 \ --hash=sha256:f95cece33329dc4aa3b0e1a771c41075812e46cf3d6e3f1dfe3d91ff09826ed2 # via -r requirements.in -certifi==2023.7.22 \ - --hash=sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082 \ - --hash=sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9 +certifi==2024.7.4 \ + --hash=sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b \ + --hash=sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90 # via requests charset-normalizer==3.2.0 \ --hash=sha256:04e57ab9fbf9607b77f7d057974694b4f6b142da9ed4a199859d9d4d5c63fe96 \ @@ -211,9 +211,9 @@ h5py==3.11.0 \ # via # keras # tensorflow-cpu -idna==3.4 \ - --hash=sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4 \ - --hash=sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2 +idna==3.7 \ + --hash=sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc \ + --hash=sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0 # via requests keras==3.2.1 \ --hash=sha256:0be1e89b041e697be562d8422ecb958ee5481acfc089913200926c561d258a03 \ @@ -451,9 +451,9 @@ pygments==2.17.2 \ --hash=sha256:b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c \ --hash=sha256:da46cec9fd2de5be3a8a784f434e4c4ab670b4ff54d605c4c2717e9d49c4c367 # via rich -requests==2.31.0 \ - --hash=sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f \ - --hash=sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1 +requests==2.32.2 \ + --hash=sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289 \ + --hash=sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c # via tensorflow-cpu rich==13.7.1 \ --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ @@ -525,13 +525,13 @@ typing-extensions==4.5.0 \ # black # optree # tensorflow-cpu -urllib3==1.26.16 \ - --hash=sha256:8d36afa7616d8ab714608411b4a3b13e58f463aee519024578e062e141dce20f \ - --hash=sha256:8f135f6502756bde6b2a9b28989df5fbe87c9970cecaa69041edcce7f0589b14 +urllib3==1.26.19 \ + --hash=sha256:37a0344459b199fce0e80b0d3569837ec6b6937435c5244e7fd73fa6006830f3 \ + --hash=sha256:3e3d753a8618b86d7de333b4223005f68720bcd6a7d2bcb9fbd2229ec7c1e429 # via requests -werkzeug==2.3.7 \ - --hash=sha256:2b8c0e447b4b9dbcc85dd97b6eeb4dcbaf6c8b6c3be0bd654e25553e0a2157d8 \ - --hash=sha256:effc12dba7f3bd72e605ce49807bbe692bd729c3bb122a3b91747a6ae77df528 +werkzeug==3.0.3 \ + --hash=sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18 \ + --hash=sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8 # via tensorboard wheel==0.41.2 \ --hash=sha256:0c5ac5ff2afb79ac23ab82bab027a0be7b5dbcf2e54dc50efe4bf507de1f7985 \ @@ -615,9 +615,9 @@ wrapt==1.14.1 \ # via tensorflow-cpu # The following packages are considered to be unsafe in a requirements file: -setuptools==68.2.2 \ - --hash=sha256:4ac1475276d2f1c48684874089fefcd83bd7162ddaafb81fac866ba0db282a87 \ - --hash=sha256:b454a35605876da60632df1a60f736524eb73cc47bbc9f3f1ef1b644de74fd2a +setuptools==70.0.0 \ + --hash=sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4 \ + --hash=sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0 # via # tensorboard # tensorflow-cpu diff --git a/requirements_3_11.txt b/requirements_3_11.txt index bf47004..6eff70c 100644 --- a/requirements_3_11.txt +++ b/requirements_3_11.txt @@ -39,9 +39,9 @@ black==24.4.0 \ --hash=sha256:f07b69fda20578367eaebbd670ff8fc653ab181e1ff95d84497f9fa20e7d0641 \ --hash=sha256:f95cece33329dc4aa3b0e1a771c41075812e46cf3d6e3f1dfe3d91ff09826ed2 # via -r requirements.in -certifi==2023.7.22 \ - --hash=sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082 \ - --hash=sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9 +certifi==2024.7.4 \ + --hash=sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b \ + --hash=sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90 # via requests charset-normalizer==3.2.0 \ --hash=sha256:04e57ab9fbf9607b77f7d057974694b4f6b142da9ed4a199859d9d4d5c63fe96 \ @@ -211,9 +211,9 @@ h5py==3.11.0 \ # via # keras # tensorflow-cpu -idna==3.4 \ - --hash=sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4 \ - --hash=sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2 +idna==3.7 \ + --hash=sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc \ + --hash=sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0 # via requests keras==3.2.1 \ --hash=sha256:0be1e89b041e697be562d8422ecb958ee5481acfc089913200926c561d258a03 \ @@ -451,9 +451,9 @@ pygments==2.17.2 \ --hash=sha256:b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c \ --hash=sha256:da46cec9fd2de5be3a8a784f434e4c4ab670b4ff54d605c4c2717e9d49c4c367 # via rich -requests==2.31.0 \ - --hash=sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f \ - --hash=sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1 +requests==2.32.2 \ + --hash=sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289 \ + --hash=sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c # via tensorflow-cpu rich==13.7.1 \ --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ @@ -520,13 +520,13 @@ typing-extensions==4.5.0 \ # via # optree # tensorflow-cpu -urllib3==1.26.16 \ - --hash=sha256:8d36afa7616d8ab714608411b4a3b13e58f463aee519024578e062e141dce20f \ - --hash=sha256:8f135f6502756bde6b2a9b28989df5fbe87c9970cecaa69041edcce7f0589b14 +urllib3==1.26.19 \ + --hash=sha256:37a0344459b199fce0e80b0d3569837ec6b6937435c5244e7fd73fa6006830f3 \ + --hash=sha256:3e3d753a8618b86d7de333b4223005f68720bcd6a7d2bcb9fbd2229ec7c1e429 # via requests -werkzeug==2.3.7 \ - --hash=sha256:2b8c0e447b4b9dbcc85dd97b6eeb4dcbaf6c8b6c3be0bd654e25553e0a2157d8 \ - --hash=sha256:effc12dba7f3bd72e605ce49807bbe692bd729c3bb122a3b91747a6ae77df528 +werkzeug==3.0.3 \ + --hash=sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18 \ + --hash=sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8 # via tensorboard wheel==0.41.2 \ --hash=sha256:0c5ac5ff2afb79ac23ab82bab027a0be7b5dbcf2e54dc50efe4bf507de1f7985 \ @@ -610,9 +610,9 @@ wrapt==1.14.1 \ # via tensorflow-cpu # The following packages are considered to be unsafe in a requirements file: -setuptools==68.2.2 \ - --hash=sha256:4ac1475276d2f1c48684874089fefcd83bd7162ddaafb81fac866ba0db282a87 \ - --hash=sha256:b454a35605876da60632df1a60f736524eb73cc47bbc9f3f1ef1b644de74fd2a +setuptools==70.0.0 \ + --hash=sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4 \ + --hash=sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0 # via # tensorboard # tensorflow-cpu diff --git a/requirements_3_12.txt b/requirements_3_12.txt index 99cc837..456a54b 100644 --- a/requirements_3_12.txt +++ b/requirements_3_12.txt @@ -39,9 +39,9 @@ black==24.4.0 \ --hash=sha256:f07b69fda20578367eaebbd670ff8fc653ab181e1ff95d84497f9fa20e7d0641 \ --hash=sha256:f95cece33329dc4aa3b0e1a771c41075812e46cf3d6e3f1dfe3d91ff09826ed2 # via -r requirements.in -certifi==2024.2.2 \ - --hash=sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f \ - --hash=sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1 +certifi==2024.7.4 \ + --hash=sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b \ + --hash=sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90 # via requests charset-normalizer==3.3.2 \ --hash=sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027 \ @@ -471,9 +471,9 @@ pygments==2.17.2 \ --hash=sha256:b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c \ --hash=sha256:da46cec9fd2de5be3a8a784f434e4c4ab670b4ff54d605c4c2717e9d49c4c367 # via rich -requests==2.31.0 \ - --hash=sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f \ - --hash=sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1 +requests==2.32.2 \ + --hash=sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289 \ + --hash=sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c # via tensorflow-cpu rich==13.7.1 \ --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ @@ -519,13 +519,13 @@ typing-extensions==4.11.0 \ # via # optree # tensorflow-cpu -urllib3==2.2.1 \ - --hash=sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d \ - --hash=sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19 +urllib3==1.26.19 \ + --hash=sha256:37a0344459b199fce0e80b0d3569837ec6b6937435c5244e7fd73fa6006830f3 \ + --hash=sha256:3e3d753a8618b86d7de333b4223005f68720bcd6a7d2bcb9fbd2229ec7c1e429 # via requests -werkzeug==3.0.2 \ - --hash=sha256:3aac3f5da756f93030740bc235d3e09449efcf65f2f55e3602e1d851b8f48795 \ - --hash=sha256:e39b645a6ac92822588e7b39a692e7828724ceae0b0d702ef96701f90e70128d +werkzeug==3.0.3 \ + --hash=sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18 \ + --hash=sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8 # via tensorboard wheel==0.43.0 \ --hash=sha256:465ef92c69fa5c5da2d1cf8ac40559a8c940886afcef87dcf14b9470862f1d85 \ @@ -605,9 +605,9 @@ wrapt==1.16.0 \ # via tensorflow-cpu # The following packages are considered to be unsafe in a requirements file: -setuptools==69.5.1 \ - --hash=sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987 \ - --hash=sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32 +setuptools==70.0.0 \ + --hash=sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4 \ + --hash=sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0 # via # tensorboard # tensorflow-cpu diff --git a/requirements_3_9.txt b/requirements_3_9.txt index 8a48b31..c014e2b 100644 --- a/requirements_3_9.txt +++ b/requirements_3_9.txt @@ -39,9 +39,9 @@ black==24.4.0 \ --hash=sha256:f07b69fda20578367eaebbd670ff8fc653ab181e1ff95d84497f9fa20e7d0641 \ --hash=sha256:f95cece33329dc4aa3b0e1a771c41075812e46cf3d6e3f1dfe3d91ff09826ed2 # via -r requirements.in -certifi==2023.7.22 \ - --hash=sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082 \ - --hash=sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9 +certifi==2024.7.4 \ + --hash=sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b \ + --hash=sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90 # via requests charset-normalizer==3.2.0 \ --hash=sha256:04e57ab9fbf9607b77f7d057974694b4f6b142da9ed4a199859d9d4d5c63fe96 \ @@ -211,9 +211,9 @@ h5py==3.11.0 \ # via # keras # tensorflow-cpu -idna==3.4 \ - --hash=sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4 \ - --hash=sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2 +idna==3.7 \ + --hash=sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc \ + --hash=sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0 # via requests importlib-metadata==6.8.0 \ --hash=sha256:3ebb78df84a805d7698245025b975d9d67053cd94c79245ba4b3eb694abe68bb \ @@ -455,9 +455,9 @@ pygments==2.17.2 \ --hash=sha256:b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c \ --hash=sha256:da46cec9fd2de5be3a8a784f434e4c4ab670b4ff54d605c4c2717e9d49c4c367 # via rich -requests==2.31.0 \ - --hash=sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f \ - --hash=sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1 +requests==2.32.2 \ + --hash=sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289 \ + --hash=sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c # via tensorflow-cpu rich==13.7.1 \ --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ @@ -529,13 +529,13 @@ typing-extensions==4.5.0 \ # black # optree # tensorflow-cpu -urllib3==1.26.16 \ - --hash=sha256:8d36afa7616d8ab714608411b4a3b13e58f463aee519024578e062e141dce20f \ - --hash=sha256:8f135f6502756bde6b2a9b28989df5fbe87c9970cecaa69041edcce7f0589b14 +urllib3==1.26.19 \ + --hash=sha256:37a0344459b199fce0e80b0d3569837ec6b6937435c5244e7fd73fa6006830f3 \ + --hash=sha256:3e3d753a8618b86d7de333b4223005f68720bcd6a7d2bcb9fbd2229ec7c1e429 # via requests -werkzeug==2.3.7 \ - --hash=sha256:2b8c0e447b4b9dbcc85dd97b6eeb4dcbaf6c8b6c3be0bd654e25553e0a2157d8 \ - --hash=sha256:effc12dba7f3bd72e605ce49807bbe692bd729c3bb122a3b91747a6ae77df528 +werkzeug==3.0.3 \ + --hash=sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18 \ + --hash=sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8 # via tensorboard wheel==0.41.2 \ --hash=sha256:0c5ac5ff2afb79ac23ab82bab027a0be7b5dbcf2e54dc50efe4bf507de1f7985 \ @@ -617,15 +617,15 @@ wrapt==1.14.1 \ --hash=sha256:ee6acae74a2b91865910eef5e7de37dc6895ad96fa23603d1d27ea69df545015 \ --hash=sha256:ef3f72c9666bba2bab70d2a8b79f2c6d2c1a42a7f7e2b0ec83bb2f9e383950af # via tensorflow-cpu -zipp==3.17.0 \ - --hash=sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31 \ - --hash=sha256:84e64a1c28cf7e91ed2078bb8cc8c259cb19b76942096c8d7b84947690cabaf0 +zipp==3.19.1 \ + --hash=sha256:2828e64edb5386ea6a52e7ba7cdb17bb30a73a858f5eb6eb93d8d36f5ea26091 \ + --hash=sha256:35427f6d5594f4acf82d25541438348c26736fa9b3afa2754bcd63cdb99d8e8f # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: -setuptools==68.2.2 \ - --hash=sha256:4ac1475276d2f1c48684874089fefcd83bd7162ddaafb81fac866ba0db282a87 \ - --hash=sha256:b454a35605876da60632df1a60f736524eb73cc47bbc9f3f1ef1b644de74fd2a +setuptools==70.0.0 \ + --hash=sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4 \ + --hash=sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0 # via # tensorboard # tensorflow-cpu