From 832deac08df9e30d7b0486968ddc5123a3cde740 Mon Sep 17 00:00:00 2001 From: Niansong Zhang Date: Fri, 6 Feb 2026 14:03:19 -0500 Subject: [PATCH 1/2] [Builder] Add F64Type support and SinOp for math operations Add F64Type to the type check for scalar math operations so that float64 arguments are recognized alongside float32 and integer types. Also add math.SinOp to the supported operation map. These changes enable math operations (exp, log, sqrt, sin, cos, tan, etc.) to work with float64 types, and add sin() support for all float types. Co-Authored-By: Claude Opus 4.6 --- allo/ir/builder.py | 5 +++- tests/test_types.py | 62 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/allo/ir/builder.py b/allo/ir/builder.py index 35b1374049..e2844a7ac7 100644 --- a/allo/ir/builder.py +++ b/allo/ir/builder.py @@ -20,6 +20,7 @@ ShapedType, IntegerType, F32Type, + F64Type, UnitAttr, IntegerAttr, StringAttr, @@ -3103,7 +3104,8 @@ class ExternalModule: if hasattr(arg, "result") and hasattr(arg.result, "type"): arg_types.append(arg.result.type) if all( - isinstance(arg_type, (F32Type, IntegerType)) for arg_type in arg_types + isinstance(arg_type, (F32Type, F64Type, IntegerType)) + for arg_type in arg_types ): opcls = { "exp": math_d.ExpOp, @@ -3111,6 +3113,7 @@ class ExternalModule: "log2": math_d.Log2Op, "log10": math_d.Log10Op, "sqrt": math_d.SqrtOp, + "sin": math_d.SinOp, "cos": math_d.CosOp, "tan": math_d.TanOp, "tanh": math_d.TanhOp, diff --git a/tests/test_types.py b/tests/test_types.py index 46df5dae3c..725e4d8c28 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -628,6 +628,68 @@ def kernel(x: int32[16], y: float32[16]) -> int32: assert allo_result == expected +def test_sin_float32(): + def kernel(A: float32[10]) -> float32[10]: + B: float32[10] + for i in range(10): + B[i] = allo.sin(A[i]) + return B + + s = allo.customize(kernel) + print(s.module) + mod = s.build() + A = np.random.rand(10).astype(np.float32) + B = mod(A) + np.testing.assert_allclose(B, np.sin(A), rtol=1e-5) + + +def test_sin_float64(): + def kernel(A: float64[10]) -> float64[10]: + B: float64[10] + for i in range(10): + B[i] = allo.sin(A[i]) + return B + + s = allo.customize(kernel) + print(s.module) + mod = s.build() + A = np.random.rand(10).astype(np.float64) + B = mod(A) + np.testing.assert_allclose(B, np.sin(A), rtol=1e-5) + + +def test_float64_math_ops(): + def kernel(A: float64[10]) -> float64[10]: + B: float64[10] + for i in range(10): + B[i] = allo.exp(A[i]) + allo.log(A[i]) + allo.sqrt(A[i]) + return B + + s = allo.customize(kernel) + print(s.module) + mod = s.build() + A = np.random.uniform(0.1, 2.0, size=10).astype(np.float64) + B = mod(A) + expected = np.exp(A) + np.log(A) + np.sqrt(A) + np.testing.assert_allclose(B, expected, rtol=1e-5) + + +def test_float64_trig_ops(): + def kernel(A: float64[10]) -> float64[10]: + B: float64[10] + for i in range(10): + B[i] = allo.sin(A[i]) + allo.cos(A[i]) + allo.tan(A[i]) + return B + + s = allo.customize(kernel) + print(s.module) + mod = s.build() + A = np.random.uniform(-1.0, 1.0, size=10).astype(np.float64) + B = mod(A) + expected = np.sin(A) + np.cos(A) + np.tan(A) + np.testing.assert_allclose(B, expected, rtol=1e-5) + + ###################################################################### # Legacy tests ###################################################################### From 3f0fef46878f4464ad202538de6311be89cf425e Mon Sep 17 00:00:00 2001 From: Niansong Zhang Date: Fri, 6 Feb 2026 14:11:25 -0500 Subject: [PATCH 2/2] [Tests] Assert float64 dtype in test outputs Add dtype assertions to float64 tests to ensure the pipeline does not silently downcast to float32. Co-Authored-By: Claude Opus 4.6 --- tests/test_types.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_types.py b/tests/test_types.py index 725e4d8c28..259bd1af35 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -655,6 +655,7 @@ def kernel(A: float64[10]) -> float64[10]: mod = s.build() A = np.random.rand(10).astype(np.float64) B = mod(A) + assert B.dtype == np.float64 np.testing.assert_allclose(B, np.sin(A), rtol=1e-5) @@ -670,6 +671,7 @@ def kernel(A: float64[10]) -> float64[10]: mod = s.build() A = np.random.uniform(0.1, 2.0, size=10).astype(np.float64) B = mod(A) + assert B.dtype == np.float64 expected = np.exp(A) + np.log(A) + np.sqrt(A) np.testing.assert_allclose(B, expected, rtol=1e-5) @@ -686,6 +688,7 @@ def kernel(A: float64[10]) -> float64[10]: mod = s.build() A = np.random.uniform(-1.0, 1.0, size=10).astype(np.float64) B = mod(A) + assert B.dtype == np.float64 expected = np.sin(A) + np.cos(A) + np.tan(A) np.testing.assert_allclose(B, expected, rtol=1e-5)