Skip to content

a-rahimi/tensor-compiler

Repository files navigation

The Two Week Tensor Compiler

This is a compiler for a Pytorch-like language. You consruct a program in this language by stringing together calls to a library. The compiler then covnerts this program to machine code, running optimization passes along the way.

For example, you might write a Python program like this:

import twcompiler.tensor as twct

def func(a: twct.Tensor, b: twct.Tensor, c: twct.Tensor) -> twct.Tensor:
    return twct.Softmax(twct.LayerNorm(a @ b + c))

a = twct.Placeholder(30, 40, dtype=int8)
b = twct.Placeholder(40, 30, dtype=int8)
c = twct.Placeholder(30, dtype=int8)
r = func(a,b,c)

opcodes = twc.lower(r)

runtime = twct.Runtime()
result_tensor = runtime.execute(
    opcodes, twct.random(30, 40), twct.random(40, 30), twct.random(30))

This program builds up a program and stores it in the Python variable r. The tensors a,b, and c are placeholder objects that describe the shape and type of the tensors involved. The object r is a TensorOp object that describes the sequence of operations to perform on the input tensors. The call twc.lower compiles r into a set of machine instructions for the target device, presumably a neural net accelerator of some kind, and returns the set of machine istructions, which can later be run on tensors.

This program above resembles a Pytorch or JAX program. The compilation step also resembles what happens with these packages. But the resemblance ends there. Under the hood, Pytorch and JAX compile tensor operations down to a narrower tensor language, like ONNX or XLA. Then a separate compiler translates these to machine instructions. For example, the operations might get translated to Triton, and the Triton compiler might then lower the tensor operations into tile operations and then into Cuda code. Then the Cuda compiler generates PTX. Here, one compiler handles the entire chain.

Why Do This

I'm trying to ramp up quickly on tensor compilers. I considered ramping up with LLVM, which provides similar abstractions as this. After reading the MLIR docs, I concluded two things 1) MLIR is the right framework build such a compiler, 2) lowering to machine code a on new architecxture would require rolling my own backend for LLVM, and reading a ton of documentation for MLIR. 3) I'd learn a lot more rolling my own in two weeks, and I could start over again with MLIR.

The Layers of Two Week Compiler

The compiler offers several languages, each designed to support a higher language.

  • Tensor Lang: This language resembles familiar tensor libraries like JAX and Pytorch (minus the auto differentiation). It provides a library of tensor operations. These operations involve dynamically creating new tensors. The compiler automatically injects code to eliminate temporary tensors by using in-place operations where possible, and automatically deallocating a tensor when it can prove the tensor is no longer needed.

  • Vector Lang: This language provides operations on dynamically allocated one-dimensional vectors whose length and type are known at compile time. Such vectors are primarily used to describe the shape and strides of tensors. It provides facilities to dynamically allocate, copy, modify, and automatically garbage-collect these vectors. The language is written in Scalar Lang, described below.

  • Tile Schedules: This language decomposes tensor operations into a sequence of tile operations. It provides convenient looping mechanisms to express common scheduling patterns, like the order in which tiles should be computed, the dependencies between the tiles, including which operations can be parallelized, and convenient ways to handle boundary conditions and padding in tiles.

  • Scalar Lang: The Tile Language is written in Scalar Lang, a functional programming language with lazy semantics. It provides a rudimentary set of operations for arithmetic, memory access, flow control, and looping. Both the schedules for tiles and the operations on tiles are written in Scalar Lang. Scalar Language is designed to be easy to lower to machine code.

  • Runtime simulator: To test and debug the compiler, I use a simple opcode execution engine. It mimics a real instruction set machine with memory.

Because the entire toolchain is written in Python, it is easy to mix Scalar Language throughout the code, including the end user code. Loop fusion is a matter of fusing the loops of the Tile Schedule by modifying the scalar lang compute graph.

Scalar Lang

Scalar lang is a small layer on top of machine code. It is similar to Static Single Assignment in that it provides the abstraction of an infinite register machine. It offers scalar arithmetic and memory movement operations, and hooks for vector operation. It also offers an If statement and a Loop structure.

Here's a simple Scalar Lang program that adds two numbers:

r = sl.Halt(sl.LoadImmediate(10) + sl.LoadImmediate(20))

You can compile it with opcodes = sl.lower(r). The resulting program just exits with status code 20.

Here is a more complicated program:

r = sl.Halt(sum(sl.LoadImmediate(i) for i in range(20)))

This program contains a for loop and a call to the sum() function. These are python constructs, not Scalar Lang constructs, but we can use them here because Scalar Lang programs are written in Python. This program is comletely equivalent to

r = sl.Halt(sum(sl.LoadImmediate(0) + ... + sl.LoadImmediate(19)))

Scalar Lang supports operations is a functional programming language, so we have to jump through minor hoops to use operations that have side effects, like operations that store values in memory:

