diff --git a/docs/configuration.rst b/docs/configuration.rst index 8718e79..a92a2d9 100644 --- a/docs/configuration.rst +++ b/docs/configuration.rst @@ -69,4 +69,39 @@ The only rule is that names cannot have a "." in them, you will see why below. Create a structure that fits your usecase, by environment, by client (if you are a consultant)...etc. +Secret values from env or SSM +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Any field in your config can be sourced from environment variables or AWS SSM +Parameter Store by using a small dictionary instead of a literal value. + +Environment variable example:: + + { + "dbs":{ + "my_db": { + "wrapper": "MysqlDB", + "host": "mysql-master.123fakestreet.net", + "user": "", + "password": { "env": "MYSQL_PASSWORD" } + } + } + } + +SSM Parameter Store example:: + + { + "dbs":{ + "my_databricks": { + "wrapper": "DatabricksSQLWarehouseDB", + "hostname": "adb-1234567890123456.7.azuredatabricks.net", + "http_path": "/sql/1.0/warehouses/abc123def456", + "access_token": { "ssm": "/prod/databricks/access_token" } + } + } + } + +Optional decrypt flag (defaults to true):: + + { "access_token": { "ssm": "/prod/databricks/access_token", "decrypt": true } } diff --git a/link/_secrets.py b/link/_secrets.py index 4145e6a..9898fa8 100644 --- a/link/_secrets.py +++ b/link/_secrets.py @@ -52,3 +52,21 @@ def get_secret(key): return json.loads(base64.b64decode(get_secret_value_response['SecretBinary'])) +def get_ssm_parameter(name, decrypt=True): + import boto3 + from botocore.exceptions import ClientError, NoRegionError + + session = boto3.session.Session() + try: + client = session.client(service_name="ssm") + except NoRegionError: + print("Warning, no default region set, defaulting to us-east-1. Please set a default region in either your aws config file or via environment variable AWS_DEFAULT_REGION") + client = session.client(service_name="ssm", region_name=DEFAULT_REGION) + + try: + resp = client.get_parameter(Name=name, WithDecryption=decrypt) + except Exception as e: + raise e + else: + return resp.get("Parameter", {}).get("Value") + diff --git a/link/link.py b/link/link.py index 15a215d..aa61d95 100644 --- a/link/link.py +++ b/link/link.py @@ -421,11 +421,29 @@ def __call__(self, wrap_name=None, *kargs, **kwargs): """ Get a wrapper given the name or some arguments """ + def resolve_value(value): + if isinstance(value, dict): + if 'ssm' in value: + return _secrets.get_ssm_parameter( + value.get('ssm'), + decrypt=value.get('decrypt', True) + ) + if 'env' in value: + return os.getenv(value.get('env'), value.get('default')) + return dict([(k, resolve_value(v)) for k, v in value.items()]) + if isinstance(value, (list, tuple)): + return [resolve_value(v) for v in value] + return value + wrap_config = {} if wrap_name: wrap_config = self.config(wrap_name) + # resolve env/ssm placeholders before any wrapper-specific handling + if isinstance(wrap_config, dict): + wrap_config = resolve_value(wrap_config) + #if they are using the aws secret manager, let's pull username nad #password from there if AWS_SECRETMANAGER_KEY in wrap_config: diff --git a/link/wrappers/dbwrappers.py b/link/wrappers/dbwrappers.py index f08ead8..6ec16a6 100644 --- a/link/wrappers/dbwrappers.py +++ b/link/wrappers/dbwrappers.py @@ -611,6 +611,54 @@ def transaction(self): raise +class DatabricksSQLWarehouseDB(DBConnectionWrapper): + + def __init__(self, wrap_name=None, hostname=None, host=None, http_path=None, + access_token=None, catalog=None, schema=None): + """ + A connection to a Databricks SQL Warehouse. Requires databricks-sql-connector. + + :param hostname: Databricks workspace hostname (without protocol) + :param host: Alias for hostname + :param http_path: HTTP path for the SQL Warehouse + :param access_token: Databricks personal access token + :param catalog: Optional catalog to use + :param schema: Optional schema to use + """ + self.hostname = hostname or host + self.http_path = http_path + self.access_token = access_token + self.catalog = catalog + self.schema = schema + super(DatabricksSQLWarehouseDB, self).__init__(wrap_name=wrap_name) + + def create_connection(self): + try: + import databricks.sql as dbsql + except ImportError: + raise Exception("databricks-sql-connector is required for DatabricksSQLWarehouseDB") + + if not self.hostname: + raise Exception("hostname is required for DatabricksSQLWarehouseDB") + if not self.http_path: + raise Exception("http_path is required for DatabricksSQLWarehouseDB") + if not self.access_token: + raise Exception("access_token is required for DatabricksSQLWarehouseDB") + + conn = dbsql.connect( + server_hostname=self.hostname, + http_path=self.http_path, + access_token=self.access_token, + ) + + if self.catalog is not None: + conn.cursor().execute("USE CATALOG {}".format(self.catalog)) + if self.schema is not None: + conn.cursor().execute("USE SCHEMA {}".format(self.schema)) + + return conn + + class RedshiftDB(DBConnectionWrapper): def __init__(self, wrap_name=None, user=None, password=None, diff --git a/setup.py b/setup.py index 506dbca..92023af 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ dir = os.path.split(os.path.abspath(__file__))[0] #import all of this version information -__version__ = '2.1.4' +__version__ = '2.1.5' __author__ = 'David Buonasera' __license__ = 'Apache 2.0' __copyright__ = 'Copyright 2019 David Buonasera'