Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions google/cloud/dataproc_magics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .magics import DataprocMagics


def load_ipython_extension(ipython):
ipython.register_magics(DataprocMagics)
73 changes: 73 additions & 0 deletions google/cloud/dataproc_magics/magics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Dataproc magic implementations."""

import shlex
from IPython.core.magic import (Magics, magics_class, line_magic)
from google.cloud.dataproc_spark_connect import DataprocSparkSession


@magics_class
class DataprocMagics(Magics):

def __init__(
self,
shell,
**kwargs,
):
super().__init__(shell, **kwargs)

def _parse_command(self, args):
if not args or args[0] != "install":
print("Usage: %dp_spark_pip install <package1> <package2> ...")
return

# filter out 'install' and the flags (not currently supported)
packages = [pkg for pkg in args[1:] if not pkg.startswith("-")]
return packages

@line_magic
def dp_spark_pip(self, line):
"""
Custom magic to install pip packages as Spark Connect artifacts.
Usage: %dp_spark_pip install pandas numpy
"""
try:
packages = self._parse_command(shlex.split(line))

if not packages:
print("No packages specified.")
return

sessions = [
obj
for obj in self.shell.user_ns.values()
if isinstance(obj, DataprocSparkSession)
]

if not sessions:
print(
"No active Spark Sessions found. Please create one first."
)
return

print("Installing packages: %s", packages)
for session in sessions:
for package in packages:
session.addArtifacts(package, pypi=True)

print("Packages successfully added as artifacts.")
except Exception as e:
print(f"Failed to add artifacts: {e}")
79 changes: 79 additions & 0 deletions google/cloud/dataproc_spark_connect/magics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Dataproc magic implementations."""

import shlex
from IPython.core.magic import (Magics, magics_class, line_magic)
from pyspark.sql import SparkSession
from google.cloud.dataproc_spark_connect import DataprocSparkSession


@magics_class
class DataprocMagics(Magics):

def __init__(
self,
shell,
**kwargs,
):
super().__init__(shell, **kwargs)

def _parse_command(self, args):
if not args or args[0] != "install":
print("Usage: %dp_spark_pip install <package1> <package2> ...")
return

# filter out 'install' and the flags (not currently supported)
packages = [pkg for pkg in args[1:] if not pkg.startswith("-")]
return packages

@line_magic
def dp_spark_pip(self, line):
"""
Custom magic to install pip packages as Spark Connect artifacts.
Usage: %dp_spark_pip install pandas numpy
"""
try:
packages = self._parse_command(shlex.split(line))

if not packages:
print("No packages specified.")
return

sessions = [
obj
for obj in self.shell.user_ns.values()
if isinstance(obj, DataprocSparkSession)
]

if not sessions:
print(
"No active Spark Sessions found. Please create one first."
)
return

print(f"Installing packages: {packages}")
for session in sessions:
for package in packages:
session.addArtifacts(package, pypi=True)

print("Packages successfully added as artifacts.")
except Exception as e:
print(f"Failed to add artifacts: {e}")


# To register the magic
def load_ipython_extension(ipython):
ipython.register_magics(DataprocMagics)
Empty file.
206 changes: 206 additions & 0 deletions tests/integration/dataproc_magics/test_magics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import certifi
from unittest import mock

from google.cloud.dataproc_spark_connect import DataprocSparkSession


_SERVICE_ACCOUNT_KEY_FILE_ = "service_account_key.json"


@pytest.fixture(params=[None, "3.0"])
def image_version(request):
return request.param


@pytest.fixture
def test_project():
return os.getenv("GOOGLE_CLOUD_PROJECT")


@pytest.fixture
def test_region():
return os.getenv("GOOGLE_CLOUD_REGION")


def is_ci_environment():
"""Detect if running in CI environment."""
return os.getenv("CI") == "true" or os.getenv("GITHUB_ACTIONS") == "true"


@pytest.fixture
def auth_type(request):
"""Auto-detect authentication type based on environment.

CI environment (CI=true or GITHUB_ACTIONS=true): Uses SERVICE_ACCOUNT
Local environment: Uses END_USER_CREDENTIALS
Test parametrization can still override this default.
"""
# Allow test parametrization to override
if hasattr(request, "param"):
return request.param

# Auto-detect based on environment
if is_ci_environment():
return "SERVICE_ACCOUNT"
else:
return "END_USER_CREDENTIALS"


@pytest.fixture
def test_subnet():
return os.getenv("DATAPROC_SPARK_CONNECT_SUBNET")


@pytest.fixture
def test_subnetwork_uri(test_subnet):
# Make DATAPROC_SPARK_CONNECT_SUBNET the full URI to align with how user would specify it in the project
return test_subnet


