diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 095d0add..ab761e3d 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -131,10 +131,11 @@ jobs: --junitxml=artifacts/tests/results.xml \ --cov=./backends \ --cov=./cli \ - --cov=./objects \ - --cov=./sdk \ --cov=./clients \ --cov=./jupyter \ + --cov=./objects \ + --cov=./sdk \ + --cov=./transformers \ --cov-report=term \ --cov-report=xml:artifacts/tests/coverage.xml \ --cov-report=html:artifacts/tests/coverage.html \ diff --git a/.vscode/launch.json b/.vscode/launch.json index 08555fb2..a0e4d2f2 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -4,6 +4,12 @@ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ + { + "name": "Python: Attach using Process Id", + "type": "python", + "request": "attach", + "processId": "${command:pickProcess}" + }, { "name": "Python: Current File", "type": "python", diff --git a/backends/durable_functions/execute_steps_workflow/execute_steps_workflow.py b/backends/durable_functions/execute_steps_workflow/execute_steps_workflow.py index 2ae31644..32fc2f6a 100644 --- a/backends/durable_functions/execute_steps_workflow/execute_steps_workflow.py +++ b/backends/durable_functions/execute_steps_workflow/execute_steps_workflow.py @@ -34,13 +34,6 @@ def _execute_steps_workflow(context: df.DurableOrchestrationContext): # TODO: Add unit test to resume from non-zero ID id = input.get("idin", 0) - # TODO Determine which ones can be executed in parallel and construct the appropriate DAG - # executions = [] - # for step in steps: - # execution = context.call_activity(EXECUTE_STEP_ACTIVITY_NAME, step) - # executions.append(execution) - # results = yield context.task_all(executions) - # Execute all steps one after the other # Note: Assuming that the list of steps is in order of required (sequential) execution results = [] diff --git a/backends/durable_functions/start_steps_workflow/start_steps_workflow.py b/backends/durable_functions/start_steps_workflow/start_steps_workflow.py index 0d5b913b..b796b306 100644 --- a/backends/durable_functions/start_steps_workflow/start_steps_workflow.py +++ b/backends/durable_functions/start_steps_workflow/start_steps_workflow.py @@ -26,7 +26,8 @@ async def start_steps_workflow(req: func.HttpRequest, starter: str) -> func.Http stats = {} response_payload = { "result": result, - "stats": stats + "stats": stats, + "HTTP_params": req.params, } return http_utils.generate_response(response_payload, status_code) diff --git a/clients/durable_functions_client.py b/clients/durable_functions_client.py index 8234c3db..02d47533 100644 --- a/clients/durable_functions_client.py +++ b/clients/durable_functions_client.py @@ -20,9 +20,6 @@ def __init__( user: str, start_state_id: int = 0 ): - if backend_host.endswith('/'): - backend_host = backend_host[-1] - self.backend_host = backend_host self.user = user self.next_state_id = start_state_id self.connect_timeout_sec = 10 @@ -30,11 +27,16 @@ def __init__( self.status_query_timeout_sec = 10 self.retry_interval_sec = 1 self.session = requests.Session() - self.workflow_url = f"{self.backend_host}/api/orchestrators/{EXECUTE_WORKFLOW_ACTIVITY_NAME}" + self.set_workflow_url(backend_host) def __del__(self): self.session.close() + def set_workflow_url(self, backend_host: str): + if backend_host.endswith('/'): + backend_host = backend_host[-1] + self.workflow_url = f"{backend_host}/api/orchestrators/{EXECUTE_WORKFLOW_ACTIVITY_NAME}" + def execute_notebook(self, notebook_path: str) -> List[dict]: """ Execute a given notebook and return the output of each Step. diff --git a/jupyter/kernel/context.py b/jupyter/kernel/context.py index 084905a7..12a036c0 100644 --- a/jupyter/kernel/context.py +++ b/jupyter/kernel/context.py @@ -5,3 +5,4 @@ from clients.durable_functions_client import DurableFunctionsClient from objects.step import Step +from backends.common.serialization_utils import deserialize_obj diff --git a/jupyter/kernel/kernel_logger.py b/jupyter/kernel/kernel_logger.py new file mode 100644 index 00000000..3dcffd1f --- /dev/null +++ b/jupyter/kernel/kernel_logger.py @@ -0,0 +1,42 @@ +import logging + + +class Colors: + """Class with a few pre-defined ANSI colors for cleaner output. + The list was extracted from: + https://gist.github.com/rene-d/9e584a7dd2935d0f461904b9f2950007 + """ + HEADER = '\033[95m' + BLUE = '\033[94m' + GREEN = '\033[92m' + YELLOW = '\033[93m' + RED = '\033[91m' + ENDC = '\033[0m' + LIGHT_GRAY = "\033[0;37m" + DARK_GRAY = "\033[1;30m" + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + + +class KernelLogger: + def __init__(self, prefix=''): + self.prefix = prefix + self.log_level = logging.DEBUG + + def log(self, msg, msg_log_level=logging.INFO): + """Log to the output which is visible to the user. + """ + if msg_log_level >= self.log_level: + if msg_log_level == logging.DEBUG: + print(self._get_log_start(), f"{Colors.LIGHT_GRAY} {msg} {Colors.ENDC}") + elif msg_log_level == logging.WARNING: + print(self._get_log_start(), f"{Colors.YELLOW} {msg} {Colors.ENDC}") + elif msg_log_level == logging.ERROR: + print(self._get_log_start(), f"{Colors.RED} {msg} {Colors.ENDC}") + else: + print(self._get_log_start(), msg) + + def _get_log_start(self): + """Get the initial part of the log. + """ + return f"{Colors.BLUE} {Colors.BOLD} {self.prefix} {Colors.ENDC} {Colors.ENDC}" diff --git a/jupyter/kernel/magics/config_magic.py b/jupyter/kernel/magics/config_magic.py new file mode 100644 index 00000000..96258a56 --- /dev/null +++ b/jupyter/kernel/magics/config_magic.py @@ -0,0 +1,86 @@ +from metakernel import Magic + + +class ConfigMagic(Magic): + def line_config(self, line): + self._handle_magic(line) + + def cell_config(self, line): + self._handle_magic(line) + + def _print_error(self, line, exception): + self.kernel.Error(f"Config failed: {line} - {exception}") + + def _handle_magic(self, line): + try: + self._handle_magic_core(line) + except Exception as exception: + self._print_error(line, exception) + + def _handle_magic_core(self, line): + tokens = line.split(' ') + command = tokens[0] + if command == "debug": + arg = tokens[1] + if arg == "enable": + self.kernel.debug_mode = True + elif arg == "disable": + self.kernel.debug_mode = False + else: + raise Exception(f"Unknown arg: {arg}") + elif command == "user": + arg = tokens[1] + if arg == "set": + user = tokens[2] + self.kernel.set_user(user) + elif arg == "get": + self.kernel.logger.log(f"Current user is: {self.kernel.get_user()}") + else: + raise Exception(f"Unknown arg: {arg}") + elif command == "state": + arg = tokens[1] + if arg == "set": + state_str = tokens[2] + state = int(state_str) + self.kernel.set_next_state_id(state) + elif arg == "get": + self.kernel.logger.log(f"Next state ID is: {self.kernel.get_next_state_id()}") + else: + raise Exception(f"Unknown arg: {arg}") + elif command == "host": + arg = tokens[1] + if arg == "set": + host = tokens[2] + self.kernel.set_backend_host(host) + elif arg == "get": + self.kernel.logger.log(f"Backend URL: {self.kernel.get_backend_host()}") + else: + raise Exception(f"Unknown arg: {arg}") + else: + raise Exception(f"Invalid command: {command}") + + +def register_magics(kernel): + kernel.register_magics(ConfigMagic) + + +def register_ipython_magics(): + from metakernel import IPythonKernel + from metakernel.utils import add_docs + from IPython.core.magic import register_line_magic, register_cell_magic + kernel = IPythonKernel() + magic = ConfigMagic(kernel) + # Make magics callable: + kernel.line_magics["config"] = magic + kernel.cell_magics["config"] = magic + + @register_line_magic + @add_docs(magic.line_config.__doc__) + def config(line): + kernel.call_magic("%config " + line) + + @register_cell_magic + @add_docs(magic.cell_config.__doc__) + def config(line, cell): + magic.code = cell + magic.cell_config(line) diff --git a/jupyter/kernel/same_kernel.py b/jupyter/kernel/same_kernel.py index 4fcede4d..76d6fef8 100644 --- a/jupyter/kernel/same_kernel.py +++ b/jupyter/kernel/same_kernel.py @@ -1,7 +1,12 @@ from __future__ import print_function from context import Step -from clients.durable_functions_client import DurableFunctionsClient +from context import deserialize_obj +from context import DurableFunctionsClient +from kernel_logger import KernelLogger from metakernel import MetaKernel +from IPython.core.interactiveshell import InteractiveShell +from six import reraise +import logging import sys @@ -27,18 +32,43 @@ class SAMEKernel(MetaKernel): "language": "python", "name": "same_kernel" } + _interactive_shell = None _same_client = None + _logger = None + _debug_mode = True + + @property + def debug_mode(self): + return self._debug_mode + + @debug_mode.setter + def debug_mode(self, value): + self._debug_mode = value + + @property + def interactive_shell(self): + if self._interactive_shell: + return self._interactive_shell + self._interactive_shell = InteractiveShell() + return self._interactive_shell @property def same_client(self): if self._same_client: return self._same_client backend_host = "http://localhost:7071" - user = "gochaudh" + user = "default" start_state_id = 0 self._same_client = DurableFunctionsClient(backend_host, user, start_state_id) return self._same_client + @property + def logger(self): + if self._logger: + return self._logger + self._logger = KernelLogger(prefix='SAME') + return self._logger + def get_usage(self): return ("This is the SAME Python Kernel") @@ -57,24 +87,50 @@ def get_variable(self, name): return python_magic.env.get(name, None) def do_execute_direct(self, code): + try: + self._do_execute_direct_core(code) + except: + self.interactive_shell.showtraceback() + + def _do_execute_direct_core(self, code): code_stripped = code.strip() step : Step = Step(code=code_stripped) steps = [step] + outputs = self.same_client.execute_steps(steps) + self.logger.log(outputs, logging.DEBUG) + assert len(outputs) == 1 output = outputs[0] result = output["result"] - stdout = result["stdout"] - stderr = result["stderr"] - exec_result = result["exec_result"] - if stdout and stdout != "": - print(stdout) - if exec_result and exec_result != "": - print(exec_result) - if stderr and stderr != "": - print(stderr, file=sys.stderr) - # TODO: Do this in DEBUG mode only. - print(result) + + status = result["status"] + if status == "success": + stdout = result["stdout"] + stderr = result["stderr"] + exec_result = result["exec_result"] + if stdout and stdout != "": + self.Print(stdout) + if exec_result and exec_result != "": + self.Print(exec_result) + if stderr and stderr != "": + self.Error(stderr, file=sys.stderr) + elif status == "fail": + reason = result["reason"] + info = result["info"] + if reason == "exception": + exception_base64 = info["exception"] + exception = deserialize_obj(exception_base64) + if type(exception) is tuple: + # This comes from sys.exc_info() + exception_tuple = exception + exception_class = exception_tuple[0] + exception_value = exception_tuple[1] + exception_traceback = exception_tuple[2] + # TODO clean the stack trace + reraise(exception_class, exception_value, exception_traceback) + else: + raise exception return exec_result def get_completions(self, info): @@ -85,6 +141,27 @@ def get_kernel_help_on(self, info, level=0, none_on_fail=False): python_magic = self.line_magics['python'] return python_magic.get_help_on(info, level, none_on_fail) + def set_user(self, user: str): + self.same_client.user = user + self.logger.log(f"Set user to: {user}") + + def set_next_state_id(self, id: int): + self.same_client.next_state_id = id + self.logger.log(f"Set next state ID to: {id}") + + def set_backend_host(self, backend_host: str): + self.same_client.set_workflow_url(backend_host) + self.logger.log(f"Set backend host to: {backend_host}") + + def get_user(self): + return self.same_client.user + + def get_next_state_id(self): + return self.same_client.next_state_id + + def get_backend_host(self): + return self.same_client.workflow_url + if __name__ == '__main__': SAMEKernel.run_as_main() diff --git a/objects/step.py b/objects/step.py index 3c435873..befc2eff 100644 --- a/objects/step.py +++ b/objects/step.py @@ -1,5 +1,6 @@ from __future__ import annotations from .json_serializable_object import JSONSerializableObject +from typing import Optional from uuid import uuid4 @@ -13,20 +14,20 @@ def __init__( name: str = "same_step_unset", cache_value: str = "P0D", environment_name: str = "default", - tags: list = [], + tags: Optional[list] = None, index: int = -1, code: str = "", - parameters: list = [], - packages_to_install: list = [] + parameters: Optional[list] = None, + packages_to_install: Optional[list] = None ): self.name = name self.cache_value = cache_value self.environment_name = environment_name - self.tags = tags + self.tags = tags if tags is not None else [] self.index = index self.code = code - self.parameters = parameters - self.packages_to_install = packages_to_install + self.parameters = parameters if parameters is not None else [] + self.packages_to_install = packages_to_install if packages_to_install is not None else [] @property def name(self): diff --git a/poetry.lock b/poetry.lock index 266476c5..a6c90754 100644 --- a/poetry.lock +++ b/poetry.lock @@ -89,6 +89,14 @@ dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pytest", "sphinx", "furo", "wh docs = ["sphinx", "furo"] tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pytest"] +[[package]] +name = "astor" +version = "0.8.1" +description = "Read/rewrite/write Python ASTs" +category = "main" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" + [[package]] name = "async-timeout" version = "3.0.1" @@ -677,6 +685,30 @@ sdist = ["setuptools-rust (>=0.11.4)"] ssh = ["bcrypt (>=3.1.5)"] test = ["pytest (>=6.0)", "pytest-cov", "pytest-subtests", "pytest-xdist", "pretend", "iso8601", "pytz", "hypothesis (>=1.11.4,!=3.79.2)"] +[[package]] +name = "dask" +version = "2021.10.0" +description = "Parallel PyData with Task Scheduling" +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +cloudpickle = ">=1.1.1" +fsspec = ">=0.6.0" +packaging = ">=20.0" +partd = ">=0.3.10" +pyyaml = "*" +toolz = ">=0.8.2" + +[package.extras] +array = ["numpy (>=1.18)"] +complete = ["bokeh (>=1.0.0,!=2.0.0)", "distributed (==2021.10.0)", "jinja2", "numpy (>=1.18)", "pandas (>=1.0)"] +dataframe = ["numpy (>=1.18)", "pandas (>=1.0)"] +diagnostics = ["bokeh (>=1.0.0,!=2.0.0)", "jinja2"] +distributed = ["distributed (==2021.10.0)"] +test = ["pytest", "pytest-rerunfailures", "pytest-xdist", "pre-commit"] + [[package]] name = "debugpy" version = "1.5.0" @@ -837,6 +869,36 @@ mccabe = ">=0.6.0,<0.7.0" pycodestyle = ">=2.7.0,<2.8.0" pyflakes = ">=2.3.0,<2.4.0" +[[package]] +name = "fsspec" +version = "2021.10.1" +description = "File-system specification" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.extras] +abfs = ["adlfs"] +adl = ["adlfs"] +arrow = ["pyarrow (>=1)"] +dask = ["dask", "distributed"] +dropbox = ["dropboxdrivefs", "requests", "dropbox"] +entrypoints = ["importlib-metadata"] +fuse = ["fusepy"] +gcs = ["gcsfs"] +git = ["pygit2"] +github = ["requests"] +gs = ["gcsfs"] +gui = ["panel"] +hdfs = ["pyarrow (>=1)"] +http = ["requests", "aiohttp"] +libarchive = ["libarchive-c"] +oci = ["ocifs"] +s3 = ["s3fs"] +sftp = ["paramiko"] +smb = ["smbprotocol"] +ssh = ["paramiko"] + [[package]] name = "furl" version = "2.1.3" @@ -1393,6 +1455,14 @@ websocket-client = ">=0.32.0,<0.40.0 || >0.40.0,<0.41.0 || >=0.43.0" [package.extras] adal = ["adal (>=1.0.2)"] +[[package]] +name = "locket" +version = "0.2.1" +description = "File-based locks for Python for Linux and Windows" +category = "main" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + [[package]] name = "markdown-it-py" version = "1.1.0" @@ -1737,6 +1807,26 @@ python-versions = ">=3.6" [package.dependencies] pyparsing = ">=2.0.2" +[[package]] +name = "pandas" +version = "1.3.4" +description = "Powerful data structures for data analysis, time series, and statistics" +category = "main" +optional = false +python-versions = ">=3.7.1" + +[package.dependencies] +numpy = [ + {version = ">=1.17.3", markers = "platform_machine != \"aarch64\" and platform_machine != \"arm64\" and python_version < \"3.10\""}, + {version = ">=1.19.2", markers = "platform_machine == \"aarch64\" and python_version < \"3.10\""}, + {version = ">=1.20.0", markers = "platform_machine == \"arm64\" and python_version < \"3.10\""}, +] +python-dateutil = ">=2.7.3" +pytz = ">=2017.3" + +[package.extras] +test = ["hypothesis (>=3.58)", "pytest (>=6.0)", "pytest-xdist"] + [[package]] name = "pandocfilters" version = "1.5.0" @@ -1757,6 +1847,21 @@ python-versions = ">=3.6" qa = ["flake8 (==3.8.3)", "mypy (==0.782)"] testing = ["docopt", "pytest (<6.0.0)"] +[[package]] +name = "partd" +version = "1.2.0" +description = "Appendable key-value storage" +category = "main" +optional = false +python-versions = ">=3.5" + +[package.dependencies] +locket = "*" +toolz = "*" + +[package.extras] +complete = ["numpy (>=1.9.0)", "pandas (>=0.19.0)", "pyzmq", "blosc"] + [[package]] name = "path" version = "16.2.0" @@ -2013,7 +2118,7 @@ python-versions = ">=3.5" [[package]] name = "pyjwt" -version = "2.2.0" +version = "2.3.0" description = "JSON Web Token implementation in Python" category = "main" optional = false @@ -2484,6 +2589,14 @@ category = "dev" optional = false python-versions = ">=3.6" +[[package]] +name = "toolz" +version = "0.11.1" +description = "List processing tools and functional utilities" +category = "main" +optional = false +python-versions = ">=3.5" + [[package]] name = "tornado" version = "6.1" @@ -2621,7 +2734,7 @@ multidict = ">=4.0" [metadata] lock-version = "1.1" python-versions = "~3.8" -content-hash = "db677c737335c94a7d38fd8b8b2997b38c0df8ea87db65cd1725b5183dc7ab4d" +content-hash = "7860dfa37a198a844968101934ed4d1854a3169666b1f2cad4b87ccea2ec05ae" [metadata.files] absl-py = [ @@ -2696,6 +2809,10 @@ argon2-cffi = [ {file = "argon2_cffi-21.1.0-pp37-pypy37_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:165cadae5ac1e26644f5ade3bd9c18d89963be51d9ea8817bd671006d7909057"}, {file = "argon2_cffi-21.1.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:566ffb581bbd9db5562327aee71b2eda24a1c15b23a356740abe3c011bbe0dcb"}, ] +astor = [ + {file = "astor-0.8.1-py2.py3-none-any.whl", hash = "sha256:070a54e890cefb5b3739d19f30f5a5ec840ffc9c50ffa7d23cc9fc1a38ebbfc5"}, + {file = "astor-0.8.1.tar.gz", hash = "sha256:6a6effda93f4e1ce9f618779b2dd1d9d84f1e32812c23a29b3fff6fd7f63fa5e"}, +] async-timeout = [ {file = "async-timeout-3.0.1.tar.gz", hash = "sha256:0c3c816a028d47f659d6ff5c745cb2acf1f966da1fe5c19c77a70282b25f4c5f"}, {file = "async_timeout-3.0.1-py3-none-any.whl", hash = "sha256:4291ca197d287d274d0b6cb5d6f8f8f82d434ed288f962539ff18cc9012f9ea3"}, @@ -2991,6 +3108,10 @@ cryptography = [ {file = "cryptography-3.4.8-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:cd65b60cfe004790c795cc35f272e41a3df4631e2fb6b35aa7ac6ef2859d554e"}, {file = "cryptography-3.4.8.tar.gz", hash = "sha256:94cc5ed4ceaefcbe5bf38c8fba6a21fc1d365bb8fb826ea1688e3370b2e24a1c"}, ] +dask = [ + {file = "dask-2021.10.0-py3-none-any.whl", hash = "sha256:b678d802ea8126c0168a1b429182b7dc01ade8ad84cfd62b804c52c4c29b6fc3"}, + {file = "dask-2021.10.0.tar.gz", hash = "sha256:a98b2e44acaad369bb21f79fc92f756532acfe62c3aeba3982006b7339d0d2e3"}, +] debugpy = [ {file = "debugpy-1.5.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:098753d30232d1e4264eee37e1ddd5d106dc5c4bc6d8d7f4dadad9e44736cd48"}, {file = "debugpy-1.5.0-cp310-cp310-win32.whl", hash = "sha256:33e8a9b4949be8b4f5fcfff07e24bd63c565060659f1c79773c08d19eee012f2"}, @@ -3070,6 +3191,10 @@ flake8 = [ {file = "flake8-3.9.2-py2.py3-none-any.whl", hash = "sha256:bf8fd333346d844f616e8d47905ef3a3384edae6b4e9beb0c5101e25e3110907"}, {file = "flake8-3.9.2.tar.gz", hash = "sha256:07528381786f2a6237b061f6e96610a4167b226cb926e2aa2b6b1d78057c576b"}, ] +fsspec = [ + {file = "fsspec-2021.10.1-py3-none-any.whl", hash = "sha256:7164a488f3f5bf6a0fb39674978b756dda84e011a5db411a79791b7c38a36ff7"}, + {file = "fsspec-2021.10.1.tar.gz", hash = "sha256:c245626e3cb8de5cd91485840b215a385fa6f2b0f6ab87978305e99e2d842753"}, +] furl = [ {file = "furl-2.1.3-py2.py3-none-any.whl", hash = "sha256:9ab425062c4217f9802508e45feb4a83e54324273ac4b202f1850363309666c0"}, {file = "furl-2.1.3.tar.gz", hash = "sha256:5a6188fe2666c484a12159c18be97a1977a71d632ef5bb867ef15f54af39cc4e"}, @@ -3253,6 +3378,10 @@ kubernetes = [ {file = "kubernetes-18.20.0-py2.py3-none-any.whl", hash = "sha256:ff31ec17437293e7d4e1459f1228c42d27c7724dfb56b4868aba7a901a5b72c9"}, {file = "kubernetes-18.20.0.tar.gz", hash = "sha256:0c72d00e7883375bd39ae99758425f5e6cb86388417cf7cc84305c211b2192cf"}, ] +locket = [ + {file = "locket-0.2.1-py2.py3-none-any.whl", hash = "sha256:12b6ada59d1f50710bca9704dbadd3f447dbf8dac6664575c1281cadab8e6449"}, + {file = "locket-0.2.1.tar.gz", hash = "sha256:3e1faba403619fe201552f083f1ecbf23f550941bc51985ac6ed4d02d25056dd"}, +] markdown-it-py = [ {file = "markdown-it-py-1.1.0.tar.gz", hash = "sha256:36be6bb3ad987bfdb839f5ba78ddf094552ca38ccbd784ae4f74a4e1419fc6e3"}, {file = "markdown_it_py-1.1.0-py3-none-any.whl", hash = "sha256:98080fc0bc34c4f2bcf0846a096a9429acbd9d5d8e67ed34026c03c61c464389"}, @@ -3504,6 +3633,29 @@ packaging = [ {file = "packaging-21.0-py3-none-any.whl", hash = "sha256:c86254f9220d55e31cc94d69bade760f0847da8000def4dfe1c6b872fd14ff14"}, {file = "packaging-21.0.tar.gz", hash = "sha256:7dc96269f53a4ccec5c0670940a4281106dd0bb343f47b7471f779df49c2fbe7"}, ] +pandas = [ + {file = "pandas-1.3.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:372d72a3d8a5f2dbaf566a5fa5fa7f230842ac80f29a931fb4b071502cf86b9a"}, + {file = "pandas-1.3.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d99d2350adb7b6c3f7f8f0e5dfb7d34ff8dd4bc0a53e62c445b7e43e163fce63"}, + {file = "pandas-1.3.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c2646458e1dce44df9f71a01dc65f7e8fa4307f29e5c0f2f92c97f47a5bf22f5"}, + {file = "pandas-1.3.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5298a733e5bfbb761181fd4672c36d0c627320eb999c59c65156c6a90c7e1b4f"}, + {file = "pandas-1.3.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22808afb8f96e2269dcc5b846decacb2f526dd0b47baebc63d913bf847317c8f"}, + {file = "pandas-1.3.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b528e126c13816a4374e56b7b18bfe91f7a7f6576d1aadba5dee6a87a7f479ae"}, + {file = "pandas-1.3.4-cp37-cp37m-win32.whl", hash = "sha256:fe48e4925455c964db914b958f6e7032d285848b7538a5e1b19aeb26ffaea3ec"}, + {file = "pandas-1.3.4-cp37-cp37m-win_amd64.whl", hash = "sha256:eaca36a80acaacb8183930e2e5ad7f71539a66805d6204ea88736570b2876a7b"}, + {file = "pandas-1.3.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:42493f8ae67918bf129869abea8204df899902287a7f5eaf596c8e54e0ac7ff4"}, + {file = "pandas-1.3.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a388960f979665b447f0847626e40f99af8cf191bce9dc571d716433130cb3a7"}, + {file = "pandas-1.3.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ba0aac1397e1d7b654fccf263a4798a9e84ef749866060d19e577e927d66e1b"}, + {file = "pandas-1.3.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f567e972dce3bbc3a8076e0b675273b4a9e8576ac629149cf8286ee13c259ae5"}, + {file = "pandas-1.3.4-cp38-cp38-win32.whl", hash = "sha256:c1aa4de4919358c5ef119f6377bc5964b3a7023c23e845d9db7d9016fa0c5b1c"}, + {file = "pandas-1.3.4-cp38-cp38-win_amd64.whl", hash = "sha256:dd324f8ee05925ee85de0ea3f0d66e1362e8c80799eb4eb04927d32335a3e44a"}, + {file = "pandas-1.3.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d47750cf07dee6b55d8423471be70d627314277976ff2edd1381f02d52dbadf9"}, + {file = "pandas-1.3.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d1dc09c0013d8faa7474574d61b575f9af6257ab95c93dcf33a14fd8d2c1bab"}, + {file = "pandas-1.3.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10e10a2527db79af6e830c3d5842a4d60383b162885270f8cffc15abca4ba4a9"}, + {file = "pandas-1.3.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:35c77609acd2e4d517da41bae0c11c70d31c87aae8dd1aabd2670906c6d2c143"}, + {file = "pandas-1.3.4-cp39-cp39-win32.whl", hash = "sha256:003ba92db58b71a5f8add604a17a059f3068ef4e8c0c365b088468d0d64935fd"}, + {file = "pandas-1.3.4-cp39-cp39-win_amd64.whl", hash = "sha256:a51528192755f7429c5bcc9e80832c517340317c861318fea9cea081b57c9afd"}, + {file = "pandas-1.3.4.tar.gz", hash = "sha256:a2aa18d3f0b7d538e21932f637fbfe8518d085238b429e4790a35e1e44a96ffc"}, +] pandocfilters = [ {file = "pandocfilters-1.5.0-py2.py3-none-any.whl", hash = "sha256:33aae3f25fd1a026079f5d27bdd52496f0e0803b3469282162bafdcbdf6ef14f"}, {file = "pandocfilters-1.5.0.tar.gz", hash = "sha256:0b679503337d233b4339a817bfc8c50064e2eff681314376a47cb582305a7a38"}, @@ -3512,6 +3664,10 @@ parso = [ {file = "parso-0.8.2-py2.py3-none-any.whl", hash = "sha256:a8c4922db71e4fdb90e0d0bc6e50f9b273d3397925e5e60a717e719201778d22"}, {file = "parso-0.8.2.tar.gz", hash = "sha256:12b83492c6239ce32ff5eed6d3639d6a536170723c6f3f1506869f1ace413398"}, ] +partd = [ + {file = "partd-1.2.0-py3-none-any.whl", hash = "sha256:5c3a5d70da89485c27916328dc1e26232d0e270771bd4caef4a5124b6a457288"}, + {file = "partd-1.2.0.tar.gz", hash = "sha256:aa67897b84d522dcbc86a98b942afab8c6aa2f7f677d904a616b74ef5ddbc3eb"}, +] path = [ {file = "path-16.2.0-py3-none-any.whl", hash = "sha256:340054c5bb459fc9fd40e7eb6768c5989f3e599d18224238465b5333bc8faa7d"}, {file = "path-16.2.0.tar.gz", hash = "sha256:2de925e8d421f93bcea80d511b81accfb6a7e6b249afa4a5559557b0cf817097"}, @@ -3689,8 +3845,8 @@ pygments = [ {file = "Pygments-2.10.0.tar.gz", hash = "sha256:f398865f7eb6874156579fdf36bc840a03cab64d1cde9e93d68f46a425ec52c6"}, ] pyjwt = [ - {file = "PyJWT-2.2.0-py3-none-any.whl", hash = "sha256:b0ed5824c8ecc5362e540c65dc6247567db130c4226670bf7699aec92fb4dae1"}, - {file = "PyJWT-2.2.0.tar.gz", hash = "sha256:a0b9a3b4e5ca5517cac9f1a6e9cd30bf1aa80be74fcdf4e28eded582ecfcfbae"}, + {file = "PyJWT-2.3.0-py3-none-any.whl", hash = "sha256:e0c4bb8d9f0af0c7f5b1ec4c5036309617d03d56932877f2f7a0beeb5318322f"}, + {file = "PyJWT-2.3.0.tar.gz", hash = "sha256:b888b4d56f06f6dcd777210c334e69c737be74755d3e5e9ee3fe67dc18a0ee41"}, ] pyopenssl = [ {file = "pyOpenSSL-20.0.1-py2.py3-none-any.whl", hash = "sha256:818ae18e06922c066f777a33f1fca45786d85edfe71cd043de6379337a7f274b"}, @@ -4001,6 +4157,10 @@ tomli = [ {file = "tomli-1.2.1-py3-none-any.whl", hash = "sha256:8dd0e9524d6f386271a36b41dbf6c57d8e32fd96fd22b6584679dc569d20899f"}, {file = "tomli-1.2.1.tar.gz", hash = "sha256:a5b75cb6f3968abb47af1b40c1819dc519ea82bcc065776a866e8d74c5ca9442"}, ] +toolz = [ + {file = "toolz-0.11.1-py3-none-any.whl", hash = "sha256:1bc473acbf1a1db4e72a1ce587be347450e8f08324908b8a266b486f408f04d5"}, + {file = "toolz-0.11.1.tar.gz", hash = "sha256:c7a47921f07822fe534fb1c01c9931ab335a4390c782bd28c6bcc7c2f71f3fbf"}, +] tornado = [ {file = "tornado-6.1-cp35-cp35m-macosx_10_9_x86_64.whl", hash = "sha256:d371e811d6b156d82aa5f9a4e08b58debf97c302a35714f6f45e35139c332e32"}, {file = "tornado-6.1-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:0d321a39c36e5f2c4ff12b4ed58d41390460f798422c4504e09eb5678e09998c"}, diff --git a/pyproject.toml b/pyproject.toml index 7e7ed790..d42d26ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,9 @@ azureml-train-automl-client = {version = "^1.34.0", python = ">=3.8,<3.9" } azureml-dataprep-native = "38.0.0" tblib = "^1.7.0" metakernel = "^0.27.5" +astor = "^0.8.1" +pandas = "^1.3.4" +dask = "^2021.10.0" [tool.poetry.dev-dependencies] pip-tools = "^6.2.0" diff --git a/requirements.txt b/requirements.txt index 2a9ad827..85fd906a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ aiohttp==3.7.4.post0; python_version >= "3.6" and python_version < "4" applicationinsights==0.11.10; python_version >= "3.8" and python_version < "3.9" appnope==0.1.2; sys_platform == "darwin" and python_version >= "3.7" and platform_system == "Darwin" argon2-cffi==21.1.0; python_version >= "3.6" +astor==0.8.1; (python_version >= "2.7" and python_full_version < "3.0.0") or (python_full_version >= "3.4.0") async-timeout==3.0.1; python_version >= "3.6" and python_version < "4" and python_full_version >= "3.5.3" attrs==21.2.0; python_version >= "3.6" and python_version < "4.0" and python_full_version >= "3.6.1" azure-common==1.1.27; python_version >= "3.8" and python_version < "3.9" @@ -46,6 +47,7 @@ cloudpickle==1.6.0; python_version >= "3.8" and python_full_version >= "3.6.1" a colorama==0.4.4; python_version >= "3.7" and python_full_version < "3.0.0" and sys_platform == "win32" or sys_platform == "win32" and python_version >= "3.7" and python_full_version >= "3.5.0" contextlib2==21.6.0; python_version >= "3.8" and python_version < "3.9" cryptography==3.4.8; python_version >= "3.8" and python_version < "3.9" and (python_version >= "3.8" and python_full_version < "3.0.0" and python_version < "3.9" or python_version >= "3.8" and python_version < "3.9" and python_full_version >= "3.5.0") +dask==2021.10.0; python_version >= "3.7" debugpy==1.5.0; python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" and python_version >= "3.7" decorator==5.1.0; python_version >= "3.7" defusedxml==0.7.1; python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" and python_version >= "3.7" @@ -57,6 +59,7 @@ docstring-parser==0.11; python_version >= "3.6" and python_full_version >= "3.6. dotnetcore2==2.1.21; python_version >= "3.8" and python_version < "3.9" entrypoints==0.3; python_full_version >= "3.6.1" and python_version >= "3.7" fire==0.4.0; python_full_version >= "3.6.1" +fsspec==2021.10.1; python_version >= "3.7" furl==2.1.3; python_version >= "3.6" and python_version < "4" google-api-core==1.31.2; python_version >= "3.6" and python_full_version >= "3.6.1" google-api-python-client==1.12.8; python_full_version >= "3.6.1" @@ -91,6 +94,7 @@ kfp-pipeline-spec==0.1.11; python_full_version >= "3.6.1" kfp-server-api==1.7.0; python_full_version >= "3.6.1" kfp==1.8.4; python_full_version >= "3.6.1" kubernetes==18.20.0 +locket==0.2.1; python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.4.0" and python_version >= "3.7" markdown-it-py==1.1.0; python_version >= "3.6" and python_version < "4.0" markupsafe==2.0.1; python_version >= "3.7" matplotlib-inline==0.1.3; python_version >= "3.7" @@ -112,8 +116,10 @@ numpy==1.21.2; python_version >= "3.7" and python_version < "3.11" oauthlib==3.1.1; python_version >= "3.6" and python_full_version >= "3.6.1" orderedmultidict==1.0.1; python_version >= "3.6" and python_version < "4" packaging==21.0; python_version >= "3.7" and python_full_version >= "3.6.1" +pandas==1.3.4; python_full_version >= "3.7.1" pandocfilters==1.5.0; python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.4.0" and python_version >= "3.7" parso==0.8.2; python_version >= "3.7" +partd==1.2.0; python_version >= "3.7" pathspec==0.9.0; python_version >= "3.8" and python_full_version < "3.0.0" and python_version < "3.9" or python_version >= "3.8" and python_version < "3.9" and python_full_version >= "3.5.0" pexpect==4.8.0; sys_platform != "win32" and python_version >= "3.7" pickleshare==0.7.5; python_version >= "3.7" @@ -129,16 +135,16 @@ pyasn1==0.4.8; python_version >= "3.8" and python_version < "3.9" and python_ful pycparser==2.20; python_version >= "3.8" and python_full_version < "3.0.0" and python_version < "3.9" or python_version >= "3.8" and python_version < "3.9" and python_full_version >= "3.4.0" pydantic==1.8.2; python_full_version >= "3.6.1" pygments==2.10.0; python_version >= "3.7" -pyjwt==2.2.0; python_version >= "3.8" and python_version < "3.9" +pyjwt==2.3.0; python_version >= "3.8" and python_version < "3.9" pyopenssl==20.0.1; python_version >= "3.8" and python_full_version < "3.0.0" and python_version < "3.9" or python_version >= "3.8" and python_version < "3.9" and python_full_version >= "3.5.0" pyparsing==2.4.7; python_full_version >= "3.6.1" and python_version >= "3.7" pyrsistent==0.18.0; python_version >= "3.6" and python_full_version >= "3.6.1" python-box==5.4.1; python_version >= "3.6" -python-dateutil==2.8.2; python_version >= "3.8" and python_version < "3.9" and python_full_version >= "3.6.1" and (python_version >= "3.8" and python_full_version < "3.0.0" and python_version < "3.9" or python_version >= "3.8" and python_version < "3.9" and python_full_version >= "3.3.0") -pytz==2021.3; python_version >= "3.8" and python_version < "3.9" and python_full_version >= "3.6.1" +python-dateutil==2.8.2; python_version >= "3.8" and python_version < "3.9" and python_full_version >= "3.7.1" and (python_version >= "3.8" and python_full_version < "3.0.0" and python_version < "3.9" or python_version >= "3.8" and python_version < "3.9" and python_full_version >= "3.3.0") +pytz==2021.3; python_version >= "3.8" and python_version < "3.9" and python_full_version >= "3.7.1" pywin32==227; sys_platform == "win32" pywinpty==1.1.4; os_name == "nt" and python_version >= "3.6" -pyyaml==5.4.1; python_version >= "3.6" and python_version < "4.0" and python_full_version >= "3.6.1" +pyyaml==5.4.1; python_version >= "3.6" and python_version < "4.0" and python_full_version >= "3.6.1" and (python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.6.0" and python_version >= "3.7") pyzmq==22.3.0; python_full_version >= "3.6.1" and python_version >= "3.7" qtconsole==5.1.1; python_version >= "3.6" qtpy==1.11.2; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.6.0" and python_version >= "3.6" @@ -158,6 +164,7 @@ termcolor==1.1.0; python_full_version >= "3.6.1" terminado==0.12.1; python_version >= "3.6" testpath==0.5.0; python_version >= "3.7" toml==0.10.2; python_version >= "3.6" and python_full_version < "3.0.0" and python_version < "4.0" or python_version >= "3.6" and python_version < "4.0" and python_full_version >= "3.3.0" +toolz==0.11.1; python_version >= "3.7" tornado==6.1; python_full_version >= "3.6.1" and python_version >= "3.7" traitlets==5.1.0; python_version >= "3.7" and python_version < "4.0" and python_full_version >= "3.6.1" typing-extensions==3.10.0.2; python_version < "3.9" and python_full_version >= "3.6.1" and python_version >= "3.6" diff --git a/test/transformers/__init__.py b/test/transformers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/transformers/context.py b/test/transformers/context.py new file mode 100644 index 00000000..0c24cbe9 --- /dev/null +++ b/test/transformers/context.py @@ -0,0 +1,7 @@ +import sys +sys.path.insert(0, "../..") +sys.path.insert(0, "..") + + +from transformers.pandas_to_dask_transformer import PandasToDaskTransformer +from objects.step import Step diff --git a/test/transformers/test_pandas_to_dask_transformer.py b/test/transformers/test_pandas_to_dask_transformer.py new file mode 100644 index 00000000..3a41fb3b --- /dev/null +++ b/test/transformers/test_pandas_to_dask_transformer.py @@ -0,0 +1,119 @@ +from .context import PandasToDaskTransformer +from .context import Step + + +class TestPandasToDaskTransformer: + def test_read_csv(self): + """ + Transform pandas.read_csv into dask.dataframe.read_csv. + This is a 1:1 mapping where the calls to pandas module are replaced with dask.dataframe module. + """ + test_csv_input_path = "test/transformers/testdata/table1.csv" + # Before transformation, the namespaces need to have the imported modules + pre_transform_code = """import pandas as pd""" + user_ns = {} + global_ns = {} + exec(pre_transform_code, user_ns, global_ns) + # Transform the use of Pandas to Dask + transform_code = f"""df = pd.read_csv('{test_csv_input_path}') +print(df) +""" + step = Step(code=transform_code) + transformer = PandasToDaskTransformer(user_ns, global_ns) + transformer.transform_step(step) + # Verify that the transformed code includes the Dask DataFrame and replaced the use of Pandas with Dask + expected_code = f"""import dask.dataframe as dd +df = dd.read_csv('{test_csv_input_path}') +print(df) +""" + assert expected_code == step.code + assert 1 == len(step.packages_to_install) + assert ['dask.dataframe'] == step.packages_to_install + + def test_sum(self): + """ + Transform df.sum() on a Pandas DataFrame into a df.sum().compute() on a Dask DataFrame. + """ + test_csv_input_path = "test/transformers/testdata/table1.csv" + # Before transformation, the namespaces need to have the imported modules + pre_transform_code = """import pandas as pd""" + user_ns = {} + global_ns = {} + exec(pre_transform_code, user_ns, global_ns) + # Transform the use of Pandas to Dask + transform_code = f""" +df = pd.read_csv('{test_csv_input_path}') +""" + step_1 = Step(code=transform_code) + transformer = PandasToDaskTransformer(user_ns, global_ns) + transformer.transform_step(step_1) + # Verify that the transformed code includes the Dask DataFrame and replaced the use of Pandas with Dask + expected_code = f"""import dask.dataframe as dd +df = dd.read_csv('{test_csv_input_path}') +""" + assert expected_code == step_1.code + assert 1 == len(step_1.packages_to_install) + assert ['dask.dataframe'] == step_1.packages_to_install + # Execute the translated code so that the df variable is in the namespaces + exec(step_1.code, user_ns, global_ns) + # Transform again so the Pandas sum() is replaced with Dask sum().compute() + transform_code = """x = df.sum() +print(x) +""" + step_2 = Step(code=transform_code) + transformer = PandasToDaskTransformer(user_ns, global_ns) + transformer.transform_step(step_2) + # Verify that sum() is now replaced with sum().compute() + expected_code = """x = df.sum().compute() +print(x) +""" + assert expected_code == step_2.code + assert 0 == len(step_2.packages_to_install) + + def test_no_translation_sum_call(self): + """ + Since we transform Pandas sum(), make sure that non-Pandas sum() is not transformed. + """ + pre_transform_code = """x = [1, 2]""" + user_ns = {} + global_ns = {} + exec(pre_transform_code, user_ns, global_ns) + transform_code = """y = sum(x) +print(y) +""" + step = Step(code=transform_code) + transformer = PandasToDaskTransformer() + transformer.transform_step(step) + # There should be no difference (no transformation should happen) + expected_code = transform_code + assert expected_code == step.code + assert 0 == len(step.packages_to_install) + + def test_no_translation_sum_attribute(self): + """ + Since we transform Pandas sum(), make sure that non-Pandas sum() is not transformed. + """ + pre_transform_code = """class MyList: + + def __init__(self, lst): + self.lst = lst + + def sum(self): + return sum(self.lst) + + +x = MyList([1, 2]) + """ + user_ns = {} + global_ns = {} + exec(pre_transform_code, user_ns, global_ns) + transform_code = """y = x.sum() +print(y) +""" + step = Step(code=transform_code) + transformer = PandasToDaskTransformer() + transformer.transform_step(step) + # There should be no difference (no transformation should happen) + expected_code = transform_code + assert expected_code == step.code + assert 0 == len(step.packages_to_install) diff --git a/test/transformers/testdata/table1.csv b/test/transformers/testdata/table1.csv new file mode 100644 index 00000000..6be18720 --- /dev/null +++ b/test/transformers/testdata/table1.csv @@ -0,0 +1,5 @@ +a,b,c,d +1,2,3,4 +5,6,7,8 +1,0,1,0 +0,1,0,1 diff --git a/transformers/__init__.py b/transformers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/transformers/context.py b/transformers/context.py new file mode 100644 index 00000000..435b4c0e --- /dev/null +++ b/transformers/context.py @@ -0,0 +1,5 @@ +import sys +sys.path.insert(0, "..") + + +from objects.step import Step diff --git a/transformers/pandas_to_dask_transformer.py b/transformers/pandas_to_dask_transformer.py new file mode 100644 index 00000000..eb991bbd --- /dev/null +++ b/transformers/pandas_to_dask_transformer.py @@ -0,0 +1,167 @@ +from typing import Any, Optional +from .context import Step +from .transformer import Transformer +import ast +import astor +import dask.dataframe +import inspect + + +class PandasToDaskTransformer(ast.NodeTransformer, Transformer): + def __init__( + self, + user_namespace: Optional[dict] = None, + global_namespace: Optional[dict] = None + ): + super().__init__() + self.user_namespace = user_namespace if user_namespace is not None else {} + self.global_namespace = global_namespace if global_namespace is not None else {} + # Import statement for Dask + self._dask_dataframe_import_name = 'dask.dataframe' + self._dask_dataframe_import_asname = 'dd' + # Import statement for Pandas + self._pandas_import_name = 'pandas' + self._pandas_import_asname = 'pd' + # Supported Pandas functions that can be 1:1 mapped to Dask + self._pandas_to_dask_functions = ['read_csv'] + # Pandas functions for which we need to call .compute when translated to Dask + self._pandas_to_dask_with_compute_functions = ['sum'] + # Function call for .compute + self._dask_compute = 'compute' + + def _import_dask_dataframe(self) -> None: + """ + Import Dask DataFrame into the environment. + """ + if self._dask_dataframe_import_name not in self._required_imports: + self._required_imports[self._dask_dataframe_import_name] = self._dask_dataframe_import_asname + + def _get_val_from_namespaces(self, id: str) -> Any: + """ + Get the object associated to a particular key from user/global namespaces. + Preference is given to user namespace. + Returns None if an object is not found in either namespace. + """ + # Look up the value for the variable from the user namespace + val_from_namespace = self.user_namespace.get(id, None) + if val_from_namespace is None: + # If not found, look up the value for the variable from the global namespace + val_from_namespace = self.global_namespace.get(id, None) + return val_from_namespace + + def _is_one_to_one_mapping_valid(self, node: ast.AST) -> bool: + """ + Checks if it is valid to convert the given function call from Pandas to Dask. + """ + attr = node.func.attr + id = node.func.value.id + if attr not in self._pandas_to_dask_functions: + return False + if id != self._pandas_import_name and id != self._pandas_import_asname: + return False + val_from_namespace = self._get_val_from_namespaces(id) + if val_from_namespace is None: + return False + if not inspect.ismodule(val_from_namespace): + return False + if val_from_namespace.__name__ != self._pandas_import_name: + return False + return True + + def _is_compute_required(self, node: ast.AST) -> bool: + """ + Checks if the Pandas operation being performed requires .compute() when translated to Dask. + """ + attr = node.func.attr + id = node.func.value.id + if attr not in self._pandas_to_dask_with_compute_functions: + return False + val_from_namespace = self._get_val_from_namespaces(id) + if val_from_namespace is None: + return False + if not isinstance(val_from_namespace, dask.dataframe.core.DataFrame): + return False + return True + + def _try_translate(self, node: ast.AST) -> Optional[ast.AST]: + """ + If the code in the given node requires any translation, it will create an updated node and return that. + If no translation was performed, it will return None. + """ + updated_node = None + if isinstance(node.func, ast.Attribute): + if self._is_one_to_one_mapping_valid(node): + updated_node = ast.Call( + func=ast.Attribute( + value=ast.Name( + id=self._dask_dataframe_import_asname, + ctx=node.func.value.ctx + ), + attr=node.func.attr, + ctx=node.func.ctx + ), + args=node.args, + keywords=node.keywords + ) + # Requires importing the dask.dataframe class as we are converting Pandas to Dask dataframe + self._import_dask_dataframe() + elif self._is_compute_required(node): + updated_node = ast.Call( + func=ast.Attribute( + value=ast.Call( + func=ast.Attribute( + value=ast.Name( + id=node.func.value.id, + ctx=node.func.ctx + ), + attr=node.func.attr, + ctx=node.func.ctx + ), + args=node.args, + keywords=node.keywords + ), + attr=self._dask_compute, + ctx=node.func.ctx + ), + args=[], + keywords=[] + ) + return updated_node + + def visit_Call(self, node: ast.Call) -> ast.Call: + """ + Visit (and translate, if applicable) the Call nodes in the AST. + """ + translated_node = self._try_translate(node) + if translated_node is not None: + # Translation resulted in updating the node, so replace it + return translated_node + # No translation happened, keep the original node + return node + + def transform_step(self, step: Step) -> None: + """ + Transform a given Step's source code into a semantically equivalent source code such that all supported Pandas + operations are replaced by equivalent Dask operations. + Note: The transformation is applied to the input Step. + Example: + 1) pd.read_csv --> dd.read_csv + 2) x = df.sum() --> x = df.sum().compute() + """ + super().transform_step(step) + # Parse the code into its AST + tree = ast.parse(step.code) + # Run the transformation + updated_tree = self.visit(tree) + # Update node locations in the tree after potential changes due to transformation + ast.fix_missing_locations(updated_tree) + # Add any import statements that are required after transformation + self.perform_imports(updated_tree) + # Convert back the AST into source code + updated_code = astor.to_source(updated_tree) + # Update the source code in the Step + step.code = updated_code + # Update the packages required for the Step + for package in self._required_imports.keys(): + if package not in step.packages_to_install: + step.packages_to_install.append(package) diff --git a/transformers/transformer.py b/transformers/transformer.py new file mode 100644 index 00000000..307e3878 --- /dev/null +++ b/transformers/transformer.py @@ -0,0 +1,20 @@ +from .context import Step +import ast + + +class Transformer: + """ + TODO: Make this an abstract class and decorate methods accordingly. + """ + def __init__(self): + self._required_imports = {} + + def transform_step(self, step: Step) -> Step: + pass + + def perform_imports(self, ast_root_node): + # TODO annotate + for import_name, import_asname in self._required_imports.items(): + import_node = ast.Import(names=[ast.alias(name=import_name, asname=import_asname)]) + ast_root_node.body.insert(0, import_node) + ast.fix_missing_locations(ast_root_node)