forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 1
Simple Test
Zhixun Tan edited this page Dec 15, 2017
·
1 revision
from __future__ import absolute_import, print_function
import tvm
import numpy as np
n = tvm.var("n")
A = tvm.placeholder(n, name='A')
B = tvm.placeholder(n, name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
s = tvm.create_schedule(C.op)
s[C].bind(C.op.axis[0], tvm.thread_axis("threadIdx.x"))
print(tvm.lower(s, [A, B, C], simple_mode=True))
fadd_gl = tvm.build(s, [A, B, C], "opengl", name="myadd")
print("------opengl code------")
print(fadd_gl.imported_modules[0].get_source(fmt="gl"))
ctx = tvm.opengl(0)
n = 10
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
fadd_gl(a, b, c)
np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())