@pytest.fixture
def os_environment(auth_type, image_version, test_project, test_region):
original_environment = dict(os.environ)
if os.path.isfile(_SERVICE_ACCOUNT_KEY_FILE_):
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = (
_SERVICE_ACCOUNT_KEY_FILE_
)
os.environ["DATAPROC_SPARK_CONNECT_AUTH_TYPE"] = auth_type
if auth_type == "END_USER_CREDENTIALS":
os.environ.pop("DATAPROC_SPARK_CONNECT_SERVICE_ACCOUNT", None)
# Add SSL certificate fix
os.environ["SSL_CERT_FILE"] = certifi.where()
os.environ["REQUESTS_CA_BUNDLE"] = certifi.where()
yield os.environ
os.environ.clear()
os.environ.update(original_environment)


@pytest.fixture
def connect_session(test_project, test_region, os_environment):
session = (
DataprocSparkSession.builder.projectId(test_project)
.location(test_region)
.getOrCreate()
)
yield session
# Clean up the session after each test to prevent resource conflicts
try:
session.stop()
except Exception:
# Ignore cleanup errors to avoid masking the actual test failure
pass


# Tests for magics.py
@pytest.fixture
def ipython_shell(connect_session):
"""Provides an IPython shell with a DataprocSparkSession in user_ns."""
pytest.importorskip("IPython", reason="IPython not available")
try:
from IPython.terminal.interactiveshell import TerminalInteractiveShell
from google.cloud.dataproc_spark_connect import magics

shell = TerminalInteractiveShell.instance()
shell.user_ns = {"spark": connect_session}

# Load magics
magics.load_ipython_extension(shell)

yield shell
finally:
from IPython.terminal.interactiveshell import TerminalInteractiveShell

TerminalInteractiveShell.clear_instance()


def test_dp_spark_pip_magic_loads(ipython_shell):
"""Test that %dp_spark_pip magic is registered."""
assert "dp_spark_pip" in ipython_shell.magics_manager.magics["line"]


@mock.patch.object(DataprocSparkSession, "addArtifacts")
def test_dp_spark_pip_install_single_package(
mock_add_artifacts, ipython_shell, capsys
):
"""Test installing a single package with %dp_spark_pip."""
ipython_shell.run_line_magic("dp_spark_pip", "install pandas")
mock_add_artifacts.assert_called_once_with("pandas", pypi=True)
captured = capsys.readouterr()
assert "Installing packages: " in captured.out
assert "Packages successfully added as artifacts." in captured.out


@mock.patch.object(DataprocSparkSession, "addArtifacts")
def test_dp_spark_pip_install_multiple_packages_with_flags(
mock_add_artifacts, ipython_shell, capsys
):
"""Test installing multiple packages with flags like -U."""
ipython_shell.run_line_magic(
"dp_spark_pip", "install -U numpy scikit-learn"
)
calls = [
mock.call("numpy", pypi=True),
mock.call("scikit-learn", pypi=True),
]
mock_add_artifacts.assert_has_calls(calls, any_order=True)
assert mock_add_artifacts.call_count == 2
captured = capsys.readouterr()
assert "Installing packages: " in captured.out
assert "Packages successfully added as artifacts." in captured.out


def test_dp_spark_pip_no_install_command(ipython_shell, capsys):
"""Test usage message when 'install' is missing."""
ipython_shell.run_line_magic("dp_spark_pip", "pandas")
captured = capsys.readouterr()
assert (
"Usage: %dp_spark_pip install <package1> <package2> ..." in captured.out
)
assert "No packages specified." in captured.out


def test_dp_spark_pip_no_packages(ipython_shell, capsys):
"""Test message when no packages are specified."""
ipython_shell.run_line_magic("dp_spark_pip", "install")
captured = capsys.readouterr()
assert "No packages specified." in captured.out


@mock.patch.object(DataprocSparkSession, "addArtifacts")
def test_dp_spark_pip_no_session(mock_add_artifacts, ipython_shell, capsys):
"""Test message when no Spark session is active."""
ipython_shell.user_ns = {} # Remove spark session from namespace
ipython_shell.run_line_magic("dp_spark_pip", "install pandas")
captured = capsys.readouterr()
assert "No active Spark Sessions found." in captured.out
mock_add_artifacts.assert_not_called()


@mock.patch.object(
DataprocSparkSession,
"addArtifacts",
side_effect=Exception("Install failed"),
)
def test_dp_spark_pip_install_failure(
mock_add_artifacts, ipython_shell, capsys
):
"""Test error message on installation failure."""
ipython_shell.run_line_magic("dp_spark_pip", "install bad-package")
mock_add_artifacts.assert_called_once_with("bad-package", pypi=True)
captured = capsys.readouterr()
assert "Failed to add artifacts: Install failed" in captured.out
Empty file.
Loading
Loading