Skip to content
Zhixun Tan edited this page Dec 15, 2017 · 1 revision
from __future__ import absolute_import, print_function

import tvm
import numpy as np

n = tvm.var("n")
A = tvm.placeholder(n, name='A')
B = tvm.placeholder(n, name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")

s = tvm.create_schedule(C.op)
s[C].bind(C.op.axis[0], tvm.thread_axis("threadIdx.x"))

print(tvm.lower(s, [A, B, C], simple_mode=True))

fadd_gl = tvm.build(s, [A, B, C], "opengl", name="myadd")
print("------opengl code------")
print(fadd_gl.imported_modules[0].get_source(fmt="gl"))

ctx = tvm.opengl(0)
n = 10
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
fadd_gl(a, b, c)

np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())

Clone this wiki locally