diff --git a/google/cloud/dataproc_magics/__init__.py b/google/cloud/dataproc_magics/__init__.py new file mode 100644 index 0000000..a348eb8 --- /dev/null +++ b/google/cloud/dataproc_magics/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .magics import DataprocMagics + + +def load_ipython_extension(ipython): + ipython.register_magics(DataprocMagics) diff --git a/google/cloud/dataproc_magics/magics.py b/google/cloud/dataproc_magics/magics.py new file mode 100644 index 0000000..44c62f9 --- /dev/null +++ b/google/cloud/dataproc_magics/magics.py @@ -0,0 +1,73 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataproc magic implementations.""" + +import shlex +from IPython.core.magic import (Magics, magics_class, line_magic) +from google.cloud.dataproc_spark_connect import DataprocSparkSession + + +@magics_class +class DataprocMagics(Magics): + + def __init__( + self, + shell, + **kwargs, + ): + super().__init__(shell, **kwargs) + + def _parse_command(self, args): + if not args or args[0] != "install": + print("Usage: %dp_spark_pip install ...") + return + + # filter out 'install' and the flags (not currently supported) + packages = [pkg for pkg in args[1:] if not pkg.startswith("-")] + return packages + + @line_magic + def dp_spark_pip(self, line): + """ + Custom magic to install pip packages as Spark Connect artifacts. + Usage: %dp_spark_pip install pandas numpy + """ + try: + packages = self._parse_command(shlex.split(line)) + + if not packages: + print("No packages specified.") + return + + sessions = [ + obj + for obj in self.shell.user_ns.values() + if isinstance(obj, DataprocSparkSession) + ] + + if not sessions: + print( + "No active Spark Sessions found. Please create one first." + ) + return + + print("Installing packages: %s", packages) + for session in sessions: + for package in packages: + session.addArtifacts(package, pypi=True) + + print("Packages successfully added as artifacts.") + except Exception as e: + print(f"Failed to add artifacts: {e}") diff --git a/google/cloud/dataproc_spark_connect/magics.py b/google/cloud/dataproc_spark_connect/magics.py new file mode 100644 index 0000000..0581699 --- /dev/null +++ b/google/cloud/dataproc_spark_connect/magics.py @@ -0,0 +1,79 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataproc magic implementations.""" + +import shlex +from IPython.core.magic import (Magics, magics_class, line_magic) +from pyspark.sql import SparkSession +from google.cloud.dataproc_spark_connect import DataprocSparkSession + + +@magics_class +class DataprocMagics(Magics): + + def __init__( + self, + shell, + **kwargs, + ): + super().__init__(shell, **kwargs) + + def _parse_command(self, args): + if not args or args[0] != "install": + print("Usage: %dp_spark_pip install ...") + return + + # filter out 'install' and the flags (not currently supported) + packages = [pkg for pkg in args[1:] if not pkg.startswith("-")] + return packages + + @line_magic + def dp_spark_pip(self, line): + """ + Custom magic to install pip packages as Spark Connect artifacts. + Usage: %dp_spark_pip install pandas numpy + """ + try: + packages = self._parse_command(shlex.split(line)) + + if not packages: + print("No packages specified.") + return + + sessions = [ + obj + for obj in self.shell.user_ns.values() + if isinstance(obj, DataprocSparkSession) + ] + + if not sessions: + print( + "No active Spark Sessions found. Please create one first." + ) + return + + print(f"Installing packages: {packages}") + for session in sessions: + for package in packages: + session.addArtifacts(package, pypi=True) + + print("Packages successfully added as artifacts.") + except Exception as e: + print(f"Failed to add artifacts: {e}") + + +# To register the magic +def load_ipython_extension(ipython): + ipython.register_magics(DataprocMagics) diff --git a/tests/integration/dataproc_magics/__init__.py b/tests/integration/dataproc_magics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/dataproc_magics/test_magics.py b/tests/integration/dataproc_magics/test_magics.py new file mode 100644 index 0000000..1cbc9a3 --- /dev/null +++ b/tests/integration/dataproc_magics/test_magics.py @@ -0,0 +1,206 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import pytest +import certifi +from unittest import mock + +from google.cloud.dataproc_spark_connect import DataprocSparkSession + + +_SERVICE_ACCOUNT_KEY_FILE_ = "service_account_key.json" + + +@pytest.fixture(params=[None, "3.0"]) +def image_version(request): + return request.param + + +@pytest.fixture +def test_project(): + return os.getenv("GOOGLE_CLOUD_PROJECT") + + +@pytest.fixture +def test_region(): + return os.getenv("GOOGLE_CLOUD_REGION") + + +def is_ci_environment(): + """Detect if running in CI environment.""" + return os.getenv("CI") == "true" or os.getenv("GITHUB_ACTIONS") == "true" + + +@pytest.fixture +def auth_type(request): + """Auto-detect authentication type based on environment. + + CI environment (CI=true or GITHUB_ACTIONS=true): Uses SERVICE_ACCOUNT + Local environment: Uses END_USER_CREDENTIALS + Test parametrization can still override this default. + """ + # Allow test parametrization to override + if hasattr(request, "param"): + return request.param + + # Auto-detect based on environment + if is_ci_environment(): + return "SERVICE_ACCOUNT" + else: + return "END_USER_CREDENTIALS" + + +@pytest.fixture +def test_subnet(): + return os.getenv("DATAPROC_SPARK_CONNECT_SUBNET") + + +@pytest.fixture +def test_subnetwork_uri(test_subnet): + # Make DATAPROC_SPARK_CONNECT_SUBNET the full URI to align with how user would specify it in the project + return test_subnet + + +@pytest.fixture +def os_environment(auth_type, image_version, test_project, test_region): + original_environment = dict(os.environ) + if os.path.isfile(_SERVICE_ACCOUNT_KEY_FILE_): + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = ( + _SERVICE_ACCOUNT_KEY_FILE_ + ) + os.environ["DATAPROC_SPARK_CONNECT_AUTH_TYPE"] = auth_type + if auth_type == "END_USER_CREDENTIALS": + os.environ.pop("DATAPROC_SPARK_CONNECT_SERVICE_ACCOUNT", None) + # Add SSL certificate fix + os.environ["SSL_CERT_FILE"] = certifi.where() + os.environ["REQUESTS_CA_BUNDLE"] = certifi.where() + yield os.environ + os.environ.clear() + os.environ.update(original_environment) + + +@pytest.fixture +def connect_session(test_project, test_region, os_environment): + session = ( + DataprocSparkSession.builder.projectId(test_project) + .location(test_region) + .getOrCreate() + ) + yield session + # Clean up the session after each test to prevent resource conflicts + try: + session.stop() + except Exception: + # Ignore cleanup errors to avoid masking the actual test failure + pass + + +# Tests for magics.py +@pytest.fixture +def ipython_shell(connect_session): + """Provides an IPython shell with a DataprocSparkSession in user_ns.""" + pytest.importorskip("IPython", reason="IPython not available") + try: + from IPython.terminal.interactiveshell import TerminalInteractiveShell + from google.cloud.dataproc_spark_connect import magics + + shell = TerminalInteractiveShell.instance() + shell.user_ns = {"spark": connect_session} + + # Load magics + magics.load_ipython_extension(shell) + + yield shell + finally: + from IPython.terminal.interactiveshell import TerminalInteractiveShell + + TerminalInteractiveShell.clear_instance() + + +def test_dp_spark_pip_magic_loads(ipython_shell): + """Test that %dp_spark_pip magic is registered.""" + assert "dp_spark_pip" in ipython_shell.magics_manager.magics["line"] + + +@mock.patch.object(DataprocSparkSession, "addArtifacts") +def test_dp_spark_pip_install_single_package( + mock_add_artifacts, ipython_shell, capsys +): + """Test installing a single package with %dp_spark_pip.""" + ipython_shell.run_line_magic("dp_spark_pip", "install pandas") + mock_add_artifacts.assert_called_once_with("pandas", pypi=True) + captured = capsys.readouterr() + assert "Installing packages: " in captured.out + assert "Packages successfully added as artifacts." in captured.out + + +@mock.patch.object(DataprocSparkSession, "addArtifacts") +def test_dp_spark_pip_install_multiple_packages_with_flags( + mock_add_artifacts, ipython_shell, capsys +): + """Test installing multiple packages with flags like -U.""" + ipython_shell.run_line_magic( + "dp_spark_pip", "install -U numpy scikit-learn" + ) + calls = [ + mock.call("numpy", pypi=True), + mock.call("scikit-learn", pypi=True), + ] + mock_add_artifacts.assert_has_calls(calls, any_order=True) + assert mock_add_artifacts.call_count == 2 + captured = capsys.readouterr() + assert "Installing packages: " in captured.out + assert "Packages successfully added as artifacts." in captured.out + + +def test_dp_spark_pip_no_install_command(ipython_shell, capsys): + """Test usage message when 'install' is missing.""" + ipython_shell.run_line_magic("dp_spark_pip", "pandas") + captured = capsys.readouterr() + assert ( + "Usage: %dp_spark_pip install ..." in captured.out + ) + assert "No packages specified." in captured.out + + +def test_dp_spark_pip_no_packages(ipython_shell, capsys): + """Test message when no packages are specified.""" + ipython_shell.run_line_magic("dp_spark_pip", "install") + captured = capsys.readouterr() + assert "No packages specified." in captured.out + + +@mock.patch.object(DataprocSparkSession, "addArtifacts") +def test_dp_spark_pip_no_session(mock_add_artifacts, ipython_shell, capsys): + """Test message when no Spark session is active.""" + ipython_shell.user_ns = {} # Remove spark session from namespace + ipython_shell.run_line_magic("dp_spark_pip", "install pandas") + captured = capsys.readouterr() + assert "No active Spark Sessions found." in captured.out + mock_add_artifacts.assert_not_called() + + +@mock.patch.object( + DataprocSparkSession, + "addArtifacts", + side_effect=Exception("Install failed"), +) +def test_dp_spark_pip_install_failure( + mock_add_artifacts, ipython_shell, capsys +): + """Test error message on installation failure.""" + ipython_shell.run_line_magic("dp_spark_pip", "install bad-package") + mock_add_artifacts.assert_called_once_with("bad-package", pypi=True) + captured = capsys.readouterr() + assert "Failed to add artifacts: Install failed" in captured.out diff --git a/tests/unit/dataproc_magics/__init__.py b/tests/unit/dataproc_magics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/dataproc_magics/test_magics.py b/tests/unit/dataproc_magics/test_magics.py new file mode 100644 index 0000000..0b8a52a --- /dev/null +++ b/tests/unit/dataproc_magics/test_magics.py @@ -0,0 +1,125 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import unittest +from contextlib import redirect_stdout +from unittest import mock + +from google.cloud.dataproc_spark_connect import DataprocSparkSession +from google.cloud.dataproc_magics import DataprocMagics +from IPython.core.interactiveshell import InteractiveShell +from traitlets.config import Config + + +class DataprocMagicsTest(unittest.TestCase): + + def setUp(self): + self.shell = mock.create_autospec(InteractiveShell, instance=True) + self.shell.user_ns = {} + self.shell.config = Config() + self.magics = DataprocMagics(shell=self.shell) + + def test_parse_command_valid(self): + packages = self.magics._parse_command(["install", "pandas", "numpy"]) + self.assertEqual(packages, ["pandas", "numpy"]) + + def test_parse_command_with_flags(self): + packages = self.magics._parse_command( + ["install", "-U", "pandas", "--upgrade", "numpy"] + ) + self.assertEqual(packages, ["pandas", "numpy"]) + + def test_parse_command_no_install(self): + packages = self.magics._parse_command(["other", "pandas"]) + self.assertIsNone(packages) + + def test_dp_spark_pip_invalid_command(self): + f = io.StringIO() + with redirect_stdout(f): + self.magics.dp_spark_pip("foo bar") + output = f.getvalue() + self.assertIn("Usage: %dp_spark_pip install", output) + self.assertIn("No packages specified", output) + + def test_dp_spark_pip_no_session(self): + f = io.StringIO() + with redirect_stdout(f): + self.magics.dp_spark_pip("install pandas") + self.assertIn("No active Spark Sessions found", f.getvalue()) + + def test_dp_spark_pip_no_packages_specified(self): + f = io.StringIO() + with redirect_stdout(f): + self.magics.dp_spark_pip("install") + self.assertIn("No packages specified", f.getvalue()) + + def test_dp_spark_pip_install_packages_single_session(self): + mock_session = mock.Mock(spec=DataprocSparkSession) + self.shell.user_ns["spark"] = mock_session + + f = io.StringIO() + with redirect_stdout(f): + self.magics.dp_spark_pip("install pandas numpy") + + mock_session.addArtifacts.assert_has_calls( + [ + mock.call("pandas", pypi=True), + mock.call("numpy", pypi=True), + ] + ) + self.assertEqual(mock_session.addArtifacts.call_count, 2) + self.assertIn("Packages successfully added as artifacts.", f.getvalue()) + + def test_dp_spark_pip_install_packages_multiple_sessions(self): + mock_session1 = mock.Mock(spec=DataprocSparkSession) + mock_session2 = mock.Mock(spec=DataprocSparkSession) + self.shell.user_ns["spark1"] = mock_session1 + self.shell.user_ns["spark2"] = mock_session2 + self.shell.user_ns["not_a_session"] = 5 + + f = io.StringIO() + with redirect_stdout(f): + self.magics.dp_spark_pip("install pandas") + + mock_session1.addArtifacts.assert_called_once_with("pandas", pypi=True) + mock_session2.addArtifacts.assert_called_once_with("pandas", pypi=True) + self.assertIn("Packages successfully added as artifacts.", f.getvalue()) + + def test_dp_spark_pip_add_artifacts_fails(self): + mock_session = mock.Mock(spec=DataprocSparkSession) + mock_session.addArtifacts.side_effect = Exception("Failed") + self.shell.user_ns["spark"] = mock_session + + f = io.StringIO() + with redirect_stdout(f): + self.magics.dp_spark_pip("install pandas") + + mock_session.addArtifacts.assert_called_once_with("pandas", pypi=True) + self.assertIn("Failed to add artifacts: Failed", f.getvalue()) + + def test_dp_spark_pip_with_flags(self): + mock_session = mock.Mock(spec=DataprocSparkSession) + self.shell.user_ns["spark"] = mock_session + + f = io.StringIO() + with redirect_stdout(f): + self.magics.dp_spark_pip("install -U pandas") + + mock_session.addArtifacts.assert_called_once_with("pandas", pypi=True) + self.assertIn("Packages successfully added as artifacts.", f.getvalue()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_pypi_artifacts.py b/tests/unit/test_pypi_artifacts.py index 0d578e9..22ef360 100644 --- a/tests/unit/test_pypi_artifacts.py +++ b/tests/unit/test_pypi_artifacts.py @@ -26,7 +26,7 @@ def test_valid_inputs(): def test_bad_format(self): with self.assertRaisesRegex( InvalidRequirement, - "Expected end or semicolon \(after name and no valid version specifier\).*", + r"Expected semicolon \(after name with no version specifier\) or end", ): PyPiArtifacts({"pypi://spacy:23"})