diff --git a/spanner_orm/admin/column.py b/spanner_orm/admin/column.py index ab42d57..7d0858b 100644 --- a/spanner_orm/admin/column.py +++ b/spanner_orm/admin/column.py @@ -23,6 +23,7 @@ _string_pattern = re.compile(r"^STRING\([0-9]+\)+$") _string_array_pattern = re.compile(r"^ARRAY+$") +_bytes_pattern = re.compile(r"^BYTES\([0-9]+\)+$") class ColumnSchema(schema.InformationSchema): @@ -51,6 +52,8 @@ def field_type(self) -> Type[field.FieldType]: return field.String elif bool(_string_array_pattern.match(self.spanner_type)): return field.StringArray + elif bool(_bytes_pattern.match(self.spanner_type)): + return field.Bytes raise error.SpannerError( "No corresponding Type for {}".format(self.spanner_type) diff --git a/spanner_orm/field.py b/spanner_orm/field.py index f395d1e..3cab38a 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -186,6 +186,23 @@ def validate_type(value) -> None: raise error.ValidationError("{} is not of type str".format(value)) +class Bytes(FieldType): + """Represents a bytes type.""" + + @staticmethod + def ddl(size="MAX") -> str: + return "BYTES({size})".format(size=size) + + @staticmethod + def grpc_type() -> type_pb2.Type: + return type_pb2.Type(code=type_pb2.BYTES) + + @staticmethod + def validate_type(value) -> None: + if not isinstance(value, bytes): + raise error.ValidationError("{} is not of type bytes".format(value)) + + class Date(FieldType): """Represents a date type.""" @@ -333,6 +350,7 @@ def validate_type(value: Any) -> None: Integer, Float, String, + Bytes, Date, Timestamp, StringArray, diff --git a/spanner_orm/tests/models.py b/spanner_orm/tests/models.py index 377500d..c00edec 100644 --- a/spanner_orm/tests/models.py +++ b/spanner_orm/tests/models.py @@ -112,6 +112,8 @@ class UnittestModel(model.Model): field.Timestamp, nullable=True, allow_commit_timestamp=True ) date = field.Field(field.Date, nullable=True) + bytes_ = field.Field(field.Bytes, nullable=True) + bytes_2 = field.Field(field.Bytes, nullable=True, size=2048) bool_array = field.Field(field.BoolArray, nullable=True) int_array = field.Field(field.IntegerArray, nullable=True) float_array = field.Field(field.FloatArray, nullable=True) diff --git a/spanner_orm/tests/update_test.py b/spanner_orm/tests/update_test.py index cca439e..41e03c0 100644 --- a/spanner_orm/tests/update_test.py +++ b/spanner_orm/tests/update_test.py @@ -110,6 +110,8 @@ def test_create_table(self, get_model): " timestamp TIMESTAMP NOT NULL," " timestamp_2 TIMESTAMP OPTIONS (allow_commit_timestamp=true)," " date DATE," + " bytes_ BYTES(MAX)," + " bytes_2 BYTES(2048)," " bool_array ARRAY," " int_array ARRAY," " float_array ARRAY," @@ -117,6 +119,7 @@ def test_create_table(self, get_model): " string_array ARRAY," " string_array_2 ARRAY) PRIMARY KEY (int_, float_, string)" ) + self.assertEqual(test_update.ddl(), test_model_ddl) @mock.patch("spanner_orm.admin.metadata.SpannerMetadata.model") @@ -180,6 +183,7 @@ def test_create_table_no_model(self, get_model): " string_array ARRAY," " string_array_2 ARRAY) PRIMARY KEY (int_, float_, string)" ) + self.assertEqual(test_update.ddl(), test_model_ddl) @mock.patch("spanner_orm.admin.metadata.SpannerMetadata.model")