Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 107 additions & 47 deletions src/idl_gen_python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,15 +208,9 @@ class PythonStubGenerator {

std::string EnumType(const EnumDef& enum_def, Imports* imports) const {
imports->Import("typing");
const Import& import =
imports->Import(ModuleFor(&enum_def), namer_.Type(enum_def));
std::ignore = imports->Import(ModuleFor(&enum_def), namer_.Type(enum_def));

std::string result = "";
for (const EnumVal* val : enum_def.Vals()) {
if (!result.empty()) result += ", ";
result += import.name + "." + namer_.Variant(*val);
}
return "typing.Literal[" + result + "]";
return namer_.Type(enum_def);
}

std::string TypeOf(const Type& type, Imports* imports) const {
Expand Down Expand Up @@ -530,7 +524,7 @@ class PythonStubGenerator {
StructBuilderArgs(*struct_def, "", imports, &args);

stub << '\n';
stub << "def Create" + namer_.Type(*struct_def)
stub << "def Create" + namer_.Function(*struct_def)
<< "(builder: flatbuffers.Builder";
for (const std::string& arg : args) {
stub << ", " << arg;
Expand Down Expand Up @@ -610,24 +604,31 @@ class PythonStubGenerator {
stub << "class " << namer_.Type(*enum_def);
imports->Export(ModuleFor(enum_def), namer_.Type(*enum_def));

imports->Import("typing", "cast");
imports->Import("typing", "Final");

if (version_.major == 3) {
imports->Import("enum", "IntEnum");
stub << "(IntEnum)";
if (parser_.opts.python_typing) {
if (enum_def->attributes.Lookup("bit_flags")) {
imports->Import("enum", "IntFlag");
stub << "(IntFlag)";
} else {
imports->Import("enum", "IntEnum");
stub << "(IntEnum)";
}
} else {
stub << "(object)";
}

stub << ":\n";
for (const EnumVal* val : enum_def->Vals()) {
stub << " " << namer_.Variant(*val) << " = cast("
<< ScalarType(enum_def->underlying_type.base_type) << ", ...)\n";
stub << " " << namer_.Variant(*val) << ": Final["
<< namer_.Type(*enum_def) << "]\n";
}
stub << " def __new__(cls, value: int) -> " << namer_.Type(*enum_def)
<< ": ...\n";

if (parser_.opts.generate_object_based_api & enum_def->is_union) {
imports->Import("flatbuffers", "table");
stub << "def " << namer_.Function(*enum_def)
stub << "\ndef " << namer_.Function(*enum_def)
<< "Creator(union_type: " << EnumType(*enum_def, imports)
<< ", table: table.Table) -> " << UnionType(*enum_def, imports)
<< ": ...\n";
Expand Down Expand Up @@ -721,7 +722,21 @@ class PythonGenerator : public BaseGenerator {
// Begin enum code with a class declaration.
void BeginEnum(const EnumDef& enum_def, std::string* code_ptr) const {
auto& code = *code_ptr;
code += "class " + namer_.Type(enum_def) + "(object):\n";

code += "class " + namer_.Type(enum_def);

python::Version version{parser_.opts.python_version};
if (version.major == 3) {
if (enum_def.attributes.Lookup("bit_flags")) {
code += "(IntFlag)";
} else {
code += "(IntEnum)";
}
} else {
code += "(object)";
}

code += ":\n";
}

// Starts a new line and then indents.
Expand Down Expand Up @@ -852,7 +867,7 @@ class PythonGenerator : public BaseGenerator {
std::string getter = GenGetter(field.value.type);
GenReceiver(struct_def, code_ptr);
code += namer_.Method(field);
code += "(self):";
code += "(self):"; // TODO: add typing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgotten todo?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not forgotten, just ran out of steam last night. Haha :)

code += OffsetPrefix(field);
getter += "o + self._tab.Pos)";
auto is_bool = IsBool(field.value.type.base_type);
Expand Down Expand Up @@ -1685,6 +1700,13 @@ class PythonGenerator : public BaseGenerator {
auto& field = **it;
if (field.deprecated) continue;

// include import for enum type if used in this struct, we want type
// information, and we want modern enums.
if (IsEnum(field.value.type) && parser_.opts.python_typing) {
imports.insert(ImportMapEntry{GenPackageReference(field.value.type),
namer_.Type(*field.value.type.enum_def)});
}

GenStructAccessor(struct_def, field, code_ptr, imports);
}

Expand Down Expand Up @@ -1739,6 +1761,12 @@ class PythonGenerator : public BaseGenerator {
} else if (IsFloat(base_type)) {
return float_const_gen_.GenFloatConstant(field);
} else if (IsInteger(base_type)) {
// wrap the default value in the enum constructor to aid type hinting
python::Version version{parser_.opts.python_version};
if (version.major == 3 && IsEnum(field.value.type)) {
auto enum_type = namer_.Type(*field.value.type.enum_def);
return enum_type + "(" + field.value.constant + ")";
}
return field.value.constant;
} else {
// For string, struct, and table.
Expand Down Expand Up @@ -1865,11 +1893,16 @@ class PythonGenerator : public BaseGenerator {
break;
}
default:
// Scalar or sting fields.
field_type = GetBasePythonTypeForScalarAndString(base_type);
if (field.IsScalarOptional()) {
import_typing_list.insert("Optional");
field_type = "Optional[" + field_type + "]";
// Scalar or string fields.
python::Version version{parser_.opts.python_version};
if (version.major == 3 && IsEnum(field.value.type)) {
field_type = namer_.Type(*field.value.type.enum_def);
} else {
field_type = GetBasePythonTypeForScalarAndString(base_type);
if (field.IsScalarOptional()) {
import_typing_list.insert("Optional");
field_type = "Optional[" + field_type + "]";
}
}
break;
}
Expand Down Expand Up @@ -2647,6 +2680,12 @@ class PythonGenerator : public BaseGenerator {

std::string GenFieldTy(const FieldDef& field) const {
if (IsScalar(field.value.type.base_type) || IsArray(field.value.type)) {
python::Version version{parser_.opts.python_version};
if (version.major == 3) {
if (IsEnum(field.value.type)) {
return namer_.Type(*field.value.type.enum_def);
}
}
const std::string ty = GenTypeBasic(field.value.type);
if (ty.find("int") != std::string::npos) {
return "int";
Expand Down Expand Up @@ -2761,7 +2800,8 @@ class PythonGenerator : public BaseGenerator {
bool generate() {
std::string one_file_code;
ImportMap one_file_imports;
if (!generateEnums(&one_file_code)) return false;

if (!generateEnums(&one_file_code, one_file_imports)) return false;
if (!generateStructs(&one_file_code, one_file_imports)) return false;

if (parser_.opts.one_file) {
Expand All @@ -2776,7 +2816,8 @@ class PythonGenerator : public BaseGenerator {
}

private:
bool generateEnums(std::string* one_file_code) const {
bool generateEnums(std::string* one_file_code,
ImportMap& one_file_imports) const {
for (auto it = parser_.enums_.vec.begin(); it != parser_.enums_.vec.end();
++it) {
auto& enum_def = **it;
Expand All @@ -2786,10 +2827,28 @@ class PythonGenerator : public BaseGenerator {
GenUnionCreator(enum_def, &enumcode);
}

python::Version version{parser_.opts.python_version};
if (parser_.opts.one_file && !enumcode.empty()) {
if (version.major == 3) {
if (enum_def.attributes.Lookup("bit_flags")) {
one_file_imports.insert({"enum", "IntFlag"});
} else {
one_file_imports.insert({"enum", "IntEnum"});
}
}

*one_file_code += enumcode + "\n\n";
} else {
ImportMap imports;

if (version.major == 3) {
if (enum_def.attributes.Lookup("bit_flags")) {
imports.insert({"enum", "IntFlag"});
} else {
imports.insert({"enum", "IntEnum"});
}
}

const std::string mod =
namer_.File(enum_def, SkipFile::SuffixAndExtension);

Expand Down Expand Up @@ -2835,49 +2894,50 @@ class PythonGenerator : public BaseGenerator {
}

// Begin by declaring namespace and imports.
void BeginFile(const std::string& name_space_name, const bool needs_imports,
std::string* code_ptr, const std::string& mod,
const ImportMap& imports) const {
void BeginFile(const std::string& name_space_name,
const bool needs_default_imports, std::string* code_ptr,
const std::string& mod, const ImportMap& imports) const {
auto& code = *code_ptr;
code = code + "# " + FlatBuffersGeneratedWarning() + "\n\n";
code += "# namespace: " + name_space_name + "\n\n";

if (needs_imports) {
const std::string local_import = "." + mod;

if (needs_default_imports) {
code += "import flatbuffers\n";
if (parser_.opts.python_gen_numpy) {
code += "from flatbuffers.compat import import_numpy\n";
}
if (parser_.opts.python_typing) {
code += "from typing import Any\n";

for (auto import_entry : imports) {
// If we have a file called, say, "MyType.py" and in it we have a
// class "MyType", we can generate imports -- usually when we
// have a type that contains arrays of itself -- of the type
// "from .MyType import MyType", which Python can't resolve. So
// if we are trying to import ourself, we skip.
if (import_entry.first != local_import) {
code += "from " + import_entry.first + " import " +
import_entry.second + "\n";
}
}
}
if (parser_.opts.python_gen_numpy) {
code += "np = import_numpy()\n\n";
}
for (auto import_entry : imports) {
const std::string local_import = "." + mod;

// If we have a file called, say, "MyType.py" and in it we have a
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a separate fix? Usually good to keep those in separate PRs, though no biggie keeping it here for now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I think this is totally unrelated -- just a result of my poor branch management it seems. will remove.

// class "MyType", we can generate imports -- usually when we
// have a type that contains arrays of itself -- of the type
// "from .MyType import MyType", which Python can't resolve. So
// if we are trying to import ourself, we skip.
if (import_entry.first != local_import) {
code += "from " + import_entry.first + " import " +
import_entry.second + "\n";
}
}

if (needs_default_imports && parser_.opts.python_gen_numpy) {
code += "np = import_numpy()\n\n";
}
}

// Save out the generated code for a Python Table type.
bool SaveType(const std::string& defname, const Namespace& ns,
const std::string& classcode, const ImportMap& imports,
const std::string& mod, bool needs_imports) const {
const std::string& mod, bool needs_default_imports) const {
if (classcode.empty()) return true;

std::string code = "";
BeginFile(LastNamespacePart(ns), needs_imports, &code, mod, imports);
BeginFile(LastNamespacePart(ns), needs_default_imports, &code, mod,
imports);
code += classcode;

const std::string directories =
Expand Down
2 changes: 1 addition & 1 deletion tests/MyGame/Example/ArrayStruct.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,5 @@ class ArrayStructT(object):
def _UnPack(self, arrayStruct: ArrayStruct) -> None: ...
def Pack(self, builder: flatbuffers.Builder) -> None: ...

def CreateArrayStruct(builder: flatbuffers.Builder, a: float, b: int, c: int, d_a: int, d_b: typing.Literal[TestEnum.A, TestEnum.B, TestEnum.C], d_c: typing.Literal[TestEnum.A, TestEnum.B, TestEnum.C], d_d: int, e: int, f: int) -> uoffset: ...
def CreateArrayStruct(builder: flatbuffers.Builder, a: float, b: int, c: int, d_a: int, d_b: TestEnum, d_c: TestEnum, d_d: int, e: int, f: int) -> uoffset: ...

1 change: 1 addition & 0 deletions tests/MyGame/Example/NestedStruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import flatbuffers
from flatbuffers.compat import import_numpy
from typing import Any
from MyGame.Example.TestEnum import TestEnum
np = import_numpy()

class NestedStruct(object):
Expand Down
14 changes: 7 additions & 7 deletions tests/MyGame/Example/NestedStruct.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ class NestedStruct(object):
def AAsNumpy(self) -> np.ndarray: ...
def ALength(self) -> int: ...
def AIsNone(self) -> bool: ...
def B(self) -> typing.Literal[TestEnum.A, TestEnum.B, TestEnum.C]: ...
def C(self, i: int) -> typing.Literal[TestEnum.A, TestEnum.B, TestEnum.C]: ...
def B(self) -> TestEnum: ...
def C(self, i: int) -> TestEnum: ...
def CAsNumpy(self) -> np.ndarray: ...
def CLength(self) -> int: ...
def CIsNone(self) -> bool: ...
Expand All @@ -28,14 +28,14 @@ class NestedStruct(object):
def DIsNone(self) -> bool: ...
class NestedStructT(object):
a: typing.List[int]
b: typing.Literal[TestEnum.A, TestEnum.B, TestEnum.C]
c: typing.List[typing.Literal[TestEnum.A, TestEnum.B, TestEnum.C]]
b: TestEnum
c: typing.List[TestEnum]
d: typing.List[int]
def __init__(
self,
a: typing.List[int] | None = ...,
b: typing.Literal[TestEnum.A, TestEnum.B, TestEnum.C] = ...,
c: typing.List[typing.Literal[TestEnum.A, TestEnum.B, TestEnum.C]] | None = ...,
b: TestEnum = ...,
c: typing.List[TestEnum] | None = ...,
d: typing.List[int] | None = ...,
) -> None: ...
@classmethod
Expand All @@ -47,5 +47,5 @@ class NestedStructT(object):
def _UnPack(self, nestedStruct: NestedStruct) -> None: ...
def Pack(self, builder: flatbuffers.Builder) -> None: ...

def CreateNestedStruct(builder: flatbuffers.Builder, a: int, b: typing.Literal[TestEnum.A, TestEnum.B, TestEnum.C], c: typing.Literal[TestEnum.A, TestEnum.B, TestEnum.C], d: int) -> uoffset: ...
def CreateNestedStruct(builder: flatbuffers.Builder, a: int, b: TestEnum, c: TestEnum, d: int) -> uoffset: ...

15 changes: 9 additions & 6 deletions tests/MyGame/Example/NestedUnion/Any.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@ import numpy as np
import typing
from MyGame.Example.NestedUnion.TestSimpleTableWithEnum import TestSimpleTableWithEnum
from MyGame.Example.NestedUnion.Vec3 import Vec3
from enum import IntEnum
from flatbuffers import table
from typing import cast
from typing import Final

uoffset: typing.TypeAlias = flatbuffers.number_types.UOffsetTFlags.py_type

class Any(object):
NONE = cast(int, ...)
Vec3 = cast(int, ...)
TestSimpleTableWithEnum = cast(int, ...)
def AnyCreator(union_type: typing.Literal[Any.NONE, Any.Vec3, Any.TestSimpleTableWithEnum], table: table.Table) -> typing.Union[None, Vec3, TestSimpleTableWithEnum]: ...
class Any(IntEnum):
NONE: Final[Any]
Vec3: Final[Any]
TestSimpleTableWithEnum: Final[Any]
def __new__(cls, value: int) -> Any: ...

def AnyCreator(union_type: Any, table: table.Table) -> typing.Union[None, Vec3, TestSimpleTableWithEnum]: ...

12 changes: 7 additions & 5 deletions tests/MyGame/Example/NestedUnion/Color.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ import flatbuffers
import numpy as np

import typing
from typing import cast
from enum import IntFlag
from typing import Final

uoffset: typing.TypeAlias = flatbuffers.number_types.UOffsetTFlags.py_type

class Color(object):
Red = cast(int, ...)
Green = cast(int, ...)
Blue = cast(int, ...)
class Color(IntFlag):
Red: Final[Color]
Green: Final[Color]
Blue: Final[Color]
def __new__(cls, value: int) -> Color: ...

1 change: 1 addition & 0 deletions tests/MyGame/Example/NestedUnion/NestedUnionTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import flatbuffers
from flatbuffers.compat import import_numpy
from typing import Any
from MyGame.Example.NestedUnion.Any import Any
from flatbuffers.table import Table
from typing import Optional
np = import_numpy()
Expand Down
Loading
Loading