diff --git a/README.md b/README.md index 7d88996..da10f9d 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ To return the test results to the Nutter CLI: result.exit(dbutils) ``` -__Note:__ The call to result.exit, behind the scenes calls dbutils.notebook.exit, passing the serialized TestResults back to the CLI. At the current time, print statements do not work when dbutils.notebook.exit is called in a notebook, even if they are written prior to the call. For this reason, it is required to *temporarily* comment out result.exit(dbutils) when running the tests locally. +__Note:__ The call to result.exit, behind the scenes calls dbutils.notebook.exit, passing the serialized TestResults back to the CLI. At the current time, print statements do not work when `dbutils.notebook.exit` is called in a notebook, even if they are written prior to the call. For this reason, it is required to *temporarily* comment out `result.exit(dbutils)` when running the tests locally. The following defines a single test fixture named 'MyTestFixture' that has 1 TestCase named 'test_name': @@ -91,10 +91,9 @@ result = MyTestFixture().execute_tests() print(result.to_string()) # Comment out the next line (result.exit(dbutils)) to see the test result report from within the notebook result.exit(dbutils) - ``` -To execute the test from within the test notebook, simply run the cell containing the above code. At the current time, in order to see the below test result, you will have to comment out the call to result.exit(dbutils). That call is required to send the results, if the test is run from the CLI, so do not forget to uncomment after locally testing. +To execute the test from within the test notebook, simply run the cell containing the above code. At the current time, in order to see the below test result, you will have to comment out the call to `result.exit(dbutils)`. That call is required to send the results, if the test is run from the CLI, so do not forget to uncomment after locally testing. ``` Python Notebook: (local) - Lifecycle State: N/A, Result: N/A @@ -109,7 +108,7 @@ test_name (19.43149897100011 seconds) ### Test Cases -A test fixture can contain 1 or more test cases. Test cases are discovered when execute_tests() is called on the test fixture. Every test case is comprised of 1 required and 3 optional methods and are discovered by the following convention: prefix_testname, where valid prefixes are: before_, run_, assertion_, and after_. A test fixture that has run_fred and assertion_fred methods has 1 test case called 'fred'. The following are details about test case methods: +A test fixture can contain 1 or more test cases. Test cases are discovered when execute_tests() is called on the test fixture. Every test case is comprised of 1 required and 3 optional methods and are discovered by the following convention: prefix_testname, where valid prefixes are: `before_`, `run_`, `assertion_`, and `after_`. A test fixture that has run_fred and assertion_fred methods has 1 test case called 'fred'. The following are details about test case methods: * _before\_(testname)_ - (optional) - if provided, is run prior to the 'run_' method. This method can be used to setup any test pre-conditions @@ -165,13 +164,13 @@ class MultiTestFixture(NutterFixture): ### Multiple test assertions pattern with before_all -It is possible to support multiple assertions for a test by implementing a before_all method, no run methods and multiple assertion methods. In this pattern, the before_all method runs the notebook under test. There are no run methods. The assertion methods simply assert against what was done in before_all. +It is possible to support multiple assertions for a test by implementing a before_all method, no run methods and multiple assertion methods. In this pattern, the before_all method runs the notebook under test. There are no run methods. The assertion methods simply assert against what was done in before_all. ``` Python from runtime.nutterfixture import NutterFixture, tag class MultiTestFixture(NutterFixture): def before_all(self): - dbutils.notebook.run('notebook_under_test', 600, args) + dbutils.notebook.run('notebook_under_test', 600, args) … def assertion_test_case_1(self): @@ -271,7 +270,7 @@ pip install nutter __Note:__ It's recommended to install the Nutter CLI in a virtual environment. -Set the environment variables. +Set the necessary environment variables. Linux @@ -287,7 +286,7 @@ $env:DATABRICKS_HOST="HOST" $env:DATABRICKS_TOKEN="TOKEN" ``` -__Note:__ For more information about personal access tokens review [Databricks API Authentication](https://docs.azuredatabricks.net/dev-tools/api/latest/authentication.html). +__Note:__ For more information about personal access tokens review [Databricks Unified Authentication](https://learn.microsoft.com/en-us/azure/databricks/dev-tools/auth/env-vars). ### Listing test notebooks @@ -299,7 +298,7 @@ nutter list /dataload __Note:__ The Nutter CLI lists only tests notebooks that follow the naming convention for Nutter test notebooks. -By default the Nutter CLI lists test notebooks in the folder ignoring sub-folders. +By default the Nutter CLI lists test notebooks in the folder ignoring sub-folders. You can list all test notebooks in the folder structure using the ```--recursive``` flag. @@ -319,7 +318,19 @@ The following command executes the test notebook ```/dataload/test_sourceLoad``` nutter run dataload/test_sourceLoad --cluster_id 0123-12334-tonedabc --notebook_params "{\"example_key_1\": \"example_value_1\", \"example_key_2\": \"example_value_2\"}" ``` -__Note:__ In Azure Databricks you can get the cluster ID by selecting a cluster name from the Clusters tab and clicking on the JSON view. +Alternatively, you can specify the cluster by name instead of ID: + +```bash +nutter run dataload/test_sourceLoad --cluster_name "My Test Cluster" --notebook_params "{\"example_key_1\": \"example_value_1\", \"example_key_2\": \"example_value_2\"}" +``` + +Or run tests on serverless compute without needing a cluster: + +```bash +nutter run dataload/test_sourceLoad --serverless 1 --notebook_params "{\"example_key_1\": \"example_value_1\", \"example_key_2\": \"example_value_2\"}" +``` + +__Note:__ In Azure Databricks you can get the cluster ID by selecting a cluster name from the Clusters tab and clicking on the JSON view. When using `--cluster_name`, Nutter will automatically resolve the name to the cluster ID. When using `--serverless`, specify the environment version as an integer (e.g., 1) and tests will run on Databricks serverless compute. ### Run multiple tests notebooks @@ -329,7 +340,7 @@ The Nutter CLI supports the execution of multiple notebooks via name pattern mat Say the *dataload* folder has the following test notebooks: *test_srcLoad* and *test_srcValidation* with the notebook_param key-value pairs of ```{"example_key_1": "example_value_1", "example_key_2": "example_value_2"}```. The following command will result in the execution of both tests. ```bash -nutter run dataload/src* --cluster_id 0123-12334-tonedabc --notebook_params "{\"example_key_1\": \"example_value_1\", \"example_key_2\": \"example_value_2\"}" +nutter run dataload/src* --cluster_id 0123-12334-tonedabc --notebook_params "{\"example_key_1\": \"example_value_1\", \"example_key_2\": \"example_value_2\"}" ``` In addition, if you have tests in a hierarchical folder structure, you can recursively execute all tests by setting the ```--recursive``` flag. @@ -340,6 +351,12 @@ The following command will execute all tests in the folder structure within the nutter run dataload/ --cluster_id 0123-12334-tonedabc --recursive ``` +You can also run multiple tests recursively using serverless compute: + +```bash +nutter run dataload/ --serverless 1 --recursive +``` + ### Parallel Execution By default the Nutter CLI executes the test notebooks sequentially. The execution is a blocking operation that returns when the job reaches a terminal state or when the timeout expires. @@ -352,6 +369,12 @@ The following command executes all the tests in the *dataload* folder structure, nutter run dataload/ --cluster_id 0123-12334-tonedabc --recursive --max_parallel_tests 2 ``` +You can also run tests in parallel on serverless compute: + +```bash +nutter run dataload/ --serverless 1 --recursive --max_parallel_tests 2 +``` + __Note:__ Running tests notebooks in parallel introduces the risk of data race conditions when two or more tests notebooks modify the same tables or files at the same time. Before increasing the level of parallelism make sure that your tests cases modify only tables or files that are used or referenced within the scope of the test notebook. ## Nutter CLI Syntax and Flags @@ -360,28 +383,47 @@ __Note:__ Running tests notebooks in parallel introduces the risk of data race c ``` bash SYNOPSIS - nutter run TEST_PATTERN CLUSTER_ID + nutter run TEST_PATTERN POSITIONAL ARGUMENTS TEST_PATTERN - CLUSTER_ID + Type: str + Required: Yes + The pattern to match test notebooks. Can include wildcards. ``` ``` bash FLAGS + --cluster_id The Databricks cluster ID where tests will be executed. + Must specify one of: cluster_id, cluster_name, or serverless. + --cluster_name The Databricks cluster name where tests will be executed. + If provided, the cluster ID will be resolved automatically. + Must specify one of: cluster_id, cluster_name, or serverless. + --serverless Run tests on serverless compute. Specify the environment version as an integer (e.g., 1). + Must specify one of: cluster_id, cluster_name, or serverless. --timeout Execution timeout in seconds. Integer value. Default is 120 --junit_report Create a JUnit XML report from the test results. --tags_report Create a CSV report from the test results that includes the test cases tags. --max_parallel_tests Sets the level of parallelism for test notebook execution. - --recursive Executes all tests in the hierarchical folder structure. + --recursive Executes all tests in the hierarchical folder structure. --poll_wait_time Polling interval duration for notebook status. Default is 5 (5 seconds). - --notebook_params Allows parameters to be passed from the CLI tool to the test notebook. From the - notebook, these parameters can then be accessed by the notebook using + --notebook_params Allows parameters to be passed from the CLI tool to the test notebook. From the + notebook, these parameters can then be accessed by the notebook using the 'dbutils.widgets.get('key')' syntax. ``` -__Note:__ You can also use flags syntax for POSITIONAL ARGUMENTS +__Note:__ You can specify the compute environment in multiple ways: + +**Using a Cluster:** +1. As a positional argument (for backward compatibility): `nutter run test_pattern cluster-id` +2. Using the `--cluster_id` flag: `nutter run test_pattern --cluster_id cluster-id` +3. Using the `--cluster_name` flag: `nutter run test_pattern --cluster_name "My Cluster"` + +**Using Serverless Compute:** +4. Using the `--serverless` flag: `nutter run test_pattern --serverless 1` + +When using `--cluster_name`, Nutter will automatically look up the cluster ID. When using `--serverless`, tests will run on Databricks serverless compute without requiring a cluster. ### List Command @@ -449,7 +491,7 @@ steps: In some scenarios, the notebooks under tests must be executed in a pre-configured test workspace, other than the development one, that contains the necessary pre-requisites such as test data, tables or mounted points. In such scenarios, you can use the pipeline to deploy the notebooks to the test workspace before executing the tests with Nutter. -The following sample pipeline uses the Databricks CLI to publish the notebooks from triggering branch to the test workspace. +The following sample pipeline uses the Databricks CLI to publish the notebooks from triggering branch to the test workspace. ```yaml @@ -515,10 +557,10 @@ pip install --force-reinstall pytest==5.0.1 Creating the wheel file and manually test wheel locally -1. Change directory to the root that contains setup.py -2. Update the version in the setup.py -3. Run the following command: python3 setup.py sdist bdist_wheel -4. (optional) Install the wheel locally by running: python3 -m pip install +1. Change directory to the root that contains `setup.py` +2. Update the version in the `setup.py` +3. Run the following command: `python3 setup.py sdist bdist_wheel` +4. (optional) Install the wheel locally by running: `python3 -m pip install ` ### Contribution Guidelines diff --git a/cli/eventhandlers.py b/cli/eventhandlers.py index 567613c..5be3fe7 100644 --- a/cli/eventhandlers.py +++ b/cli/eventhandlers.py @@ -26,8 +26,7 @@ def _get_and_handle(self, event_queue): try: event_instance = event_queue.get() if self._debug: - logging.debug( - 'Message from queue: {}'.format(event_instance)) + logging.debug(f'Message from queue: {event_instance}') return output = self._get_output(event_instance) self._print_output(output) @@ -44,7 +43,7 @@ def _get_output(self, event_instance): event_output = self._get_event_ouput(event_instance) if event_output is None: return - return '--> {}\n'.format(event_output) + return f'--> {event_output}\n' def _get_event_ouput(self, event_instance): if event_instance.event is NutterStatusEvents.TestsListing: @@ -64,33 +63,30 @@ def _get_event_ouput(self, event_instance): return '' def _handle_testlisting(self, event): - return 'Looking for tests in {}'.format(event.data) + return f'Looking for tests in {event.data}' def _handle_testlistingfiltered(self, event): self._filtered_tests = event.data - return '{} tests matched the pattern'.format(self._filtered_tests) + return f'{self._filtered_tests} tests matched the pattern' def _handle_testlistingresults(self, event): - return '{} tests found'.format(event.data) + return f'{event.data} tests found' def _handle_testsexecuted(self, event): - return '{} Success:{} {}'.format(event.data.notebook_path, - event.data.success, - event.data.notebook_run_page_url) + return f'{event.data.notebook_path} Success:{event.data.success} {event.data.notebook_run_page_url}' def _handle_testsexecutionrequest(self, event): - return 'Execution request: {}'.format(event.data) + return f'Execution request: {event.data}' def _handle_testscheduling(self, event): num_of_tests = self._num_of_test_to_execute() self._scheduled_tests += 1 - return '{} of {} tests scheduled for execution'.format(self._scheduled_tests, - num_of_tests) + return f'{self._scheduled_tests} of {num_of_tests} tests scheduled for execution' def _handle_testsexecutionresult(self, event): num_of_tests = self._num_of_test_to_execute() self._done_tests += 1 - return '{} of {} tests executed'.format(self._done_tests, num_of_tests) + return f'{self._done_tests} of {num_of_tests} tests executed' def _num_of_test_to_execute(self): if self._filtered_tests > 0: diff --git a/cli/nuttercli.py b/cli/nuttercli.py index 48318fb..dde5040 100644 --- a/cli/nuttercli.py +++ b/cli/nuttercli.py @@ -9,7 +9,8 @@ import datetime import common.api as api -from common.apiclient import DEFAULT_POLL_WAIT_TIME, InvalidConfigurationException +from common.apiclient import InvalidConfigurationException +from common.utils import get_nutter_version import common.resultsview as view from .eventhandlers import ConsoleEventHandler @@ -17,20 +18,9 @@ from .reportsman import ReportWriters from . import reportsman as reports -__version__ = '0.1.35' - -BUILD_NUMBER_ENV_VAR = 'NUTTER_BUILD_NUMBER' - - -def get_cli_version(): - build_number = os.environ.get(BUILD_NUMBER_ENV_VAR) - if build_number: - return '{}.{}'.format(__version__, build_number) - return __version__ - def get_cli_header(): - header = 'Nutter Version {}\n'.format(get_cli_version()) + header = f'Nutter Version {get_nutter_version()}\n' header += '+' * 50 header += '\n' @@ -50,32 +40,62 @@ def __init__(self, debug=False, log_to_file=False, version=False): self._set_nutter(debug) super().__init__() - def run(self, test_pattern, cluster_id, + def run(self, test_pattern, cluster_id=None, cluster_name=None, serverless=None, timeout=120, junit_report=False, tags_report=False, max_parallel_tests=1, - recursive=False, poll_wait_time=DEFAULT_POLL_WAIT_TIME, notebook_params=None): + recursive=False, notebook_params=None): try: - logging.debug(""" Running tests. test_pattern: {} cluster_id: {} notebook_params: {} timeout: {} + # Validate compute configuration + compute_options = [cluster_id is not None, cluster_name is not None, serverless is not None] + compute_count = sum(compute_options) + + if compute_count == 0: + self._logger.fatal("Must specify one of: --cluster_id, --cluster_name, or --serverless") + exit(1) + + if compute_count > 1: + self._logger.fatal("Cannot specify multiple compute options. Use only one of: --cluster_id, --cluster_name, or --serverless") + exit(1) + + # If cluster_name is provided, resolve it to cluster_id + if cluster_name is not None: + logging.debug(f"Resolving cluster name '{cluster_name}' to cluster ID") + try: + cluster_id = self._nutter.dbclient.get_cluster_id_by_name(cluster_name) + logging.debug(f"Resolved cluster name '{cluster_name}' to cluster ID '{cluster_id}'") + except ValueError as e: + self._logger.fatal(f"Error resolving cluster name: {e}") + exit(1) + + # Log execution parameters + if serverless: + logging.debug(f"Running tests with serverless compute (version: {serverless})") + else: + logging.debug(f"Running tests with cluster ID: {cluster_id}") + + logging.debug(""" Running tests. test_pattern: {} cluster_id: {} serverless: {} notebook_params: {} timeout: {} junit_report: {} max_parallel_tests: {} tags_report: {} recursive:{} """ - .format(test_pattern, cluster_id, timeout, + .format(test_pattern, cluster_id, serverless, timeout, junit_report, max_parallel_tests, tags_report, recursive, notebook_params)) - logging.debug("Executing test(s): {}".format(test_pattern)) + logging.debug(f"Executing test(s): {test_pattern}") if self._is_a_test_pattern(test_pattern): logging.debug('Executing pattern') results = self._nutter.run_tests( - test_pattern, cluster_id, timeout, - max_parallel_tests, recursive, poll_wait_time, notebook_params) + test_pattern, cluster_id=cluster_id, timeout=timeout, + max_parallel_tests=max_parallel_tests, recursive=recursive, + notebook_params=notebook_params, serverless=serverless) self._nutter.events_processor_wait() self._handle_results(results, junit_report, tags_report) return logging.debug('Executing single test') - result = self._nutter.run_test(test_pattern, cluster_id, - timeout, poll_wait_time) + result = self._nutter.run_test(test_pattern, cluster_id=cluster_id, + timeout=timeout, notebook_params=notebook_params, + serverless=serverless) self._handle_results([result], junit_report, tags_report) @@ -85,7 +105,7 @@ def run(self, test_pattern, cluster_id, def list(self, path, recursive=False): try: - logging.debug("Running tests. path: {}".format(path)) + logging.debug(f"Running tests. path: {path}") results = self._nutter.list_tests(path, recursive) self._nutter.events_processor_wait() self._display_list_results(results) @@ -101,7 +121,8 @@ def _handle_results(self, results, junit_report, tags_report): ExecutionResultsValidator().validate(results) - def _get_report_writer_manager(self, junit_report, tags_report): + @staticmethod + def _get_report_writer_manager(junit_report, tags_report): writers = 0 if junit_report: writers = ReportWriters.JUNIT @@ -110,15 +131,16 @@ def _get_report_writer_manager(self, junit_report, tags_report): return reports.get_report_writer_manager(writers) - def _handle_reports(self, report_manager, exec_results): + @staticmethod + def _handle_reports(report_manager, exec_results): if not report_manager.has_providers(): logging.debug('No providers were registered.') return for provider in report_manager.providers_names(): - print('Writing {} report.'.format(provider)) + print(f'Writing {provider} report.') for exec_result in exec_results: - t_result = api.to_testresults( + t_result = api.to_test_results( exec_result.notebook_result.exit_output) if t_result is None: print('Warning:') @@ -130,33 +152,37 @@ def _handle_reports(self, report_manager, exec_results): for file_name in report_manager.write(): print('File {} written'.format(file_name)) - def _display_list_results(self, results): + @staticmethod + def _display_list_results(results): list_results_view = view.get_list_results_view(results) view.print_results_view(list_results_view) - def _display_test_results(self, results): + @staticmethod + def _display_test_results(results): results_view = view.get_run_results_views(results) view.print_results_view(results_view) - def _is_a_test_pattern(self, pattern): + @staticmethod + def _is_a_test_pattern(pattern): segments = pattern.split('/') if len(segments) > 0: search_pattern = segments[len(segments)-1] if api.TestNotebook._is_valid_test_name(search_pattern): return False return True - logging.Fatal( + logging.fatal( """ Invalid argument. The value must be the full path to the test or a pattern """) - def _print_cli_header(self): + @staticmethod + def _print_cli_header(): print(get_cli_header()) def _set_nutter(self, debug): try: event_handler = ConsoleEventHandler(debug) self._nutter = api.get_nutter(event_handler) - except InvalidConfigurationException as ex: + except (InvalidConfigurationException, ValueError) as ex: logging.debug(ex) self._print_config_error_and_exit() @@ -166,17 +192,23 @@ def _handle_show_version(self, version): print(self._get_version_label()) exit(0) - def _get_version_label(self): - version = get_cli_version() + @staticmethod + def _get_version_label(): + version = get_nutter_version() return 'Nutter Version {}'.format(version) - def _print_config_error_and_exit(self): + @staticmethod + def _print_config_error_and_exit(): print(""" Invalid configuration.\n - DATABRICKS_HOST and DATABRICKS_TOKEN - environment variables are not set """) + Set relevant environment variables: i.e., DATABRICKS_HOST and DATABRICKS_TOKEN + Example: + export DATABRICKS_HOST= + export DATABRICKS_TOKEN= + """) exit(1) - def _set_debugging(self, debug, log_to_file): + @staticmethod + def _set_debugging(debug, log_to_file): if debug: log_name = None if log_to_file: diff --git a/cli/resultsvalidator.py b/cli/resultsvalidator.py index 9159c31..aa4bd26 100644 --- a/cli/resultsvalidator.py +++ b/cli/resultsvalidator.py @@ -20,33 +20,35 @@ def _validate_result(self, result): if not isinstance(result, ExecuteNotebookResult): raise ValueError("Expected ExecuteNotebookResult") if result.is_error: - msg = """ The job is not in a successfull terminal state. - Life cycle state:{} """.format(result.task_result_state) + msg = f""" The job is not in a successfull terminal state. + Life cycle state:{result.task_result_state} """ raise JobExecutionFailureException(message=msg) if result.notebook_result.is_error: - msg = 'The notebook failed. result state:{}'.format( - result.notebook_result.result_state) + msg = f'The notebook failed. result state: {result.notebook_result.result_state}' raise NotebookExecutionFailureException(message=msg) self._validate_test_results(result.notebook_result.exit_output) - def _validate_test_results(self, exit_output): + @staticmethod + def _validate_test_results(exit_output): test_results = None try: test_results = TestResults().deserialize(exit_output) except Exception as ex: logging.debug(ex) - msg = """ The Notebook exit output value is invalid or missing. - Additional info: {} """.format(str(ex)) + msg = f""" The Notebook exit output value is invalid or missing. + Additional info: {ex} """ raise InvalidNotebookOutputException(msg) for test_result in test_results.results: if not test_result.passed: - msg = 'The Test Case: {} failed.'.format(test_result.test_name) + msg = f'The Test Case: {test_result.test_name} failed.' raise TestCaseFailureException(msg) class TestCaseFailureException(Exception): + __test__ = False # Tell pytest this is not a test class + def __init__(self, message): super().__init__(message) diff --git a/common/api.py b/common/api.py index 23608dc..e0b479a 100644 --- a/common/api.py +++ b/common/api.py @@ -4,7 +4,6 @@ """ from abc import abstractmethod, ABCMeta -from common.apiclient import DEFAULT_POLL_WAIT_TIME from . import utils from .testresult import TestResults from . import scheduler @@ -14,7 +13,7 @@ import enum import logging - +import os import re import importlib @@ -37,15 +36,13 @@ def get_report_writer(writer): return instance -def to_testresults(exit_output): +def to_test_results(exit_output): if not exit_output: return None try: return TestResults().deserialize(exit_output) except Exception as ex: - error = 'error while creating result from {}. Error: {}'.format( - ex, exit_output) - logging.debug(error) + logging.debug(f'error while creating result from {ex}. Error: {exit_output}') return None @@ -87,22 +84,21 @@ def list_tests(self, path, recursive=False): return tests - def run_test(self, testpath, cluster_id, - timeout=120, pull_wait_time=DEFAULT_POLL_WAIT_TIME, notebook_params=None): + def run_test(self, testpath, cluster_id=None, + timeout=120, notebook_params=None, serverless=None): self._add_status_event(NutterStatusEvents.TestExecutionRequest, testpath) test_notebook = TestNotebook.from_path(testpath) if test_notebook is None: raise InvalidTestException result = self.dbclient.execute_notebook( - test_notebook.path, cluster_id, - timeout=timeout, pull_wait_time=pull_wait_time, notebook_params=notebook_params) + test_notebook.path, cluster_id=cluster_id, + timeout=timeout, notebook_params=notebook_params, serverless=serverless) return result - def run_tests(self, pattern, cluster_id, - timeout=120, max_parallel_tests=1, recursive=False, - poll_wait_time=DEFAULT_POLL_WAIT_TIME, notebook_params=None): + def run_tests(self, pattern, cluster_id=None, + timeout=120, max_parallel_tests=1, recursive=False, notebook_params=None, serverless=None): self._add_status_event(NutterStatusEvents.TestExecutionRequest, pattern) root, pattern_to_match = self._get_root_and_pattern(pattern) @@ -119,7 +115,7 @@ def run_tests(self, pattern, cluster_id, NutterStatusEvents.TestsListingFiltered, len(filtered_notebooks)) return self._schedule_and_run( - filtered_notebooks, cluster_id, max_parallel_tests, timeout, poll_wait_time, notebook_params) + filtered_notebooks, cluster_id, max_parallel_tests, timeout, notebook_params, serverless) def events_processor_wait(self): if self._events_processor is None: @@ -140,7 +136,8 @@ def _list_tests(self, path, recursive): for test in self._list_tests(directory.path, True): yield test - def _get_status_events_handler(self, events_handler): + @staticmethod + def _get_status_events_handler(events_handler): if events_handler is None: return None processor = StatusEventsHandler(events_handler) @@ -150,11 +147,12 @@ def _get_status_events_handler(self, events_handler): def _add_status_event(self, name, status): if self._events_processor is None: return - logging.debug('Status event. name:{} status:{}'.format(name, status)) + logging.debug(f'Status event. name:{name} status:{status}') self._events_processor.add_event(name, status) - def _get_root_and_pattern(self, pattern): + @staticmethod + def _get_root_and_pattern(pattern): segments = pattern.split('/') if len(segments) == 0: raise ValueError("Invalid pattern. The value must start with /") @@ -168,7 +166,7 @@ def _get_root_and_pattern(self, pattern): return root, valid_pattern def _schedule_and_run(self, test_notebooks, cluster_id, - max_parallel_tests, timeout, pull_wait_time, notebook_params=None): + max_parallel_tests, timeout, notebook_params=None, serverless=None): func_scheduler = scheduler.get_scheduler(max_parallel_tests) for test_notebook in test_notebooks: self._add_status_event( @@ -176,15 +174,16 @@ def _schedule_and_run(self, test_notebooks, cluster_id, logging.debug( 'Scheduling execution of: {}'.format(test_notebook.path)) func_scheduler.add_function(self._execute_notebook, - test_notebook.path, cluster_id, timeout, pull_wait_time, notebook_params) + test_notebook.path, cluster_id, timeout, notebook_params, serverless) return self._run_and_await(func_scheduler) - def _execute_notebook(self, test_notebook_path, cluster_id, timeout, pull_wait_time, notebook_params=None): - result = self.dbclient.execute_notebook(test_notebook_path, - cluster_id, timeout, pull_wait_time, notebook_params) + def _execute_notebook(self, test_notebook_path, cluster_id, timeout, notebook_params=None, serverless=None): + result = self.dbclient.execute_notebook(test_notebook_path, cluster_id=cluster_id, + timeout=timeout, notebook_params=notebook_params, + serverless=serverless) self._add_status_event(NutterStatusEvents.TestExecuted, ExecutionResultEventData.from_execution_results(result)) - logging.debug('Executed: {}'.format(test_notebook_path)) + logging.debug(f'Executed: {test_notebook_path}') return result def _run_and_await(self, func_scheduler): @@ -206,11 +205,13 @@ def _inspect_result(self, func_result): func_result.exception is not None)) if func_result.exception is not None: - logging.debug('Exception:{}'.format(func_result.exception)) + logging.debug(f'Exception:{func_result.exception}') raise func_result.exception class TestNotebook(object): + __test__ = False # Tell pytest this is not a test class + def __init__(self, name, path): if not self._is_valid_test_name(name): raise InvalidTestException @@ -223,7 +224,8 @@ def __eq__(self, obj): is_equal = obj.name == self.name and obj.path == self.path return isinstance(obj, TestNotebook) and is_equal - def get_test_name(self, name): + @staticmethod + def get_test_name(name): if name.lower().startswith('test_'): return name.split("_")[1] if name.lower().endswith('_test'): @@ -236,20 +238,22 @@ def from_path(cls, path): return None return cls(name, path) - @classmethod - def _is_valid_test_name(cls, name): - return utils.contains_test_prefix_or_surfix(name) + @staticmethod + def _is_valid_test_name(name): + return utils.contains_test_prefix_or_suffix(name) - @classmethod - def _get_notebook_name_from_path(cls, path): + @staticmethod + def _get_notebook_name_from_path(path): segments = path.split('/') if len(segments) == 0: raise ValueError('Invalid path. Path must start /') - name = segments[len(segments)-1] + name = segments[-1] return name class TestNamePatternMatcher(object): + __test__ = False # Tell pytest this is not a test class + def __init__(self, pattern): try: # * is an invalid regex in python @@ -259,7 +263,7 @@ def __init__(self, pattern): return re.compile(pattern) except re.error as ex: - logging.debug('Pattern could not be compiled. {}'.format(ex)) + logging.debug(f'Pattern could not be compiled. {ex}') raise ValueError( """ The pattern provided is invalid. The pattern must start with an alphanumeric character """) @@ -277,6 +281,7 @@ def filter_by_pattern(self, test_notebooks): results.append(test_notebook) return results + class ExecutionResultEventData(): def __init__(self, notebook_path, success, notebook_run_page_url): self.success = success @@ -287,11 +292,11 @@ def __init__(self, notebook_path, success, notebook_run_page_url): def from_execution_results(cls, exec_results): notebook_run_page_url = exec_results.notebook_run_page_url notebook_path = exec_results.notebook_path + success = False try: success = not exec_results.is_any_error except Exception as ex: - logging.debug("Error while creating the ExecutionResultEventData {}", ex) - success = False + logging.debug(f"Error while creating the ExecutionResultEventData {ex}") finally: return cls(notebook_path, success, notebook_run_page_url) diff --git a/common/apiclient.py b/common/apiclient.py index db282a8..4e34089 100644 --- a/common/apiclient.py +++ b/common/apiclient.py @@ -4,21 +4,20 @@ """ import uuid -import time -from databricks_api import DatabricksAPI -from . import authconfig as cfg, utils +import datetime from .apiclientresults import ExecuteNotebookResult, WorkspacePath -from .httpretrier import HTTPRetrier import logging +from .utils import get_nutter_version -DEFAULT_POLL_WAIT_TIME = 5 -MIN_TIMEOUT = 10 +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.jobs import NotebookTask, Task, JobEnvironment +from databricks.sdk.service.compute import Environment -def databricks_client(): +MIN_TIMEOUT = 10 - db = DatabricksAPIClient() - return db +def databricks_client(): + return DatabricksAPIClient() class DatabricksAPIClient(object): @@ -26,20 +25,8 @@ class DatabricksAPIClient(object): """ def __init__(self): - config = cfg.get_auth_config() self.min_timeout = MIN_TIMEOUT - - if config is None: - raise InvalidConfigurationException - - # TODO: remove the dependency with this API, an instead use httpclient/requests - db = DatabricksAPI(host=config.host, - token=config.token) - self.inner_dbclient = db - - # The retrier uses the recommended defaults - # https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/latest/jobs - self._retrier = HTTPRetrier() + self.dbclient = WorkspaceClient(product="nutter", product_version=get_nutter_version()) def list_notebooks(self, path): workspace_objects = self.list_objects(path) @@ -47,95 +34,125 @@ def list_notebooks(self, path): return notebooks def list_objects(self, path): - objects = self.inner_dbclient.workspace.list(path) - logging.debug('Creating WorkspacePath for path {}'.format(path)) - logging.debug('List response: \n\t{}'.format(objects)) + objects = self.dbclient.workspace.list(path) + logging.debug(f'Creating WorkspacePath for path {path}') + logging.debug(f'List response: \n\t{objects}') workspace_path_obj = WorkspacePath.from_api_response(objects) logging.debug('WorkspacePath created') return workspace_path_obj - def execute_notebook(self, notebook_path, cluster_id, timeout=120, - pull_wait_time=DEFAULT_POLL_WAIT_TIME, - notebook_params=None): + def get_cluster_id_by_name(self, cluster_name): + """ + Get cluster ID by cluster name (case-insensitive). + + Args: + cluster_name: The name of the cluster to find + + Returns: + The cluster ID if found + + Raises: + ValueError: If cluster name is empty, not found, or multiple clusters with the same name exist + """ + if not cluster_name: + raise ValueError("empty cluster name") + + # List all clusters and filter by name (case-insensitive) + clusters = list(self.dbclient.clusters.list()) + cluster_name_lower = cluster_name.lower() + matching_clusters = [c for c in clusters if c.cluster_name and c.cluster_name.lower() == cluster_name_lower] + + if len(matching_clusters) == 0: + raise ValueError(f"No cluster found with name '{cluster_name}'") + elif len(matching_clusters) > 1: + raise ValueError(f"Multiple clusters found with name '{cluster_name}'. Please use cluster_id instead.") + + return matching_clusters[0].cluster_id + + def execute_notebook(self, notebook_path, cluster_id=None, timeout=120, + notebook_params=None, serverless=None): + """ + Execute a notebook on either a cluster or serverless compute. + + Args: + notebook_path: Path to the notebook to execute + cluster_id: Cluster ID to run on (mutually exclusive with serverless) + timeout: Execution timeout in seconds (default: 120) + notebook_params: Parameters to pass to the notebook (dict) + serverless: Serverless environment version as integer (e.g., 1) (mutually exclusive with cluster_id) + + Raises: + ValueError: If validation fails + """ if not notebook_path: raise ValueError("empty path") - if not cluster_id: + + # Validate that either cluster_id or serverless is provided, but not both + if cluster_id is None and serverless is None: + raise ValueError("either cluster_id or serverless must be specified") + if cluster_id is not None and serverless is not None: + raise ValueError("cannot specify both cluster_id and serverless") + + # Validate cluster_id is not empty if provided + if cluster_id is not None and not cluster_id: raise ValueError("empty cluster id") + + # Validate serverless is an integer if provided + if serverless is not None: + if not isinstance(serverless, int): + raise ValueError("serverless must be an integer") + if timeout < self.min_timeout: raise ValueError( - "Timeout must be greater than {}".format(self.min_timeout)) + f"Timeout must be greater than {self.min_timeout}") if notebook_params is not None: if not isinstance(notebook_params, dict): - raise ValueError("Parameters must be in the form of a dictionary (See #run-single-test-notebook section in README)") - if pull_wait_time <= 1: - pull_wait_time = DEFAULT_POLL_WAIT_TIME + raise ValueError("Parameters must be in the form of a dictionary (See " + "#run-single-test-notebook section in README)") name = str(uuid.uuid1()) - ntask = self.__get_notebook_task(notebook_path, notebook_params) - - runid = self._retrier.execute(self.inner_dbclient.jobs.submit_run, - run_name=name, - existing_cluster_id=cluster_id, - notebook_task=ntask, - ) - - if 'run_id' not in runid: - raise NotebookTaskRunIDMissingException - - life_cycle_state, output = self.__pull_for_output( - runid['run_id'], timeout, pull_wait_time) - - return ExecuteNotebookResult.from_job_output(output) - - def __pull_for_output(self, run_id, timeout, pull_wait_time): - timedout = time.time() + timeout - output = {} - while time.time() < timedout: - output = self._retrier.execute( - self.inner_dbclient.jobs.get_run_output, run_id) - logging.debug(output) - - lcs = utils.recursive_find( - output, ['metadata', 'state', 'life_cycle_state']) - - # As per: - # https://docs.azuredatabricks.net/api/latest/jobs.html#jobsrunlifecyclestate - # All these are terminal states - if lcs == 'TERMINATED' or lcs == 'SKIPPED' or lcs == 'INTERNAL_ERROR': - logging.debug('Terminal state returned. {}'.format(lcs)) - return lcs, output - logging.debug('Not terminal state returned. Sleeping {}s'.format(pull_wait_time)) - time.sleep(pull_wait_time) - - self._raise_timeout(output) - - def _raise_timeout(self, output): - run_page_url = utils.recursive_find( - output, ['metadata', 'run_page_url']) - raise TimeOutException( - """ Timeout while waiting for the result of a test.\n - Check the status of the execution\n - Run page URL: {} """.format(run_page_url)) - - def __get_notebook_task(self, path, params): - ntask = {} - ntask['notebook_path'] = path - base_params = [] - if params is not None: - for key in params: - param = {} - param['key'] = key - param['value'] = params[key] - base_params.append(param) - ntask['base_parameters'] = base_params - - return ntask - - -class NotebookTaskRunIDMissingException(Exception): - pass + + # Configure task based on compute type + if serverless is not None: + # Use serverless compute with environment + # Convert integer to string for environment_version + ntask = Task( + notebook_task=NotebookTask(notebook_path, base_parameters=notebook_params), + environment_key="serverless", + task_key="a" + ) + + # Define serverless environment + environments = [ + JobEnvironment( + environment_key="serverless", + spec=Environment(environment_version=str(serverless)) + ) + ] + + run = self.dbclient.jobs.submit_and_wait( + tasks=[ntask], + run_name=name, + environments=environments, + timeout=datetime.timedelta(seconds=timeout) + ) + else: + # Use existing cluster + ntask = Task( + notebook_task=NotebookTask(notebook_path, base_parameters=notebook_params), + existing_cluster_id=cluster_id, + task_key="a" + ) + + run = self.dbclient.jobs.submit_and_wait( + tasks=[ntask], + run_name=name, + timeout=datetime.timedelta(seconds=timeout) + ) + + return ExecuteNotebookResult.from_job_output(run, self.dbclient) class InvalidConfigurationException(Exception): diff --git a/common/apiclientresults.py b/common/apiclientresults.py index a096349..cf657ef 100644 --- a/common/apiclientresults.py +++ b/common/apiclientresults.py @@ -2,82 +2,55 @@ Copyright (c) Microsoft Corporation. Licensed under the MIT license. """ +from typing import Iterator, Optional, Union + +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.jobs import Run, RunResultState from . import utils from abc import ABCMeta from .testresult import TestResults import logging +from databricks.sdk.service.workspace import ObjectType, ObjectInfo -class ExecuteNotebookResult(object): - def __init__(self, life_cycle_state, notebook_path, - notebook_result, notebook_run_page_url): - self.task_result_state = life_cycle_state - self.notebook_path = notebook_path - self.notebook_result = notebook_result - self.notebook_run_page_url = notebook_run_page_url - - @classmethod - def from_job_output(cls, job_output): - life_cycle_state = utils.recursive_find( - job_output, ['metadata', 'state', 'life_cycle_state']) - notebook_path = utils.recursive_find( - job_output, ['metadata', 'task', 'notebook_task', 'notebook_path']) - notebook_run_page_url = utils.recursive_find( - job_output, ['metadata', 'run_page_url']) - notebook_result = NotebookOutputResult.from_job_output(job_output) - - return cls(life_cycle_state, notebook_path, - notebook_result, notebook_run_page_url) - - @property - def is_error(self): - # The assumption is that the task is an terminal state - # Success state must be TERMINATED all the others are considered failures - return self.task_result_state != 'TERMINATED' - - @property - def is_any_error(self): - if self.is_error: - return True - if self.notebook_result.is_error: - return True - if self.notebook_result.nutter_test_results is None: - return True - - for test_case in self.notebook_result.nutter_test_results.results: - if not test_case.passed: - return True - return False class NotebookOutputResult(object): - def __init__(self, result_state, exit_output, nutter_test_results): - self.result_state = result_state + def __init__(self, result_state: Union[RunResultState, str, None], exit_output, nutter_test_results): + if result_state is None: + self.result_state = None + elif isinstance(result_state, str): + self.result_state = result_state + else: + self.result_state = result_state.value self.exit_output = exit_output self.nutter_test_results = nutter_test_results @classmethod - def from_job_output(cls, job_output): + def from_job_output(cls, run: Run, dbclient: WorkspaceClient): + ntb_task = run.tasks[0] + run_id = ntb_task.run_id + + output = dbclient.jobs.get_run_output(run_id) + exit_output = '' nutter_test_results = '' - notebook_result_state = '' - if 'error' in job_output: - exit_output = job_output['error'] - - if 'notebook_output' in job_output: - notebook_result_state = utils.recursive_find( - job_output, ['metadata', 'state', 'result_state']) + notebook_result_state = ntb_task.state.result_state + if output.error: + exit_output = output - if 'result' in job_output['notebook_output']: - exit_output = job_output['notebook_output']['result'] - nutter_test_results = cls._get_nutter_test_results(exit_output) + if output.notebook_output is not None and output.notebook_output.result is not None: + exit_output = output.notebook_output.result + nutter_test_results = cls._get_nutter_test_results(exit_output) return cls(notebook_result_state, exit_output, nutter_test_results) @property def is_error(self): # https://docs.azuredatabricks.net/dev-tools/api/latest/jobs.html#jobsrunresultstate - return self.result_state != 'SUCCESS' and not self.is_run_from_notebook + return self.result_state != 'SUCCESS' and \ + self.result_state != 'SUCCESS_WITH_FAILURES' and \ + not self.is_run_from_notebook @property def is_run_from_notebook(self): @@ -98,12 +71,58 @@ def _to_nutter_test_results(cls, exit_output): try: return TestResults().deserialize(exit_output) except Exception as ex: - error = 'error while creating result from {}. Error: {}'.format( - ex, exit_output) + error = f'error while creating result from {ex}. Error: {exit_output}' logging.debug(error) return None +class ExecuteNotebookResult(object): + def __init__(self, task_result_state: Union[RunResultState, str, None], notebook_path: str, + notebook_result: NotebookOutputResult, notebook_run_page_url): + if task_result_state is None: + self.task_result_state = None + elif isinstance(task_result_state, str): + self.task_result_state = task_result_state + else: + self.task_result_state = task_result_state.value + self.notebook_path = notebook_path + self.notebook_result = notebook_result + self.notebook_run_page_url = notebook_run_page_url + + @classmethod + def from_job_output(cls, run: Run, dbclient: WorkspaceClient): + notebook_result = NotebookOutputResult.from_job_output(run, dbclient) + + return cls(run.state.life_cycle_state, run.tasks[0].notebook_task.notebook_path, + notebook_result, run.run_page_url) + + @property + def is_error(self) -> bool: + # task_result_state now contains lifecycle state + # TERMINATED is the normal completion state, other states indicate issues + err = self.task_result_state not in ['TERMINATED', 'SUCCESS', 'SUCCESS_WITH_FAILURES', None] + return err + + @property + def is_any_error(self): + if self.is_error: + logging.debug(f"is_error: {self}") + return True + if self.notebook_result.is_error: + logging.debug(f"self.notebook_result.is_error: {self}") + return True + if self.notebook_result.nutter_test_results is None: + logging.debug(f"self.notebook_result.nutter_test_results: {self}") + return True + + for test_case in self.notebook_result.nutter_test_results.results: + if not test_case.passed: + logging.debug(f"!test_case.passed: {test_case}, {self}") + return True + + return False + + class WorkspacePath(object): def __init__(self, notebooks, directories): self.notebooks = notebooks @@ -111,24 +130,20 @@ def __init__(self, notebooks, directories): self.test_notebooks = self._set_test_notebooks() @classmethod - def from_api_response(cls, objects): + def from_api_response(cls, objects: Iterator[ObjectInfo]): notebooks = cls._set_notebooks(objects) directories = cls._set_directories(objects) return cls(notebooks, directories) @classmethod - def _set_notebooks(cls, objects): - if 'objects' not in objects: - return [] - return [NotebookObject(object['path']) for object in objects['objects'] - if object['object_type'] == 'NOTEBOOK'] + def _set_notebooks(cls, objects: Iterator[ObjectInfo]): + return [NotebookObject(obj.path) for obj in objects + if obj.object_type == ObjectType.NOTEBOOK] @classmethod - def _set_directories(cls, objects): - if 'objects' not in objects: - return [] - return [Directory(object['path']) for object in objects['objects'] - if object['object_type'] == 'DIRECTORY'] + def _set_directories(cls, objects: Iterator[ObjectInfo]): + return [Directory(obj.path) for obj in objects + if obj.object_type == ObjectType.DIRECTORY] def _set_test_notebooks(self): return [notebook for notebook in self.notebooks @@ -151,12 +166,12 @@ def _get_notebook_name_from_path(self, path): segments = path.split('/') if len(segments) == 0: raise ValueError('Invalid path. Path must start /') - name = segments[len(segments)-1] + name = segments[-1] return name @property def is_test_notebook(self): - return utils.contains_test_prefix_or_surfix(self.name) + return utils.contains_test_prefix_or_suffix(self.name) class Directory(WorkspaceObject): diff --git a/common/authconfig.py b/common/authconfig.py deleted file mode 100644 index 607b3dd..0000000 --- a/common/authconfig.py +++ /dev/null @@ -1,56 +0,0 @@ -""" -Copyright (c) Microsoft Corporation. -Licensed under the MIT license. -""" - -import os -from abc import abstractmethod, ABCMeta - -def get_auth_config(): - """ - """ - - providers = (EnvVariableAuthConfigProvider(),) - - for provider in providers: - config = provider.get_auth_config() - if config is not None and config.is_valid: - return config - return None - -class DatabricksApiAuthConfigProvider(object): - """ - """ - - __metaclass__ = ABCMeta - - @abstractmethod - def get_auth_config(self): - pass - -class DatabricksApiAuthConfig(object): - def __init__(self, host, token, insecure): - self.host = host - self.token = token - self.insecure = insecure - - @property - def is_valid(self): - if self.host == '' or self.token == '': - return False - - return self.host is not None and self.token is not None - -class EnvVariableAuthConfigProvider(DatabricksApiAuthConfigProvider): - """ - Loads token auth configuration from environment variables. - """ - - def get_auth_config(self): - host = os.environ.get('DATABRICKS_HOST') - token = os.environ.get('DATABRICKS_TOKEN') - insecure = os.environ.get('DATABRICKS_INSECURE') - config = DatabricksApiAuthConfig(host, token, insecure) - if config.is_valid: - return config - return None diff --git a/common/httpretrier.py b/common/httpretrier.py deleted file mode 100644 index f0ec28f..0000000 --- a/common/httpretrier.py +++ /dev/null @@ -1,47 +0,0 @@ -""" -Copyright (c) Microsoft Corporation. -Licensed under the MIT license. -""" - -import logging -from time import sleep -from requests.exceptions import HTTPError - - -class HTTPRetrier(object): - def __init__(self, max_retries=20, delay=30): - self._max_retries = max_retries - self._delay = delay - self._tries = 0 - - def execute(self, function, *args, **kwargs): - waitfor = self._delay - retry = True - self._tries = 0 - while retry: - try: - retry = self._tries < self._max_retries - logging.debug( - 'Executing function with HTTP retry policy. Max tries:{} delay:{}' - .format(self._max_retries, self._delay)) - - return function(*args, **kwargs) - except HTTPError as exc: - logging.debug("Error: {0}".format(str(exc))) - if not retry: - raise - if isinstance(exc.response.status_code, int): - if exc.response.status_code < 500: - if not self._is_invalid_state_response(exc.response): - raise - if retry: - logging.debug( - 'Retrying in {0}s, {1} of {2} retries' - .format(str(waitfor), str(self._tries+1), str(self._max_retries))) - sleep(waitfor) - self._tries = self._tries + 1 - - def _is_invalid_state_response(self, response): - if response.status_code == 400: - return 'INVALID_STATE' in response.text - return False diff --git a/common/resultreports.py b/common/resultreports.py index a6bd5fa..d658c6a 100644 --- a/common/resultreports.py +++ b/common/resultreports.py @@ -32,12 +32,14 @@ def has_data(self): def write(self): pass - def _validate_add_results(self, notebook_path, test_result): + @staticmethod + def _validate_add_results(notebook_path, test_result): if not isinstance(test_result, TestResults): raise ValueError('Expected an instance of TestResults') if notebook_path is None or notebook_path == '': raise ValueError("Invalid notebook path") + class TagsReportRow(object): def __init__(self, notebook_name, test_result): self.notebook_name = notebook_name @@ -48,7 +50,8 @@ def __init__(self, notebook_name, test_result): self.duration = test_result.execution_time self.tags = self._to_tag_string(test_result.tags) - def _to_tag_string(self, tags): + @staticmethod + def _to_tag_string(tags): logging.debug(tags) if tags is None: return '' @@ -63,6 +66,7 @@ def to_string(self): self.test_name, self.passed_str, self.duration) return str_value + class TagsReportWriter(TestResultsReportWriter): def __init__(self): super().__init__() @@ -105,7 +109,8 @@ def add_result(self, notebook_path, test_result): t_suite = self._to_junitxml(notebook_path, test_result) self.all_test_suites.append(t_suite) - def _to_junitxml(self, notebook_path, test_result): + @staticmethod + def _to_junitxml(notebook_path, test_result): tsuite = TestSuite("nutter") for t_result in test_result.results: fail_error = None diff --git a/common/resultsview.py b/common/resultsview.py index a531800..b310585 100644 --- a/common/resultsview.py +++ b/common/resultsview.py @@ -9,6 +9,8 @@ from .stringwriter import StringWriter from .api import TestNotebook +RESULTS_SEPARATOR = '-' * 55 + def get_run_results_views(exec_results): if not isinstance(exec_results, list): @@ -60,12 +62,12 @@ def __init__(self, listresults): def get_view(self): writer = StringWriter() - writer.write_line('{}'.format('\nTests Found')) - writer.write_line('-' * 55) + writer.write_line('\nTests Found') + writer.write_line(RESULTS_SEPARATOR) for list_result in self.list_results: writer.write(list_result.get_view()) - writer.write_line('-' * 55) + writer.write_line(RESULTS_SEPARATOR) return writer.to_string() @@ -87,7 +89,7 @@ def from_test_notebook(cls, test_notebook): return cls(test_notebook.name, test_notebook.path) def get_view(self): - return "Name:\t{}\nPath:\t{}\n\n".format(self.name, self.path) + return f"Name:\t{self.name}\nPath:\t{self.path}\n\n" @property def total(self): @@ -146,16 +148,14 @@ def _get_test_results(self, result): def get_view(self): sw = StringWriter() - sw.write_line("Notebook: {} - Lifecycle State: {}, Result: {}".format( - self.notebook_path, self.task_result_state, self.notebook_result_state)) + sw.write_line(f"Notebook: {self.notebook_path} - Lifecycle State: {self.task_result_state}, Result: {self.notebook_result_state}") sw.write_line('Run Page URL: {}'.format(self.notebook_run_page_url)) sw.write_line("=" * 60) if len(self.test_cases_views) == 0: sw.write_line("No test cases were returned.") - sw.write_line("Notebook output: {}".format( - self.raw_notebook_output)) + sw.write_line(f"Notebook output: {self.raw_notebook_output}") sw.write_line("=" * 60) return sw.to_string() @@ -181,15 +181,14 @@ def get_view(self): return sw.to_string() - def __to_testresults(self, exit_output): + @staticmethod + def __to_testresults(exit_output): if not exit_output: return None try: return TestResults().deserialize(exit_output) except Exception as ex: - error = 'error while creating result from {}. Error: {}'.format( - ex, exit_output) - logging.debug(error) + logging.debug('error while creating result from {}. Error: {}', ex, exit_output) return None @property @@ -206,6 +205,8 @@ def failing_tests(self): class TestCaseResultView(ResultsView): + __test__ = False # Tell pytest this is not a test class + def __init__(self, nutter_test_results): if not isinstance(nutter_test_results, TestResult): @@ -222,10 +223,9 @@ def __init__(self, nutter_test_results): def get_view(self): sw = StringWriter() - time = '{} seconds'.format(self.execution_time) - sw.write_line('{} ({})'.format(self.test_case, time)) + sw.write_line(f'{self.test_case} ({self.execution_time} seconds)') - if (self.passed): + if self.passed: return sw.to_string() sw.write_line("") diff --git a/common/scheduler.py b/common/scheduler.py index a572cbc..b45a48e 100644 --- a/common/scheduler.py +++ b/common/scheduler.py @@ -8,6 +8,7 @@ from threading import Thread from queue import Queue + def get_scheduler(num_of_workers): return Scheduler(num_of_workers) @@ -16,7 +17,7 @@ class Scheduler(object): def __init__(self, num_of_workers): if num_of_workers < 1 or num_of_workers > 15: raise ValueError( - 'Number of workers is invalid. It must be a value bettwen 1 and 15') + 'Number of workers is invalid. It must be a value between 1 and 15') self._num_of_workers = num_of_workers self._in_queue = Queue() self._out_queue = Queue() @@ -108,15 +109,14 @@ def run(self): if function_exe is None: logging.debug("Function Handler Stopped") break - logging.debug('Function Handler: Execute for {}'.format(function_exe)) + logging.debug(f'Function Handler: Execute for {function_exe}') result = function_exe.execute() logging.debug('Function Handler: Execute called.') self._out_queue.put(FunctionResult(result, None)) except Exception as ex: self._out_queue.put(FunctionResult(None, ex)) - logging.debug('Function Handler. Exception in function. Error {} {}' - .format(str(ex), ex is None)) + logging.debug(f'Function Handler. Exception in function. Error {ex} {ex is None}') finally: self._in_queue.task_done() self.set_done() diff --git a/common/pickleserializable.py b/common/serializabledata.py similarity index 78% rename from common/pickleserializable.py rename to common/serializabledata.py index a8475c5..2da6256 100644 --- a/common/pickleserializable.py +++ b/common/serializabledata.py @@ -5,7 +5,8 @@ from abc import abstractmethod, ABCMeta -class PickleSerializable(): + +class SerializableData: __metaclass__ = ABCMeta @abstractmethod @@ -13,5 +14,5 @@ def serialize(self): pass @abstractmethod - def deserialize(self): + def deserialize(self, pickle_string): pass diff --git a/common/statuseventhandler.py b/common/statuseventhandler.py index 8a7cb8d..2108bb4 100644 --- a/common/statuseventhandler.py +++ b/common/statuseventhandler.py @@ -14,7 +14,7 @@ class StatusEventsHandler(object): def __init__(self, handler): self._event_queue = Queue() - self._processor = Processor(handler, self._event_queue,) + self._processor = Processor(handler, self._event_queue, ) self._processor.daemon = True self._processor.start() @@ -25,6 +25,7 @@ def add_event(self, event, data): def wait(self): self._event_queue.join() + class StatusEvent(object): def __init__(self, event, data): if not isinstance(event, Enum): @@ -34,12 +35,14 @@ def __init__(self, event, data): self.event = event self.data = data + class EventHandler(ABC): @abstractmethod def handle(self, queue): pass + class Processor(Thread): def __init__(self, handler, event_queue): self._handler = handler diff --git a/common/stringwriter.py b/common/stringwriter.py index 42bf900..f73a86d 100644 --- a/common/stringwriter.py +++ b/common/stringwriter.py @@ -3,6 +3,7 @@ Licensed under the MIT license. """ + class StringWriter(): def __init__(self): self.result = "" diff --git a/common/testexecresults.py b/common/testexecresults.py index 485847e..bea39d7 100644 --- a/common/testexecresults.py +++ b/common/testexecresults.py @@ -8,7 +8,9 @@ from .testresult import TestResults -class TestExecResults(): +class TestExecResults: + __test__ = False # Tell pytest this is not a test class + def __init__(self, test_results): if not isinstance(test_results, TestResults): raise TypeError("test_results must be of type TestResults") @@ -26,7 +28,8 @@ def to_string(self): def exit(self, dbutils): dbutils.notebook.exit(self.test_results.serialize()) - def get_ExecuteNotebookResult(self, notebook_path, test_results): + @staticmethod + def get_ExecuteNotebookResult(notebook_path, test_results): notebook_result = NotebookOutputResult( 'N/A', None, test_results) diff --git a/common/testresult.py b/common/testresult.py index e21767e..9aa7467 100644 --- a/common/testresult.py +++ b/common/testresult.py @@ -1,92 +1,105 @@ -""" -Copyright (c) Microsoft Corporation. -Licensed under the MIT license. -""" - -import base64 -import pickle - -from py4j.protocol import Py4JJavaError - -from .pickleserializable import PickleSerializable - - -def get_test_results(): - return TestResults() - -class TestResults(PickleSerializable): - def __init__(self): - self.results = [] - self.test_cases = 0 - self.num_failures = 0 - self.total_execution_time = 0 - - def append(self, testresult): - if not isinstance(testresult, TestResult): - raise TypeError("Can only append TestResult to TestResults") - - self.results.append(testresult) - self.test_cases = self.test_cases + 1 - if (not testresult.passed): - self.num_failures = self.num_failures + 1 - - total_execution_time = self.total_execution_time + testresult.execution_time - self.total_execution_time = total_execution_time - - def serialize(self): - for i in self.results: - if isinstance(i.exception, Py4JJavaError): - i.exception = Exception(str(i.exception)) - bin_data = pickle.dumps(self) - return str(base64.encodebytes(bin_data), "utf-8") - - def deserialize(self, pickle_string): - bin_str = pickle_string.encode("utf-8") - decoded_bin_data = base64.decodebytes(bin_str) - return pickle.loads(decoded_bin_data) - - def passed(self): - for item in self.results: - if not item.passed: - return False - return True - - def __eq__(self, other): - if not isinstance(self, other.__class__): - return False - if len(self.results) != len(other.results): - return False - for item in other.results: - if not self.__item_in_list_equalto(item): - return False - - return True - - def __item_in_list_equalto(self, expected_item): - for item in self.results: - if (item == expected_item): - return True - - return False - -class TestResult: - def __init__(self, test_name, passed, - execution_time, tags, exception=None, stack_trace=""): - - if not isinstance(tags, list): - raise ValueError("tags must be a list") - self.passed = passed - self.exception = exception - self.stack_trace = stack_trace - self.test_name = test_name - self.execution_time = execution_time - self.tags = tags - - def __eq__(self, other): - if isinstance(self, other.__class__): - return self.test_name == other.test_name \ - and self.passed == other.passed \ - and type(self.exception) == type(other.exception) \ - and str(self.exception) == str(other.exception) - - return False +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. +""" + +import base64 +import gzip +import zlib +import jsonpickle + +from py4j.protocol import Py4JJavaError + +from .serializabledata import SerializableData + + +def get_test_results(): + return TestResults() + + +class TestResults(SerializableData): + __test__ = False # Tell pytest this is not a test class + + def __str__(self) -> str: + return jsonpickle.encode(self) + + def __init__(self): + self.results = [] + self.test_cases = 0 + self.num_failures = 0 + self.total_execution_time = 0 + + def append(self, test_result): + if not isinstance(test_result, TestResult): + raise TypeError("Can only append TestResult to TestResults") + + self.results.append(test_result) + self.test_cases = self.test_cases + 1 + if not test_result.passed: + self.num_failures = self.num_failures + 1 + + total_execution_time = self.total_execution_time + test_result.execution_time + self.total_execution_time = total_execution_time + + @staticmethod + def serialize_object(obj): + bin_data = zlib.compress(bytes(jsonpickle.encode(obj), 'utf-8')) + return str(base64.encodebytes(bin_data), "utf-8") + + def serialize(self): + for i in self.results: + if isinstance(i.exception, Py4JJavaError): + i.exception = Exception(str(i.exception)) + return self.serialize_object(self) + + def deserialize(self, pickle_string): + bin_str = pickle_string.encode("utf-8") + decoded_bin_data = base64.decodebytes(bin_str) + return jsonpickle.decode(zlib.decompress(decoded_bin_data)) + + def passed(self): + for item in self.results: + if not item.passed: + return False + return True + + def __eq__(self, other): + if not isinstance(self, other.__class__): + return False + if len(self.results) != len(other.results): + return False + for item in other.results: + if not self.__item_in_list_equalto(item): + return False + + return True + + def __item_in_list_equalto(self, expected_item): + for item in self.results: + if item == expected_item: + return True + + return False + + +class TestResult: + __test__ = False # Tell pytest this is not a test class + + def __init__(self, test_name, passed, execution_time, tags, exception=None, stack_trace=""): + if not isinstance(tags, list): + raise ValueError("tags must be a list") + self.passed = passed + self.exception = exception + self.stack_trace = stack_trace + self.test_name = test_name + self.execution_time = execution_time + self.tags = tags + + def __eq__(self, other): + if isinstance(self, other.__class__): + return self.test_name == other.test_name \ + and self.passed == other.passed \ + and type(self.exception) == type(other.exception) \ + and str(self.exception) == str(other.exception) + + return False diff --git a/common/utils.py b/common/utils.py index 0efe418..2abd990 100644 --- a/common/utils.py +++ b/common/utils.py @@ -3,23 +3,23 @@ Licensed under the MIT license. """ -def recursive_find(dict_instance, keys): - if not isinstance(keys, list): - raise ValueError("Expected list of keys") - if not isinstance(dict_instance, dict): - return None - if len(keys) == 0: - return None - key = keys[0] - value = dict_instance.get(key, None) - if value is None: - return None - if len(keys) == 1: - return value - return recursive_find(value, keys[1:len(keys)]) - -def contains_test_prefix_or_surfix(name): +import os + +__version__ = '0.2.0' + +BUILD_NUMBER_ENV_VAR = 'NUTTER_BUILD_NUMBER' + + +def get_nutter_version(): + build_number = os.environ.get(BUILD_NUMBER_ENV_VAR) + if build_number: + return f'{__version__}.{build_number}' + return __version__ + + +def contains_test_prefix_or_suffix(name): if name is None: return False - return name.lower().startswith('test_') or name.lower().endswith('_test') + lower_name = name.lower() + return lower_name.startswith('test_') or lower_name.endswith('_test') diff --git a/dev_requirements.txt b/dev_requirements.txt index 4d90e1b..cf5e99f 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,4 +1,4 @@ -pytest==5.0.1 +pytest==9.0.1 mock pytest-mock -pytest-cov \ No newline at end of file +pytest-cov diff --git a/requirements.txt b/requirements.txt index e0db8e8..fe87c81 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ -databricks-api -requests +databricks-sdk fire junit_xml py4j - +jsonpickle diff --git a/runtime/fixtureloader.py b/runtime/fixtureloader.py index 7c2ee23..07cebf6 100644 --- a/runtime/fixtureloader.py +++ b/runtime/fixtureloader.py @@ -46,7 +46,8 @@ def load_fixture(self, nutter_fixture): return self.__test_case_dictionary - def __is_test_method(self, attribute): + @staticmethod + def __is_test_method(attribute): if attribute.startswith("before_") or \ attribute.startswith("run_") or \ attribute.startswith("assertion_") or \ @@ -54,7 +55,8 @@ def __is_test_method(self, attribute): return True return False - def __set_method(self, case, name, func): + @staticmethod + def __set_method(case, name, func): if name.startswith("before_"): case.set_before(func) return case @@ -81,7 +83,8 @@ def __get_test_name(self, full_name): return name - def __remove_prefix(self, text, prefix): + @staticmethod + def __remove_prefix(text, prefix): if text.startswith(prefix): return text[len(prefix):] return text diff --git a/runtime/nutterfixture.py b/runtime/nutterfixture.py index 3db9f81..dc9d6fa 100644 --- a/runtime/nutterfixture.py +++ b/runtime/nutterfixture.py @@ -20,6 +20,7 @@ def tag_decorator(function): function.tag = the_tag return function + return tag_decorator @@ -42,9 +43,9 @@ def execute_tests(self): self.before_all() for key, value in self.__test_case_dict.items(): - logging.debug('Running test: {}'.format(key)) + logging.debug(f'Running test: {key}') test_result = value.execute_test() - logging.debug('Completed running test: {}'.format(key)) + logging.debug(f'Completed running test: {key}') self.test_results.append(test_result) if len(self.__test_case_dict) > 0 and self.__has_method("after_all"): @@ -66,9 +67,9 @@ def __load_fixture(self): raise InvalidTestFixtureException("Invalid Test Fixture") self.__test_case_dict = OrderedDict(sorted(test_case_dict.items(), key=lambda t: t[0])) - logging.debug("Found {} test cases".format(len(test_case_dict))) + logging.debug(f"Found {len(test_case_dict)} test cases") for key, value in self.__test_case_dict.items(): - logging.debug('Test Case: {}'.format(key)) + logging.debug(f'Test Case: {key}') def __has_method(self, method_name): method = getattr(self, method_name, None) @@ -81,6 +82,7 @@ class InvalidTestFixtureException(Exception): def __init__(self, message): super().__init__(message) + class InitializationException(Exception): def __init__(self, message): super().__init__(message) diff --git a/runtime/runner.py b/runtime/runner.py index c80c852..d2fb5d1 100644 --- a/runtime/runner.py +++ b/runtime/runner.py @@ -44,7 +44,8 @@ def execute(self): return self._collect_results(results) - def _collect_results(self, results): + @staticmethod + def _collect_results(results): """Collect all results in a single TestExecResults object.""" all_results = TestResults() diff --git a/runtime/testcase.py b/runtime/testcase.py index a4f80c4..eb2fa8d 100644 --- a/runtime/testcase.py +++ b/runtime/testcase.py @@ -1,104 +1,105 @@ -""" -Copyright (c) Microsoft Corporation. -Licensed under the MIT license. -""" - -import os -import time -import traceback -from common.testresult import TestResult - - -def get_testcase(test_name): - - tc = TestCase(test_name) - - return tc - - -class TestCase(): - ERROR_MESSAGE_ASSERTION_MISSING = """ TestCase does not contain an assertion function. - Please pass a function to set_assertion """ - - def __init__(self, test_name): - self.test_name = test_name - self.before = None - self.__before_set = False - self.run = None - self.__run_set = False - self.assertion = None - self.after = None - self.__after_set = False - self.invalid_message = "" - self.tags = [] - - def set_before(self, before): - self.before = before - self.__before_set = True - - def set_run(self, run): - self.run = run - self.__run_set = True - - def set_assertion(self, assertion): - self.assertion = assertion - - def set_after(self, after): - self.after = after - self.__after_set = True - - def execute_test(self): - start_time = time.perf_counter() - try: - if hasattr(self.run, "tag"): - if isinstance(self.run.tag, list): - self.tags.extend(self.run.tag) - else: - self.tags.append(self.run.tag) - if not self.is_valid(): - raise NoTestCasesFoundError( - "Both a run and an assertion are required for every test") - if self.__before_set and self.before is not None: - self.before() - if self.__run_set: - self.run() - self.assertion() - if self.__after_set and self.after is not None: - self.after() - - except Exception as exc: - return TestResult(self.test_name, False, - self.__get_elapsed_time(start_time), self.tags, - exc, traceback.format_exc()) - - return TestResult(self.test_name, True, - self.__get_elapsed_time(start_time), self.tags, None) - - def is_valid(self): - is_valid = True - - if self.assertion is None: - self.__add_message_to_error(self.ERROR_MESSAGE_ASSERTION_MISSING) - is_valid = False - - return is_valid - - def __get_elapsed_time(self, start_time): - end_time = time.perf_counter() - elapsed_time = end_time - start_time - return elapsed_time - - def __add_message_to_error(self, message): - if self.invalid_message: - self.invalid_message += os.linesep - - self.invalid_message += message - - def get_invalid_message(self): - self.is_valid() - - return self.invalid_message - - -class NoTestCasesFoundError(Exception): - pass +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. +""" + +import os +import time +import traceback +from common.testresult import TestResult + + +def get_testcase(test_name): + tc = TestCase(test_name) + + return tc + + +class TestCase(): + __test__ = False # Tell pytest this is not a test class + + ERROR_MESSAGE_ASSERTION_MISSING = """ TestCase does not contain an assertion function. + Please pass a function to set_assertion """ + + def __init__(self, test_name): + self.test_name = test_name + self.before = None + self.__before_set = False + self.run = None + self.__run_set = False + self.assertion = None + self.after = None + self.__after_set = False + self.invalid_message = "" + self.tags = [] + + def set_before(self, before): + self.before = before + self.__before_set = True + + def set_run(self, run): + self.run = run + self.__run_set = True + + def set_assertion(self, assertion): + self.assertion = assertion + + def set_after(self, after): + self.after = after + self.__after_set = True + + def execute_test(self): + start_time = time.perf_counter() + try: + if hasattr(self.run, "tag"): + if isinstance(self.run.tag, list): + self.tags.extend(self.run.tag) + else: + self.tags.append(self.run.tag) + if not self.is_valid(): + raise NoTestCasesFoundError( + "Both a run and an assertion are required for every test") + if self.__before_set and self.before is not None: + self.before() + if self.__run_set: + self.run() + self.assertion() + if self.__after_set and self.after is not None: + self.after() + + except Exception as exc: + return TestResult(self.test_name, False, + self.__get_elapsed_time(start_time), self.tags, + exc, traceback.format_exc()) + + return TestResult(self.test_name, True, + self.__get_elapsed_time(start_time), self.tags, None) + + def is_valid(self): + is_valid = True + + if self.assertion is None: + self.__add_message_to_error(self.ERROR_MESSAGE_ASSERTION_MISSING) + is_valid = False + + return is_valid + + def __get_elapsed_time(self, start_time): + end_time = time.perf_counter() + elapsed_time = end_time - start_time + return elapsed_time + + def __add_message_to_error(self, message): + if self.invalid_message: + self.invalid_message += os.linesep + + self.invalid_message += message + + def get_invalid_message(self): + self.is_valid() + + return self.invalid_message + + +class NoTestCasesFoundError(Exception): + pass diff --git a/setup.py b/setup.py index 6015f04..2d0f27a 100644 --- a/setup.py +++ b/setup.py @@ -1,10 +1,11 @@ import setuptools -import cli.nuttercli as nuttercli +import common.utils as utils with open("README.md", "r") as fh: long_description = fh.read() -version = nuttercli.get_cli_version() +version = utils.get_nutter_version() + def parse_requirements(filename): """Load requirements from a pip requirements file.""" @@ -32,5 +33,5 @@ def parse_requirements(filename): "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], - python_requires='>=3.7.0', + python_requires='>=3.8.0', ) diff --git a/tests/cli/test_nuttercli.py b/tests/cli/test_nuttercli.py index e9367a5..b5f47b0 100644 --- a/tests/cli/test_nuttercli.py +++ b/tests/cli/test_nuttercli.py @@ -5,22 +5,26 @@ import pytest import os -import json +from unittest.mock import Mock +from databricks.sdk.service.jobs import Run, RunTask, NotebookTask, RunState, \ + RunLifeCycleState, RunResultState, RunOutput, NotebookOutput + import cli.nuttercli as nuttercli from cli.nuttercli import NutterCLI from common.apiclientresults import ExecuteNotebookResult +from common.utils import BUILD_NUMBER_ENV_VAR import mock from common.testresult import TestResults, TestResult from cli.reportsman import ReportWriterManager, ReportWritersTypes, ReportWriters def test__get_cli_version__without_build__env_var__returns_value(): - version = nuttercli.get_cli_version() + version = nuttercli.get_nutter_version() assert version is not None def test__get_cli_header_value(): - version = nuttercli.get_cli_version() + version = nuttercli.get_nutter_version() header = 'Nutter Version {}\n'.format(version) header += '+' * 50 header += '\n' @@ -30,18 +34,18 @@ def test__get_cli_header_value(): def test__get_cli_version__with_build__env_var__returns_value(mocker): - version = nuttercli.get_cli_version() + version = nuttercli.get_nutter_version() build_number = '1.2.3' mocker.patch.dict( - os.environ, {nuttercli.BUILD_NUMBER_ENV_VAR: build_number}) - version_with_build_number = nuttercli.get_cli_version() + os.environ, {BUILD_NUMBER_ENV_VAR: build_number}) + version_with_build_number = nuttercli.get_nutter_version() assert version_with_build_number == '{}.{}'.format(version, build_number) def test__get_version_label__valid_string(mocker): mocker.patch.dict(os.environ, {'DATABRICKS_HOST': 'myhost'}) mocker.patch.dict(os.environ, {'DATABRICKS_TOKEN': 'mytoken'}) - version = nuttercli.get_cli_version() + version = nuttercli.get_nutter_version() expected = 'Nutter Version {}'.format(version) cli = NutterCLI() version_from_cli = cli._get_version_label() @@ -71,8 +75,11 @@ def test__run__pattern__display_results(mocker): def test__nutter_cli_ctor__handles__configurationexception_and_exits_1(mocker): - mocker.patch.dict(os.environ, {'DATABRICKS_HOST': ''}) - mocker.patch.dict(os.environ, {'DATABRICKS_TOKEN': ''}) + from databricks.sdk import WorkspaceClient + + # Mock WorkspaceClient to raise ValueError when credentials are missing + mocker.patch.dict(os.environ, {'DATABRICKS_HOST': '', 'DATABRICKS_TOKEN': ''}) + mocker.patch.object(WorkspaceClient, '__init__', side_effect=ValueError("cannot configure default credentials")) with pytest.raises(SystemExit) as mock_ex: cli = NutterCLI() @@ -135,32 +142,39 @@ def _get_cli_for_tests(mocker, result_state, life_cycle_state, notebook_result): def _get_run_test_response(result_state, life_cycle_state, notebook_result): - data_json = """ - {"notebook_output": - {"result": "IHaveReturned", "truncated": false}, - "metadata": - {"execution_duration": 15000, - "run_type": "SUBMIT_RUN", - "cleanup_duration": 0, - "number_in_job": 1, - "cluster_instance": - {"cluster_id": "0925-141d1222-narcs242", - "spark_context_id": "803963628344534476"}, - "creator_user_name": "abc@microsoft.com", - "task": {"notebook_task": {"notebook_path": "/test_mynotebook"}}, - "run_id": 7, "start_time": 1569887259173, - "job_id": 4, - "state": {"result_state": "SUCCESS", "state_message": "", - "life_cycle_state": "TERMINATED"}, "setup_duration": 2000, - "run_page_url": "https://westus2.azuredatabricks.net/?o=14702dasda6094293890#job/4/run/1", - "cluster_spec": {"existing_cluster_id": "0925-141122-narcs242"}, "run_name": "myrun"}} - """ - data_dict = json.loads(data_json) - data_dict['notebook_output']['result'] = notebook_result - data_dict['metadata']['state']['result_state'] = result_state - data_dict['metadata']['state']['life_cycle_state'] = life_cycle_state - - return ExecuteNotebookResult.from_job_output(data_dict) + # Create proper SDK objects instead of JSON + result_state_enum = getattr(RunResultState, result_state) if result_state else None + lifecycle_state_enum = getattr(RunLifeCycleState, life_cycle_state) + + run_info = Run( + tasks=[ + RunTask( + task_key="test_task", + notebook_task=NotebookTask(notebook_path="/test_mynotebook"), + run_id=2, + state=RunState( + life_cycle_state=lifecycle_state_enum, + result_state=result_state_enum, + state_message="" + ) + ) + ], + run_id=1, + run_page_url="https://westus2.azuredatabricks.net/?o=14702dasda6094293890#job/4/run/1", + state=RunState( + life_cycle_state=lifecycle_state_enum, + result_state=result_state_enum, + state_message="" + ), + ) + + # Create mock WorkspaceClient + mock_client = Mock() + mock_client.jobs.get_run_output.return_value = RunOutput( + notebook_output=NotebookOutput(result=notebook_result, truncated=False) + ) + + return ExecuteNotebookResult.from_job_output(run_info, mock_client) def _get_list_tests_response(): @@ -171,38 +185,186 @@ def _get_list_tests_response(): def _get_run_tests_response(result_state, life_cycle_state, notebook_result): - data_json = """ - {"notebook_output": - {"result": "IHaveReturned", "truncated": false}, - "metadata": - {"execution_duration": 15000, - "run_type": "SUBMIT_RUN", - "cleanup_duration": 0, - "number_in_job": 1, - "cluster_instance": - {"cluster_id": "0925-141d1222-narcs242", - "spark_context_id": "803963628344534476"}, - "creator_user_name": "abc@microsoft.com", - "task": {"notebook_task": {"notebook_path": "/test_mynotebook"}}, - "run_id": 7, "start_time": 1569887259173, - "job_id": 4, - "state": {"result_state": "SUCCESS", "state_message": "", - "life_cycle_state": "TERMINATED"}, "setup_duration": 2000, - "run_page_url": "https://westus2.azuredatabricks.net/?o=14702dasda6094293890#job/4/run/1", - "cluster_spec": {"existing_cluster_id": "0925-141122-narcs242"}, "run_name": "myrun"}} - """ - data_dict = json.loads(data_json) - data_dict['notebook_output']['result'] = notebook_result - data_dict['metadata']['state']['result_state'] = result_state - data_dict['metadata']['state']['life_cycle_state'] = life_cycle_state - - data_dict2 = json.loads(data_json) - data_dict2['notebook_output']['result'] = notebook_result - data_dict2['metadata']['state']['result_state'] = result_state - data_dict2['metadata']['task']['notebook_task']['notebook_path'] = '/test_mynotebook2' - data_dict2['metadata']['state']['life_cycle_state'] = life_cycle_state - + # Create proper SDK objects for two test results + result_state_enum = getattr(RunResultState, result_state) if result_state else None + lifecycle_state_enum = getattr(RunLifeCycleState, life_cycle_state) + + # First result + run_info1 = Run( + tasks=[ + RunTask( + task_key="test_task", + notebook_task=NotebookTask(notebook_path="/test_mynotebook"), + run_id=2, + state=RunState( + life_cycle_state=lifecycle_state_enum, + result_state=result_state_enum, + state_message="" + ) + ) + ], + run_id=1, + run_page_url="https://westus2.azuredatabricks.net/?o=14702dasda6094293890#job/4/run/1", + state=RunState( + life_cycle_state=lifecycle_state_enum, + result_state=result_state_enum, + state_message="" + ), + ) + + # Second result with different notebook path + run_info2 = Run( + tasks=[ + RunTask( + task_key="test_task", + notebook_task=NotebookTask(notebook_path="/test_mynotebook2"), + run_id=3, + state=RunState( + life_cycle_state=lifecycle_state_enum, + result_state=result_state_enum, + state_message="" + ) + ) + ], + run_id=2, + run_page_url="https://westus2.azuredatabricks.net/?o=14702dasda6094293890#job/4/run/1", + state=RunState( + life_cycle_state=lifecycle_state_enum, + result_state=result_state_enum, + state_message="" + ), + ) + + # Create mock WorkspaceClients + mock_client1 = Mock() + mock_client1.jobs.get_run_output.return_value = RunOutput( + notebook_output=NotebookOutput(result=notebook_result, truncated=False) + ) + + mock_client2 = Mock() + mock_client2.jobs.get_run_output.return_value = RunOutput( + notebook_output=NotebookOutput(result=notebook_result, truncated=False) + ) + results = [] - results.append(ExecuteNotebookResult.from_job_output(data_dict)) - results.append(ExecuteNotebookResult.from_job_output(data_dict2)) + results.append(ExecuteNotebookResult.from_job_output(run_info1, mock_client1)) + results.append(ExecuteNotebookResult.from_job_output(run_info2, mock_client2)) return results + + +def test__run__with_cluster_name__resolves_to_cluster_id(mocker): + test_results = TestResults().serialize() + cli = _get_cli_for_tests( + mocker, 'SUCCESS', 'TERMINATED', test_results) + + # Mock the get_cluster_id_by_name method + mocker.patch.object(cli._nutter.dbclient, 'get_cluster_id_by_name') + cli._nutter.dbclient.get_cluster_id_by_name.return_value = 'resolved-cluster-id' + + mocker.patch.object(cli, '_display_test_results') + cli.run('test_mynotebook2', cluster_name='my-cluster') + + # Verify that get_cluster_id_by_name was called with the correct name + cli._nutter.dbclient.get_cluster_id_by_name.assert_called_once_with('my-cluster') + assert cli._display_test_results.call_count == 1 + + +def test__run__with_both_cluster_id_and_name__exits_with_error(mocker): + test_results = TestResults().serialize() + cli = _get_cli_for_tests( + mocker, 'SUCCESS', 'TERMINATED', test_results) + + # Should exit with error when both cluster_id and cluster_name are provided + with pytest.raises(SystemExit) as mock_ex: + cli.run('test_mynotebook2', cluster_id='cluster-id', cluster_name='cluster-name') + + assert mock_ex.type == SystemExit + assert mock_ex.value.code == 1 + + +def test__run__with_neither_cluster_id_nor_name__exits_with_error(mocker): + test_results = TestResults().serialize() + cli = _get_cli_for_tests( + mocker, 'SUCCESS', 'TERMINATED', test_results) + + # Should exit with error when neither cluster_id nor cluster_name is provided + with pytest.raises(SystemExit) as mock_ex: + cli.run('test_mynotebook2') + + assert mock_ex.type == SystemExit + assert mock_ex.value.code == 1 + + +def test__run__with_cluster_name_resolution_fails__exits_with_error(mocker): + test_results = TestResults().serialize() + cli = _get_cli_for_tests( + mocker, 'SUCCESS', 'TERMINATED', test_results) + + # Mock the get_cluster_id_by_name method to raise ValueError + mocker.patch.object(cli._nutter.dbclient, 'get_cluster_id_by_name') + cli._nutter.dbclient.get_cluster_id_by_name.side_effect = ValueError("Cluster not found") + + # Should exit with error when cluster name resolution fails + with pytest.raises(SystemExit) as mock_ex: + cli.run('test_mynotebook2', cluster_name='nonexistent-cluster') + + assert mock_ex.type == SystemExit + assert mock_ex.value.code == 1 + + +def test__run__with_serverless__executes_successfully(mocker): + test_results = TestResults().serialize() + cli = _get_cli_for_tests( + mocker, 'SUCCESS', 'TERMINATED', test_results) + + mocker.patch.object(cli, '_display_test_results') + cli.run('test_mynotebook2', serverless=1) + + # Verify that run_test was called with serverless parameter + assert cli._nutter.run_test.call_count == 1 + call_kwargs = cli._nutter.run_test.call_args[1] + assert call_kwargs['serverless'] == 1 + assert call_kwargs['cluster_id'] is None + assert cli._display_test_results.call_count == 1 + + +def test__run__with_serverless_and_cluster_id__exits_with_error(mocker): + test_results = TestResults().serialize() + cli = _get_cli_for_tests( + mocker, 'SUCCESS', 'TERMINATED', test_results) + + # Should exit with error when both serverless and cluster_id are provided + with pytest.raises(SystemExit) as mock_ex: + cli.run('test_mynotebook2', cluster_id='cluster-123', serverless=1) + + assert mock_ex.type == SystemExit + assert mock_ex.value.code == 1 + + +def test__run__with_no_compute_option__exits_with_error(mocker): + test_results = TestResults().serialize() + cli = _get_cli_for_tests( + mocker, 'SUCCESS', 'TERMINATED', test_results) + + # Should exit with error when no compute option is provided + with pytest.raises(SystemExit) as mock_ex: + cli.run('test_mynotebook2') + + assert mock_ex.type == SystemExit + assert mock_ex.value.code == 1 + + +def test__run__pattern_with_serverless__executes_successfully(mocker): + test_results = TestResults().serialize() + cli = _get_cli_for_tests( + mocker, 'SUCCESS', 'TERMINATED', test_results) + + mocker.patch.object(cli, '_display_test_results') + cli.run('my*', serverless=1) + + # Verify that run_tests was called with serverless parameter + assert cli._nutter.run_tests.call_count == 1 + call_kwargs = cli._nutter.run_tests.call_args[1] + assert call_kwargs['serverless'] == 1 + assert call_kwargs['cluster_id'] is None + assert cli._display_test_results.call_count == 1 diff --git a/tests/cli/test_resultsvalidator.py b/tests/cli/test_resultsvalidator.py index 93f09fa..bf45e95 100644 --- a/tests/cli/test_resultsvalidator.py +++ b/tests/cli/test_resultsvalidator.py @@ -4,10 +4,13 @@ """ import pytest +from unittest.mock import Mock +from databricks.sdk.service.jobs import Run, RunTask, NotebookTask, RunState, \ + RunLifeCycleState, RunResultState, RunOutput, NotebookOutput + import common.testresult as testresult from common.apiclientresults import ExecuteNotebookResult from cli.resultsvalidator import ExecutionResultsValidator, TestCaseFailureException, JobExecutionFailureException, NotebookExecutionFailureException, InvalidNotebookOutputException -import json def test__validate__results_is_none__valueerror(): @@ -141,29 +144,36 @@ def test__validate__results_with_job_failure__throws_jobexecutionfailureexceptio def __get_ExecuteNotebookResult(result_state, life_cycle_state, notebook_result): - data_json = """ - {"notebook_output": - {"result": "IHaveReturned", "truncated": false}, - "metadata": - {"execution_duration": 15000, - "run_type": "SUBMIT_RUN", - "cleanup_duration": 0, - "number_in_job": 1, - "cluster_instance": - {"cluster_id": "0925-141d1222-narcs242", - "spark_context_id": "803963628344534476"}, - "creator_user_name": "abc@microsoft.com", - "task": {"notebook_task": {"notebook_path": "/test_mynotebook"}}, - "run_id": 7, "start_time": 1569887259173, - "job_id": 4, - "state": {"result_state": "SUCCESS", "state_message": "", - "life_cycle_state": "TERMINATED"}, "setup_duration": 2000, - "run_page_url": "https://westus2.azuredatabricks.net/?o=14702dasda6094293890#job/4/run/1", - "cluster_spec": {"existing_cluster_id": "0925-141122-narcs242"}, "run_name": "myrun"}} - """ - data_dict = json.loads(data_json) - data_dict['notebook_output']['result'] = notebook_result - data_dict['metadata']['state']['result_state'] = result_state - data_dict['metadata']['state']['life_cycle_state'] = life_cycle_state - - return ExecuteNotebookResult.from_job_output(data_dict) + # Create proper SDK objects instead of JSON + result_state_enum = getattr(RunResultState, result_state) if result_state else None + lifecycle_state_enum = getattr(RunLifeCycleState, life_cycle_state) + + run_info = Run( + tasks=[ + RunTask( + task_key="test_task", + notebook_task=NotebookTask(notebook_path="/test_mynotebook"), + run_id=2, + state=RunState( + life_cycle_state=lifecycle_state_enum, + result_state=result_state_enum, + state_message="" + ) + ) + ], + run_id=1, + run_page_url="https://westus2.azuredatabricks.net/?o=14702dasda6094293890#job/4/run/1", + state=RunState( + life_cycle_state=lifecycle_state_enum, + result_state=result_state_enum, + state_message="" + ), + ) + + # Create mock WorkspaceClient + mock_client = Mock() + mock_client.jobs.get_run_output.return_value = RunOutput( + notebook_output=NotebookOutput(result=notebook_result, truncated=False) + ) + + return ExecuteNotebookResult.from_job_output(run_info, mock_client) diff --git a/tests/databricks/test_apiclient.py b/tests/databricks/test_apiclient.py index 3f66f54..5ecc21c 100644 --- a/tests/databricks/test_apiclient.py +++ b/tests/databricks/test_apiclient.py @@ -7,14 +7,21 @@ from common import apiclient as client from common.apiclient import DatabricksAPIClient import os -import json + +from databricks.sdk.service.jobs import Run, RunTask, NotebookTask, RunState, \ + RunLifeCycleState, RunResultState, RunOutput, NotebookOutput +from databricks.sdk.service.workspace import ObjectType, ObjectInfo, Language +from databricks.sdk.service.compute import ClusterDetails def test__databricks_client__token_host_notset__clientfails(mocker): - mocker.patch.dict(os.environ, {'DATABRICKS_HOST': ''}) - mocker.patch.dict(os.environ, {'DATABRICKS_TOKEN': ''}) + from databricks.sdk import WorkspaceClient + + # Mock WorkspaceClient to raise ValueError when credentials are missing + mocker.patch.dict(os.environ, {'DATABRICKS_HOST': '', 'DATABRICKS_TOKEN': ''}) + mocker.patch.object(WorkspaceClient, '__init__', side_effect=ValueError("cannot configure default credentials")) - with pytest.raises(client.InvalidConfigurationException): + with pytest.raises(ValueError): dbclient = client.databricks_client() @@ -29,13 +36,14 @@ def test__databricks_client__token_host_set__clientreturns(mocker): def test__list_notebooks__onenotebook__okay(mocker): db = __get_client(mocker) - mocker.patch.object(db.inner_dbclient.workspace, 'list') + mocker.patch.object(db.dbclient.workspace, 'list') - objects = """{"objects":[ - {"object_type":"NOTEBOOK","path":"/nutfixjob","language":"PYTHON"}, - {"object_type":"DIRECTORY","path":"/ETL-Part-3-1.0.3"}]}""" + objects = [ + ObjectInfo(object_type=ObjectType.NOTEBOOK, path="/nutfixjob", language=Language.PYTHON), + ObjectInfo(object_type=ObjectType.DIRECTORY, path="/ETL-Part-3-1.0.3") + ] - db.inner_dbclient.workspace.list.return_value = json.loads(objects) + db.dbclient.workspace.list.return_value = iter(objects) notebooks = db.list_notebooks('/') @@ -44,12 +52,13 @@ def test__list_notebooks__onenotebook__okay(mocker): def test__list_notebooks__zeronotebook__okay(mocker): db = __get_client(mocker) - mocker.patch.object(db.inner_dbclient.workspace, 'list') + mocker.patch.object(db.dbclient.workspace, 'list') - objects = """{"objects":[ - {"object_type":"DIRECTORY","path":"/ETL-Part-3-1.0.3"}]}""" + objects = [ + ObjectInfo(object_type=ObjectType.DIRECTORY, path="/ETL-Part-3-1.0.3") + ] - db.inner_dbclient.workspace.list.return_value = json.loads(objects) + db.dbclient.workspace.list.return_value = iter(objects) notebooks = db.list_notebooks('/') @@ -93,11 +102,9 @@ def test__execute_notebook__nonecluster__valueerror(mocker): def test__execute_notebook__success__executeresult_has_run_url(mocker): run_page_url = "http://runpage" - output_data = __get_submit_run_response( + run_info, run_output = __get_submit_run_response( 'SUCCESS', 'TERMINATED', '', run_page_url) - run_id = {} - run_id['run_id'] = 1 - db = __get_client_for_execute_notebook(mocker, output_data, run_id) + db = __get_client_for_execute_notebook(mocker, run_info, run_output) result = db.execute_notebook('/mynotebook', 'clusterid') @@ -105,11 +112,9 @@ def test__execute_notebook__success__executeresult_has_run_url(mocker): def test__execute_notebook__failure__executeresult_has_run_url(mocker): run_page_url = "http://runpage" - output_data = __get_submit_run_response( - 'FAILURE', 'TERMINATED', '', run_page_url) - run_id = {} - run_id['run_id'] = 1 - db = __get_client_for_execute_notebook(mocker, output_data, run_id) + run_info, run_output = __get_submit_run_response( + 'FAILED', 'TERMINATED', '', run_page_url) + db = __get_client_for_execute_notebook(mocker, run_info, run_output) result = db.execute_notebook('/mynotebook', 'clusterid') @@ -117,10 +122,8 @@ def test__execute_notebook__failure__executeresult_has_run_url(mocker): def test__execute_notebook__terminatestate__success(mocker): - output_data = __get_submit_run_response('SUCCESS', 'TERMINATED', '') - run_id = {} - run_id['run_id'] = 1 - db = __get_client_for_execute_notebook(mocker, output_data, run_id) + run_info, run_output = __get_submit_run_response('SUCCESS', 'TERMINATED', '') + db = __get_client_for_execute_notebook(mocker, run_info, run_output) result = db.execute_notebook('/mynotebook', 'clusterid') @@ -128,10 +131,8 @@ def test__execute_notebook__terminatestate__success(mocker): def test__execute_notebook__skippedstate__resultstate_is_SKIPPED(mocker): - output_data = __get_submit_run_response('', 'SKIPPED', '') - run_id = {} - run_id['run_id'] = 1 - db = __get_client_for_execute_notebook(mocker, output_data, run_id) + run_info, run_output = __get_submit_run_response('', 'SKIPPED', '') + db = __get_client_for_execute_notebook(mocker, run_info, run_output) result = db.execute_notebook('/mynotebook', 'clusterid') @@ -139,10 +140,8 @@ def test__execute_notebook__skippedstate__resultstate_is_SKIPPED(mocker): def test__execute_notebook__internal_error_state__resultstate_is_INTERNAL_ERROR(mocker): - output_data = __get_submit_run_response('', 'INTERNAL_ERROR', '') - run_id = {} - run_id['run_id'] = 1 - db = __get_client_for_execute_notebook(mocker, output_data, run_id) + run_info, run_output = __get_submit_run_response('', 'INTERNAL_ERROR', '') + db = __get_client_for_execute_notebook(mocker, run_info, run_output) result = db.execute_notebook('/mynotebook', 'clusterid') @@ -150,10 +149,12 @@ def test__execute_notebook__internal_error_state__resultstate_is_INTERNAL_ERROR( def test__execute_notebook__timeout_1_sec_lcs_isrunning__timeoutexception(mocker): - output_data = __get_submit_run_response('', 'RUNNING', '') - run_id = {} - run_id['run_id'] = 1 - db = __get_client_for_execute_notebook(mocker, output_data, run_id) + run_info, run_output = __get_submit_run_response('', 'RUNNING', '') + db = __get_client(mocker) + + # Make submit_and_wait raise TimeOutException when called + mocker.patch.object(db.dbclient.jobs, 'submit_and_wait') + db.dbclient.jobs.submit_and_wait.side_effect = client.TimeOutException("Timeout waiting for job") with pytest.raises(client.TimeOutException): db.min_timeout = 1 @@ -161,10 +162,8 @@ def test__execute_notebook__timeout_1_sec_lcs_isrunning__timeoutexception(mocker def test__execute_notebook__timeout_greater_than_min__valueerror(mocker): - output_data = __get_submit_run_response('', 'RUNNING', '') - run_id = {} - run_id['run_id'] = 1 - db = __get_client_for_execute_notebook(mocker, output_data, run_id) + run_info, run_output = __get_submit_run_response('', 'RUNNING', '') + db = __get_client_for_execute_notebook(mocker, run_info, run_output) with pytest.raises(ValueError): db.min_timeout = 10 @@ -175,42 +174,45 @@ def test__execute_notebook__timeout_greater_than_min__valueerror(mocker): def __get_submit_run_response(task_result_state, life_cycle_state, result, run_page_url=default_run_page_url): - data_json = """ - {"notebook_output": - {"result": "IHaveReturned", "truncated": false}, - "metadata": - {"execution_duration": 15000, - "run_type": "SUBMIT_RUN", - "cleanup_duration": 0, - "number_in_job": 1, - "cluster_instance": - {"cluster_id": "0925-141d1222-narcs242", - "spark_context_id": "803963628344534476"}, - "creator_user_name": "abc@microsoft.com", - "task": {"notebook_task": {"notebook_path": "/mynotebook"}}, - "run_id": 7, "start_time": 1569887259173, - "job_id": 4, - "state": {"result_state": "SUCCESS", "state_message": "", - "life_cycle_state": "TERMINATED"}, "setup_duration": 2000, - "run_page_url": "https://westus2.azuredatabricks.net/?o=14702dasda6094293890#job/4/run/1", - "cluster_spec": {"existing_cluster_id": "0925-141122-narcs242"}, "run_name": "myrun"}} - """ - data_dict = json.loads(data_json) - data_dict['notebook_output']['result'] = result - data_dict['metadata']['state']['result_state'] = task_result_state - data_dict['metadata']['state']['life_cycle_state'] = life_cycle_state - data_dict['metadata']['run_page_url'] = run_page_url - - return json.dumps(data_dict) - - -def __get_client_for_execute_notebook(mocker, output_data, run_id): - db = __get_client(mocker) - mocker.patch.object(db.inner_dbclient.jobs, 'submit_run') - db.inner_dbclient.jobs.submit_run.return_value = run_id - mocker.patch.object(db.inner_dbclient.jobs, 'get_run_output') - db.inner_dbclient.jobs.get_run_output.return_value = json.loads( - output_data) + # Create proper SDK objects instead of JSON + result_state = getattr(RunResultState, task_result_state) if task_result_state else None + lifecycle_state = getattr(RunLifeCycleState, life_cycle_state) + + run_info = Run( + tasks=[ + RunTask( + task_key="test_task", + notebook_task=NotebookTask(notebook_path="/mynotebook"), + run_id=2, + state=RunState( + life_cycle_state=lifecycle_state, + result_state=result_state, + state_message="" + ) + ) + ], + run_id=1, + run_page_url=run_page_url, + state=RunState( + life_cycle_state=lifecycle_state, + result_state=result_state, + state_message="" + ), + ) + + notebook_output = RunOutput( + notebook_output=NotebookOutput(result=result, truncated=False) + ) + + return run_info, notebook_output + + +def __get_client_for_execute_notebook(mocker, run_info, run_output): + db = __get_client(mocker) + mocker.patch.object(db.dbclient.jobs, 'submit_and_wait') + db.dbclient.jobs.submit_and_wait.return_value = run_info + mocker.patch.object(db.dbclient.jobs, 'get_run_output') + db.dbclient.jobs.get_run_output.return_value = run_output return db @@ -220,3 +222,155 @@ def __get_client(mocker): mocker.patch.dict(os.environ, {'DATABRICKS_TOKEN': 'mytoken'}) return DatabricksAPIClient() + + +def test__get_cluster_id_by_name__cluster_found__returns_id(mocker): + db = __get_client(mocker) + mocker.patch.object(db.dbclient.clusters, 'list') + + # Mock clusters list response + clusters = [ + ClusterDetails(cluster_id='1234-567890-abcd', cluster_name='test-cluster'), + ClusterDetails(cluster_id='9876-543210-wxyz', cluster_name='other-cluster') + ] + db.dbclient.clusters.list.return_value = iter(clusters) + + cluster_id = db.get_cluster_id_by_name('test-cluster') + + assert cluster_id == '1234-567890-abcd' + + +def test__get_cluster_id_by_name__empty_name__raises_error(mocker): + db = __get_client(mocker) + + with pytest.raises(ValueError, match="empty cluster name"): + db.get_cluster_id_by_name('') + + +def test__get_cluster_id_by_name__cluster_not_found__raises_error(mocker): + db = __get_client(mocker) + mocker.patch.object(db.dbclient.clusters, 'list') + + # Mock clusters list response + clusters = [ + ClusterDetails(cluster_id='1234-567890-abcd', cluster_name='test-cluster'), + ClusterDetails(cluster_id='9876-543210-wxyz', cluster_name='other-cluster') + ] + db.dbclient.clusters.list.return_value = iter(clusters) + + with pytest.raises(ValueError, match="No cluster found with name 'nonexistent-cluster'"): + db.get_cluster_id_by_name('nonexistent-cluster') + + +def test__get_cluster_id_by_name__multiple_clusters_with_same_name__raises_error(mocker): + db = __get_client(mocker) + mocker.patch.object(db.dbclient.clusters, 'list') + + # Mock clusters list response with duplicate names + clusters = [ + ClusterDetails(cluster_id='1234-567890-abcd', cluster_name='duplicate-cluster'), + ClusterDetails(cluster_id='9876-543210-wxyz', cluster_name='duplicate-cluster') + ] + db.dbclient.clusters.list.return_value = iter(clusters) + + with pytest.raises(ValueError, match="Multiple clusters found with name 'duplicate-cluster'"): + db.get_cluster_id_by_name('duplicate-cluster') + + +def test__get_cluster_id_by_name__case_insensitive__returns_id(mocker): + db = __get_client(mocker) + + # Test lowercase + mocker.patch.object(db.dbclient.clusters, 'list') + clusters = [ + ClusterDetails(cluster_id='1234-567890-abcd', cluster_name='Test-Cluster'), + ClusterDetails(cluster_id='9876-543210-wxyz', cluster_name='other-cluster') + ] + db.dbclient.clusters.list.return_value = iter(clusters) + cluster_id = db.get_cluster_id_by_name('test-cluster') + assert cluster_id == '1234-567890-abcd' + + # Test uppercase + mocker.patch.object(db.dbclient.clusters, 'list') + clusters = [ + ClusterDetails(cluster_id='1234-567890-abcd', cluster_name='Test-Cluster'), + ClusterDetails(cluster_id='9876-543210-wxyz', cluster_name='other-cluster') + ] + db.dbclient.clusters.list.return_value = iter(clusters) + cluster_id = db.get_cluster_id_by_name('TEST-CLUSTER') + assert cluster_id == '1234-567890-abcd' + + # Test mixed case + mocker.patch.object(db.dbclient.clusters, 'list') + clusters = [ + ClusterDetails(cluster_id='1234-567890-abcd', cluster_name='Test-Cluster'), + ClusterDetails(cluster_id='9876-543210-wxyz', cluster_name='other-cluster') + ] + db.dbclient.clusters.list.return_value = iter(clusters) + cluster_id = db.get_cluster_id_by_name('TeSt-ClUsTeR') + assert cluster_id == '1234-567890-abcd' + + +def test__get_cluster_id_by_name__multiple_clusters_case_insensitive__raises_error(mocker): + db = __get_client(mocker) + mocker.patch.object(db.dbclient.clusters, 'list') + + # Mock clusters list response with duplicate names in different cases + clusters = [ + ClusterDetails(cluster_id='1234-567890-abcd', cluster_name='Test-Cluster'), + ClusterDetails(cluster_id='9876-543210-wxyz', cluster_name='test-cluster') + ] + db.dbclient.clusters.list.return_value = iter(clusters) + + with pytest.raises(ValueError, match="Multiple clusters found with name"): + db.get_cluster_id_by_name('TEST-CLUSTER') + + +def test__execute_notebook__with_serverless__success(mocker): + db = __get_client(mocker) + run_info, run_output = __get_submit_run_response('SUCCESS', 'TERMINATED', 'result') + mocker.patch.object(db.dbclient.jobs, 'submit_and_wait') + db.dbclient.jobs.submit_and_wait.return_value = run_info + mocker.patch.object(db.dbclient.jobs, 'get_run_output') + db.dbclient.jobs.get_run_output.return_value = run_output + + result = db.execute_notebook('/mynotebook', serverless=1, timeout=120) + + # Verify submit_and_wait was called with environments + assert db.dbclient.jobs.submit_and_wait.call_count == 1 + call_kwargs = db.dbclient.jobs.submit_and_wait.call_args[1] + assert 'environments' in call_kwargs + assert len(call_kwargs['environments']) == 1 + assert call_kwargs['environments'][0].environment_key == 'serverless' + assert call_kwargs['environments'][0].spec.environment_version == '1' + + # Verify task has environment_key instead of cluster_id + tasks = db.dbclient.jobs.submit_and_wait.call_args[1]['tasks'] + assert tasks[0].environment_key == 'serverless' + assert tasks[0].existing_cluster_id is None + + +def test__execute_notebook__no_compute__raises_error(mocker): + db = __get_client(mocker) + + with pytest.raises(ValueError, match="either cluster_id or serverless must be specified"): + db.execute_notebook('/mynotebook', timeout=120) + + +def test__execute_notebook__both_cluster_and_serverless__raises_error(mocker): + db = __get_client(mocker) + + with pytest.raises(ValueError, match="cannot specify both cluster_id and serverless"): + db.execute_notebook('/mynotebook', cluster_id='cluster-123', serverless=1, timeout=120) + + +def test__execute_notebook__serverless_not_integer__raises_error(mocker): + db = __get_client(mocker) + + # Test with string + with pytest.raises(ValueError, match="serverless must be an integer"): + db.execute_notebook('/mynotebook', serverless='1.0', timeout=120) + + # Test with float + with pytest.raises(ValueError, match="serverless must be an integer"): + db.execute_notebook('/mynotebook', serverless=1.0, timeout=120) diff --git a/tests/databricks/test_authconfig.py b/tests/databricks/test_authconfig.py deleted file mode 100644 index d202bf5..0000000 --- a/tests/databricks/test_authconfig.py +++ /dev/null @@ -1,55 +0,0 @@ -""" -Copyright (c) Microsoft Corporation. -Licensed under the MIT license. -""" - -import pytest -import os -from common import authconfig as auth - -def test_tokenhostset_okay(mocker): - mocker.patch.dict(os.environ,{'DATABRICKS_HOST':'host'}) - mocker.patch.dict(os.environ,{'DATABRICKS_TOKEN':'token'}) - - config = auth.get_auth_config() - # Assert - assert config != None - assert config.host == 'host' - assert config.token == 'token' - -def test_onlytokenset_none(mocker): - mocker.patch.dict(os.environ,{'DATABRICKS_HOST':''}) - mocker.patch.dict(os.environ,{'DATABRICKS_TOKEN':'token'}) - - config = auth.get_auth_config() - # Assert - assert config == None - -def test_tokenhostsetemtpy_none(mocker): - mocker.patch.dict(os.environ,{'DATABRICKS_HOST':''}) - mocker.patch.dict(os.environ,{'DATABRICKS_TOKEN':''}) - - config = auth.get_auth_config() - # Assert - assert config == None - -def test_onlyhostset_none(mocker): - mocker.patch.dict(os.environ,{'DATABRICKS_HOST':'host'}) - mocker.patch.dict(os.environ,{'DATABRICKS_TOKEN':''}) - - config = auth.get_auth_config() - # Assert - assert config == None - -def test_tokenhostinsecureset_okay(mocker): - mocker.patch.dict(os.environ,{'DATABRICKS_HOST':'host'}) - mocker.patch.dict(os.environ,{'DATABRICKS_TOKEN':'token'}) - mocker.patch.dict(os.environ,{'DATABRICKS_INSECURE':'insecure'}) - - config = auth.get_auth_config() - - # Assert - assert config != None - assert config.host == 'host' - assert config.token == 'token' - assert config.insecure == 'insecure' diff --git a/tests/databricks/test_httpretrier.py b/tests/databricks/test_httpretrier.py deleted file mode 100644 index 4e9fd56..0000000 --- a/tests/databricks/test_httpretrier.py +++ /dev/null @@ -1,115 +0,0 @@ -""" -Copyright (c) Microsoft Corporation. -Licensed under the MIT license. -""" - -import pytest -from common.httpretrier import HTTPRetrier -import requests -import io -from requests.exceptions import HTTPError -from databricks_api import DatabricksAPI - -def test__execute__no_exception__returns_value(): - retrier = HTTPRetrier() - value = 'hello' - - return_value = retrier.execute(_get_value, value) - - assert return_value == value - -def test__execute__no_exception_named_args__returns_value(): - retrier = HTTPRetrier() - value = 'hello' - - return_value = retrier.execute(_get_value, return_value = value) - - assert return_value == value - - -def test__execute__no_exception_named_args_set_first_arg__returns_value(): - retrier = HTTPRetrier() - value = 'hello' - - return_values = retrier.execute(_get_values, value1 = value) - - assert return_values[0] == value - assert return_values[1] is None - - -def test__execute__no_exception_named_args_set_second_arg__returns_value(): - retrier = HTTPRetrier() - value = 'hello' - - return_values = retrier.execute(_get_values, value2 = value) - - assert return_values[0] is None - assert return_values[1] == value - -def test__execute__raises_non_http_exception__exception_arises(mocker): - retrier = HTTPRetrier() - raiser = ExceptionRaiser(0, ValueError) - - with pytest.raises(ValueError): - return_value = retrier.execute(raiser.execute) - -def test__execute__raises_500_http_exception__retries_twice_and_raises(mocker): - retrier = HTTPRetrier(2,1) - - db = DatabricksAPI(host='HOST',token='TOKEN') - mock_request = mocker.patch.object(db.client.session, 'request') - mock_resp = requests.models.Response() - mock_resp.status_code = 500 - mock_request.return_value = mock_resp - - with pytest.raises(HTTPError): - return_value = retrier.execute(db.jobs.get_run_output, 1) - assert retrier._tries == 2 - -def test__execute__raises_invalid_state_http_exception__retries_twice_and_raises(mocker): - retrier = HTTPRetrier(2,1) - - db = DatabricksAPI(host='HOST',token='TOKEN') - mock_request = mocker.patch.object(db.client.session, 'request') - response_body = " { 'error_code': 'INVALID_STATE', 'message': 'Run result is empty. " + \ - " There may have been issues while saving or reading results.'} " - - mock_resp = requests.models.Response() - mock_resp.status_code = 400 - mock_resp.raw = io.BytesIO(bytes(response_body, 'utf-8')) - mock_request.return_value = mock_resp - - with pytest.raises(HTTPError): - return_value = retrier.execute(db.jobs.get_run_output, 1) - assert retrier._tries == 2 - -def test__execute__raises_403_http_exception__no_retries_and_raises(mocker): - retrier = HTTPRetrier(2,1) - - db = DatabricksAPI(host='HOST',token='TOKEN') - mock_request = mocker.patch.object(db.client.session, 'request') - mock_resp = requests.models.Response() - mock_resp.status_code = 403 - mock_request.return_value = mock_resp - - with pytest.raises(HTTPError): - return_value = retrier.execute(db.jobs.get_run_output, 1) - assert retrier._tries == 0 - -def _get_value(return_value): - return return_value - -def _get_values(value1=None, value2=None): - return value1, value2 - -class ExceptionRaiser(object): - def __init__(self, raise_after, exception): - self._raise_after = raise_after - self._called = 1 - self._exception = exception - - def execute(self): - if self._called > self._raise_after: - raise self._exception() - self._called = self._called + 1 - return self._called \ No newline at end of file diff --git a/tests/databricks/test_utils.py b/tests/databricks/test_utils.py deleted file mode 100644 index 5b001d1..0000000 --- a/tests/databricks/test_utils.py +++ /dev/null @@ -1,53 +0,0 @@ -""" -Copyright (c) Microsoft Corporation. -Licensed under the MIT license. -""" - -import pytest -import common.utils as utils - - -def test__recursive_find__2_levels_value__value(mocker): - keys = ["a", "b"] - test_dict = __get_test_dict() - value = utils.recursive_find(test_dict, keys) - - assert value == "c" - - -def test__recursive_find__3_levels_no_value__none(mocker): - keys = ["a", "b", "c"] - test_dict = __get_test_dict() - value = utils.recursive_find(test_dict, keys) - - assert value is None - - -def test__recursive_find__3_levels_value__value(mocker): - keys = ["a", "C", "D"] - test_dict = __get_test_dict() - value = utils.recursive_find(test_dict, keys) - - assert value == "E" - - -def test__recursive_find__3_levels_value__value(mocker): - keys = ["a", "C", "D"] - test_dict = __get_test_dict() - value = utils.recursive_find(test_dict, keys) - - assert value == "E" - - -def test__recursive_find__2_levels_dict__dict(mocker): - keys = ["a", "C"] - test_dict = __get_test_dict() - value = utils.recursive_find(test_dict, keys) - - assert isinstance(value, dict) - - -def __get_test_dict(): - test_dict = {"a": {"b": "c", "C": {"D": "E"}}, "1": {"2": {"3": "4"}}} - - return test_dict diff --git a/tests/nutter/test_api.py b/tests/nutter/test_api.py index c39c9c6..1dc36cc 100644 --- a/tests/nutter/test_api.py +++ b/tests/nutter/test_api.py @@ -6,6 +6,10 @@ import pytest import os import json + +from databricks.sdk.service.jobs import Run, Task, NotebookTask, RunTask, RunState, \ + RunLifeCycleState, RunResultState, RunOutput, NotebookOutput + from common.api import Nutter, TestNotebook, NutterStatusEvents import common.api as nutter_api from common.testresult import TestResults, TestResult @@ -15,10 +19,14 @@ from common.apiclient import WorkspacePath, DatabricksAPIClient from common.statuseventhandler import StatusEventsHandler, EventHandler, StatusEvent +from databricks.sdk.service.workspace import ObjectType, ObjectInfo, Language + + def test__workspacepath__empty_object_response__instance_is_created(): objects = {} workspace_path = WorkspacePath.from_api_response(objects) + def test__get_report_writer__junitxmlreportwriter__valid_instance(): writer = nutter_api.get_report_writer('JunitXMLReportWriter') @@ -32,14 +40,14 @@ def test__get_report_writer__tagsreportwriter__valid_instance(): def test__list_tests__twotest__okay(mocker): - nutter = _get_nutter(mocker) dbapi_client = _get_client(mocker) nutter.dbclient = dbapi_client mocker.patch.object(nutter.dbclient, 'list_objects') workspace_path_1 = _get_workspacepathobject( - [('NOTEBOOK', '/mynotebook'), ('NOTEBOOK', '/test_mynotebook'), ('NOTEBOOK', '/mynotebook_test')]) + [(ObjectType.NOTEBOOK, '/mynotebook'), (ObjectType.NOTEBOOK, '/test_mynotebook'), + (ObjectType.NOTEBOOK, '/mynotebook_test')]) nutter.dbclient.list_objects.return_value = workspace_path_1 @@ -49,15 +57,16 @@ def test__list_tests__twotest__okay(mocker): assert tests[0] == TestNotebook('test_mynotebook', '/test_mynotebook') assert tests[1] == TestNotebook('mynotebook_test', '/mynotebook_test') -def test__list_tests__twotest_in_folder__okay(mocker): +def test__list_tests__twotest_in_folder__okay(mocker): nutter = _get_nutter(mocker) dbapi_client = _get_client(mocker) nutter.dbclient = dbapi_client mocker.patch.object(nutter.dbclient, 'list_objects') workspace_path_1 = _get_workspacepathobject( - [('NOTEBOOK', '/folder/mynotebook'), ('NOTEBOOK', '/folder/test_mynotebook'), ('NOTEBOOK', '/folder/mynotebook_test')]) + [(ObjectType.NOTEBOOK, '/folder/mynotebook'), (ObjectType.NOTEBOOK, '/folder/test_mynotebook'), + (ObjectType.NOTEBOOK, '/folder/mynotebook_test')]) nutter.dbclient.list_objects.return_value = workspace_path_1 @@ -69,43 +78,24 @@ def test__list_tests__twotest_in_folder__okay(mocker): assert tests[1] == TestNotebook( 'mynotebook_test', '/folder/mynotebook_test') -@pytest.mark.skip('No longer needed') -def test__list_tests__response_without_root_object__okay(mocker): - - nutter = _get_nutter(mocker) - dbapi_client = _get_client(mocker) - nutter.dbclient = dbapi_client - mocker.patch.object(nutter.dbclient, 'list_objects') - - objects = """{"objects":[ - {"object_type":"NOTEBOOK","path":"/mynotebook","language":"PYTHON"}, - {"object_type":"NOTEBOOK","path":"/test_mynotebook","language":"PYTHON"}]}""" - - nutter.dbclient.list_notebooks.return_value = WorkspacePath(json.loads(objects)[ - 'objects']) - - tests = nutter.list_tests("/") - - assert len(tests) == 1 - assert tests[0] == TestNotebook('test_mynotebook', '/test_mynotebook') - def test__list_tests__twotest_uppercase_name__okay(mocker): - nutter = _get_nutter(mocker) dbapi_client = _get_client(mocker) nutter.dbclient = dbapi_client mocker.patch.object(nutter.dbclient, 'list_objects') workspace_path_1 = _get_workspacepathobject( - [('NOTEBOOK', '/mynotebook'), ('NOTEBOOK', '/TEST_mynote'), ('NOTEBOOK', '/mynote_TEST')]) + [(ObjectType.NOTEBOOK, '/mynotebook'), (ObjectType.NOTEBOOK, '/TEST_mynote'), (ObjectType.NOTEBOOK, '/mynote_TEST')]) nutter.dbclient.list_objects.return_value = workspace_path_1 tests = nutter.list_tests("/") assert len(tests) == 2 - assert tests == [TestNotebook('TEST_mynote', '/TEST_mynote'), TestNotebook('mynote_TEST', '/mynote_TEST')] + assert tests == [TestNotebook('TEST_mynote', '/TEST_mynote'), + TestNotebook('mynote_TEST', '/mynote_TEST')] + def test__list_tests__nutterstatusevents_testlisting_sequence_is_fired(mocker): event_handler = TestEventHandler() @@ -115,7 +105,7 @@ def test__list_tests__nutterstatusevents_testlisting_sequence_is_fired(mocker): mocker.patch.object(nutter.dbclient, 'list_objects') workspace_path_1 = _get_workspacepathobject( - [('NOTEBOOK', '/mynotebook'), ('NOTEBOOK', '/TEST_mynote'), ('NOTEBOOK', '/mynote_TEST')]) + [(ObjectType.NOTEBOOK, '/mynotebook'), (ObjectType.NOTEBOOK, '/TEST_mynote'), (ObjectType.NOTEBOOK, '/mynote_TEST')]) nutter.dbclient.list_objects.return_value = workspace_path_1 @@ -127,6 +117,7 @@ def test__list_tests__nutterstatusevents_testlisting_sequence_is_fired(mocker): assert status_event.event == NutterStatusEvents.TestsListingResults assert status_event.data == 2 + def test__list_tests_recursively__1test1dir2test__3_tests(mocker): nutter = _get_nutter(mocker) dbapi_client = _get_client(mocker) @@ -134,8 +125,9 @@ def test__list_tests_recursively__1test1dir2test__3_tests(mocker): mocker.patch.object(nutter.dbclient, 'list_objects') workspace_path_1 = _get_workspacepathobject( - [('NOTEBOOK', '/test_1'), ('DIRECTORY', '/p')]) - workspace_path_2 = _get_workspacepathobject([('NOTEBOOK', '/p/test_1'), ('NOTEBOOK', '/p/2_test')]) + [(ObjectType.NOTEBOOK, '/test_1'), (ObjectType.DIRECTORY, '/p')]) + workspace_path_2 = _get_workspacepathobject( + [(ObjectType.NOTEBOOK, '/p/test_1'), (ObjectType.NOTEBOOK, '/p/2_test')]) nutter.dbclient.list_objects.side_effect = [workspace_path_1, workspace_path_2] @@ -147,6 +139,7 @@ def test__list_tests_recursively__1test1dir2test__3_tests(mocker): assert expected == tests assert nutter.dbclient.list_objects.call_count == 2 + def test__list_tests_recursively__2test1dir2test__4_tests(mocker): nutter = _get_nutter(mocker) dbapi_client = _get_client(mocker) @@ -154,9 +147,9 @@ def test__list_tests_recursively__2test1dir2test__4_tests(mocker): mocker.patch.object(nutter.dbclient, 'list_objects') workspace_path_1 = _get_workspacepathobject( - [('NOTEBOOK', '/test_1'), ('NOTEBOOK', '/3_test'), ('DIRECTORY', '/p')]) + [(ObjectType.NOTEBOOK, '/test_1'), (ObjectType.NOTEBOOK, '/3_test'), (ObjectType.DIRECTORY, '/p')]) workspace_path_2 = _get_workspacepathobject( - [('NOTEBOOK', '/p/test_1'), ('NOTEBOOK', '/p/test_2')]) + [(ObjectType.NOTEBOOK, '/p/test_1'), (ObjectType.NOTEBOOK, '/p/test_2')]) nutter.dbclient.list_objects.side_effect = [ workspace_path_1, workspace_path_2] @@ -171,6 +164,7 @@ def test__list_tests_recursively__2test1dir2test__4_tests(mocker): assert expected == tests assert nutter.dbclient.list_objects.call_count == 2 + def test__list_tests_recursively__1test1test1dir1dir__2_test(mocker): nutter = _get_nutter(mocker) dbapi_client = _get_client(mocker) @@ -178,8 +172,8 @@ def test__list_tests_recursively__1test1test1dir1dir__2_test(mocker): mocker.patch.object(nutter.dbclient, 'list_objects') workspace_path_1 = _get_workspacepathobject( - [('NOTEBOOK', '/test_1'), ('NOTEBOOK', '/2_test'), ('DIRECTORY', '/p')]) - workspace_path_2 = _get_workspacepathobject([('DIRECTORY', '/p/c')]) + [(ObjectType.NOTEBOOK, '/test_1'), (ObjectType.NOTEBOOK, '/2_test'), (ObjectType.DIRECTORY, '/p')]) + workspace_path_2 = _get_workspacepathobject([(ObjectType.DIRECTORY, '/p/c')]) workspace_path_3 = _get_workspacepathobject([]) nutter.dbclient.list_objects.side_effect = [ @@ -192,30 +186,31 @@ def test__list_tests_recursively__1test1test1dir1dir__2_test(mocker): assert expected == tests assert nutter.dbclient.list_objects.call_count == 3 + def test__list_tests__notest__empty_list(mocker): nutter = _get_nutter(mocker) dbapi_client = _get_client(mocker) nutter.dbclient = dbapi_client _mock_dbclient_list_objects(mocker, dbapi_client, [ - ('NOTEBOOK', '/my'), ('NOTEBOOK', '/my2')]) + (ObjectType.NOTEBOOK, '/my'), (ObjectType.NOTEBOOK, '/my2')]) results = nutter.list_tests("/") assert len(results) == 0 - -def test__run_tests__twomatch_three_tests___nutterstatusevents_testlisting_scheduling_execution_sequence_is_fired(mocker): +def test__run_tests__twomatch_three_tests___nutterstatusevents_testlisting_scheduling_execution_sequence_is_fired( + mocker): event_handler = TestEventHandler() nutter = _get_nutter(mocker, event_handler) test_results = TestResults() - test_results.append(TestResult('case',True, 10,[])) - submit_response = _get_submit_run_response('SUCCESS', 'TERMINATED', test_results.serialize()) - dbapi_client = _get_client_for_execute_notebook(mocker, submit_response) + test_results.append(TestResult('case', True, 10, [])) + submit_response, run_output = _get_submit_run_response('SUCCESS', 'TERMINATED', test_results.serialize()) + dbapi_client = _get_client_for_execute_notebook(mocker, submit_response, run_output) nutter.dbclient = dbapi_client _mock_dbclient_list_objects(mocker, dbapi_client, [( - 'NOTEBOOK', '/test_my'), ('NOTEBOOK', '/test_abc'), ('NOTEBOOK', '/my_test')]) + ObjectType.NOTEBOOK, '/test_my'), (ObjectType.NOTEBOOK, '/test_abc'), (ObjectType.NOTEBOOK, '/my_test')]) results = nutter.run_tests("/my*", "cluster") @@ -252,18 +247,19 @@ def test__run_tests__twomatch_three_tests___nutterstatusevents_testlisting_sched status_event = event_handler.get_item() assert status_event.event == NutterStatusEvents.TestExecutionResult - assert status_event.data #True if success + assert status_event.data # True if success + def test__run_tests__twomatch__okay(mocker): nutter = _get_nutter(mocker) - submit_response = _get_submit_run_response('SUCCESS', 'TERMINATED', '') - dbapi_client = _get_client_for_execute_notebook(mocker, submit_response) + submit_response, run_output = _get_submit_run_response('SUCCESS', 'TERMINATED', '') + dbapi_client = _get_client_for_execute_notebook(mocker, submit_response, run_output) nutter.dbclient = dbapi_client _mock_dbclient_list_objects(mocker, dbapi_client, [ - ('NOTEBOOK', '/test_my'), - ('NOTEBOOK', '/my'), - ('NOTEBOOK', '/my_test')]) + (ObjectType.NOTEBOOK, '/test_my'), + (ObjectType.NOTEBOOK, '/my'), + (ObjectType.NOTEBOOK, '/my_test')]) results = nutter.run_tests("/my*", "cluster") @@ -275,81 +271,87 @@ def test__run_tests__twomatch__okay(mocker): result = results[1] assert result.task_result_state == 'TERMINATED' + def test__run_tests_recursively__2test1dir3test__5_tests(mocker): nutter = _get_nutter(mocker) - submit_response = _get_submit_run_response('SUCCESS', 'TERMINATED', '') - dbapi_client = _get_client_for_execute_notebook(mocker, submit_response) + submit_response, run_output = _get_submit_run_response('SUCCESS', 'TERMINATED', '') + dbapi_client = _get_client_for_execute_notebook(mocker, submit_response, run_output) nutter.dbclient = dbapi_client mocker.patch.object(nutter.dbclient, 'list_objects') workspace_path_1 = _get_workspacepathobject( - [('NOTEBOOK', '/test_1'), ('NOTEBOOK', '/2_test'), ('DIRECTORY', '/p')]) + [(ObjectType.NOTEBOOK, '/test_1'), (ObjectType.NOTEBOOK, '/2_test'), (ObjectType.DIRECTORY, '/p')]) workspace_path_2 = _get_workspacepathobject( - [('NOTEBOOK', '/p/test_1'), ('NOTEBOOK', '/p/test_2'), ('NOTEBOOK', '/p/3_test')]) + [(ObjectType.NOTEBOOK, '/p/test_1'), (ObjectType.NOTEBOOK, '/p/test_2'), (ObjectType.NOTEBOOK, '/p/3_test')]) nutter.dbclient.list_objects.side_effect = [ workspace_path_1, workspace_path_2] - tests = nutter.run_tests('/','cluster', 120, 1, True) + tests = nutter.run_tests('/', 'cluster', 120, 1, True) assert len(tests) == 5 + def test__run_tests_recursively__1dir1dir3test__3_tests(mocker): nutter = _get_nutter(mocker) - submit_response = _get_submit_run_response('SUCCESS', 'TERMINATED', '') - dbapi_client = _get_client_for_execute_notebook(mocker, submit_response) + submit_response, run_output = _get_submit_run_response('SUCCESS', 'TERMINATED', '') + dbapi_client = _get_client_for_execute_notebook(mocker, submit_response, run_output) nutter.dbclient = dbapi_client mocker.patch.object(nutter.dbclient, 'list_objects') workspace_path_1 = _get_workspacepathobject( - [('DIRECTORY', '/p')]) + [(ObjectType.DIRECTORY, '/p')]) workspace_path_2 = _get_workspacepathobject( - [('DIRECTORY', '/c')]) + [(ObjectType.DIRECTORY, '/c')]) workspace_path_3 = _get_workspacepathobject( - [('NOTEBOOK', '/p/c/test_1'), ('NOTEBOOK', '/p/c/test_2'), ('NOTEBOOK', '/p/c/3_test')]) + [(ObjectType.NOTEBOOK, '/p/c/test_1'), (ObjectType.NOTEBOOK, '/p/c/test_2'), (ObjectType.NOTEBOOK, '/p/c/3_test')]) nutter.dbclient.list_objects.side_effect = [ workspace_path_1, workspace_path_2, workspace_path_3] - tests = nutter.run_tests('/','cluster', 120, 1, True) + tests = nutter.run_tests('/', 'cluster', 120, 1, True) assert len(tests) == 3 + def test__run_tests__twomatch__is_uppercase__okay(mocker): nutter = _get_nutter(mocker) - submit_response = _get_submit_run_response('SUCCESS', 'TERMINATED', '') - dbapi_client = _get_client_for_execute_notebook(mocker, submit_response) + submit_response, run_output = _get_submit_run_response('SUCCESS', 'TERMINATED', '') + dbapi_client = _get_client_for_execute_notebook(mocker, submit_response, run_output) nutter.dbclient = dbapi_client _mock_dbclient_list_objects(mocker, dbapi_client, [( - 'NOTEBOOK', '/TEST_my'), ('NOTEBOOK', '/my'), ('NOTEBOOK', '/my_TEST')]) + ObjectType.NOTEBOOK, '/TEST_my'), (ObjectType.NOTEBOOK, '/my'), (ObjectType.NOTEBOOK, '/my_TEST')]) results = nutter.run_tests("/my*", "cluster") assert len(results) == 2 assert results[0].task_result_state == 'TERMINATED' + def test__run_tests__nomatch_case_sensitive__okay(mocker): nutter = _get_nutter(mocker) - submit_response = _get_submit_run_response('SUCCESS', 'TERMINATED', '') - dbapi_client = _get_client_for_execute_notebook(mocker, submit_response) + submit_response, run_output = _get_submit_run_response('SUCCESS', 'TERMINATED', '') + dbapi_client = _get_client_for_execute_notebook(mocker, submit_response, run_output) nutter.dbclient = dbapi_client _mock_dbclient_list_objects(mocker, dbapi_client, [( - 'NOTEBOOK', '/test_MY'), ('NOTEBOOK', '/my'), ('NOTEBOOK', '/MY_test')]) + ObjectType.NOTEBOOK, '/test_MY'), (ObjectType.NOTEBOOK, '/my'), (ObjectType.NOTEBOOK, '/MY_test')]) results = nutter.run_tests("/my*", "cluster") assert len(results) == 0 + def test__run_tests__fourmatches_with_pattern__okay(mocker): - submit_response = _get_submit_run_response('SUCCESS', 'TERMINATED', '') - dbapi_client = _get_client_for_execute_notebook(mocker, submit_response) + submit_response, run_output = _get_submit_run_response('SUCCESS', 'TERMINATED', '') + dbapi_client = _get_client_for_execute_notebook(mocker, submit_response, run_output) nutter = _get_nutter(mocker) nutter.dbclient = dbapi_client _mock_dbclient_list_objects(mocker, dbapi_client, [( - 'NOTEBOOK', '/test_my'), ('NOTEBOOK', '/test_my2'), ('NOTEBOOK', '/my_test'), ('NOTEBOOK', '/my3_test')]) + ObjectType.NOTEBOOK, '/test_my'), (ObjectType.NOTEBOOK, '/test_my2'), (ObjectType.NOTEBOOK, '/my_test'), + (ObjectType.NOTEBOOK, '/my3_test')]) results = nutter.run_tests("/my*", "cluster") @@ -359,47 +361,50 @@ def test__run_tests__fourmatches_with_pattern__okay(mocker): assert results[2].task_result_state == 'TERMINATED' assert results[3].task_result_state == 'TERMINATED' + def test__run_tests__with_invalid_pattern__valueerror(mocker): - submit_response = _get_submit_run_response('SUCCESS', 'TERMINATED', '') - dbapi_client = _get_client_for_execute_notebook(mocker, submit_response) + submit_response, run_output = _get_submit_run_response('SUCCESS', 'TERMINATED', '') + dbapi_client = _get_client_for_execute_notebook(mocker, submit_response, run_output) nutter = _get_nutter(mocker) nutter.dbclient = dbapi_client _mock_dbclient_list_objects(mocker, dbapi_client, [( - 'NOTEBOOK', '/test_my'), ('NOTEBOOK', '/test_my2')]) + ObjectType.NOTEBOOK, '/test_my'), (ObjectType.NOTEBOOK, '/test_my2')]) with pytest.raises(ValueError): results = nutter.run_tests("/my/(", "cluster") def test__run_tests__nomatches__okay(mocker): - submit_response = _get_submit_run_response('SUCCESS', 'TERMINATED', '') - dbapi_client = _get_client_for_execute_notebook(mocker, submit_response) + submit_response, run_output = _get_submit_run_response('SUCCESS', 'TERMINATED', '') + dbapi_client = _get_client_for_execute_notebook(mocker, submit_response, run_output) nutter = _get_nutter(mocker) nutter.dbclient = dbapi_client _mock_dbclient_list_objects(mocker, dbapi_client, [( - 'NOTEBOOK', '/test_my'), ('NOTEBOOK', '/test_my2'), ('NOTEBOOK', '/my_test'), ('NOTEBOOK', '/my3_test')]) + ObjectType.NOTEBOOK, '/test_my'), (ObjectType.NOTEBOOK, '/test_my2'), (ObjectType.NOTEBOOK, '/my_test'), + (ObjectType.NOTEBOOK, '/my3_test')]) results = nutter.run_tests("/abc*", "cluster") assert len(results) == 0 + def test__to_testresults__none_output__none(mocker): output = None - result = nutter_api.to_testresults(output) + result = nutter_api.to_test_results(output) assert result is None def test__to_testresults__non_pickle_output__none(mocker): output = 'NOT A PICKLE' - result = nutter_api.to_testresults(output) + result = nutter_api.to_test_results(output) assert result is None def test__to_testresults__pickle_output__testresult(mocker): output = TestResults().serialize() - result = nutter_api.to_testresults(output) + result = nutter_api.to_test_results(output) assert isinstance(result, TestResults) @@ -411,6 +416,8 @@ def test__to_testresults__pickle_output__testresult(mocker): ('abc'), ('abc*'), ] + + @pytest.mark.parametrize('pattern', patterns) def test__testnamepatternmatcher_ctor_valid_pattern__instance(pattern): pattern_matcher = TestNamePatternMatcher(pattern) @@ -423,6 +430,8 @@ def test__testnamepatternmatcher_ctor_valid_pattern__instance(pattern): ('*'), (None), ] + + @pytest.mark.parametrize('pattern', all_patterns) def test__testnamepatternmatcher_ctor_valid_all_pattern__pattern_is_none(pattern): pattern_matcher = TestNamePatternMatcher(pattern) @@ -436,6 +445,8 @@ def test__testnamepatternmatcher_ctor_valid_all_pattern__pattern_is_none(pattern ('tt*'), ('e^6'), ] + + @pytest.mark.parametrize('pattern', reg_patterns) def test__testnamepatternmatcher_ctor_valid_regex_pattern__pattern_is_pattern(pattern): pattern_matcher = TestNamePatternMatcher(pattern) @@ -450,16 +461,18 @@ def test__testnamepatternmatcher_ctor_valid_regex_pattern__pattern_is_pattern(pa ('a', [TestNotebook("a_test", "/a_test")], 1), ('*', [TestNotebook("test_a", "/test_a"), TestNotebook("test_b", "/test_b")], 2), ('*', [TestNotebook("a_test", "/a_test"), TestNotebook("b_test", "/b_test")], 2), - ('b*',[TestNotebook("test_a", "/test_a"), TestNotebook("test_b", "/test_b")], 1), - ('b*',[TestNotebook("a_test", "/a_test"), TestNotebook("b_test", "/b_test")], 1), - ('b*',[TestNotebook("test_ba", "/test_ba"), TestNotebook("test_b", "/test_b")], 2), - ('b*',[TestNotebook("ba_test", "/ba_test"), TestNotebook("b_test", "/b_test")], 2), - ('c*',[TestNotebook("test_a", "/test_a"), TestNotebook("test_b", "/test_b")], 0), - ('c*',[TestNotebook("a_test", "/a_test"), TestNotebook("b_test", "/b_test")], 0), + ('b*', [TestNotebook("test_a", "/test_a"), TestNotebook("test_b", "/test_b")], 1), + ('b*', [TestNotebook("a_test", "/a_test"), TestNotebook("b_test", "/b_test")], 1), + ('b*', [TestNotebook("test_ba", "/test_ba"), TestNotebook("test_b", "/test_b")], 2), + ('b*', [TestNotebook("ba_test", "/ba_test"), TestNotebook("b_test", "/b_test")], 2), + ('c*', [TestNotebook("test_a", "/test_a"), TestNotebook("test_b", "/test_b")], 0), + ('c*', [TestNotebook("a_test", "/a_test"), TestNotebook("b_test", "/b_test")], 0), ] -@pytest.mark.parametrize('pattern, list_results, expected_count', filter_patterns) -def test__filter_by_pattern__valid_scenarios__result_len_is_expected_count(pattern, list_results, expected_count): + +@pytest.mark.parametrize('pattern, list_results, expected_count', filter_patterns) +def test__filter_by_pattern__valid_scenarios__result_len_is_expected_count(pattern, list_results, + expected_count): pattern_matcher = TestNamePatternMatcher(pattern) filtered = pattern_matcher.filter_by_pattern(list_results) @@ -470,52 +483,40 @@ def test__filter_by_pattern__valid_scenarios__result_len_is_expected_count(patte ('('), ('--)'), ] + + @pytest.mark.parametrize('pattern', invalid_patterns) def test__testnamepatternmatcher_ctor__invali_pattern__valueerror(pattern): - with pytest.raises(ValueError): pattern_matcher = TestNamePatternMatcher(pattern) def _get_submit_run_response(result_state, life_cycle_state, result): - data_json = """ - {"notebook_output": - {"result": "IHaveReturned", "truncated": false}, - "metadata": - {"execution_duration": 15000, - "run_type": "SUBMIT_RUN", - "cleanup_duration": 0, - "number_in_job": 1, - "cluster_instance": - {"cluster_id": "0925-141d1222-narcs242", - "spark_context_id": "803963628344534476"}, - "creator_user_name": "abc@microsoft.com", - "task": {"notebook_task": {"notebook_path": "/mynotebook"}}, - "run_id": 7, "start_time": 1569887259173, - "job_id": 4, - "state": {"result_state": "SUCCESS", "state_message": "", - "life_cycle_state": "TERMINATED"}, "setup_duration": 2000, - "run_page_url": "https://westus2.azuredatabricks.net/?o=14702dasda6094293890#job/4/run/1", - "cluster_spec": {"existing_cluster_id": "0925-141122-narcs242"}, "run_name": "myrun"}} - """ - data_dict = json.loads(data_json) - data_dict['notebook_output']['result'] = result - data_dict['metadata']['state']['result_state'] = result_state - data_dict['metadata']['state']['life_cycle_state'] = life_cycle_state - - return json.dumps(data_dict) - - -def _get_client_for_execute_notebook(mocker, output_data): - run_id = {} - run_id['run_id'] = 1 - + run_info = Run( + tasks=[ + RunTask(task_key="test_task", + notebook_task=NotebookTask("/mynotebook"), + run_id=2, + state=RunState(life_cycle_state=getattr(RunLifeCycleState, life_cycle_state), + result_state=getattr(RunResultState, result_state), + state_message="")) + ], + run_id=1, + run_page_url="https://westus2.azuredatabricks.net/?o=14702dasda6094293890#job/4/run/1", + state=RunState(life_cycle_state=getattr(RunLifeCycleState, life_cycle_state), + result_state=getattr(RunResultState, result_state), + state_message=""), + ) + notebook_output = RunOutput(notebook_output=NotebookOutput(result=result, truncated=False)) + return run_info, notebook_output + + +def _get_client_for_execute_notebook(mocker, run_info, run_output): db = _get_client(mocker) - mocker.patch.object(db.inner_dbclient.jobs, 'submit_run') - db.inner_dbclient.jobs.submit_run.return_value = run_id - mocker.patch.object(db.inner_dbclient.jobs, 'get_run_output') - db.inner_dbclient.jobs.get_run_output.return_value = json.loads( - output_data) + mocker.patch.object(db.dbclient.jobs, 'submit_and_wait') + db.dbclient.jobs.submit_and_wait.return_value = run_info + mocker.patch.object(db.dbclient.jobs, 'get_run_output') + db.dbclient.jobs.get_run_output.return_value = run_output return db @@ -527,7 +528,7 @@ def _get_client(mocker): return DatabricksAPIClient() -def _get_nutter(mocker, event_handler = None): +def _get_nutter(mocker, event_handler=None): mocker.patch.dict(os.environ, {'DATABRICKS_HOST': 'myhost'}) mocker.patch.dict(os.environ, {'DATABRICKS_TOKEN': 'mytoken'}) @@ -542,20 +543,14 @@ def _mock_dbclient_list_objects(mocker, dbclient, objects): def _get_workspacepathobject(objects): - objects_list = [] - for object in objects: - item = {} - item['object_type'] = object[0] - item['path'] = object[1] - item['language'] = 'PYTHON' - objects_list.append(item) - - root_obj = {'objects': objects_list} - - return WorkspacePath.from_api_response(root_obj) + objects_list = [ObjectInfo(path=obj[1], language=Language.PYTHON, object_type=obj[0]) + for obj in objects] + return WorkspacePath.from_api_response(objects_list) class TestEventHandler(EventHandler): + __test__ = False # Tell pytest this is not a test class + def __init__(self): self._queue = None super().__init__() @@ -567,4 +562,3 @@ def get_item(self): item = self._queue.get() self._queue.task_done() return item - diff --git a/tests/nutter/test_apiclientresults.py b/tests/nutter/test_apiclientresults.py index 4928483..54016e9 100644 --- a/tests/nutter/test_apiclientresults.py +++ b/tests/nutter/test_apiclientresults.py @@ -4,64 +4,68 @@ """ import pytest -import json +from unittest.mock import Mock +from databricks.sdk.service.jobs import Run, RunTask, NotebookTask, RunState, \ + RunLifeCycleState, RunResultState, RunOutput, NotebookOutput + from common.api import Nutter, TestNotebook, NutterStatusEvents import common.api as nutter_api from common.testresult import TestResults, TestResult from common.apiclientresults import ExecuteNotebookResult, NotebookOutputResult + def test__is_any_error__not_terminated__true(): - exec_result = _get_run_test_response('', 'SKIPPED','') + exec_result = _get_run_test_response('', 'SKIPPED', '') assert exec_result.is_any_error def test__is_any_error__terminated_not_success__true(): - exec_result = _get_run_test_response('FAILED', 'TERMINATED','') + exec_result = _get_run_test_response('FAILED', 'TERMINATED', '') assert exec_result.is_any_error def test__is_any_error__terminated_success_invalid_results__true(): - exec_result = _get_run_test_response('SUCCESS', 'TERMINATED','') + exec_result = _get_run_test_response('SUCCESS', 'TERMINATED', '') assert exec_result.is_any_error def test__is_any_error__terminated_success_valid_results_with_failure__true(): test_results = TestResults() - test_results.append(TestResult('case',False, 10,[])) - exec_result = _get_run_test_response('SUCCESS', 'TERMINATED',test_results.serialize()) + test_results.append(TestResult('case', False, 10, [])) + exec_result = _get_run_test_response('SUCCESS', 'TERMINATED', test_results.serialize()) assert exec_result.is_any_error - def test__is_any_error__terminated_success_valid_results_with_no_failure__false(): test_results = TestResults() - test_results.append(TestResult('case',True, 10,[])) - exec_result = _get_run_test_response('SUCCESS', 'TERMINATED',test_results.serialize()) + test_results.append(TestResult('case', True, 10, [])) + exec_result = _get_run_test_response('SUCCESS', 'TERMINATED', test_results.serialize()) assert not exec_result.is_any_error - def test__is_any_error__terminated_success_2_valid_results_with_no_failure__false(): test_results = TestResults() - test_results.append(TestResult('case',True, 10,[])) - test_results.append(TestResult('case2',True, 10,[])) - exec_result = _get_run_test_response('SUCCESS', 'TERMINATED',test_results.serialize()) + test_results.append(TestResult('case', True, 10, [])) + test_results.append(TestResult('case2', True, 10, [])) + exec_result = _get_run_test_response('SUCCESS', 'TERMINATED', test_results.serialize()) assert not exec_result.is_any_error + def test__is_any_error__terminated_success_2_results_1_invalid__true(): test_results = TestResults() - test_results.append(TestResult('case',True, 10,[])) - test_results.append(TestResult('case2',False, 10,[])) - exec_result = _get_run_test_response('SUCCESS', 'TERMINATED',test_results.serialize()) + test_results.append(TestResult('case', True, 10, [])) + test_results.append(TestResult('case2', False, 10, [])) + exec_result = _get_run_test_response('SUCCESS', 'TERMINATED', test_results.serialize()) assert exec_result.is_any_error + def test__is_run_from_notebook__result_state_NA__returns_true(): # Arrange nbr = NotebookOutputResult('N/A', None, None) @@ -69,9 +73,10 @@ def test__is_run_from_notebook__result_state_NA__returns_true(): # Act is_run_from_notebook = nbr.is_run_from_notebook - #Assert + # Assert assert True == is_run_from_notebook + def test__is_error__is_run_from_notebook_true__returns_false(): # Arrange nbr = NotebookOutputResult('N/A', None, None) @@ -79,34 +84,41 @@ def test__is_error__is_run_from_notebook_true__returns_false(): # Act is_error = nbr.is_error - #Assert + # Assert assert False == is_error -def _get_run_test_response(result_state, life_cycle_state, notebook_result): - data_json = """ - {"notebook_output": - {"result": "IHaveReturned", "truncated": false}, - "metadata": - {"execution_duration": 15000, - "run_type": "SUBMIT_RUN", - "cleanup_duration": 0, - "number_in_job": 1, - "cluster_instance": - {"cluster_id": "0925-141d1222-narcs242", - "spark_context_id": "803963628344534476"}, - "creator_user_name": "abc@microsoft.com", - "task": {"notebook_task": {"notebook_path": "/test_mynotebook"}}, - "run_id": 7, "start_time": 1569887259173, - "job_id": 4, - "state": {"result_state": "SUCCESS", "state_message": "", - "life_cycle_state": "TERMINATED"}, "setup_duration": 2000, - "run_page_url": "https://westus2.azuredatabricks.net/?o=14702dasda6094293890#job/4/run/1", - "cluster_spec": {"existing_cluster_id": "0925-141122-narcs242"}, "run_name": "myrun"}} - """ - data_dict = json.loads(data_json) - data_dict['notebook_output']['result'] = notebook_result - data_dict['metadata']['state']['result_state'] = result_state - data_dict['metadata']['state']['life_cycle_state'] = life_cycle_state - - return ExecuteNotebookResult.from_job_output(data_dict) +def _get_run_test_response(result_state, life_cycle_state, notebook_result): + # Create proper SDK objects instead of JSON + result_state_enum = getattr(RunResultState, result_state) if result_state else None + lifecycle_state_enum = getattr(RunLifeCycleState, life_cycle_state) + + run_info = Run( + tasks=[ + RunTask( + task_key="test_task", + notebook_task=NotebookTask(notebook_path="/test_mynotebook"), + run_id=2, + state=RunState( + life_cycle_state=lifecycle_state_enum, + result_state=result_state_enum, + state_message="" + ) + ) + ], + run_id=1, + run_page_url="https://westus2.azuredatabricks.net/?o=14702dasda6094293890#job/4/run/1", + state=RunState( + life_cycle_state=lifecycle_state_enum, + result_state=result_state_enum, + state_message="" + ), + ) + + # Create mock WorkspaceClient + mock_client = Mock() + mock_client.jobs.get_run_output.return_value = RunOutput( + notebook_output=NotebookOutput(result=notebook_result, truncated=False) + ) + + return ExecuteNotebookResult.from_job_output(run_info, mock_client) diff --git a/tests/nutter/test_resultreports.py b/tests/nutter/test_resultreports.py index 0f2033a..6ee781c 100644 --- a/tests/nutter/test_resultreports.py +++ b/tests/nutter/test_resultreports.py @@ -8,6 +8,7 @@ from common.resultreports import JunitXMLReportWriter from common.resultreports import TagsReportWriter + def test_junitxmlreportwriter_add_result__invalid_params__raises_valueerror(): writer = JunitXMLReportWriter() diff --git a/tests/nutter/test_resultsview.py b/tests/nutter/test_resultsview.py index 865c1ae..489ee44 100644 --- a/tests/nutter/test_resultsview.py +++ b/tests/nutter/test_resultsview.py @@ -3,15 +3,19 @@ Licensed under the MIT license. """ -import json import pytest -from common.resultsview import RunCommandResultsView, TestCaseResultView, ListCommandResultView, ListCommandResultsView +from unittest.mock import Mock +from databricks.sdk.service.jobs import Run, RunTask, NotebookTask, RunState, \ + RunLifeCycleState, RunResultState, RunOutput, NotebookOutput + +from common.resultsview import RunCommandResultsView, TestCaseResultView, ListCommandResultView, \ + ListCommandResultsView from common.apiclientresults import ExecuteNotebookResult from common.testresult import TestResults, TestResult from common.api import TestNotebook -def test__add_exec_result__vaid_instance__isadded(mocker): +def test__add_exec_result__vaid_instance__isadded(mocker): test_results = TestResults().serialize() notebook_results = __get_ExecuteNotebookResult( 'SUCCESS', 'TERMINATED', test_results) @@ -23,7 +27,6 @@ def test__add_exec_result__vaid_instance__isadded(mocker): def test__add_exec_result__vaid_instance_invalid_output__isadded(mocker): - test_results = "NO PICKLE" notebook_results = __get_ExecuteNotebookResult( 'SUCCESS', 'TERMINATED', test_results) @@ -39,7 +42,6 @@ def test__add_exec_result__vaid_instance_invalid_output__isadded(mocker): def test__add_exec_result__vaid_instance_invalid_output__no_test_case_view(mocker): - test_results = "NO PICKLE" notebook_results = __get_ExecuteNotebookResult( 'SUCCESS', 'TERMINATED', test_results) @@ -51,7 +53,6 @@ def test__add_exec_result__vaid_instance_invalid_output__no_test_case_view(mocke def test__add_exec_result__vaid_instance__test_case_view(mocker): - test_results = TestResults() test_case = TestResult("mycase", True, 10, []) test_results.append(test_case) @@ -76,7 +77,6 @@ def test__add_exec_result__vaid_instance__test_case_view(mocker): def test__add_exec_result__vaid_instance_two_test_cases__two_test_case_view(mocker): - test_results = TestResults() test_case = TestResult("mycase", True, 10, []) test_results.append(test_case) @@ -115,11 +115,11 @@ def test__get_view__for_testcase_failed__returns_correct_string(mocker): stack_trace = "Stack Trace" exception = AssertionError("1 == 2") test_case = TestResult("mycase", False, 5.43, [ - 'tag1', 'tag2'], exception, stack_trace) + 'tag1', 'tag2'], exception, stack_trace) test_case_result_view = TestCaseResultView(test_case) expected_view = "mycase (5.43 seconds)\n\n" + \ - stack_trace + "\n\n" + "AssertionError: 1 == 2" + "\n" + stack_trace + "\n\n" + "AssertionError: 1 == 2" + "\n" # Act view = test_case_result_view.get_view() @@ -128,8 +128,8 @@ def test__get_view__for_testcase_failed__returns_correct_string(mocker): assert expected_view == view -def test__get_view__for_run_command_result_with_passing_test_case__shows_test_result_under_passing(mocker): - +def test__get_view__for_run_command_result_with_passing_test_case__shows_test_result_under_passing( + mocker): test_results = TestResults() test_case = TestResult("mycase", True, 10, []) test_results.append(test_case) @@ -139,7 +139,7 @@ def test__get_view__for_run_command_result_with_passing_test_case__shows_test_re notebook_results = __get_ExecuteNotebookResult( 'SUCCESS', 'TERMINATED', serialized_results) - #expected_view = 'Name: \t/test_mynotebook\nNotebook Exec Result:\tTERMINATED \nTests Cases:\nCase:\tmycase\n\n\tPASSED\n\t\n\t\n\tDuration: 10\n\nCase:\tmycase2\n\n\tPASSED\n\t\n\t\n\tDuration: 10\n\n\n----------------------------------------\n' + # expected_view = 'Name: \t/test_mynotebook\nNotebook Exec Result:\tTERMINATED \nTests Cases:\nCase:\tmycase\n\n\tPASSED\n\t\n\t\n\tDuration: 10\n\nCase:\tmycase2\n\n\tPASSED\n\t\n\t\n\tDuration: 10\n\n\n----------------------------------------\n' expected_view = '\nNotebook: /test_mynotebook - Lifecycle State: TERMINATED, Result: SUCCESS\n' expected_view += 'Run Page URL: {}\n'.format(notebook_results.notebook_run_page_url) expected_view += '============================================================\n' @@ -157,13 +157,14 @@ def test__get_view__for_run_command_result_with_passing_test_case__shows_test_re assert expected_view == view -def test__get_view__for_run_command_result_with_failing_test_case__shows_test_result_under_failing(mocker): +def test__get_view__for_run_command_result_with_failing_test_case__shows_test_result_under_failing( + mocker): test_results = TestResults() stack_trace = "Stack Trace" exception = AssertionError("1 == 2") test_case = TestResult("mycase", False, 5.43, [ - 'tag1', 'tag2'], exception, stack_trace) + 'tag1', 'tag2'], exception, stack_trace) test_case_result_view = TestCaseResultView(test_case) test_results.append(test_case) @@ -178,9 +179,9 @@ def test__get_view__for_run_command_result_with_failing_test_case__shows_test_re serialized_results = test_results.serialize() notebook_results = __get_ExecuteNotebookResult( - 'FAILURE', 'TERMINATED', serialized_results) + 'FAILED', 'TERMINATED', serialized_results) - expected_view = '\nNotebook: /test_mynotebook - Lifecycle State: TERMINATED, Result: FAILURE\n' + expected_view = '\nNotebook: /test_mynotebook - Lifecycle State: TERMINATED, Result: FAILED\n' expected_view += 'Run Page URL: {}\n'.format(notebook_results.notebook_run_page_url) expected_view += '============================================================\n' expected_view += 'FAILING TESTS\n' @@ -201,14 +202,14 @@ def test__get_view__for_run_command_result_with_failing_test_case__shows_test_re assert expected_view == view + def test__get_view__for_list_command__with_tests_Found__shows_listing(mocker): - test_notebook1 = TestNotebook('test_one','/test_one') - test_notebook2 = TestNotebook('test_two','/test_two') + test_notebook1 = TestNotebook('test_one', '/test_one') + test_notebook2 = TestNotebook('test_two', '/test_two') test_notebooks = [test_notebook1, test_notebook2] list_result_view1 = ListCommandResultView.from_test_notebook(test_notebook1) list_result_view2 = ListCommandResultView.from_test_notebook(test_notebook2) - expected_view = '\nTests Found\n' expected_view += '-------------------------------------------------------\n' expected_view += list_result_view1.get_view() @@ -222,13 +223,12 @@ def test__get_view__for_list_command__with_tests_Found__shows_listing(mocker): assert view == expected_view - -def test__get_view__for_run_command_result_with_one_passing_one_failing__shows_failing_then_passing(mocker): - +def test__get_view__for_run_command_result_with_one_passing_one_failing__shows_failing_then_passing( + mocker): stack_trace = "Stack Trace" exception = AssertionError("1 == 2") test_case = TestResult("mycase", False, 5.43, [ - 'tag1', 'tag2'], exception, stack_trace) + 'tag1', 'tag2'], exception, stack_trace) test_case_result_view = TestCaseResultView(test_case) test_results = TestResults() @@ -236,9 +236,9 @@ def test__get_view__for_run_command_result_with_one_passing_one_failing__shows_f serialized_results = test_results.serialize() notebook_results = __get_ExecuteNotebookResult( - 'FAILURE', 'TERMINATED', serialized_results) + 'FAILED', 'TERMINATED', serialized_results) - expected_view = '\nNotebook: /test_mynotebook - Lifecycle State: TERMINATED, Result: FAILURE\n' + expected_view = '\nNotebook: /test_mynotebook - Lifecycle State: TERMINATED, Result: FAILED\n' expected_view += 'Run Page URL: {}\n'.format(notebook_results.notebook_run_page_url) expected_view += '============================================================\n' expected_view += 'FAILING TESTS\n' @@ -256,29 +256,36 @@ def test__get_view__for_run_command_result_with_one_passing_one_failing__shows_f def __get_ExecuteNotebookResult(result_state, life_cycle_state, notebook_result): - data_json = """ - {"notebook_output": - {"result": "IHaveReturned", "truncated": false}, - "metadata": - {"execution_duration": 15000, - "run_type": "SUBMIT_RUN", - "cleanup_duration": 0, - "number_in_job": 1, - "cluster_instance": - {"cluster_id": "0925-141d1222-narcs242", - "spark_context_id": "803963628344534476"}, - "creator_user_name": "abc@microsoft.com", - "task": {"notebook_task": {"notebook_path": "/test_mynotebook"}}, - "run_id": 7, "start_time": 1569887259173, - "job_id": 4, - "state": {"result_state": "SUCCESS", "state_message": "", - "life_cycle_state": "TERMINATED"}, "setup_duration": 2000, - "run_page_url": "https://westus2.azuredatabricks.net/?o=14702dasda6094293890#job/4/run/1", - "cluster_spec": {"existing_cluster_id": "0925-141122-narcs242"}, "run_name": "myrun"}} - """ - data_dict = json.loads(data_json) - data_dict['notebook_output']['result'] = notebook_result - data_dict['metadata']['state']['result_state'] = result_state - data_dict['metadata']['state']['life_cycle_state'] = life_cycle_state - - return ExecuteNotebookResult.from_job_output(data_dict) + # Create proper SDK objects instead of JSON + result_state_enum = getattr(RunResultState, result_state) if result_state else None + lifecycle_state_enum = getattr(RunLifeCycleState, life_cycle_state) + + run_info = Run( + tasks=[ + RunTask( + task_key="test_task", + notebook_task=NotebookTask(notebook_path="/test_mynotebook"), + run_id=2, + state=RunState( + life_cycle_state=lifecycle_state_enum, + result_state=result_state_enum, + state_message="" + ) + ) + ], + run_id=1, + run_page_url="https://westus2.azuredatabricks.net/?o=14702dasda6094293890#job/4/run/1", + state=RunState( + life_cycle_state=lifecycle_state_enum, + result_state=result_state_enum, + state_message="" + ), + ) + + # Create mock WorkspaceClient + mock_client = Mock() + mock_client.jobs.get_run_output.return_value = RunOutput( + notebook_output=NotebookOutput(result=notebook_result, truncated=False) + ) + + return ExecuteNotebookResult.from_job_output(run_info, mock_client) diff --git a/tests/nutter/test_scheduler.py b/tests/nutter/test_scheduler.py index f502d4d..fc1c084 100644 --- a/tests/nutter/test_scheduler.py +++ b/tests/nutter/test_scheduler.py @@ -26,8 +26,12 @@ def test__run_and_wait__1_function_1_worker_exception__result_is_none_and_except (2, 2, {'this': 'this'}), (2, 2, ('this', 'that')), ] + + @pytest.mark.parametrize('num_of_funcs, num_of_workers, func_return_value', params) -def test__run_and_wait__X_functions_X_workers_x_value__results_are_okay(num_of_funcs, num_of_workers, func_return_value): +def test__run_and_wait__X_functions_X_workers_x_value__results_are_okay(num_of_funcs, + num_of_workers, + func_return_value): func_scheduler = scheduler.get_scheduler(num_of_workers) for i in range(0, num_of_funcs): diff --git a/tests/nutter/test_statuseventhandler.py b/tests/nutter/test_statuseventhandler.py index 54b8c4f..7039e83 100644 --- a/tests/nutter/test_statuseventhandler.py +++ b/tests/nutter/test_statuseventhandler.py @@ -7,6 +7,7 @@ import enum from common.statuseventhandler import StatusEventsHandler, EventHandler, StatusEvent + def test__add_event_and_wait__1_event__handler_receives_it(): test_handler = TestEventHandler() status_handler = StatusEventsHandler(test_handler) @@ -18,6 +19,7 @@ def test__add_event_and_wait__1_event__handler_receives_it(): assert item.event == TestStatusEvent.AnEvent assert item.data == 'added' + def test__add_event_and_wait__2_event2__handler_receives_them(): test_handler = TestEventHandler() status_handler = StatusEventsHandler(test_handler) @@ -34,7 +36,10 @@ def test__add_event_and_wait__2_event2__handler_receives_them(): assert item2.event == TestStatusEvent.AnEvent assert item2.data == 'added' + class TestEventHandler(EventHandler): + __test__ = False # Tell pytest this is not a test class + def __init__(self): self._queue = None super().__init__() @@ -47,5 +52,7 @@ def get_item(self): self._queue.task_done() return item + class TestStatusEvent(enum.Enum): - AnEvent = 1 \ No newline at end of file + __test__ = False # Tell pytest this is not a test class + AnEvent = 1 diff --git a/tests/nutter/test_testexecresults.py b/tests/nutter/test_testexecresults.py index 5c4d57f..4341256 100644 --- a/tests/nutter/test_testexecresults.py +++ b/tests/nutter/test_testexecresults.py @@ -7,10 +7,12 @@ from common.testexecresults import TestExecResults from common.testresult import TestResults, TestResult + def test__ctor__test_results_not_correct_type__raises_type_error(): with pytest.raises(TypeError): test_exec_result = TestExecResults("invalidtype") + def test__to_string__valid_test_results__creates_view_from_test_results_and_returns(mocker): # Arrange test_results = TestResults() @@ -26,17 +28,20 @@ def test__to_string__valid_test_results__creates_view_from_test_results_and_retu mocker.patch.object(test_exec_result.runcommand_results_view, 'add_exec_result') mocker.patch.object(test_exec_result.runcommand_results_view, 'get_view') test_exec_result.runcommand_results_view.get_view.return_value = "expectedview" - + # Act view = test_exec_result.to_string() # Assert test_exec_result.get_ExecuteNotebookResult.assert_called_once_with("", test_results) - test_exec_result.runcommand_results_view.add_exec_result.assert_called_once_with(notebook_result) + test_exec_result.runcommand_results_view.add_exec_result.assert_called_once_with( + notebook_result) test_exec_result.runcommand_results_view.get_view.assert_called_once_with() assert view == "expectedview" -def test__to_string__valid_test_results_run_from_notebook__creates_view_from_test_results_and_returns(mocker): + +def test__to_string__valid_test_results_run_from_notebook__creates_view_from_test_results_and_returns( + mocker): # Arrange test_results = TestResults() test_results.append(TestResult("test1", True, 10, [])) @@ -52,6 +57,7 @@ def test__to_string__valid_test_results_run_from_notebook__creates_view_from_tes assert "test1" in view assert "test2" in view + def test__exit__valid_test_results__serializes_test_results_and_passes_to_dbutils_exit(mocker): # Arrange test_results = TestResults() @@ -74,10 +80,12 @@ def test__exit__valid_test_results__serializes_test_results_and_passes_to_dbutil assert True == dbutils_stub.notebook.exit_called assert serialized_data == dbutils_stub.notebook.data_passed + class DbUtilsStub: def __init__(self): self.notebook = NotebookStub() + class NotebookStub(): def __init__(self): self.exit_called = False @@ -85,4 +93,4 @@ def __init__(self): def exit(self, data): self.exit_called = True - self.data_passed = data \ No newline at end of file + self.data_passed = data diff --git a/tests/nutter/test_testresult.py b/tests/nutter/test_testresult.py index bd4b29e..1c85fe9 100644 --- a/tests/nutter/test_testresult.py +++ b/tests/nutter/test_testresult.py @@ -4,8 +4,6 @@ """ import base64 -import json -import pickle import mock import pytest @@ -21,16 +19,18 @@ def test__testresults_append__type_not_testresult__throws_error(): with pytest.raises(TypeError): test_results.append("Test") + def test__testresults_append__type_testresult__appends_testresult(): # Arrange test_results = TestResults() # Act test_results.append(TestResult("Test Name", True, 1, [])) - + # Assert assert len(test_results.results) == 1 + def test__eq__test_results_not_equal__are_not_equal(): # Arrange test_results = TestResults() @@ -45,6 +45,7 @@ def test__eq__test_results_not_equal__are_not_equal(): are_not_equal = test_results != test_results1 assert are_not_equal == True + def test__deserialize__no_constraints__is_serializable_and_deserializable(): # Arrange test_results = TestResults() @@ -58,27 +59,30 @@ def test__deserialize__no_constraints__is_serializable_and_deserializable(): assert test_results == deserialized_data -def test__deserialize__empty_pickle_data__throws_exception(): + +def test__deserialize__empty_data__throws_exception(): # Arrange test_results = TestResults() - invalid_pickle = "" + invalid_data = "" # Act / Assert with pytest.raises(Exception): - test_results.deserialize(invalid_pickle) + test_results.deserialize(invalid_data) -def test__deserialize__invalid_pickle_data__throws_Exception(): + +def test__deserialize__invalid_pickle_data__throws_exception(): # Arrange test_results = TestResults() - invalid_pickle = "test" + invalid_data = "test" # Act / Assert with pytest.raises(Exception): - test_results.deserialize(invalid_pickle) + test_results.deserialize(invalid_data) + -def test__deserialize__p4jjavaerror__is_serializable_and_deserializable(): +def test__deserialize__py4j_javaerror__is_serializable_and_deserializable(): # Arrange test_results = TestResults() @@ -107,6 +111,7 @@ def test__eq__test_results_equal_but_not_same_ref__are_equal(): # Act / Assert assert test_results == test_results1 + def test__num_tests__5_test_cases__is_5(): # Arrange test_results = TestResults() @@ -119,6 +124,7 @@ def test__num_tests__5_test_cases__is_5(): # Act / Assert assert 5 == test_results.test_cases + def test__num_failures__5_test_cases_4_failures__is_4(): # Arrange test_results = TestResults() @@ -131,6 +137,7 @@ def test__num_failures__5_test_cases_4_failures__is_4(): # Act / Assert assert 4 == test_results.num_failures + def test__total_execution_time__5_test_cases__is_sum_of_execution_times(): # Arrange test_results = TestResults() @@ -143,22 +150,6 @@ def test__total_execution_time__5_test_cases__is_sum_of_execution_times(): # Act / Assert assert 32.990534 == test_results.total_execution_time -def test__serialize__result_data__is_base64_str(): - test_results = TestResults() - serialized_data = test_results.serialize() - serialized_bin_data = base64.encodebytes(pickle.dumps(test_results)) - - assert serialized_data == str(serialized_bin_data, "utf-8") - - -def test__deserialize__data_is_base64_str__can_deserialize(): - test_results = TestResults() - serialized_bin_data = pickle.dumps(test_results) - serialized_str = str(base64.encodebytes(serialized_bin_data), "utf-8") - test_results_from_data = TestResults().deserialize(serialized_str) - - assert test_results == test_results_from_data - def get_mock_gateway_client(): mock_client = mock.Mock() diff --git a/tests/runtime/test_nutterfixture.py b/tests/runtime/test_nutterfixture.py index 5c81412..1fa06ea 100644 --- a/tests/runtime/test_nutterfixture.py +++ b/tests/runtime/test_nutterfixture.py @@ -1,393 +1,395 @@ -""" -Copyright (c) Microsoft Corporation. -Licensed under the MIT license. -""" - -import pytest -from runtime.nutterfixture import NutterFixture, tag, InvalidTestFixtureException, InitializationException -from runtime.testcase import TestCase -from runtime.fixtureloader import FixtureLoader -from common.testresult import TestResult, TestResults -from tests.runtime.testnutterfixturebuilder import TestNutterFixtureBuilder -from common.apiclientresults import ExecuteNotebookResult -import sys - -def test__ctor__creates_fixture_loader(): - # Arrange / Act - fix = SimpleTestFixture() - - # Assert - assert fix.data_loader is not None - -def test__execute_tests__calls_load_fixture_on_fixture_loader(mocker): - # Arrange - fix = SimpleTestFixture() - - mocker.patch.object(fix.data_loader, 'load_fixture') - - # Act - fix.execute_tests() - - # Assert - fix.data_loader.load_fixture.assert_called_once_with(fix) - -def test__execute_tests__data_loader_returns_none__throws_invalidfixtureexception(mocker): - # Arrange - fix = SimpleTestFixture() - - mocker.patch.object(fix.data_loader, 'load_fixture') - fix.data_loader.load_fixture.return_value = None - - # Act / Assert - with pytest.raises(InvalidTestFixtureException): - fix.execute_tests() - -def test__execute_tests__data_loader_returns_empty_dictionary__returns_empty_results(mocker): - # Arrange - fix = SimpleTestFixture() - - mocker.patch.object(fix.data_loader, 'load_fixture') - fix.data_loader.load_fixture.return_value = {} - - # Act - test_exec_results = fix.execute_tests() - - # Assert - assert len(test_exec_results.test_results.results) == 0 - -def test__execute_tests__before_all_set_and_data_loader_returns_empty_dictionary__does_not_call_before_all(mocker): - # Arrange - fix = SimpleTestFixture() - - mocker.patch.object(fix.data_loader, 'load_fixture') - fix.data_loader.load_fixture.return_value = {} - fix.before_all = lambda self: 1 == 1 - - mocker.patch.object(fix, 'before_all') - - # Act - test_results = fix.execute_tests() - - # Assert - fix.before_all.assert_not_called() - -def test__execute_tests__before_all_none_and_data_loader_returns_empty_dictionary__does_not_call_before_all(mocker): - # Arrange - fix = SimpleTestFixture() - - mocker.patch.object(fix.data_loader, 'load_fixture') - fix.data_loader.load_fixture.return_value = {} - fix.before_all = None - - mocker.patch.object(fix, 'before_all') - - # Act - test_results = fix.execute_tests() - - # Assert - fix.before_all.assert_not_called() - -def test__execute_tests__before_all_set_and_data_loader_returns_dictionary_with_testcases__calls_before_all(mocker): - # Arrange - fix = SimpleTestFixture() - - mocker.patch.object(fix.data_loader, 'load_fixture') - - tc = __get_test_case("TestName", fix.run_test, fix.assertion_test) - fix.before_all = lambda self: 1 == 1 - mocker.patch.object(fix, 'before_all') - - test_case_dict = { - "test": tc - } - - fix.data_loader.load_fixture.return_value = test_case_dict - - # Act - fix.execute_tests() - - # Assert - fix.before_all.assert_called_once_with() - -def test__execute_tests__after_all_set_and_data_loader_returns_empty_dictionary__does_not_call_after_all(mocker): - # Arrange - fix = SimpleTestFixture() - - mocker.patch.object(fix.data_loader, 'load_fixture') - fix.data_loader.load_fixture.return_value = {} - fix.after_all = lambda self: 1 == 1 - - mocker.patch.object(fix, 'after_all') - - # Act - test_results = fix.execute_tests() - - # Assert - fix.after_all.assert_not_called() - -def test__execute_tests__after_all_none_and_data_loader_returns_empty_dictionary__does_not_call_after_all(mocker): - # Arrange - fix = SimpleTestFixture() - - mocker.patch.object(fix.data_loader, 'load_fixture') - fix.data_loader.load_fixture.return_value = {} - fix.after_all = None - - mocker.patch.object(fix, 'after_all') - - # Act - test_results = fix.execute_tests() - - # Assert - fix.after_all.assert_not_called() - -def test__execute_tests__after_all_set_and_data_loader_returns_dictionary_with_testcases__calls_after_all(mocker): - # Arrange - fix = SimpleTestFixture() - - mocker.patch.object(fix.data_loader, 'load_fixture') - - tc = __get_test_case("TestName", fix.run_test, fix.assertion_test) - fix.after_all = lambda self: 1 == 1 - mocker.patch.object(fix, 'after_all') - - test_case_dict = { - "test": tc - } - - fix.data_loader.load_fixture.return_value = test_case_dict - - # Act - fix.execute_tests() - - # Assert - fix.after_all.assert_called_once_with() - -def test__execute_tests__data_loader_returns_dictionary_with_testcases__iterates_over_dictionary_and_calls_execute(mocker): - # Arrange - fix = SimpleTestFixture() - mocker.patch.object(fix.data_loader, 'load_fixture') - - tc = __get_test_case("TestName", fix.run_test, fix.assertion_test) - mocker.patch.object(tc, 'execute_test') - tc.execute_test.return_value = TestResult("TestName", True, 1, []) - tc1 = __get_test_case("TestName", fix.run_test, fix.assertion_test) - mocker.patch.object(tc1, 'execute_test') - tc1.execute_test.return_value = TestResult("TestName", True, 1, []) - - test_case_dict = { - "test": tc, - "test1": tc1 - } - - fix.data_loader.load_fixture.return_value = test_case_dict - - # Act - fix.execute_tests() - - # Assert - tc.execute_test.assert_called_once_with() - tc1.execute_test.assert_called_once_with() - -def test__execute_tests__returns_test_result__calls_append_on_testresults(mocker): - # Arrange - fix = SimpleTestFixture() - mocker.patch.object(fix.test_results, 'append') - - tc = __get_test_case("TestName", lambda: 1 == 1, lambda: 1 == 1) - - test_case_dict = { - "test": tc - } - mocker.patch.object(fix.data_loader, 'load_fixture') - fix.data_loader.load_fixture.return_value = test_case_dict - - # Act - result = fix.execute_tests() - - # Assert - fix.test_results.append.assert_called_once_with(mocker.ANY) - -def test__execute_tests__two_test_cases__returns_test_results_with_2_test_results(mocker): - # Arrange - fix = SimpleTestFixture() - - tc = __get_test_case("TestName", lambda: 1 == 1, lambda: 1 == 1) - tc1 = __get_test_case("TestName1", lambda: 1 == 1, lambda: 1 == 1) - - test_case_dict = { - "TestName": tc, - "TestName1": tc1 - } - - mocker.patch.object(fix.data_loader, 'load_fixture') - fix.data_loader.load_fixture.return_value = test_case_dict - - # Act - result = fix.execute_tests() - - # Assert - assert len(result.test_results.results) == 2 - -def test__execute_tests__test_names_not_in_order_in_class__tests_executed_in_alphabetical_order(): - # Arrange - fix = OutOfOrderTestFixture() - - # Act - fix.execute_tests() - - # Assert - assert '1wxyz' == fix.get_method_order() - -def test__execute_tests__subclass_init_does_not_call_NutterFixture_init__throws_InitializationException(): - # Arrange - fix = TestFixtureThatDoesNotCallBaseCtor() - - # Act - with pytest.raises(InitializationException): - fix.execute_tests() - -def test__run_test_method__has_list_tag_decorator__list_set_on_method(): - # Arrange - class Wrapper(NutterFixture): - tag_list = ["tag1", "tag2"] - @tag(tag_list) - def run_test(self): - lambda: 1 == 1 - - test_name = "test" - tag_list = ["tag1", "tag2"] - - test_fixture = TestNutterFixtureBuilder() \ - .with_name("MyClass") \ - .with_assertion(test_name) \ - .with_run(test_name, Wrapper.run_test) \ - .build() - - # Act / Assert - assert tag_list == test_fixture.run_test.tag - -def test__run_test_method__has_str_tag_decorator__str_set_on_method(): - # Arrange - class Wrapper(NutterFixture): - tag_str = "mytag" - @tag(tag_str) - def run_test(self): - lambda: 1 == 1 - - test_name = "test" - test_fixture = TestNutterFixtureBuilder() \ - .with_name("MyClass") \ - .with_assertion(test_name) \ - .with_run(test_name, Wrapper.run_test) \ - .build() - - # Act / Assert - assert "mytag" == test_fixture.run_test.tag - -def test__run_test_method__has_tag_decorator_not_list__raises_value_error(): - # Arrange - with pytest.raises(ValueError): - class Wrapper(NutterFixture): - tag_invalid = {} - @tag(tag_invalid) - def run_test(self): - lambda: 1 == 1 - -def test__run_test_method__has_tag_decorator_not_listhas_invalid_tag_decorator_none__raises_value_error(): - # Arrange - with pytest.raises(ValueError): - class Wrapper(NutterFixture): - tag_invalid = None - @tag(tag_invalid) - def run_test(self): - lambda: 1 == 1 - -def test__non_run_test_method__valid_tag_on_non_run_method__raises_value_error(): - # Arrange - with pytest.raises(ValueError): - class Wrapper(NutterFixture): - tag_valid = "mytag" - @tag(tag_valid) - def assertion_test(self): - lambda: 1 == 1 - -def __get_test_case(name, setrun, setassert): - tc = TestCase(name) - if setrun != None: - tc.set_run(setrun) - tc.set_assertion(setassert) - - return tc - -def test__run_test_method__has_invalid_tag_decorator_not_list_or_str_using_class_not_builder__raises_value_error(): - # Arrange - simple_test_fixture = SimpleTestFixture() - - # Act / Assert - with pytest.raises(ValueError): - simple_test_fixture.run_test_with_invalid_decorator() - -def test__run_test_method__has_valid_tag_decorator_in_class__tag_set_on_method(): - # Arrange - simple_test_fixture = SimpleTestFixture() - - # Act / Assert - assert "mytag" == simple_test_fixture.run_test_with_valid_decorator.tag - -class SimpleTestFixture(NutterFixture): - - def before_test(self): - pass - - def run_test(self): - pass - - def assertion_test(self): - assert 1 == 1 - - def after_test(self): - pass - - @tag("mytag") - def run_test_with_valid_decorator(self): - pass - - @tag - def run_test_with_invalid_decorator(self): - pass - -class OutOfOrderTestFixture(NutterFixture): - def __init__(self): - super(OutOfOrderTestFixture, self).__init__() - self.__method_order = '' - - def assertion_y(self): - self.__method_order += 'y' - assert 1 == 1 - - def assertion_z(self): - self.__method_order += 'z' - assert 1 == 1 - - def assertion_1(self): - self.__method_order += '1' - assert 1 == 1 - - def assertion_w(self): - self.__method_order += 'w' - assert 1 == 1 - - def assertion_x(self): - self.__method_order += 'x' - assert 1 == 1 - - def get_method_order(self): - return self.__method_order - -class TestFixtureThatDoesNotCallBaseCtor(NutterFixture): - def __init__(self): - pass - - def assertion_test_case(self): - assert 1 == 1 +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. +""" + +import pytest +from runtime.nutterfixture import NutterFixture, tag, InvalidTestFixtureException, InitializationException +from runtime.testcase import TestCase +from runtime.fixtureloader import FixtureLoader +from common.testresult import TestResult, TestResults +from tests.runtime.testnutterfixturebuilder import TestNutterFixtureBuilder +from common.apiclientresults import ExecuteNotebookResult +import sys + +def test__ctor__creates_fixture_loader(): + # Arrange / Act + fix = SimpleTestFixture() + + # Assert + assert fix.data_loader is not None + +def test__execute_tests__calls_load_fixture_on_fixture_loader(mocker): + # Arrange + fix = SimpleTestFixture() + + mocker.patch.object(fix.data_loader, 'load_fixture') + + # Act + fix.execute_tests() + + # Assert + fix.data_loader.load_fixture.assert_called_once_with(fix) + +def test__execute_tests__data_loader_returns_none__throws_invalidfixtureexception(mocker): + # Arrange + fix = SimpleTestFixture() + + mocker.patch.object(fix.data_loader, 'load_fixture') + fix.data_loader.load_fixture.return_value = None + + # Act / Assert + with pytest.raises(InvalidTestFixtureException): + fix.execute_tests() + +def test__execute_tests__data_loader_returns_empty_dictionary__returns_empty_results(mocker): + # Arrange + fix = SimpleTestFixture() + + mocker.patch.object(fix.data_loader, 'load_fixture') + fix.data_loader.load_fixture.return_value = {} + + # Act + test_exec_results = fix.execute_tests() + + # Assert + assert len(test_exec_results.test_results.results) == 0 + +def test__execute_tests__before_all_set_and_data_loader_returns_empty_dictionary__does_not_call_before_all(mocker): + # Arrange + fix = SimpleTestFixture() + + mocker.patch.object(fix.data_loader, 'load_fixture') + fix.data_loader.load_fixture.return_value = {} + fix.before_all = lambda self: 1 == 1 + + mocker.patch.object(fix, 'before_all') + + # Act + test_results = fix.execute_tests() + + # Assert + fix.before_all.assert_not_called() + +def test__execute_tests__before_all_none_and_data_loader_returns_empty_dictionary__does_not_call_before_all(mocker): + # Arrange + fix = SimpleTestFixture() + + mocker.patch.object(fix.data_loader, 'load_fixture') + fix.data_loader.load_fixture.return_value = {} + fix.before_all = None + + mocker.patch.object(fix, 'before_all') + + # Act + test_results = fix.execute_tests() + + # Assert + fix.before_all.assert_not_called() + +def test__execute_tests__before_all_set_and_data_loader_returns_dictionary_with_testcases__calls_before_all(mocker): + # Arrange + fix = SimpleTestFixture() + + mocker.patch.object(fix.data_loader, 'load_fixture') + + tc = __get_test_case("TestName", fix.run_test, fix.assertion_test) + fix.before_all = lambda self: 1 == 1 + mocker.patch.object(fix, 'before_all') + + test_case_dict = { + "test": tc + } + + fix.data_loader.load_fixture.return_value = test_case_dict + + # Act + fix.execute_tests() + + # Assert + fix.before_all.assert_called_once_with() + +def test__execute_tests__after_all_set_and_data_loader_returns_empty_dictionary__does_not_call_after_all(mocker): + # Arrange + fix = SimpleTestFixture() + + mocker.patch.object(fix.data_loader, 'load_fixture') + fix.data_loader.load_fixture.return_value = {} + fix.after_all = lambda self: 1 == 1 + + mocker.patch.object(fix, 'after_all') + + # Act + test_results = fix.execute_tests() + + # Assert + fix.after_all.assert_not_called() + +def test__execute_tests__after_all_none_and_data_loader_returns_empty_dictionary__does_not_call_after_all(mocker): + # Arrange + fix = SimpleTestFixture() + + mocker.patch.object(fix.data_loader, 'load_fixture') + fix.data_loader.load_fixture.return_value = {} + fix.after_all = None + + mocker.patch.object(fix, 'after_all') + + # Act + test_results = fix.execute_tests() + + # Assert + fix.after_all.assert_not_called() + +def test__execute_tests__after_all_set_and_data_loader_returns_dictionary_with_testcases__calls_after_all(mocker): + # Arrange + fix = SimpleTestFixture() + + mocker.patch.object(fix.data_loader, 'load_fixture') + + tc = __get_test_case("TestName", fix.run_test, fix.assertion_test) + fix.after_all = lambda self: 1 == 1 + mocker.patch.object(fix, 'after_all') + + test_case_dict = { + "test": tc + } + + fix.data_loader.load_fixture.return_value = test_case_dict + + # Act + fix.execute_tests() + + # Assert + fix.after_all.assert_called_once_with() + +def test__execute_tests__data_loader_returns_dictionary_with_testcases__iterates_over_dictionary_and_calls_execute(mocker): + # Arrange + fix = SimpleTestFixture() + mocker.patch.object(fix.data_loader, 'load_fixture') + + tc = __get_test_case("TestName", fix.run_test, fix.assertion_test) + mocker.patch.object(tc, 'execute_test') + tc.execute_test.return_value = TestResult("TestName", True, 1, []) + tc1 = __get_test_case("TestName", fix.run_test, fix.assertion_test) + mocker.patch.object(tc1, 'execute_test') + tc1.execute_test.return_value = TestResult("TestName", True, 1, []) + + test_case_dict = { + "test": tc, + "test1": tc1 + } + + fix.data_loader.load_fixture.return_value = test_case_dict + + # Act + fix.execute_tests() + + # Assert + tc.execute_test.assert_called_once_with() + tc1.execute_test.assert_called_once_with() + +def test__execute_tests__returns_test_result__calls_append_on_testresults(mocker): + # Arrange + fix = SimpleTestFixture() + mocker.patch.object(fix.test_results, 'append') + + tc = __get_test_case("TestName", lambda: 1 == 1, lambda: 1 == 1) + + test_case_dict = { + "test": tc + } + mocker.patch.object(fix.data_loader, 'load_fixture') + fix.data_loader.load_fixture.return_value = test_case_dict + + # Act + result = fix.execute_tests() + + # Assert + fix.test_results.append.assert_called_once_with(mocker.ANY) + +def test__execute_tests__two_test_cases__returns_test_results_with_2_test_results(mocker): + # Arrange + fix = SimpleTestFixture() + + tc = __get_test_case("TestName", lambda: 1 == 1, lambda: 1 == 1) + tc1 = __get_test_case("TestName1", lambda: 1 == 1, lambda: 1 == 1) + + test_case_dict = { + "TestName": tc, + "TestName1": tc1 + } + + mocker.patch.object(fix.data_loader, 'load_fixture') + fix.data_loader.load_fixture.return_value = test_case_dict + + # Act + result = fix.execute_tests() + + # Assert + assert len(result.test_results.results) == 2 + +def test__execute_tests__test_names_not_in_order_in_class__tests_executed_in_alphabetical_order(): + # Arrange + fix = OutOfOrderTestFixture() + + # Act + fix.execute_tests() + + # Assert + assert '1wxyz' == fix.get_method_order() + +def test__execute_tests__subclass_init_does_not_call_NutterFixture_init__throws_InitializationException(): + # Arrange + fix = TestFixtureThatDoesNotCallBaseCtor() + + # Act + with pytest.raises(InitializationException): + fix.execute_tests() + +def test__run_test_method__has_list_tag_decorator__list_set_on_method(): + # Arrange + class Wrapper(NutterFixture): + tag_list = ["tag1", "tag2"] + @tag(tag_list) + def run_test(self): + lambda: 1 == 1 + + test_name = "test" + tag_list = ["tag1", "tag2"] + + test_fixture = TestNutterFixtureBuilder() \ + .with_name("MyClass") \ + .with_assertion(test_name) \ + .with_run(test_name, Wrapper.run_test) \ + .build() + + # Act / Assert + assert tag_list == test_fixture.run_test.tag + +def test__run_test_method__has_str_tag_decorator__str_set_on_method(): + # Arrange + class Wrapper(NutterFixture): + tag_str = "mytag" + @tag(tag_str) + def run_test(self): + lambda: 1 == 1 + + test_name = "test" + test_fixture = TestNutterFixtureBuilder() \ + .with_name("MyClass") \ + .with_assertion(test_name) \ + .with_run(test_name, Wrapper.run_test) \ + .build() + + # Act / Assert + assert "mytag" == test_fixture.run_test.tag + +def test__run_test_method__has_tag_decorator_not_list__raises_value_error(): + # Arrange + with pytest.raises(ValueError): + class Wrapper(NutterFixture): + tag_invalid = {} + @tag(tag_invalid) + def run_test(self): + lambda: 1 == 1 + +def test__run_test_method__has_tag_decorator_not_listhas_invalid_tag_decorator_none__raises_value_error(): + # Arrange + with pytest.raises(ValueError): + class Wrapper(NutterFixture): + tag_invalid = None + @tag(tag_invalid) + def run_test(self): + lambda: 1 == 1 + +def test__non_run_test_method__valid_tag_on_non_run_method__raises_value_error(): + # Arrange + with pytest.raises(ValueError): + class Wrapper(NutterFixture): + tag_valid = "mytag" + @tag(tag_valid) + def assertion_test(self): + lambda: 1 == 1 + +def __get_test_case(name, setrun, setassert): + tc = TestCase(name) + if setrun != None: + tc.set_run(setrun) + tc.set_assertion(setassert) + + return tc + +def test__run_test_method__has_invalid_tag_decorator_not_list_or_str_using_class_not_builder__raises_value_error(): + # Arrange + simple_test_fixture = SimpleTestFixture() + + # Act / Assert + with pytest.raises(ValueError): + simple_test_fixture.run_test_with_invalid_decorator() + +def test__run_test_method__has_valid_tag_decorator_in_class__tag_set_on_method(): + # Arrange + simple_test_fixture = SimpleTestFixture() + + # Act / Assert + assert "mytag" == simple_test_fixture.run_test_with_valid_decorator.tag + +class SimpleTestFixture(NutterFixture): + + def before_test(self): + pass + + def run_test(self): + pass + + def assertion_test(self): + assert 1 == 1 + + def after_test(self): + pass + + @tag("mytag") + def run_test_with_valid_decorator(self): + pass + + @tag + def run_test_with_invalid_decorator(self): + pass + +class OutOfOrderTestFixture(NutterFixture): + def __init__(self): + super(OutOfOrderTestFixture, self).__init__() + self.__method_order = '' + + def assertion_y(self): + self.__method_order += 'y' + assert 1 == 1 + + def assertion_z(self): + self.__method_order += 'z' + assert 1 == 1 + + def assertion_1(self): + self.__method_order += '1' + assert 1 == 1 + + def assertion_w(self): + self.__method_order += 'w' + assert 1 == 1 + + def assertion_x(self): + self.__method_order += 'x' + assert 1 == 1 + + def get_method_order(self): + return self.__method_order + +class TestFixtureThatDoesNotCallBaseCtor(NutterFixture): + __test__ = False # Tell pytest this is not a test class + + def __init__(self): + pass + + def assertion_test_case(self): + assert 1 == 1 diff --git a/tests/runtime/testnutterfixturebuilder.py b/tests/runtime/testnutterfixturebuilder.py index 2dfcdb5..d6b9fa9 100644 --- a/tests/runtime/testnutterfixturebuilder.py +++ b/tests/runtime/testnutterfixturebuilder.py @@ -6,6 +6,8 @@ from runtime.nutterfixture import NutterFixture class TestNutterFixtureBuilder(): + __test__ = False # Tell pytest this is not a test class + def __init__(self): self.attributes = {} self.class_name = "ImplementingClass"