diff --git a/renderapi/msgspec_json.py b/renderapi/msgspec_json.py new file mode 100644 index 0000000..b110d31 --- /dev/null +++ b/renderapi/msgspec_json.py @@ -0,0 +1,49 @@ +import io +import msgspec +import numpy +from typing import Any, Type + + +def is_binary(f): + return isinstance(f, (io.RawIOBase, io.BufferedIOBase)) + + +def render_encode_hook(obj: Any) -> Any: + if isinstance(obj, numpy.integer): + return int(obj) + if isinstance(obj, numpy.floating): + return float(obj) + to_dict = getattr(obj, "to_dict", None) + if callable(to_dict): + return obj.to_dict() + else: + try: + return dict(obj) + except TypeError as e: + return obj.__dict__ + +render_enc = msgspec.json.Encoder(enc_hook=render_encode_hook) + + +class MsgSpecRenderJson: + @staticmethod + def loads(j, *args, **kwargs): + return msgspec.json.decode(j) + + @staticmethod + def dumps(o, *args, **kwargs): + return render_enc.encode(o).decode() + + @staticmethod + def dump(o, fp, *args, **kwargs): + if is_binary(fp): + fp.write(render_enc.encode(o)) + else: + fp.write(render_enc.encode(o).decode()) + + @staticmethod + def load(fp, *args, **kwargs): + if is_binary(fp): + msgspec.json.decode(fp.read()) + else: + msgspec.json.decode(fp.read()) diff --git a/renderapi/utils.py b/renderapi/utils.py index c46b2fb..88031f4 100644 --- a/renderapi/utils.py +++ b/renderapi/utils.py @@ -18,12 +18,15 @@ from .errors import RenderError -# use ujson if installed for faster json try: - import ujson as requests_json + from .msgspec_json import MsgSpecRenderJson + default_json = MsgSpecRenderJson except ImportError: - import json as requests_json -requests.models.complexjson = requests_json + default_json = json + +# default_json = MsgSpecRenderJson + +requests.models.complexjson = default_json class NullHandler(logging.Handler): @@ -104,7 +107,7 @@ def post_json(session, request_url, d, params=None): headers = {"content-type": "application/json"} if d is not None: - payload = json.dumps(d) + payload = default_json.dumps(d) else: payload = None headers['Accept'] = "application/json" @@ -235,7 +238,7 @@ def renderdumps(obj, *args, **kwargs): serialized object """ cls_ = kwargs.pop('cls', RenderEncoder) - return json.dumps(obj, *args, cls=cls_, **kwargs) + return default_json.dumps(obj, *args, cls=cls_, **kwargs) def renderdump(obj, *args, **kwargs): @@ -251,7 +254,7 @@ def renderdump(obj, *args, **kwargs): json.dump kwargs """ cls_ = kwargs.pop('cls', RenderEncoder) - return json.dump(obj, *args, cls=cls_, **kwargs) + return default_json.dump(obj, *args, cls=cls_, **kwargs) def renderdump_temp(obj, *args, **kwargs): diff --git a/test/test_utils.py b/test/test_utils.py index a2500e6..ec45c03 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -3,7 +3,9 @@ import renderapi import pytest import numpy as np -import ujson +import msgspec + +from renderapi.msgspec_json import MsgSpecRenderJson def cross_py23_reload(module): @@ -13,27 +15,27 @@ def cross_py23_reload(module): importlib.reload(module) -@pytest.mark.parametrize("use_ujson", [True, False]) -def test_json_load(use_ujson): - if not use_ujson: +@pytest.mark.parametrize("use_msgspec", [True, False]) +def test_json_load(use_msgspec): + if not use_msgspec: try: import builtins except ImportError: import __builtin__ as builtins realimport = builtins.__import__ - def noujson_import(name, globals=None, locals=None, - fromlist=(), level=0): - if 'ujson' in name: + def no_msgspec_import(name, globals=None, locals=None, + fromlist=(), level=0): + if 'msgspec' in name: raise ImportError return realimport(name, globals, locals, fromlist, level) - builtins.__import__ = noujson_import + builtins.__import__ = no_msgspec_import cross_py23_reload(renderapi.utils) - assert (renderapi.utils.requests_json is ujson - if use_ujson else renderapi.utils.requests_json is json) + assert (renderapi.utils.default_json is MsgSpecRenderJson + if use_msgspec else renderapi.utils.default_json is json) assert ( - renderapi.utils.requests.models.complexjson is ujson - if use_ujson else renderapi.utils.requests.models.complexjson is json) + renderapi.utils.requests.models.complexjson is MsgSpecRenderJson + if use_msgspec else renderapi.utils.requests.models.complexjson is json) def test_jbool(): @@ -45,7 +47,7 @@ def test_jbool(): def test_renderdumps_simple(): s = renderapi.utils.renderdumps({'a': 1}) - assert(s == '{"a": 1}') + assert(s in ['{"a": 1}', '{"a":1}']) s = renderapi.utils.renderdumps(5) assert(s == '5') diff --git a/test_requirements.txt b/test_requirements.txt index 74e88cc..4746927 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -7,5 +7,5 @@ pytest-pep8>=1.0.6 pytest-xdist>=1.14 flake8>=3.0.4 pylint>=1.5.4 -ujson +msgspec jinja2