diff --git a/.gitignore b/.gitignore index e01e346..116cb70 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ .vscode/ build/ +*.egg-info/ *.pyc +*.so diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..1a5c573 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "torchsparse/backend/third_party/sparsehash"] + path = torchsparse/backend/third_party/sparsehash + url = https://github.com/sparsehash/sparsehash.git diff --git a/cython_setup.py b/cython_setup.py index 524bce0..9d879e4 100644 --- a/cython_setup.py +++ b/cython_setup.py @@ -80,7 +80,6 @@ "tqdm", "typing-extensions", "wheel", - "rootpath", "attributedict", ], cmdclass={"build_ext": BuildExtension}, diff --git a/requirements.txt b/requirements.txt index 27bc0e2..41ec64e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,3 @@ tqdm typing-extensions wheel attributedict -rootpath diff --git a/setup.py b/setup.py index eeecf22..b954115 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,7 @@ import glob import os +from pathlib import Path +from subprocess import run import torch import torch.cuda @@ -10,12 +12,11 @@ CppExtension, CUDAExtension, ) +from torchsparse.version import __version__ -# from torchsparse import __version__ +print("torchsparse version:", __version__) -version_file = open("./torchsparse/version.py") -version = version_file.read().split("'")[1] -print("torchsparse version:", version) +build_ext = BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=True) if (torch.cuda.is_available() and CUDA_HOME is not None) or ( os.getenv("FORCE_CUDA", "0") == "1" @@ -34,14 +35,24 @@ sources.append(fpath) extension_type = CUDAExtension if device == "cuda" else CppExtension +current_dir = Path(__file__).parent.resolve() +sparsehash_dir = current_dir / "torchsparse" / "backend" / "third_party" / "sparsehash" +sparsehash_dir_inc = sparsehash_dir / "src" +sparseconfig_path = sparsehash_dir_inc / "sparsehash" / "internal" / "sparseconfig.h" + +if not sparseconfig_path.exists(): + print("Generating sparseconfig.h ...") + run(["./configure"], cwd=sparsehash_dir, check=True) + run(["make", "src/sparsehash/internal/sparseconfig.h"], cwd=sparsehash_dir, check=True) + extra_compile_args = { - "cxx": ["-g", "-O3", "-fopenmp", "-lgomp"], - "nvcc": ["-O3", "-std=c++17"], + "cxx": ["-O3", "-fopenmp", "-lgomp", f"-I{sparsehash_dir_inc}"], + "nvcc": ["-O3"], } setup( name="torchsparse", - version=version, + version=__version__, packages=find_packages(), ext_modules=[ extension_type( @@ -49,19 +60,27 @@ ) ], url="https://github.com/mit-han-lab/torchsparse", + include_package_data=True, install_requires=[ + "ninja", "numpy", "backports.cached_property", "tqdm", "typing-extensions", "wheel", - "rootpath", "torch", "torchvision" ], - dependency_links=[ - 'https://download.pytorch.org/whl/cu118' - ], - cmdclass={"build_ext": BuildExtension}, + cmdclass={"build_ext": build_ext}, zip_safe=False, ) + +for f in [ + "Makefile", + "config.log", + "config.status", + "src/config.h", + "src/sparsehash/internal/sparseconfig.h", + "src/stamp-h1", +]: + (sparsehash_dir / f).unlink(missing_ok=True) diff --git a/torchsparse/backend/third_party/sparsehash b/torchsparse/backend/third_party/sparsehash new file mode 160000 index 0000000..1dffea3 --- /dev/null +++ b/torchsparse/backend/third_party/sparsehash @@ -0,0 +1 @@ +Subproject commit 1dffea3d917445d70d33d0c7492919fc4408fe5c diff --git a/torchsparse/nn/functional/conv/utils/collections.py b/torchsparse/nn/functional/conv/utils/collections.py index b828645..6a69bfa 100644 --- a/torchsparse/nn/functional/conv/utils/collections.py +++ b/torchsparse/nn/functional/conv/utils/collections.py @@ -25,8 +25,6 @@ # IMPORT # -------------------------------------- -import rootpath - import collections from . import compat diff --git a/torchsparse/nn/functional/conv/utils/compat.py b/torchsparse/nn/functional/conv/utils/compat.py index 02d9bd5..7bcbe1b 100644 --- a/torchsparse/nn/functional/conv/utils/compat.py +++ b/torchsparse/nn/functional/conv/utils/compat.py @@ -25,8 +25,6 @@ # DEPS # -------------------------------------- -import rootpath - # @see https://github.com/benjaminp/six/blob/master/six.py import sys