memory = sl.InitialMemory()
memory = sl.Store(memory, 0xA0, 10)
memory = sl.Store(memory, 0xA1, 20)
v = sl.Load(memory, 0xA0) + sl.Load(memory, 0xA1)
r = sl.Store(memory, 0xA2, v)

This program stores the value 10 at meory address 0xA0, and the value 20 at memory address 0xA1. It then reads these values and adds them, storing the result at memory address 0xA2. We avoided the need for side effects by introducing a memory object that appears to keep a log of memory operations. In reality, this object is compiled out of existence, but we're having to keep it around in the program.

This way of representing side effects gives us lazy semantics. Consdier this variant of the above program:

memory = sl.InitialMemory()
memory = sl.Store(memory, 0xA0, 10)
memory_new = sl.Store(memory, 0xA1, 20)
memory_new = sl.Store(memory_new, 0xA1, -20)
v = sl.Load(memory, 0xA0) + sl.Load(memory, 0xA1)
memory = sl.Store(memory, 0xA2, v)

We've defined two memory objects here. If you ignore the names of these objects and treat the program as a sequencial set of store and load operations, you'd expect to see the value -10 at 0xA2 (we store 10 at 0xA0, then store 20 at 0xA1 and immediately overwrite it with -20, compute the sum of 10 and -20, and store the result at 0xA2). But in fact, the value stored at 0xA2 is 10. That's because the -20 is never stored in 0xA1, since memory_new doesn't appear in subsequent operations.

These lazy semantics become more important in If statements and Loops. The If statement works like the C ternary operator:

phi = sl.If(sl.LoadImmediate(1), sl.LoadImmediate(2), sl.LoadImmediate(3))
r = sl.Halt(phi)

The first argument to If is a condition. The second argument is the result of the operator to return if the codntion is true. The third argument is the result of the operator to return if the condition is false. This programs ends with status code 2.

Notice first that second and third arguments are operations, not blocks of code. Because Scalar Lang programs are wrtitten in Python, the following two programs are completely equivalent:

s.If(
    sl.LoadImmediate(1)-sl.LoadImmediate(1),
    sl.LoadImmediate(2) - sl.LoadImmediate(3),
    sl.Load(memory, 0xA0)
)

and

then = sl.LoadImmediate(2) 
then = then - sl.LoadImmediate(3)
otherwise = sl.Load(memory, 0xA0)
s.If(sl.LoadImmediate(1)-sl.LoadImmediate(1), then, otherwise)

Scalar Lang is obvlivious to how the expressions are constructed. It just cares about the graph of operations. Now consider two branches that have side effects:

memory = sl.InitialMemory()
memory_then = sl.Store(memory, 0xA1, 1)
memory_otherwise = sl.Store(memory, 0xA2, 2)
phi = sl.If(
   sl.LoadImmediate(0), memory_then, memory_otherwise
)

Here, each branch writes a different value to memory, but because we've written the memory operations in a functional form, there is no side effect o be concerned about.

To make the whole thing more concrete, here are the opcodes the compiler generates for the above program:

  0             LoadImmediate(dst=r0, value=0)
  1             Negate(dst=r1, src=r0)
  2             JumpIf(condition=r1, target_program_counter=8)
  3             LoadImmediate(dst=r2, value=161)
  4             LoadImmediate(dst=r3, value=1)
  5             Store(dst_address=r2, src=r3)
  6             Move(dst=r2, src=r3)
  7             Jump(target_program_counter=12)
  8             LoadImmediate(dst=r3, value=162)
  9             LoadImmediate(dst=r4, value=2)
 10             Store(dst_address=r3, src=r4)
 11             Move(dst=r2, src=r4)
 12             Halt(status=r2)

Loops work in a similar way.

iteration, sum = sl.LoopVariables(-10, -20)
iteration, sum = sl.While(
    iteration < 5, (iteration, iteration + 1), (sum, sum + 3)
)
return sl.Halt(iteration)

The first line initializes the loop variables, assigning them the values -10 and -20 respectively. The While statement checks the condition (passed as the first argument), and assigns the new values (pass as the second value in each tuple) to the loop variables (passed as the first value in each tuple). It repeats this process until the condition is false. Here is a more complicated loop that computes the Fibonacci sequence and stores it in memory. It uses a sugared form of While called For:

BASE = 0xA0
memory = sl.InitialMemory()
memory = sl.Store(memory, BASE, 1)
memory = sl.Store(memory, BASE + 1, 2)

i, memory = sl.For(
    # A tuple of loop variables. These variables are reassigned at each
    # iteration of the loop and returned when the loop condition fails.
    (2, memory),

    # The stopping condition. The loop stops when this evaluates to false.
    lambda i, _: i < 10,

    # The loop body. It returns a tuple that has as many elements as there are
    # loop variables.
    lambda i, memory: (
        i + 1,
        sl.Store(
            memory,
            BASE + i,
            memory[BASE + i - 2] + memory[BASE + i - 1],
        ),
    ),
)

return sl.Halt(i)

About

tensor compiler

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages