Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 78 additions & 23 deletions gcpde/bq.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from google.oauth2.service_account import Credentials
from loguru import logger

from gcpde.types import ListJsonType
from gcpde.types import BigQuerySchema, ListJsonType

FIVE_MINUTES = 1 * 60 * 5

Expand Down Expand Up @@ -75,7 +75,7 @@ def create_table(
self,
dataset: str,
table: str,
schema: list[dict[str, str]],
schema: BigQuerySchema,
) -> None:
"""Create a new table on bigquery.

Expand All @@ -84,8 +84,9 @@ def create_table(
Args:
dataset: dataset name.
table: table name.
schema: json dict schema for the table.
https://cloud.google.com/bigquery/docs/schemas#creating_a_json_schema_file
schema: BigQuerySchema for the table. Use
gcpde.bq.get_schema_from_json to create it from a list of
dictionaries.

Raises:
google.cloud.exceptions.Conflict: if the table already exists.
Expand All @@ -94,7 +95,7 @@ def create_table(
table_ref = self._create_table_reference(dataset, table)
table_obj = bigquery.Table(
table_ref=table_ref,
schema=[bigquery.SchemaField.from_api_repr(field) for field in schema],
schema=schema,
)
self.client.create_table(table_obj)

Expand Down Expand Up @@ -186,8 +187,8 @@ class BigQuerySchemaMismatchException(Exception):
def __init__(
self,
message: str,
source_schema: list[bigquery.SchemaField],
target_schema: list[bigquery.SchemaField],
source_schema: BigQuerySchema,
target_schema: BigQuerySchema,
):
super().__init__(message)
self.message = message
Expand All @@ -202,6 +203,19 @@ def __str__(self) -> str:
)


def get_schema_from_json(schema: list[dict[str, str]]) -> BigQuerySchema:
"""Get a schema from a list of dictionaries.

Args:
schema: list of dictionaries representing the schema.
ref: https://cloud.google.com/bigquery/docs/schemas#creating_a_json_schema_file

Returns:
BigQuerySchema
"""
return [bigquery.SchemaField.from_api_repr(field) for field in schema]


def delete_table(
dataset: str,
table: str,
Expand All @@ -228,7 +242,7 @@ def delete_table(
return


def _create_schema_from_records(records: ListJsonType) -> list[dict[str, str]]:
def _create_schema_from_records(records: ListJsonType) -> BigQuerySchema:
generator = SchemaGenerator(
input_format="dict",
keep_nulls=True,
Expand All @@ -241,9 +255,10 @@ def _create_schema_from_records(records: ListJsonType) -> list[dict[str, str]]:
raise BigQueryClientException(
f"Can't infer schema from records, error: {error_logs}"
)
output: list[dict[str, str]] = generator.flatten_schema(schema_map)
output_json: list[dict[str, str]] = generator.flatten_schema(schema_map)
output_api: BigQuerySchema = get_schema_from_json(output_json)
logger.debug("Schema generator complete!")
return output
return output_api


@tenacity.retry(
Expand All @@ -258,7 +273,7 @@ def _create_schema_from_records(records: ListJsonType) -> list[dict[str, str]]:
def create_table(
dataset: str,
table: str,
schema: list[dict[str, str]] | None = None,
schema: BigQuerySchema | None = None,
schema_from_records: ListJsonType | None = None,
json_key: dict[str, str] | None = None,
client: BigQueryClient | None = None,
Expand All @@ -270,8 +285,9 @@ def create_table(
Args:
dataset: dataset name.
table: table name.
schema: json dict schema for the table.
https://cloud.google.com/bigquery/docs/schemas#creating_a_json_schema_file
schema: BigQuerySchema for the table. Use
gcpde.bq.get_schema_from_json to create it from a list of
dictionaries.
schema_from_records: infer schema from a records sample.
json_key: json key with gcp credentials.
client: client to connect to gcp.
Expand Down Expand Up @@ -382,6 +398,7 @@ def upsert_table_from_records(
json_key: dict[str, str] | None = None,
insert_chunk_size: int | None = None,
client: BigQueryClient | None = None,
use_target_schema: bool = True,
) -> None:
"""Upsert records into a table.

