-
Notifications
You must be signed in to change notification settings - Fork 275
Open
Description
Is your feature request related to a problem? Please describe.
MatrixExpr's dot operation is really slow when it's a large dataset.
Constant (with 200x200 size) @ MatrixExpr (with 200x200 size) will cost 6s.
Describe the solution you'd like
quicksum could optimize this problem.
MatrixExpr @ MatrixExpr doesn't have any performance gain. Because Expr * Expr is the bottleneck.
Additional context
It sppeds up 2.8x than before in the shape of (200, 200).
from time import time
import numpy as np
from pyscipopt import MatrixExpr, MatrixVariable, Model, quicksum
class SpeedMatrix(MatrixVariable):
def __array_ufunc__(self, ufunc, method, *args, **kwargs):
if ufunc in (np.matmul, np.dot) and method == "__call__":
a = _ensure_array(args[0])
b = _ensure_array(args[1])
if (a.dtype.kind in "fiub") ^ (b.dtype.kind in "fiub"):
return _vec_dot(a, b, (a.dtype.kind in "fiub"))
return super().__array_ufunc__(ufunc, method, *args, **kwargs)
@np.vectorize(otypes=[object], excluded=[2], signature="(m,n),(n,p)->(m,p)")
def _vec_dot(a, b, is_numeric):
return _core_dot(a, b) if is_numeric else _core_dot(b.T, a.T).T
def _core_dot(a, x):
a = np.ascontiguousarray(a)
m, n = a.shape
k = x.shape[1] if x.ndim > 1 else 1
res = np.empty((m, k), dtype=object)
for i in range(m):
row = a[i, :]
if (nonzer := np.flatnonzero(row)).size == 0:
res[i, :] = 0.0
continue
coeff = row[nonzer]
for j in range(k):
res[i, j] = quicksum(coeff * x[nonzer, j])
return res.view(MatrixExpr)
def _ensure_array(arg):
if isinstance(arg, (list, tuple)):
return np.asarray(arg)
elif isinstance(arg, np.ndarray):
return arg.view(np.ndarray)
return np.array(arg, dtype=object)
if __name__ == "__main__":
model = Model()
n = 200
x_raw = model.addMatrixVar((n, n))
x = x_raw.view(SpeedMatrix)
A = np.random.rand(n, n)
print(f"Testing SpeedMatrix Efficiency ({n}x{n})...")
t0 = time()
res_ax = A @ x
print(f"Const @ Var (Ax) Time : {time() - t0:.6f}s")
# Const @ Var (Ax) Time : 5.979419s
t1 = time()
res_xa = A @ x_raw
print(f"Const @ Raw Var (Ax) Time: {time() - t1:.6f}s")
# Const @ Raw Var (Ax) Time: 16.737834sMetadata
Metadata
Assignees
Labels
No labels