-
Notifications
You must be signed in to change notification settings - Fork 12
Add support for dp-spark-pip magic to allow synchronous package installation #176
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
tim-u
wants to merge
6
commits into
main
Choose a base branch
from
add-dp-spark-pip
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
b640f53
Add support for dp-spark-pip magic to allow synchronous installation
tim-u df7e1f9
Formatting and fixing an unrelated broken test
tim-u a605d5c
Add integration tests
tim-u bed689a
formatting
tim-u cd4b2b0
Moved the magic to a separate module
tim-u 563a78c
Apply suggestion from @gemini-code-assist[bot]
tim-u File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| # Copyright 2026 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from .magics import DataprocMagics | ||
|
|
||
|
|
||
| def load_ipython_extension(ipython): | ||
| ipython.register_magics(DataprocMagics) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| # Copyright 2026 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Dataproc magic implementations.""" | ||
|
|
||
| import shlex | ||
| from IPython.core.magic import (Magics, magics_class, line_magic) | ||
| from google.cloud.dataproc_spark_connect import DataprocSparkSession | ||
|
|
||
|
|
||
| @magics_class | ||
| class DataprocMagics(Magics): | ||
|
|
||
| def __init__( | ||
| self, | ||
| shell, | ||
| **kwargs, | ||
| ): | ||
| super().__init__(shell, **kwargs) | ||
|
|
||
| def _parse_command(self, args): | ||
| if not args or args[0] != "install": | ||
| print("Usage: %dp_spark_pip install <package1> <package2> ...") | ||
| return | ||
|
|
||
| # filter out 'install' and the flags (not currently supported) | ||
| packages = [pkg for pkg in args[1:] if not pkg.startswith("-")] | ||
| return packages | ||
|
|
||
| @line_magic | ||
| def dp_spark_pip(self, line): | ||
| """ | ||
| Custom magic to install pip packages as Spark Connect artifacts. | ||
| Usage: %dp_spark_pip install pandas numpy | ||
| """ | ||
| try: | ||
| packages = self._parse_command(shlex.split(line)) | ||
|
|
||
| if not packages: | ||
| print("No packages specified.") | ||
| return | ||
|
|
||
| sessions = [ | ||
| obj | ||
| for obj in self.shell.user_ns.values() | ||
| if isinstance(obj, DataprocSparkSession) | ||
| ] | ||
|
|
||
| if not sessions: | ||
| print( | ||
| "No active Spark Sessions found. Please create one first." | ||
| ) | ||
| return | ||
|
|
||
| print("Installing packages: %s", packages) | ||
| for session in sessions: | ||
| for package in packages: | ||
| session.addArtifacts(package, pypi=True) | ||
|
|
||
| print("Packages successfully added as artifacts.") | ||
| except Exception as e: | ||
| print(f"Failed to add artifacts: {e}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| # Copyright 2026 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Dataproc magic implementations.""" | ||
|
|
||
| import shlex | ||
| from IPython.core.magic import (Magics, magics_class, line_magic) | ||
| from pyspark.sql import SparkSession | ||
| from google.cloud.dataproc_spark_connect import DataprocSparkSession | ||
|
|
||
|
|
||
| @magics_class | ||
| class DataprocMagics(Magics): | ||
|
|
||
| def __init__( | ||
| self, | ||
| shell, | ||
| **kwargs, | ||
| ): | ||
| super().__init__(shell, **kwargs) | ||
|
|
||
| def _parse_command(self, args): | ||
| if not args or args[0] != "install": | ||
| print("Usage: %dp_spark_pip install <package1> <package2> ...") | ||
| return | ||
|
|
||
| # filter out 'install' and the flags (not currently supported) | ||
| packages = [pkg for pkg in args[1:] if not pkg.startswith("-")] | ||
| return packages | ||
|
|
||
| @line_magic | ||
| def dp_spark_pip(self, line): | ||
| """ | ||
| Custom magic to install pip packages as Spark Connect artifacts. | ||
| Usage: %dp_spark_pip install pandas numpy | ||
| """ | ||
| try: | ||
| packages = self._parse_command(shlex.split(line)) | ||
|
|
||
| if not packages: | ||
| print("No packages specified.") | ||
| return | ||
|
|
||
| sessions = [ | ||
| obj | ||
| for obj in self.shell.user_ns.values() | ||
| if isinstance(obj, DataprocSparkSession) | ||
| ] | ||
|
|
||
| if not sessions: | ||
| print( | ||
| "No active Spark Sessions found. Please create one first." | ||
| ) | ||
| return | ||
|
|
||
| print(f"Installing packages: {packages}") | ||
| for session in sessions: | ||
| for package in packages: | ||
| session.addArtifacts(package, pypi=True) | ||
|
|
||
| print("Packages successfully added as artifacts.") | ||
| except Exception as e: | ||
| print(f"Failed to add artifacts: {e}") | ||
|
|
||
|
|
||
| # To register the magic | ||
| def load_ipython_extension(ipython): | ||
| ipython.register_magics(DataprocMagics) | ||
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,206 @@ | ||
| # Copyright 2026 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import os | ||
| import pytest | ||
| import certifi | ||
| from unittest import mock | ||
|
|
||
| from google.cloud.dataproc_spark_connect import DataprocSparkSession | ||
|
|
||
|
|
||
| _SERVICE_ACCOUNT_KEY_FILE_ = "service_account_key.json" | ||
|
|
||
|
|
||
| @pytest.fixture(params=[None, "3.0"]) | ||
| def image_version(request): | ||
| return request.param | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def test_project(): | ||
| return os.getenv("GOOGLE_CLOUD_PROJECT") | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def test_region(): | ||
| return os.getenv("GOOGLE_CLOUD_REGION") | ||
|
|
||
|
|
||
| def is_ci_environment(): | ||
| """Detect if running in CI environment.""" | ||
| return os.getenv("CI") == "true" or os.getenv("GITHUB_ACTIONS") == "true" | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def auth_type(request): | ||
| """Auto-detect authentication type based on environment. | ||
|
|
||
| CI environment (CI=true or GITHUB_ACTIONS=true): Uses SERVICE_ACCOUNT | ||
| Local environment: Uses END_USER_CREDENTIALS | ||
| Test parametrization can still override this default. | ||
| """ | ||
| # Allow test parametrization to override | ||
| if hasattr(request, "param"): | ||
| return request.param | ||
|
|
||
| # Auto-detect based on environment | ||
| if is_ci_environment(): | ||
| return "SERVICE_ACCOUNT" | ||
| else: | ||
| return "END_USER_CREDENTIALS" | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def test_subnet(): | ||
| return os.getenv("DATAPROC_SPARK_CONNECT_SUBNET") | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def test_subnetwork_uri(test_subnet): | ||
| # Make DATAPROC_SPARK_CONNECT_SUBNET the full URI to align with how user would specify it in the project | ||
| return test_subnet | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def os_environment(auth_type, image_version, test_project, test_region): | ||
| original_environment = dict(os.environ) | ||
| if os.path.isfile(_SERVICE_ACCOUNT_KEY_FILE_): | ||
| os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = ( | ||
| _SERVICE_ACCOUNT_KEY_FILE_ | ||
| ) | ||
| os.environ["DATAPROC_SPARK_CONNECT_AUTH_TYPE"] = auth_type | ||
| if auth_type == "END_USER_CREDENTIALS": | ||
| os.environ.pop("DATAPROC_SPARK_CONNECT_SERVICE_ACCOUNT", None) | ||
| # Add SSL certificate fix | ||
| os.environ["SSL_CERT_FILE"] = certifi.where() | ||
| os.environ["REQUESTS_CA_BUNDLE"] = certifi.where() | ||
| yield os.environ | ||
| os.environ.clear() | ||
| os.environ.update(original_environment) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def connect_session(test_project, test_region, os_environment): | ||
| session = ( | ||
| DataprocSparkSession.builder.projectId(test_project) | ||
| .location(test_region) | ||
| .getOrCreate() | ||
| ) | ||
| yield session | ||
| # Clean up the session after each test to prevent resource conflicts | ||
| try: | ||
| session.stop() | ||
| except Exception: | ||
| # Ignore cleanup errors to avoid masking the actual test failure | ||
| pass | ||
|
|
||
|
|
||
| # Tests for magics.py | ||
| @pytest.fixture | ||
| def ipython_shell(connect_session): | ||
| """Provides an IPython shell with a DataprocSparkSession in user_ns.""" | ||
| pytest.importorskip("IPython", reason="IPython not available") | ||
| try: | ||
| from IPython.terminal.interactiveshell import TerminalInteractiveShell | ||
| from google.cloud.dataproc_spark_connect import magics | ||
|
|
||
| shell = TerminalInteractiveShell.instance() | ||
| shell.user_ns = {"spark": connect_session} | ||
|
|
||
| # Load magics | ||
| magics.load_ipython_extension(shell) | ||
|
|
||
| yield shell | ||
| finally: | ||
| from IPython.terminal.interactiveshell import TerminalInteractiveShell | ||
|
|
||
| TerminalInteractiveShell.clear_instance() | ||
|
|
||
|
|
||
| def test_dp_spark_pip_magic_loads(ipython_shell): | ||
| """Test that %dp_spark_pip magic is registered.""" | ||
| assert "dp_spark_pip" in ipython_shell.magics_manager.magics["line"] | ||
|
|
||
|
|
||
| @mock.patch.object(DataprocSparkSession, "addArtifacts") | ||
| def test_dp_spark_pip_install_single_package( | ||
| mock_add_artifacts, ipython_shell, capsys | ||
| ): | ||
| """Test installing a single package with %dp_spark_pip.""" | ||
| ipython_shell.run_line_magic("dp_spark_pip", "install pandas") | ||
| mock_add_artifacts.assert_called_once_with("pandas", pypi=True) | ||
| captured = capsys.readouterr() | ||
| assert "Installing packages: " in captured.out | ||
| assert "Packages successfully added as artifacts." in captured.out | ||
|
|
||
|
|
||
| @mock.patch.object(DataprocSparkSession, "addArtifacts") | ||
| def test_dp_spark_pip_install_multiple_packages_with_flags( | ||
| mock_add_artifacts, ipython_shell, capsys | ||
| ): | ||
| """Test installing multiple packages with flags like -U.""" | ||
| ipython_shell.run_line_magic( | ||
| "dp_spark_pip", "install -U numpy scikit-learn" | ||
| ) | ||
| calls = [ | ||
| mock.call("numpy", pypi=True), | ||
| mock.call("scikit-learn", pypi=True), | ||
| ] | ||
| mock_add_artifacts.assert_has_calls(calls, any_order=True) | ||
| assert mock_add_artifacts.call_count == 2 | ||
| captured = capsys.readouterr() | ||
| assert "Installing packages: " in captured.out | ||
| assert "Packages successfully added as artifacts." in captured.out | ||
|
|
||
|
|
||
| def test_dp_spark_pip_no_install_command(ipython_shell, capsys): | ||
| """Test usage message when 'install' is missing.""" | ||
| ipython_shell.run_line_magic("dp_spark_pip", "pandas") | ||
| captured = capsys.readouterr() | ||
| assert ( | ||
| "Usage: %dp_spark_pip install <package1> <package2> ..." in captured.out | ||
| ) | ||
| assert "No packages specified." in captured.out | ||
|
|
||
|
|
||
| def test_dp_spark_pip_no_packages(ipython_shell, capsys): | ||
| """Test message when no packages are specified.""" | ||
| ipython_shell.run_line_magic("dp_spark_pip", "install") | ||
| captured = capsys.readouterr() | ||
| assert "No packages specified." in captured.out | ||
|
|
||
|
|
||
| @mock.patch.object(DataprocSparkSession, "addArtifacts") | ||
| def test_dp_spark_pip_no_session(mock_add_artifacts, ipython_shell, capsys): | ||
| """Test message when no Spark session is active.""" | ||
| ipython_shell.user_ns = {} # Remove spark session from namespace | ||
| ipython_shell.run_line_magic("dp_spark_pip", "install pandas") | ||
| captured = capsys.readouterr() | ||
| assert "No active Spark Sessions found." in captured.out | ||
| mock_add_artifacts.assert_not_called() | ||
|
|
||
|
|
||
| @mock.patch.object( | ||
| DataprocSparkSession, | ||
| "addArtifacts", | ||
| side_effect=Exception("Install failed"), | ||
| ) | ||
| def test_dp_spark_pip_install_failure( | ||
| mock_add_artifacts, ipython_shell, capsys | ||
| ): | ||
| """Test error message on installation failure.""" | ||
| ipython_shell.run_line_magic("dp_spark_pip", "install bad-package") | ||
| mock_add_artifacts.assert_called_once_with("bad-package", pypi=True) | ||
| captured = capsys.readouterr() | ||
| assert "Failed to add artifacts: Install failed" in captured.out |
Empty file.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.