Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion allo/ir/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ShapedType,
IntegerType,
F32Type,
F64Type,
UnitAttr,
IntegerAttr,
StringAttr,
Expand Down Expand Up @@ -3103,14 +3104,16 @@ class ExternalModule:
if hasattr(arg, "result") and hasattr(arg.result, "type"):
arg_types.append(arg.result.type)
if all(
isinstance(arg_type, (F32Type, IntegerType)) for arg_type in arg_types
isinstance(arg_type, (F32Type, F64Type, IntegerType))
for arg_type in arg_types
):
opcls = {
"exp": math_d.ExpOp,
"log": math_d.LogOp,
"log2": math_d.Log2Op,
"log10": math_d.Log10Op,
"sqrt": math_d.SqrtOp,
"sin": math_d.SinOp,
"cos": math_d.CosOp,
"tan": math_d.TanOp,
"tanh": math_d.TanhOp,
Expand Down
65 changes: 65 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,71 @@ def kernel(x: int32[16], y: float32[16]) -> int32:
assert allo_result == expected


def test_sin_float32():
def kernel(A: float32[10]) -> float32[10]:
B: float32[10]
for i in range(10):
B[i] = allo.sin(A[i])
return B

s = allo.customize(kernel)
print(s.module)
mod = s.build()
A = np.random.rand(10).astype(np.float32)
B = mod(A)
np.testing.assert_allclose(B, np.sin(A), rtol=1e-5)


def test_sin_float64():
def kernel(A: float64[10]) -> float64[10]:
B: float64[10]
for i in range(10):
B[i] = allo.sin(A[i])
return B

s = allo.customize(kernel)
print(s.module)
mod = s.build()
A = np.random.rand(10).astype(np.float64)
B = mod(A)
assert B.dtype == np.float64
np.testing.assert_allclose(B, np.sin(A), rtol=1e-5)


def test_float64_math_ops():
def kernel(A: float64[10]) -> float64[10]:
B: float64[10]
for i in range(10):
B[i] = allo.exp(A[i]) + allo.log(A[i]) + allo.sqrt(A[i])
return B

s = allo.customize(kernel)
print(s.module)
mod = s.build()
A = np.random.uniform(0.1, 2.0, size=10).astype(np.float64)
B = mod(A)
assert B.dtype == np.float64
expected = np.exp(A) + np.log(A) + np.sqrt(A)
np.testing.assert_allclose(B, expected, rtol=1e-5)


def test_float64_trig_ops():
def kernel(A: float64[10]) -> float64[10]:
B: float64[10]
for i in range(10):
B[i] = allo.sin(A[i]) + allo.cos(A[i]) + allo.tan(A[i])
return B

s = allo.customize(kernel)
print(s.module)
mod = s.build()
A = np.random.uniform(-1.0, 1.0, size=10).astype(np.float64)
B = mod(A)
assert B.dtype == np.float64
expected = np.sin(A) + np.cos(A) + np.tan(A)
np.testing.assert_allclose(B, expected, rtol=1e-5)


######################################################################
# Legacy tests
######################################################################
Expand Down