diff --git a/gcpde/bq.py b/gcpde/bq.py index 8f37b23..ad2693c 100644 --- a/gcpde/bq.py +++ b/gcpde/bq.py @@ -55,13 +55,22 @@ def check_table(self, dataset: str, table: str) -> bool: True if the table exists, False otherwise. """ - table_ref = self._create_table_reference(dataset, table) try: - self.client.get_table(table=table_ref) + self.get_table(dataset, table) return True except NotFound: return False + def get_table(self, dataset: str, table: str) -> bigquery.Table: + """Get a table from bigquery. + + Args: + dataset: dataset name. + table: table name. + """ + table_ref = self._create_table_reference(dataset, table) + return self.client.get_table(table=table_ref) + def create_table( self, dataset: str, @@ -171,6 +180,28 @@ class BigQueryClientException(Exception): """Base exception for connection or command errors.""" +class BigQuerySchemaMismatchException(Exception): + """Exception for schema mismatch.""" + + def __init__( + self, + message: str, + source_schema: list[bigquery.SchemaField], + target_schema: list[bigquery.SchemaField], + ): + super().__init__(message) + self.message = message + self.source_schema = source_schema + self.target_schema = target_schema + + def __str__(self) -> str: + return ( + f"{self.message}\n" + f"Source schema: {self.source_schema}\n" + f"Target schema: {self.target_schema}" + ) + + def delete_table( dataset: str, table: str, @@ -343,6 +374,89 @@ def create_or_replace_table_as( logger.info("Command executed!") +def upsert_table_from_records( + dataset: str, + table: str, + records: ListJsonType, + key_field: str, + json_key: dict[str, str] | None = None, + insert_chunk_size: int | None = None, + client: BigQueryClient | None = None, +) -> None: + """Upsert records into a table. + + This function performs an upsert (update/insert) operation by: + 1. Creating a temporary table with the new records + 2. using MERGE statement to update/insert records + 3. Cleaning up temporary table + + Args: + dataset: dataset name. + table: table name. + records: records to be upserted. + key_field: field used to match records for update. + json_key: json key with gcp credentials. + insert_chunk_size: chunk size for batch inserts. + client: client to connect to BigQuery. + + Raises: + BigQuerySchemaMismatchException: if schema of new records doesn't match table. + BigQueryClientException: if schema cannot be inferred from records. + """ + if not records: + logger.warning("No records to create a table from! (empty collection given)") + return + + client = client or BigQueryClient(json_key=json_key or {}) + tmp_table = table + "_tmp" + + create_table_from_records( + dataset=dataset, + table=tmp_table, + records=records, + overwrite=True, + json_key=json_key, + client=client, + chunk_size=insert_chunk_size, + ) + + tmp_table_schema_bq = client.get_table(dataset, tmp_table).schema + table_schema_bq = client.get_table(dataset, table).schema + + if table_schema_bq != tmp_table_schema_bq: + logger.info("Cleaning up temporary table...") + delete_table(dataset=dataset, table=tmp_table, client=client) + + raise BigQuerySchemaMismatchException( + message="New data schema does not match table schema", + source_schema=table_schema_bq, + target_schema=tmp_table_schema_bq, + ) + + update_statement = ", ".join( + [f"{field.name} = source.{field.name}" for field in table_schema_bq] + ) + table_fields = ", ".join([field.name for field in table_schema_bq]) + + merge_command_sql = ( + f"MERGE INTO {dataset}.{table} AS target " + f"USING {dataset}.{tmp_table} AS source " + f"ON source.{key_field} = target.{key_field} " + f"WHEN MATCHED THEN " + f"UPDATE SET {update_statement} " + f"WHEN NOT MATCHED THEN " + f"INSERT ({table_fields}) " + f"VALUES ({table_fields})" + ) + + logger.info(f"Running `{merge_command_sql}`...") + client.run_command(command_sql=merge_command_sql) + logger.info("Command executed!") + + logger.info("Cleaning up temporary table...") + delete_table(dataset=dataset, table=tmp_table, client=client) + + def replace_table( dataset: str, table: str, diff --git a/tests/unit/test_bq.py b/tests/unit/test_bq.py index d4766da..8fa1836 100644 --- a/tests/unit/test_bq.py +++ b/tests/unit/test_bq.py @@ -1,8 +1,9 @@ -from unittest.mock import Mock +from unittest.mock import Mock, patch import pandas as pd import pytest from google.api_core.exceptions import NotFound +from google.cloud.bigquery import DatasetReference, SchemaField, TableReference from gcpde import bq @@ -87,6 +88,18 @@ def test_run_command(self, bq_client: bq.BigQueryClient): # assert bq_client.client.query.assert_called_once_with(command) + def test_get_table(self, bq_client: bq.BigQueryClient): + # arrange + dataset_ref = DatasetReference(bq_client.client.project, self.dataset) + expected_table_ref = TableReference(dataset_ref, self.table) + + # act + result = bq_client.get_table(dataset=self.dataset, table=self.table) + + # assert + bq_client.client.get_table.assert_called_once_with(table=expected_table_ref) + assert result is not None + def test_create_table_from_records(): # arrange @@ -283,3 +296,125 @@ def test_create_table_from_query(): command_sql="create or replace table dataset.table as select * from table", timeout=10, ) + + +@patch("gcpde.bq.delete_table") +@patch("gcpde.bq.create_table_from_records") +def test_upsert_table_from_records(mock_create_table, mock_delete_table): + # arrange + mock_client = Mock(spec_set=bq.BigQueryClient) + table_tmp = "table_tmp" + table = "table" + dataset = "dataset" + + table_mock = Mock(schema=[]) + mock_client.get_table.return_value = table_mock + + schema_json = [ + {"name": "id", "type": "INTEGER", "mode": "NULLABLE"}, + {"name": "name", "type": "STRING", "mode": "NULLABLE"}, + ] + table_mock.schema = [SchemaField.from_api_repr(field) for field in schema_json] + + command_sql = ( + "MERGE INTO dataset.table AS target " + "USING dataset.table_tmp AS source " + "ON source.id = target.id " + "WHEN MATCHED THEN " + "UPDATE SET id = source.id, name = source.name " + "WHEN NOT MATCHED THEN " + "INSERT (id, name) " + "VALUES (id, name)" + ) + + # act + bq.upsert_table_from_records( + dataset=dataset, + table=table, + records=[{"id": 1, "name": "test"}, {"id": 2, "name": "test2"}], + key_field="id", + client=mock_client, + insert_chunk_size=None, + ) + + # assert + mock_create_table.assert_called_once_with( + dataset=dataset, + table=table_tmp, + records=[{"id": 1, "name": "test"}, {"id": 2, "name": "test2"}], + overwrite=True, + json_key=None, + client=mock_client, + chunk_size=None, + ) + + mock_delete_table.assert_called_once_with( + dataset=dataset, table=table_tmp, client=mock_client + ) + + mock_client.run_command.assert_called_with(command_sql=command_sql) + + for call in mock_delete_table.call_args_list: + assert call.kwargs.get("table") != table + + +@patch("gcpde.bq.delete_table") +@patch("gcpde.bq.create_table_from_records") +def test_upsert_table_from_records_no_records(mock_create_table, mock_delete_table): + # arrange + mock_client = Mock(spec_set=bq.BigQueryClient) + + # act + bq.upsert_table_from_records( + dataset="dataset", table="table", records=[], key_field="id", client=mock_client + ) + + # assert + mock_create_table.assert_not_called() + mock_delete_table.assert_not_called() + + +@patch("gcpde.bq.delete_table") +def test_upsert_table_from_records_schema_mismatch(mock_delete_table): + # arrange + mock_client = Mock(spec_set=bq.BigQueryClient) + + table_mock = Mock() + temp_table_mock = Mock() + table_mock.schema = [{"name": "uuid", "type": "STRING", "mode": "NULLABLE"}] + temp_table_mock.schema = [{"name": "id", "type": "INTEGER", "mode": "NULLABLE"}] + mock_client.get_table = ( + lambda dataset, table: table_mock if table == "table" else temp_table_mock + ) + + schema_json = [{"name": "uuid", "type": "STRING", "mode": "NULLABLE"}] + table_mock.schema = [SchemaField.from_api_repr(field) for field in schema_json] + + # act/assert + with pytest.raises(bq.BigQuerySchemaMismatchException): + bq.upsert_table_from_records( + dataset="dataset", + table="table", + records=[{"id": 1}], + key_field="id", + client=mock_client, + ) + + mock_delete_table.call_count == 2 + + +def test_big_query_schema_mismatch_exception(): + # arrange + source_schema = [{"name": "id"}] + target_schema = [{"name": "id"}] + + # act + exception = bq.BigQuerySchemaMismatchException( + message="message", source_schema=source_schema, target_schema=target_schema + ) + + # assert + assert ( + str(exception) + == "message\nSource schema: [{'name': 'id'}]\nTarget schema: [{'name': 'id'}]" + )