From 3f96d1972b5633de0ea9b45c609fb455e421ae17 Mon Sep 17 00:00:00 2001 From: "kanvi.khanna" Date: Tue, 16 Oct 2018 16:52:53 -0700 Subject: [PATCH 1/9] Changes for skipping the Assert op --- src/CMakeLists.txt | 1 + src/ngraph_rewrite_pass.cc | 10 ++++++++ src/ngraph_skip_assert.cc | 51 ++++++++++++++++++++++++++++++++++++++ src/ngraph_skip_assert.h | 29 ++++++++++++++++++++++ test/python/test_assert.py | 48 +++++++++++++++++++++++++++++++++++ 5 files changed, 139 insertions(+) create mode 100644 src/ngraph_skip_assert.cc create mode 100644 src/ngraph_skip_assert.h create mode 100644 test/python/test_assert.py diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c831d6ac..a7dac1d6 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -44,6 +44,7 @@ set(SRC ngraph_utils.cc tf_graphcycles.cc tf_deadness_analysis.cc + ngraph_skip_assert.cc ) add_library(${LIB_NAME} SHARED ${SRC}) diff --git a/src/ngraph_rewrite_pass.cc b/src/ngraph_rewrite_pass.cc index 08747d6c..ee5a18c3 100644 --- a/src/ngraph_rewrite_pass.cc +++ b/src/ngraph_rewrite_pass.cc @@ -24,6 +24,7 @@ #include "ngraph_log.h" #include "ngraph_mark_for_clustering.h" #include "ngraph_rewrite_for_tracking.h" +#include "ngraph_skip_assert.h" #include "tf_graph_writer.h" @@ -203,6 +204,15 @@ class NGraphEncapsulationPass : public NGraphRewritePass { DumpGraphs(options, idx, "unmarked", "Unmarked Graph"); } + // Skip "Assert" if specifically asked by the user + if (std::getenv("NGRAPH_TF_SKIP_ASSERT") != nullptr) { + TF_RETURN_IF_ERROR(SkipAssert(options.graph->get())); + // If requested, dump unmarked graphs without asserts + if (DumpUnmarkedGraphs()) { + DumpGraphs(options, idx, "assert_skipped", "Unmarked Graph without Assert"); + } + } + // 1. Mark for clustering then, if requested, dump the graphs. TF_RETURN_IF_ERROR(MarkForClustering(options.graph->get())); if (DumpMarkedGraphs()) { diff --git a/src/ngraph_skip_assert.cc b/src/ngraph_skip_assert.cc new file mode 100644 index 00000000..a98136a1 --- /dev/null +++ b/src/ngraph_skip_assert.cc @@ -0,0 +1,51 @@ +/******************************************************************************* + * Copyright 2017-2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ +#include "ngraph_skip_assert.h" + +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/types.h" + +#include "ngraph_utils.h" + +using namespace std; + +namespace tensorflow { + +namespace ngraph_bridge { + +// +// Main entry point for skip assert. +// +Status SkipAssert(Graph* graph) { + for (auto node : graph->op_nodes()) { + if (node->type_string() == "Assert") { + NGRAPH_VLOG(4) << "Checking: " << node->name(); + for (auto edge : node->out_edges()) { + if (edge->IsControlEdge()) { + NGRAPH_VLOG(4) << "Control edge: " << node->name(); + graph->RemoveControlEdge(edge); + NGRAPH_VLOG(4) << "Control edge removed. "; + } + } + } + } + return Status::OK(); +} + +} // namespace ngraph_bridge + +} // namespace tensorflow diff --git a/src/ngraph_skip_assert.h b/src/ngraph_skip_assert.h new file mode 100644 index 00000000..182069fc --- /dev/null +++ b/src/ngraph_skip_assert.h @@ -0,0 +1,29 @@ +/******************************************************************************* + * Copyright 2017-2018 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#pragma once + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +namespace ngraph_bridge { + +Status SkipAssert(Graph* graph); + +} // namespace ngraph_bridge + +} // namespace tensorflow diff --git a/test/python/test_assert.py b/test/python/test_assert.py new file mode 100644 index 00000000..2ffc4681 --- /dev/null +++ b/test/python/test_assert.py @@ -0,0 +1,48 @@ +# ============================================================================== +# Copyright 2018 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""nGraph TensorFlow bridge abs operation test + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pytest + +import tensorflow as tf +import os + +from common import NgraphTest + + +class TestAssertOperations(NgraphTest): + + def test_assert(self): + x = tf.constant([1,2]) + y = tf.constant([1,2]) + z = tf.constant([1,1]) + assert_op = tf.Assert(tf.less_equal(tf.reduce_max(z), 1), [z]) + + with tf.control_dependencies([assert_op]): + a2 = tf.add(z, y) + #a1 = tf.add(x, y) + + #a3 = tf.add(y,y) + def run_test(sess): + return sess.run(a2) + + self.with_ngraph(run_test) + From d79a4e68768dac36cddcbb27a92976a1bf14854b Mon Sep 17 00:00:00 2001 From: "kanvi.khanna" Date: Tue, 16 Oct 2018 17:34:38 -0700 Subject: [PATCH 2/9] Disable deadness check --- src/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a7dac1d6..39d4ecab 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -22,7 +22,7 @@ set(NGRAPH_DEVICE_INCLUDE_PATH ${CMAKE_CURRENT_SOURCE_DIR}) # For some reason the following is needed for ABI compatibility with TF. # (There must be some dependency on it in the struct/class definitions.) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNDEBUG") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNDEBUG -D NGRAPH_TF_DISABLE_DEADNESS_CHECK") #----------------------------------------------------------------------------------------------- # Compiler-specific logic... From 31e055c5e9bbff6fdc46e17e5f1aabaaaf36dfba Mon Sep 17 00:00:00 2001 From: "kanvi.khanna" Date: Wed, 17 Oct 2018 13:56:23 -0700 Subject: [PATCH 3/9] Skip Assert Added code for skipping the assert op if the env variable NGRAPH_TF_SKIP_ASSERT is set i.e basically get rid of the control edge Dump unmarked graphs with assert skipped Added python test for checking the assert is actually skipped --- src/CMakeLists.txt | 2 +- src/ngraph_rewrite_pass.cc | 18 ++++++++--------- .../{test_assert.py => test_skip_assert.py} | 20 +++++++++---------- 3 files changed, 20 insertions(+), 20 deletions(-) rename test/python/{test_assert.py => test_skip_assert.py} (78%) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 39d4ecab..a7dac1d6 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -22,7 +22,7 @@ set(NGRAPH_DEVICE_INCLUDE_PATH ${CMAKE_CURRENT_SOURCE_DIR}) # For some reason the following is needed for ABI compatibility with TF. # (There must be some dependency on it in the struct/class definitions.) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNDEBUG -D NGRAPH_TF_DISABLE_DEADNESS_CHECK") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNDEBUG") #----------------------------------------------------------------------------------------------- # Compiler-specific logic... diff --git a/src/ngraph_rewrite_pass.cc b/src/ngraph_rewrite_pass.cc index ee5a18c3..48daf793 100644 --- a/src/ngraph_rewrite_pass.cc +++ b/src/ngraph_rewrite_pass.cc @@ -151,6 +151,15 @@ class NGraphVariableCapturePass : public NGraphRewritePass { DumpGraphs(options, idx, "captured", "Graph With Variables Captured"); } + // Skip "Assert" if specifically asked by the user + if (std::getenv("NGRAPH_TF_SKIP_ASSERT") != nullptr) { + TF_RETURN_IF_ERROR(SkipAssert(options.graph->get())); + // If requested, dump unmarked graphs without asserts + if (DumpCapturedGraphs()) { + DumpGraphs(options, idx, "assert_skipped", "Captured Graph without Assert"); + } + } + return Status::OK(); } @@ -204,15 +213,6 @@ class NGraphEncapsulationPass : public NGraphRewritePass { DumpGraphs(options, idx, "unmarked", "Unmarked Graph"); } - // Skip "Assert" if specifically asked by the user - if (std::getenv("NGRAPH_TF_SKIP_ASSERT") != nullptr) { - TF_RETURN_IF_ERROR(SkipAssert(options.graph->get())); - // If requested, dump unmarked graphs without asserts - if (DumpUnmarkedGraphs()) { - DumpGraphs(options, idx, "assert_skipped", "Unmarked Graph without Assert"); - } - } - // 1. Mark for clustering then, if requested, dump the graphs. TF_RETURN_IF_ERROR(MarkForClustering(options.graph->get())); if (DumpMarkedGraphs()) { diff --git a/test/python/test_assert.py b/test/python/test_skip_assert.py similarity index 78% rename from test/python/test_assert.py rename to test/python/test_skip_assert.py index 2ffc4681..88231a28 100644 --- a/test/python/test_assert.py +++ b/test/python/test_skip_assert.py @@ -30,19 +30,19 @@ class TestAssertOperations(NgraphTest): - def test_assert(self): - x = tf.constant([1,2]) - y = tf.constant([1,2]) - z = tf.constant([1,1]) - assert_op = tf.Assert(tf.less_equal(tf.reduce_max(z), 1), [z]) + def test_skip_assert(self): + test_input = ((1,1)) + x = tf.placeholder(tf.int32, shape=(2,)) + y = tf.placeholder(tf.int32, shape=(2,)) + z = tf.placeholder(tf.int32, shape=(2,)) + assert_op = tf.Assert(tf.less_equal(tf.reduce_max(z), 1), [x]) with tf.control_dependencies([assert_op]): - a2 = tf.add(z, y) - #a1 = tf.add(x, y) - - #a3 = tf.add(y,y) + a2 = tf.add(x, y) + def run_test(sess): - return sess.run(a2) + return sess.run(a2, feed_dict={x:test_input, y:test_input, z:test_input}) self.with_ngraph(run_test) + From cbd15d779d76c6abde9bb00b18fb92c642012971 Mon Sep 17 00:00:00 2001 From: "kanvi.khanna" Date: Wed, 17 Oct 2018 14:05:48 -0700 Subject: [PATCH 4/9] Fix format --- src/ngraph_rewrite_pass.cc | 3 ++- src/ngraph_skip_assert.cc | 14 +++++++------- test/python/test_skip_assert.py | 13 ++++++++----- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/ngraph_rewrite_pass.cc b/src/ngraph_rewrite_pass.cc index 48daf793..8b54bc75 100644 --- a/src/ngraph_rewrite_pass.cc +++ b/src/ngraph_rewrite_pass.cc @@ -156,7 +156,8 @@ class NGraphVariableCapturePass : public NGraphRewritePass { TF_RETURN_IF_ERROR(SkipAssert(options.graph->get())); // If requested, dump unmarked graphs without asserts if (DumpCapturedGraphs()) { - DumpGraphs(options, idx, "assert_skipped", "Captured Graph without Assert"); + DumpGraphs(options, idx, "assert_skipped", + "Captured Graph without Assert"); } } diff --git a/src/ngraph_skip_assert.cc b/src/ngraph_skip_assert.cc index a98136a1..46ab4cc4 100644 --- a/src/ngraph_skip_assert.cc +++ b/src/ngraph_skip_assert.cc @@ -33,14 +33,14 @@ namespace ngraph_bridge { Status SkipAssert(Graph* graph) { for (auto node : graph->op_nodes()) { if (node->type_string() == "Assert") { - NGRAPH_VLOG(4) << "Checking: " << node->name(); - for (auto edge : node->out_edges()) { - if (edge->IsControlEdge()) { - NGRAPH_VLOG(4) << "Control edge: " << node->name(); - graph->RemoveControlEdge(edge); - NGRAPH_VLOG(4) << "Control edge removed. "; - } + NGRAPH_VLOG(4) << "Checking: " << node->name(); + for (auto edge : node->out_edges()) { + if (edge->IsControlEdge()) { + NGRAPH_VLOG(4) << "Control edge: " << node->name(); + graph->RemoveControlEdge(edge); + NGRAPH_VLOG(4) << "Control edge removed. "; } + } } } return Status::OK(); diff --git a/test/python/test_skip_assert.py b/test/python/test_skip_assert.py index 88231a28..48adc3fa 100644 --- a/test/python/test_skip_assert.py +++ b/test/python/test_skip_assert.py @@ -31,7 +31,7 @@ class TestAssertOperations(NgraphTest): def test_skip_assert(self): - test_input = ((1,1)) + test_input = ((1, 1)) x = tf.placeholder(tf.int32, shape=(2,)) y = tf.placeholder(tf.int32, shape=(2,)) z = tf.placeholder(tf.int32, shape=(2,)) @@ -39,10 +39,13 @@ def test_skip_assert(self): with tf.control_dependencies([assert_op]): a2 = tf.add(x, y) - + def run_test(sess): - return sess.run(a2, feed_dict={x:test_input, y:test_input, z:test_input}) + return sess.run( + a2, feed_dict={ + x: test_input, + y: test_input, + z: test_input + }) self.with_ngraph(run_test) - - From 98abc875c8567fb781a6bf81feedffd31d888408 Mon Sep 17 00:00:00 2001 From: "kanvi.khanna" Date: Thu, 18 Oct 2018 12:01:52 -0700 Subject: [PATCH 5/9] Incorporate comments --- src/CMakeLists.txt | 2 +- ...ngraph_skip_assert.cc => ngraph_disable_assert.cc} | 9 ++++----- src/{ngraph_skip_assert.h => ngraph_disable_assert.h} | 2 +- src/ngraph_rewrite_pass.cc | 11 ++++++----- .../{test_skip_assert.py => test_disable_assert.py} | 5 +++-- 5 files changed, 15 insertions(+), 14 deletions(-) rename src/{ngraph_skip_assert.cc => ngraph_disable_assert.cc} (86%) rename src/{ngraph_skip_assert.h => ngraph_disable_assert.h} (96%) rename test/python/{test_skip_assert.py => test_disable_assert.py} (92%) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a7dac1d6..b9894f89 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -44,7 +44,7 @@ set(SRC ngraph_utils.cc tf_graphcycles.cc tf_deadness_analysis.cc - ngraph_skip_assert.cc + ngraph_disable_assert.cc ) add_library(${LIB_NAME} SHARED ${SRC}) diff --git a/src/ngraph_skip_assert.cc b/src/ngraph_disable_assert.cc similarity index 86% rename from src/ngraph_skip_assert.cc rename to src/ngraph_disable_assert.cc index 46ab4cc4..ef1316cf 100644 --- a/src/ngraph_skip_assert.cc +++ b/src/ngraph_disable_assert.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ -#include "ngraph_skip_assert.h" +#include "ngraph_disable_assert.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" @@ -28,17 +28,16 @@ namespace tensorflow { namespace ngraph_bridge { // -// Main entry point for skip assert. +// Main entry point for disbale assert. // -Status SkipAssert(Graph* graph) { +Status DisableAssert(Graph* graph) { for (auto node : graph->op_nodes()) { if (node->type_string() == "Assert") { NGRAPH_VLOG(4) << "Checking: " << node->name(); for (auto edge : node->out_edges()) { if (edge->IsControlEdge()) { - NGRAPH_VLOG(4) << "Control edge: " << node->name(); + NGRAPH_VLOG(4) << "Removing control edge: " << edge->DebugString(); graph->RemoveControlEdge(edge); - NGRAPH_VLOG(4) << "Control edge removed. "; } } } diff --git a/src/ngraph_skip_assert.h b/src/ngraph_disable_assert.h similarity index 96% rename from src/ngraph_skip_assert.h rename to src/ngraph_disable_assert.h index 182069fc..960c3f15 100644 --- a/src/ngraph_skip_assert.h +++ b/src/ngraph_disable_assert.h @@ -22,7 +22,7 @@ namespace tensorflow { namespace ngraph_bridge { -Status SkipAssert(Graph* graph); +Status DisableAssert(Graph* graph); } // namespace ngraph_bridge diff --git a/src/ngraph_rewrite_pass.cc b/src/ngraph_rewrite_pass.cc index 8b54bc75..7b1d71ce 100644 --- a/src/ngraph_rewrite_pass.cc +++ b/src/ngraph_rewrite_pass.cc @@ -20,11 +20,11 @@ #include "ngraph_assign_clusters.h" #include "ngraph_capture_variables.h" #include "ngraph_deassign_clusters.h" +#include "ngraph_disable_assert.h" #include "ngraph_encapsulate_clusters.h" #include "ngraph_log.h" #include "ngraph_mark_for_clustering.h" #include "ngraph_rewrite_for_tracking.h" -#include "ngraph_skip_assert.h" #include "tf_graph_writer.h" @@ -152,12 +152,13 @@ class NGraphVariableCapturePass : public NGraphRewritePass { } // Skip "Assert" if specifically asked by the user - if (std::getenv("NGRAPH_TF_SKIP_ASSERT") != nullptr) { - TF_RETURN_IF_ERROR(SkipAssert(options.graph->get())); + if (std::getenv("NGRAPH_TF_DISABLE_ASSERTS") != nullptr) { + TF_RETURN_IF_ERROR(DisableAssert(options.graph->get())); + NGRAPH_VLOG(0) << "Model running with Asserts disabled."; // If requested, dump unmarked graphs without asserts if (DumpCapturedGraphs()) { - DumpGraphs(options, idx, "assert_skipped", - "Captured Graph without Assert"); + DumpGraphs(options, idx, "assert_disabled", + "Captured Graph without Asserts"); } } diff --git a/test/python/test_skip_assert.py b/test/python/test_disable_assert.py similarity index 92% rename from test/python/test_skip_assert.py rename to test/python/test_disable_assert.py index 48adc3fa..d1b67c8d 100644 --- a/test/python/test_skip_assert.py +++ b/test/python/test_disable_assert.py @@ -30,7 +30,7 @@ class TestAssertOperations(NgraphTest): - def test_skip_assert(self): + def test_disable_assert(self): test_input = ((1, 1)) x = tf.placeholder(tf.int32, shape=(2,)) y = tf.placeholder(tf.int32, shape=(2,)) @@ -48,4 +48,5 @@ def run_test(sess): z: test_input }) - self.with_ngraph(run_test) + assert ( + self.with_ngraph(run_test) == self.without_ngraph(run_test)).all() From 9acd86e4aa38c0b1118cdb0333f0721fdc654a34 Mon Sep 17 00:00:00 2001 From: "kanvi.khanna" Date: Thu, 18 Oct 2018 12:09:06 -0700 Subject: [PATCH 6/9] Corrected spelling mistakes --- src/ngraph_disable_assert.cc | 2 +- src/ngraph_rewrite_pass.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ngraph_disable_assert.cc b/src/ngraph_disable_assert.cc index ef1316cf..137bdb37 100644 --- a/src/ngraph_disable_assert.cc +++ b/src/ngraph_disable_assert.cc @@ -28,7 +28,7 @@ namespace tensorflow { namespace ngraph_bridge { // -// Main entry point for disbale assert. +// Main entry point for disable assert. // Status DisableAssert(Graph* graph) { for (auto node : graph->op_nodes()) { diff --git a/src/ngraph_rewrite_pass.cc b/src/ngraph_rewrite_pass.cc index 7b1d71ce..18029f1f 100644 --- a/src/ngraph_rewrite_pass.cc +++ b/src/ngraph_rewrite_pass.cc @@ -151,7 +151,7 @@ class NGraphVariableCapturePass : public NGraphRewritePass { DumpGraphs(options, idx, "captured", "Graph With Variables Captured"); } - // Skip "Assert" if specifically asked by the user + // Disable "Assert" if specifically asked by the user if (std::getenv("NGRAPH_TF_DISABLE_ASSERTS") != nullptr) { TF_RETURN_IF_ERROR(DisableAssert(options.graph->get())); NGRAPH_VLOG(0) << "Model running with Asserts disabled."; From 67bda5472044bc8e5fea0613435ced336c0efe4d Mon Sep 17 00:00:00 2001 From: "kanvi.khanna" Date: Fri, 19 Oct 2018 11:33:31 -0700 Subject: [PATCH 7/9] Fix segmentation fault Add more checks --- src/ngraph_disable_assert.cc | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/ngraph_disable_assert.cc b/src/ngraph_disable_assert.cc index 137bdb37..5d5f39c4 100644 --- a/src/ngraph_disable_assert.cc +++ b/src/ngraph_disable_assert.cc @@ -31,17 +31,24 @@ namespace ngraph_bridge { // Main entry point for disable assert. // Status DisableAssert(Graph* graph) { + std::vector edges; for (auto node : graph->op_nodes()) { if (node->type_string() == "Assert") { NGRAPH_VLOG(4) << "Checking: " << node->name(); for (auto edge : node->out_edges()) { if (edge->IsControlEdge()) { - NGRAPH_VLOG(4) << "Removing control edge: " << edge->DebugString(); - graph->RemoveControlEdge(edge); + if(edge != NULL) { + NGRAPH_VLOG(4) << "Collecting all the control edges"; + edges.push_back(edge); + } } } } } + for (auto edge : edges) { + NGRAPH_VLOG(4) << "Removing control edge: " << edge->DebugString(); + graph->RemoveControlEdge(edge); + } return Status::OK(); } From a540b2224725ce4833151c39c09ac6cb4ab656e0 Mon Sep 17 00:00:00 2001 From: "kanvi.khanna" Date: Fri, 19 Oct 2018 17:03:10 -0700 Subject: [PATCH 8/9] Add another test ci changes --- test/ci/run-premerge-ci-checks.sh | 4 ++++ test/python/test_disable_assert.py | 25 +++++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/test/ci/run-premerge-ci-checks.sh b/test/ci/run-premerge-ci-checks.sh index e922ead0..2a0332fb 100755 --- a/test/ci/run-premerge-ci-checks.sh +++ b/test/ci/run-premerge-ci-checks.sh @@ -58,6 +58,10 @@ python -m pytest \ test_slice.py \ test_sigmoidgrad.py \ test_tanhgrad.py + +export NGRAPH_TF_DISABLE_ASSERTS=1 +python -m pytest test_disable_assert.py +unset NGRAPH_TF_DISABLE_ASSERTS popd echo "--------------------------------------------------------------------------" diff --git a/test/python/test_disable_assert.py b/test/python/test_disable_assert.py index d1b67c8d..3a7af9ba 100644 --- a/test/python/test_disable_assert.py +++ b/test/python/test_disable_assert.py @@ -50,3 +50,28 @@ def run_test(sess): assert ( self.with_ngraph(run_test) == self.without_ngraph(run_test)).all() + + def test_disable_assert_tf_fails_ng_pass(self): + test_input = ((2, 2)) + x = tf.placeholder(tf.int32, shape=(2,)) + y = tf.placeholder(tf.int32, shape=(2,)) + z = tf.placeholder(tf.int32, shape=(2,)) + assert_op = tf.Assert(tf.less_equal(tf.reduce_max(z), 1), [x]) + + with tf.control_dependencies([assert_op]): + a2 = tf.add(x, y) + + def run_test(sess): + return sess.run( + a2, feed_dict={ + x: test_input, + y: test_input, + z: test_input + }) + + try: + self.without_ngraph(run_test) + assert False + except tf.errors.InvalidArgumentError as e: + print("hfhrhgthgutrihy") + self.with_ngraph(run_test) From 47e065be26056811bbe4232b017822a959661ea0 Mon Sep 17 00:00:00 2001 From: "kanvi.khanna" Date: Fri, 19 Oct 2018 17:50:10 -0700 Subject: [PATCH 9/9] Minor changes --- src/ngraph_disable_assert.cc | 2 +- src/ngraph_rewrite_pass.cc | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/ngraph_disable_assert.cc b/src/ngraph_disable_assert.cc index 5d5f39c4..af34a0b8 100644 --- a/src/ngraph_disable_assert.cc +++ b/src/ngraph_disable_assert.cc @@ -37,7 +37,7 @@ Status DisableAssert(Graph* graph) { NGRAPH_VLOG(4) << "Checking: " << node->name(); for (auto edge : node->out_edges()) { if (edge->IsControlEdge()) { - if(edge != NULL) { + if (edge != NULL) { NGRAPH_VLOG(4) << "Collecting all the control edges"; edges.push_back(edge); } diff --git a/src/ngraph_rewrite_pass.cc b/src/ngraph_rewrite_pass.cc index 18029f1f..ce29e503 100644 --- a/src/ngraph_rewrite_pass.cc +++ b/src/ngraph_rewrite_pass.cc @@ -155,8 +155,8 @@ class NGraphVariableCapturePass : public NGraphRewritePass { if (std::getenv("NGRAPH_TF_DISABLE_ASSERTS") != nullptr) { TF_RETURN_IF_ERROR(DisableAssert(options.graph->get())); NGRAPH_VLOG(0) << "Model running with Asserts disabled."; - // If requested, dump unmarked graphs without asserts - if (DumpCapturedGraphs()) { + // If requested, dump captured graphs without asserts + if (DumpDisabledAssertsGraphs()) { DumpGraphs(options, idx, "assert_disabled", "Captured Graph without Asserts"); } @@ -174,6 +174,10 @@ class NGraphVariableCapturePass : public NGraphRewritePass { return DumpAllGraphs() || std::getenv("NGRAPH_TF_DUMP_CAPTURED_GRAPHS") != nullptr; } + static bool DumpDisabledAssertsGraphs() { + return DumpAllGraphs() || + std::getenv("NGRAPH_TF_DISABLE_ASSERTS") != nullptr; + } }; //