diff --git a/gcpde/bq.py b/gcpde/bq.py index ad2693c..b7c61cf 100644 --- a/gcpde/bq.py +++ b/gcpde/bq.py @@ -13,7 +13,7 @@ from google.oauth2.service_account import Credentials from loguru import logger -from gcpde.types import ListJsonType +from gcpde.types import BigQuerySchema, ListJsonType FIVE_MINUTES = 1 * 60 * 5 @@ -75,7 +75,7 @@ def create_table( self, dataset: str, table: str, - schema: list[dict[str, str]], + schema: BigQuerySchema, ) -> None: """Create a new table on bigquery. @@ -84,8 +84,9 @@ def create_table( Args: dataset: dataset name. table: table name. - schema: json dict schema for the table. - https://cloud.google.com/bigquery/docs/schemas#creating_a_json_schema_file + schema: BigQuerySchema for the table. Use + gcpde.bq.get_schema_from_json to create it from a list of + dictionaries. Raises: google.cloud.exceptions.Conflict: if the table already exists. @@ -94,7 +95,7 @@ def create_table( table_ref = self._create_table_reference(dataset, table) table_obj = bigquery.Table( table_ref=table_ref, - schema=[bigquery.SchemaField.from_api_repr(field) for field in schema], + schema=schema, ) self.client.create_table(table_obj) @@ -186,8 +187,8 @@ class BigQuerySchemaMismatchException(Exception): def __init__( self, message: str, - source_schema: list[bigquery.SchemaField], - target_schema: list[bigquery.SchemaField], + source_schema: BigQuerySchema, + target_schema: BigQuerySchema, ): super().__init__(message) self.message = message @@ -202,6 +203,19 @@ def __str__(self) -> str: ) +def get_schema_from_json(schema: list[dict[str, str]]) -> BigQuerySchema: + """Get a schema from a list of dictionaries. + + Args: + schema: list of dictionaries representing the schema. + ref: https://cloud.google.com/bigquery/docs/schemas#creating_a_json_schema_file + + Returns: + BigQuerySchema + """ + return [bigquery.SchemaField.from_api_repr(field) for field in schema] + + def delete_table( dataset: str, table: str, @@ -228,7 +242,7 @@ def delete_table( return -def _create_schema_from_records(records: ListJsonType) -> list[dict[str, str]]: +def _create_schema_from_records(records: ListJsonType) -> BigQuerySchema: generator = SchemaGenerator( input_format="dict", keep_nulls=True, @@ -241,9 +255,10 @@ def _create_schema_from_records(records: ListJsonType) -> list[dict[str, str]]: raise BigQueryClientException( f"Can't infer schema from records, error: {error_logs}" ) - output: list[dict[str, str]] = generator.flatten_schema(schema_map) + output_json: list[dict[str, str]] = generator.flatten_schema(schema_map) + output_api: BigQuerySchema = get_schema_from_json(output_json) logger.debug("Schema generator complete!") - return output + return output_api @tenacity.retry( @@ -258,7 +273,7 @@ def _create_schema_from_records(records: ListJsonType) -> list[dict[str, str]]: def create_table( dataset: str, table: str, - schema: list[dict[str, str]] | None = None, + schema: BigQuerySchema | None = None, schema_from_records: ListJsonType | None = None, json_key: dict[str, str] | None = None, client: BigQueryClient | None = None, @@ -270,8 +285,9 @@ def create_table( Args: dataset: dataset name. table: table name. - schema: json dict schema for the table. - https://cloud.google.com/bigquery/docs/schemas#creating_a_json_schema_file + schema: BigQuerySchema for the table. Use + gcpde.bq.get_schema_from_json to create it from a list of + dictionaries. schema_from_records: infer schema from a records sample. json_key: json key with gcp credentials. client: client to connect to gcp. @@ -382,6 +398,7 @@ def upsert_table_from_records( json_key: dict[str, str] | None = None, insert_chunk_size: int | None = None, client: BigQueryClient | None = None, + use_target_schema: bool = True, ) -> None: """Upsert records into a table. @@ -390,6 +407,9 @@ def upsert_table_from_records( 2. using MERGE statement to update/insert records 3. Cleaning up temporary table + > If the target table doesn't exist, it will be created using + inferred schema from records. + Args: dataset: dataset name. table: table name. @@ -398,17 +418,33 @@ def upsert_table_from_records( json_key: json key with gcp credentials. insert_chunk_size: chunk size for batch inserts. client: client to connect to BigQuery. + use_target_schema: whether to use the schema of the target table or + infer from records for the temporary table. Raises: BigQuerySchemaMismatchException: if schema of new records doesn't match table. BigQueryClientException: if schema cannot be inferred from records. """ + client = client or BigQueryClient(json_key=json_key or {}) + tmp_table = table + "_tmp" + if not records: logger.warning("No records to create a table from! (empty collection given)") return - client = client or BigQueryClient(json_key=json_key or {}) - tmp_table = table + "_tmp" + try: + table_schema_bq = client.get_table(dataset, table).schema + except NotFound: + create_table_from_records( + dataset=dataset, + table=table, + records=records, + overwrite=False, + json_key=json_key, + client=client, + chunk_size=insert_chunk_size, + ) + return create_table_from_records( dataset=dataset, @@ -418,10 +454,14 @@ def upsert_table_from_records( json_key=json_key, client=client, chunk_size=insert_chunk_size, + schema=table_schema_bq if use_target_schema else None, ) - tmp_table_schema_bq = client.get_table(dataset, tmp_table).schema - table_schema_bq = client.get_table(dataset, table).schema + tmp_table_schema_bq = ( + table_schema_bq + if use_target_schema + else client.get_table(dataset, tmp_table).schema + ) if table_schema_bq != tmp_table_schema_bq: logger.info("Cleaning up temporary table...") @@ -429,8 +469,8 @@ def upsert_table_from_records( raise BigQuerySchemaMismatchException( message="New data schema does not match table schema", - source_schema=table_schema_bq, - target_schema=tmp_table_schema_bq, + source_schema=tmp_table_schema_bq, + target_schema=table_schema_bq, ) update_statement = ", ".join( @@ -461,7 +501,7 @@ def replace_table( dataset: str, table: str, records: ListJsonType, - schema: list[dict[str, str]] | None = None, + schema: BigQuerySchema | None = None, chunk_size: int | None = None, json_key: dict[str, str] | None = None, client: BigQueryClient | None = None, @@ -471,7 +511,13 @@ def replace_table( tmp_table = table + "_tmp" delete_table(dataset=dataset, table=tmp_table, client=client) - create_table(dataset=dataset, table=tmp_table, schema=schema, client=client) + create_table( + dataset=dataset, + table=tmp_table, + schema_from_records=records, + schema=schema, + client=client, + ) insert( dataset=dataset, table=tmp_table, @@ -496,6 +542,7 @@ def create_table_from_records( json_key: dict[str, str] | None = None, client: BigQueryClient | None = None, chunk_size: int | None = None, + schema: BigQuerySchema | None = None, ) -> None: """Create or replace a table from a collection of records. @@ -507,12 +554,14 @@ def create_table_from_records( json_key: json key with gcp credentials. client: client to connect to gcp. chunk_size: chunk size number to send to GCP API. + schema: BigQuerySchema for the table. Use + gcpde.bq.get_schema_from_json to create it from a list of + dictionaries. """ if not records: logger.warning("No records to create a table from! (empty collection given)") return - schema = _create_schema_from_records(records=records or []) client = client or BigQueryClient(json_key=json_key or {}) if overwrite: @@ -526,7 +575,13 @@ def create_table_from_records( ) return - create_table(dataset=dataset, table=table, schema=schema, client=client) + create_table( + dataset=dataset, + table=table, + schema_from_records=records, + schema=schema, + client=client, + ) insert( dataset=dataset, table=table, diff --git a/gcpde/types.py b/gcpde/types.py index ea5a7a8..dbf7c33 100644 --- a/gcpde/types.py +++ b/gcpde/types.py @@ -2,4 +2,8 @@ from typing import Any +from google.cloud import bigquery + ListJsonType = list[dict[str, Any]] + +BigQuerySchema = list[bigquery.SchemaField] diff --git a/tests/unit/test_bq.py b/tests/unit/test_bq.py index 8fa1836..1f828a2 100644 --- a/tests/unit/test_bq.py +++ b/tests/unit/test_bq.py @@ -112,33 +112,37 @@ def test_create_table_from_records(): {"json_col": {"col3": "abc"}}, ] target_schema = [ - { - "name": "id", - "type": "INTEGER", - "mode": "NULLABLE", - }, - { - "name": "json_col", - "type": "RECORD", - "mode": "NULLABLE", - "fields": [ - { - "name": "col1", - "type": "INTEGER", - "mode": "NULLABLE", - }, - { - "name": "col2", - "type": "BOOLEAN", - "mode": "NULLABLE", - }, - { - "name": "col3", - "type": "STRING", - "mode": "NULLABLE", - }, - ], - }, + SchemaField.from_api_repr( + { + "name": "id", + "type": "INTEGER", + "mode": "NULLABLE", + } + ), + SchemaField.from_api_repr( + { + "name": "json_col", + "type": "RECORD", + "mode": "NULLABLE", + "fields": [ + { + "name": "col1", + "type": "INTEGER", + "mode": "NULLABLE", + }, + { + "name": "col2", + "type": "BOOLEAN", + "mode": "NULLABLE", + }, + { + "name": "col3", + "type": "STRING", + "mode": "NULLABLE", + }, + ], + } + ), ] # act @@ -171,11 +175,13 @@ def test_create_table_from_records_overwrite_false(): mock_client = Mock(spec_set=bq.BigQueryClient) input_records = [{"id": 1}] target_schema = [ - { - "name": "id", - "type": "INTEGER", - "mode": "NULLABLE", - } + SchemaField.from_api_repr( + { + "name": "id", + "type": "INTEGER", + "mode": "NULLABLE", + } + ) ] # act @@ -346,6 +352,7 @@ def test_upsert_table_from_records(mock_create_table, mock_delete_table): json_key=None, client=mock_client, chunk_size=None, + schema=table_mock.schema, ) mock_delete_table.assert_called_once_with( @@ -398,11 +405,39 @@ def test_upsert_table_from_records_schema_mismatch(mock_delete_table): records=[{"id": 1}], key_field="id", client=mock_client, + use_target_schema=False, ) mock_delete_table.call_count == 2 +@patch("gcpde.bq.create_table_from_records") +def test_upsert_table_from_records_missing_target_table(mock_create_table): + # arrange + mock_client = Mock(spec_set=bq.BigQueryClient) + mock_client.get_table.side_effect = NotFound("") + + # act + bq.upsert_table_from_records( + dataset="dataset", + table="table", + records=[{"id": 1}], + key_field="id", + client=mock_client, + ) + + # assert + mock_create_table.assert_called_once_with( + dataset="dataset", + table="table", + records=[{"id": 1}], + overwrite=False, + json_key=None, + client=mock_client, + chunk_size=None, + ) + + def test_big_query_schema_mismatch_exception(): # arrange source_schema = [{"name": "id"}] @@ -418,3 +453,23 @@ def test_big_query_schema_mismatch_exception(): str(exception) == "message\nSource schema: [{'name': 'id'}]\nTarget schema: [{'name': 'id'}]" ) + + +def test_get_schema_from_json(): + # arrange + schema_json = [ + {"name": "id", "type": "INTEGER", "mode": "NULLABLE"}, + {"name": "name", "type": "STRING", "mode": "REQUIRED"}, + ] + + # act + result = bq.get_schema_from_json(schema_json) + + # assert + assert len(result) == 2 + assert result[0].name == "id" + assert result[0].field_type == "INTEGER" + assert result[0].mode == "NULLABLE" + assert result[1].name == "name" + assert result[1].field_type == "STRING" + assert result[1].mode == "REQUIRED"