Skip to content

Step 1: Write a Matrix Class

tomstewart89 edited this page Jan 18, 2022 · 13 revisions

Writing a Matrix Class (6aeb24d)

Most non-linear solvers boil down to repeatedly solving systems of linear equations; and this one is no exception. So first things first, let's define a class to represent a 2D matrix. Just to be clear that's one of these guys:

If you like templates, then your first run at implementing such a class might look something like this:

template <int Rows, int Cols>
class Matrix
{
   public:
    std::array<double, Rows * Cols> storage;

    double& operator()(int i, int j = 0) { return storage[i * Cols + j]; 
};

Which is a pretty good start, using just that simple definition we can instantiate matrices of arbitrary sizes and access their elements like so:

Matrix<3,5> M;
M(0,3) = 500.0;

A Curiously Recurring Matrix Class

As we'll see as we get further into writing this solver, it's often helpful to define more abstract matrices; for example ones that return 0.0 for every element or even ones that represent a block of a larger matrix.

For that reason it'd be nice for that operator() above to be polymorphic i.e. to define matrices that do different things when operator() is called but nonetheless still have the look and feel of the matrix we defined above.

One straightforward way to do that is to use a virtual function but a more fun (type 2 fun that is) approach is to use an idiom called the Curiously Recurring Template Pattern (CRTP). CRTP lets us define polymorphic functions that are resolved at compile time rather than at runtime which cuts down the overhead of looking up a function in a vtable.

To incorporate CRTP into our Matrix we define a MatrixBase class like so (bear with me here):

template <typename DerivedType, int Rows, int Cols>
struct MatrixBase
{
   public:

    double &operator()(int i, int j = 0)
    {
        return static_cast<DerivedType *>(this)->operator()(i, j);
    }
};

And then we have our Matrix class inherit from that class like so:

template <int Rows, int Cols>
class Matrix : public MatrixBase<Matrix<Rows, Cols>, Rows, Cols>
{
   public:
    std::array<double, Rows * Cols> storage;

    double &operator()(int i, int j = 0) { return storage[i * Cols + j]; }
};

Note that the type passed as the DerivedType template parameter to MatrixBase is the Matrix itself! This, surprisingly is totally legal C++ code. Now thanks to inheritance, we can refer to a Matrix via a MatrixBase pointer or reference and when that MatrixBase needs to produce one of its elements, it can just downcast itself to its DerivedType and call its operator().

The upshot for the Matrix class is basically nothing, it still works just like this:

Matrix<3,5> M;
M(0,3) = 500.0;

But! Now it's fairly easy to define matrices that do other things than return elements from an array. For example here's a matrix type Zeros that (aptly) just returns zero for every element:

template <int Rows, int Cols = 1>
class Zeros : public MatrixBase<Zeros<Rows, Cols>, Rows, Cols>
{
   public:
    double operator()(int i, int j = 0) const { return 0.0; }
};

Similarly, here's a matrix type that allows us to manipulate the elements within a submatrix of some parent matrix. Note that Blocks share the memory of their parent so any modifications in the block will also be reflected in the parent.

template <typename RefType, int Rows, int Cols>
class Block : public MatrixBase<Block<RefType, Rows, Cols>, Rows, Cols>
{
    RefType &parent_;
    const int row_offset_;
    const int col_offset_;

   public:
    explicit Block(RefType &parent, int row_offset, int col_offset)
        : parent_(parent), row_offset_(row_offset), col_offset_(col_offset) {}

    double &operator()(int i, int j) { return parent_(i + row_offset_, j + col_offset_); }
};

Matrixs, Zeros and Blocks can be referred to via a reference to MatrixBase so this means that we're free to define functions that accept two MatrixBases (and say add them together) and none will be the wiser as to the fact that the second MatrixBase& is actually a Zeros&. In the next section we'll implement a few such functions so that we can start to do more interesting things with matrices than just get their elements.

Defining Some Matrix Operations

At this point we have some nice Matrix-like classes, but no way to do anything with them. Let's change that by defining operator overloads so that we can add, subtract multiply etc matrices like so:

Matrix<3,5> A;
Matrix<3,5> B;
Matrix<3,5> C = A + B * 2.0;

It's tempting to go crazy and define all the operators but since this project is meant to be tinyoptimizer, let's abstain and just define the functions that we'll actually need to implement for the solver. As it turns out those operators are as follows:

  • Matrix Multiplication
  • Matrix Addition
  • Matrix Subtraction
  • Elementwise Multiplication by a Scalar
  • Elementwise Subtraction by a Scalar
  • Euclidean Norm

Defining these operators is pretty repetitive, so let's just walk through matrix multiplication and leave the rest as an exercise for the reader. I'll assume you're familiar with matrix multiplication, if not then I'm sure wikipedia, will do a better job of explaining it than me.

In any case, in code, we can define an operator* between two MatrixBases as follows. Let's start with the prototype:

template <typename AType, typename BType, int ARows, int ACols, int BCols>
Matrix<ARows, BCols> operator*(const MatrixBase<AType, ARows, ACols> &A,
                               const MatrixBase<BType, ACols, BCols> &B);

The main things to note here are that this function deals in MatrixBases rather than any specific matrix type. Instead, we let the function take the template arguments AType and BType and the compiler will figure out which matrices we're using at compile time.

Secondly, note that the operator isn't defined when the number of columns of A don't match the number of rows of B. So if instead we tried to multiply MatrixBases with incompatible dimensions we'll get some cryptic compiler error to the effect of:

error: no match for ‘operator*’ (operand types are ‘Matrix<3, 5>’ and ‘Matrix<4, 7>’)

From there the implementation is pretty straightforward, we initialize the output Matrix C with zeros then fill in its (i,j)th element with the dot product between row i of A and column j of B:

{
    Matrix<ARows, BCols> C = Zeros<ARows, BCols>();

    for (int i = 0; i < ARows; ++i)
    {
        for (int j = 0; j < BCols; ++j)
        {
            for (int k = 0; k < ACols; ++k)
            {
                C(i, j) += A(i, k) * B(k, j);
            }
        }
    }
    return C;
}

With this definition defined we can now declare two matrices, (even those abstract ones we described earlier) and multiply them together like so:

Matrix<5,5> A;
Block<Matrix<5, 5>, 5, 2> B(A, 0, 3); // take the last two columns of `A` and call it `B`
Matrix<5,2> C = A * B;

And that's it! The other operators are basically more of the same, if you want to see them in action check out the tests for the matrix related code here.

Clone this wiki locally