diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c831d6ac..b9894f89 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -44,6 +44,7 @@ set(SRC ngraph_utils.cc tf_graphcycles.cc tf_deadness_analysis.cc + ngraph_disable_assert.cc ) add_library(${LIB_NAME} SHARED ${SRC}) diff --git a/src/ngraph_disable_assert.cc b/src/ngraph_disable_assert.cc new file mode 100644 index 00000000..af34a0b8 --- /dev/null +++ b/src/ngraph_disable_assert.cc @@ -0,0 +1,57 @@ +/******************************************************************************* + * Copyright 2017-2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ +#include "ngraph_disable_assert.h" + +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/types.h" + +#include "ngraph_utils.h" + +using namespace std; + +namespace tensorflow { + +namespace ngraph_bridge { + +// +// Main entry point for disable assert. +// +Status DisableAssert(Graph* graph) { + std::vector edges; + for (auto node : graph->op_nodes()) { + if (node->type_string() == "Assert") { + NGRAPH_VLOG(4) << "Checking: " << node->name(); + for (auto edge : node->out_edges()) { + if (edge->IsControlEdge()) { + if (edge != NULL) { + NGRAPH_VLOG(4) << "Collecting all the control edges"; + edges.push_back(edge); + } + } + } + } + } + for (auto edge : edges) { + NGRAPH_VLOG(4) << "Removing control edge: " << edge->DebugString(); + graph->RemoveControlEdge(edge); + } + return Status::OK(); +} + +} // namespace ngraph_bridge + +} // namespace tensorflow diff --git a/src/ngraph_disable_assert.h b/src/ngraph_disable_assert.h new file mode 100644 index 00000000..960c3f15 --- /dev/null +++ b/src/ngraph_disable_assert.h @@ -0,0 +1,29 @@ +/******************************************************************************* + * Copyright 2017-2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#pragma once + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +namespace ngraph_bridge { + +Status DisableAssert(Graph* graph); + +} // namespace ngraph_bridge + +} // namespace tensorflow diff --git a/src/ngraph_rewrite_pass.cc b/src/ngraph_rewrite_pass.cc index 08747d6c..ce29e503 100644 --- a/src/ngraph_rewrite_pass.cc +++ b/src/ngraph_rewrite_pass.cc @@ -20,6 +20,7 @@ #include "ngraph_assign_clusters.h" #include "ngraph_capture_variables.h" #include "ngraph_deassign_clusters.h" +#include "ngraph_disable_assert.h" #include "ngraph_encapsulate_clusters.h" #include "ngraph_log.h" #include "ngraph_mark_for_clustering.h" @@ -150,6 +151,17 @@ class NGraphVariableCapturePass : public NGraphRewritePass { DumpGraphs(options, idx, "captured", "Graph With Variables Captured"); } + // Disable "Assert" if specifically asked by the user + if (std::getenv("NGRAPH_TF_DISABLE_ASSERTS") != nullptr) { + TF_RETURN_IF_ERROR(DisableAssert(options.graph->get())); + NGRAPH_VLOG(0) << "Model running with Asserts disabled."; + // If requested, dump captured graphs without asserts + if (DumpDisabledAssertsGraphs()) { + DumpGraphs(options, idx, "assert_disabled", + "Captured Graph without Asserts"); + } + } + return Status::OK(); } @@ -162,6 +174,10 @@ class NGraphVariableCapturePass : public NGraphRewritePass { return DumpAllGraphs() || std::getenv("NGRAPH_TF_DUMP_CAPTURED_GRAPHS") != nullptr; } + static bool DumpDisabledAssertsGraphs() { + return DumpAllGraphs() || + std::getenv("NGRAPH_TF_DISABLE_ASSERTS") != nullptr; + } }; // diff --git a/test/ci/run-premerge-ci-checks.sh b/test/ci/run-premerge-ci-checks.sh index e922ead0..2a0332fb 100755 --- a/test/ci/run-premerge-ci-checks.sh +++ b/test/ci/run-premerge-ci-checks.sh @@ -58,6 +58,10 @@ python -m pytest \ test_slice.py \ test_sigmoidgrad.py \ test_tanhgrad.py + +export NGRAPH_TF_DISABLE_ASSERTS=1 +python -m pytest test_disable_assert.py +unset NGRAPH_TF_DISABLE_ASSERTS popd echo "--------------------------------------------------------------------------" diff --git a/test/python/test_disable_assert.py b/test/python/test_disable_assert.py new file mode 100644 index 00000000..3a7af9ba --- /dev/null +++ b/test/python/test_disable_assert.py @@ -0,0 +1,77 @@ +# ============================================================================== +# Copyright 2018 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""nGraph TensorFlow bridge abs operation test + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pytest + +import tensorflow as tf +import os + +from common import NgraphTest + + +class TestAssertOperations(NgraphTest): + + def test_disable_assert(self): + test_input = ((1, 1)) + x = tf.placeholder(tf.int32, shape=(2,)) + y = tf.placeholder(tf.int32, shape=(2,)) + z = tf.placeholder(tf.int32, shape=(2,)) + assert_op = tf.Assert(tf.less_equal(tf.reduce_max(z), 1), [x]) + + with tf.control_dependencies([assert_op]): + a2 = tf.add(x, y) + + def run_test(sess): + return sess.run( + a2, feed_dict={ + x: test_input, + y: test_input, + z: test_input + }) + + assert ( + self.with_ngraph(run_test) == self.without_ngraph(run_test)).all() + + def test_disable_assert_tf_fails_ng_pass(self): + test_input = ((2, 2)) + x = tf.placeholder(tf.int32, shape=(2,)) + y = tf.placeholder(tf.int32, shape=(2,)) + z = tf.placeholder(tf.int32, shape=(2,)) + assert_op = tf.Assert(tf.less_equal(tf.reduce_max(z), 1), [x]) + + with tf.control_dependencies([assert_op]): + a2 = tf.add(x, y) + + def run_test(sess): + return sess.run( + a2, feed_dict={ + x: test_input, + y: test_input, + z: test_input + }) + + try: + self.without_ngraph(run_test) + assert False + except tf.errors.InvalidArgumentError as e: + print("hfhrhgthgutrihy") + self.with_ngraph(run_test)