Expand All @@ -390,6 +407,9 @@ def upsert_table_from_records(
2. using MERGE statement to update/insert records
3. Cleaning up temporary table

> If the target table doesn't exist, it will be created using
inferred schema from records.

Args:
dataset: dataset name.
table: table name.
Expand All @@ -398,17 +418,33 @@ def upsert_table_from_records(
json_key: json key with gcp credentials.
insert_chunk_size: chunk size for batch inserts.
client: client to connect to BigQuery.
use_target_schema: whether to use the schema of the target table or
infer from records for the temporary table.

Raises:
BigQuerySchemaMismatchException: if schema of new records doesn't match table.
BigQueryClientException: if schema cannot be inferred from records.
"""
client = client or BigQueryClient(json_key=json_key or {})
tmp_table = table + "_tmp"

if not records:
logger.warning("No records to create a table from! (empty collection given)")
return

client = client or BigQueryClient(json_key=json_key or {})
tmp_table = table + "_tmp"
try:
table_schema_bq = client.get_table(dataset, table).schema
except NotFound:
create_table_from_records(
dataset=dataset,
table=table,
records=records,
overwrite=False,
json_key=json_key,
client=client,
chunk_size=insert_chunk_size,
)
return

create_table_from_records(
dataset=dataset,
Expand All @@ -418,19 +454,23 @@ def upsert_table_from_records(
json_key=json_key,
client=client,
chunk_size=insert_chunk_size,
schema=table_schema_bq if use_target_schema else None,
)

tmp_table_schema_bq = client.get_table(dataset, tmp_table).schema
table_schema_bq = client.get_table(dataset, table).schema
tmp_table_schema_bq = (
table_schema_bq
if use_target_schema
else client.get_table(dataset, tmp_table).schema
)

if table_schema_bq != tmp_table_schema_bq:
logger.info("Cleaning up temporary table...")
delete_table(dataset=dataset, table=tmp_table, client=client)

raise BigQuerySchemaMismatchException(
message="New data schema does not match table schema",
source_schema=table_schema_bq,
target_schema=tmp_table_schema_bq,
source_schema=tmp_table_schema_bq,
target_schema=table_schema_bq,
)

update_statement = ", ".join(
Expand Down Expand Up @@ -461,7 +501,7 @@ def replace_table(
dataset: str,
table: str,
records: ListJsonType,
schema: list[dict[str, str]] | None = None,
schema: BigQuerySchema | None = None,
chunk_size: int | None = None,
json_key: dict[str, str] | None = None,
client: BigQueryClient | None = None,
Expand All @@ -471,7 +511,13 @@ def replace_table(

tmp_table = table + "_tmp"
delete_table(dataset=dataset, table=tmp_table, client=client)
create_table(dataset=dataset, table=tmp_table, schema=schema, client=client)
create_table(
dataset=dataset,
table=tmp_table,
schema_from_records=records,
schema=schema,
client=client,
)
insert(
dataset=dataset,
table=tmp_table,
Expand All @@ -496,6 +542,7 @@ def create_table_from_records(
json_key: dict[str, str] | None = None,
client: BigQueryClient | None = None,
chunk_size: int | None = None,
schema: BigQuerySchema | None = None,
) -> None:
"""Create or replace a table from a collection of records.

Expand All @@ -507,12 +554,14 @@ def create_table_from_records(
json_key: json key with gcp credentials.
client: client to connect to gcp.
chunk_size: chunk size number to send to GCP API.
schema: BigQuerySchema for the table. Use
gcpde.bq.get_schema_from_json to create it from a list of
dictionaries.

"""
if not records:
logger.warning("No records to create a table from! (empty collection given)")
return
schema = _create_schema_from_records(records=records or [])

client = client or BigQueryClient(json_key=json_key or {})
if overwrite:
Expand All @@ -526,7 +575,13 @@ def create_table_from_records(
)
return

create_table(dataset=dataset, table=table, schema=schema, client=client)
create_table(
dataset=dataset,
table=table,
schema_from_records=records,
schema=schema,
client=client,
)
insert(
dataset=dataset,
table=table,
Expand Down
4 changes: 4 additions & 0 deletions gcpde/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,8 @@

from typing import Any

from google.cloud import bigquery

ListJsonType = list[dict[str, Any]]

BigQuerySchema = list[bigquery.SchemaField]
Loading
Loading