-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsetup.py
More file actions
172 lines (146 loc) · 5.1 KB
/
setup.py
File metadata and controls
172 lines (146 loc) · 5.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#!/usr/bin/env python3
"""
Setup script for volresample package with Cython extensions.
This setup.py builds optimized Cython extensions for 3D volume resampling
with OpenMP parallelization and architecture-specific optimizations.
"""
import os
import platform
import sys
import numpy as np
from setuptools import Extension, find_packages, setup
from setuptools.command.build_ext import build_ext
# Try to import Cython
try:
from Cython.Build import cythonize
CYTHON_AVAILABLE = True
except ImportError:
CYTHON_AVAILABLE = False
print("WARNING: Cython not available. Extensions will not be built.")
class BuildExtWithArchDetection(build_ext):
"""Custom build_ext that provides helpful messages."""
def run(self):
"""Run the build with architecture detection info."""
machine = platform.machine().lower()
print(f"\n{'='*70}")
print(f"Building volresample Cython extensions for: {machine}")
print(f"Python: {sys.version}")
print(f"NumPy: {np.__version__}")
print(f"{'='*70}\n")
super().run()
def get_cython_extensions():
"""
Create Cython extension modules with architecture-appropriate flags.
"""
if not CYTHON_AVAILABLE:
return []
# Detect target architecture
machine = platform.machine().lower()
is_arm = machine in ["arm64", "aarch64", "armv7l", "armv8"]
is_x86 = machine in ["x86_64", "amd64", "i386", "i686"]
# Allow override via environment variable (for cross-compilation)
build_arch = os.environ.get("CIBW_ARCHS", machine)
if build_arch in ["ARM64", "aarch64"]:
is_arm = True
is_x86 = False
elif build_arch in ["x86_64", "AMD64"]:
is_x86 = True
is_arm = False
extensions = []
# Source file for resampling
resampling_source = "src/volresample/_resample.pyx"
# Common settings
include_dirs = [np.get_include()]
define_macros = [("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")]
# Compiler flags - platform dependent
if sys.platform == "win32":
# Windows with MSVC
extra_compile_args_base = ["/O2", "/arch:AVX2"]
extra_link_args_base = []
openmp_compile = ["/openmp"]
openmp_link = []
elif sys.platform == "darwin":
# macOS with Apple Clang — needs special OpenMP handling
avx_flags = []
if is_x86:
avx_flags = [
"-mavx2",
"-mfma",
"-ftree-vectorize",
"-ffast-math",
]
extra_compile_args_base = ["-O3"] + avx_flags
extra_link_args_base = []
# Apple Clang does not support -fopenmp directly; use -Xclang -fopenmp
openmp_compile = ["-Xclang", "-fopenmp"]
openmp_link = ["-lomp"]
else:
# Linux with GCC/Clang
avx_flags = []
if is_x86:
avx_flags = [
"-mavx2",
"-mfma",
"-ftree-vectorize",
"-ffast-math",
]
extra_compile_args_base = ["-O3"] + avx_flags
extra_link_args_base = []
openmp_compile = ["-fopenmp"]
openmp_link = ["-fopenmp"]
print(f"\nBuilding extensions for: {machine}")
print(f" is_x86: {is_x86}")
print(f" is_arm: {is_arm}")
if is_x86:
print(" AVX2/FMA optimizations: ENABLED")
# Build the volresample extension
print(" - Building volresample._resample")
extensions.append(
Extension(
name="volresample._resample",
sources=[resampling_source],
include_dirs=include_dirs,
extra_compile_args=extra_compile_args_base + openmp_compile,
extra_link_args=extra_link_args_base + openmp_link,
define_macros=define_macros,
)
)
print(f"\nTotal extensions to build: {len(extensions)}\n")
return extensions
def main():
"""Main setup function."""
# Get Cython extensions
extensions = get_cython_extensions()
# Cythonize if available
if extensions and CYTHON_AVAILABLE:
# Remove duplicate Extension objects (by name)
unique_exts = {}
for ext in extensions:
if ext.name not in unique_exts:
unique_exts[ext.name] = ext
ext_modules = cythonize(
list(unique_exts.values()),
compiler_directives={
"language_level": "3",
"boundscheck": False,
"wraparound": False,
"cdivision": True,
"initializedcheck": False,
"nonecheck": False,
},
nthreads=int(os.environ.get("CYTHON_NTHREADS", "1")),
)
else:
ext_modules = []
# All metadata (name, version, dependencies, etc.) lives in pyproject.toml.
# setup.py only provides build-time extension configuration.
setup(
packages=find_packages(where="src"),
package_dir={"": "src"},
ext_modules=ext_modules,
package_data={"volresample": ["*.pyi", "py.typed"]},
cmdclass={"build_ext": BuildExtWithArchDetection},
zip_safe=False,
)
if __name__ == "__main__":
main()