diff --git a/operatorspy/tests/matmul.py b/operatorspy/tests/matmul.py index 67daf48c..3876be41 100644 --- a/operatorspy/tests/matmul.py +++ b/operatorspy/tests/matmul.py @@ -79,6 +79,10 @@ def test( for i in range(NUM_PRERUN if PROFILE else 1): ans = matmul(c, beta, a, b, alpha) + + if torch_device == "npu": + torch.npu.synchronize() + if PROFILE: start_time = time.time() for i in range(NUM_ITERATIONS): @@ -86,7 +90,6 @@ def test( elapsed = (time.time() - start_time) / NUM_ITERATIONS print(f"pytorch time: {elapsed :6f}") - a_tensor = to_tensor(a, lib) b_tensor = to_tensor(b, lib) c_tensor = to_tensor(c, lib) @@ -283,6 +286,7 @@ def test_ascend(lib, test_cases): (1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None, torch.float32), (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1), torch.float16), (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1), torch.float32), + (1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1), torch.float16) ] args = get_args() lib = open_lib() diff --git a/src/devices/ascend/tensor_aclnn.cc b/src/devices/ascend/tensor_aclnn.cc index 556d57e2..c9319fb7 100644 --- a/src/devices/ascend/tensor_aclnn.cc +++ b/src/devices/ascend/tensor_aclnn.cc @@ -1,5 +1,6 @@ #include "tensor_aclnn.h" #include "../../ops/utils.h" +#include /// @brief Set aclnnTensorDescriptor from infiniopTensorDescriptor /// @param y infiniopTensorDescriptor @@ -34,16 +35,21 @@ infiniopStatus_t aclnnTensorDescriptor::fromInfiniOpTensorDescriptor(infiniopTen this->dataType = dt; this->format = format; + infiniopTensorDescriptor_t yOri; + CHECK_STATUS(inferOriginInfiniOpTensorDescriptor(y, &yOri), STATUS_SUCCESS); + // Infer continuous storageShape auto storageShape = new std::vector(ndim); for (uint64_t i = 0; i < ndim - 1; ++i) { - (*storageShape)[i] = ((*shape)[i] * (*strides)[i]) / - ((*shape)[i + 1] * (*strides)[i + 1]); + (*storageShape)[i] = ((yOri->shape)[i] * (yOri->strides)[i]) / + ((yOri->shape)[i + 1] * (yOri->strides)[i + 1]); } - (*storageShape)[ndim - 1] = (*shape)[ndim - 1]; + (*storageShape)[ndim - 1] = (yOri->shape)[ndim - 1]; this->storageShape = (*storageShape).data(); this->storageNdim = ndim; + CHECK_STATUS(infiniopDestroyTensorDescriptor(yOri), STATUS_SUCCESS); + return STATUS_SUCCESS; } @@ -70,10 +76,10 @@ infiniopStatus_t aclnnTensorDescriptor::createTensor() { } infiniopStatus_t aclnnTensorDescriptor::destroyTensor() { - auto status = aclDestroyTensor(this->t); - if (status != 0) { - return STATUS_EXECUTION_FAILED; - } + auto ret = aclDestroyTensor(this->t); + CHECK_RET(ret == ACL_SUCCESS, + LOG_PRINT("aclDesctroyTensor failed, ERROR: %d\n", ret); + return STATUS_EXECUTION_FAILED); t = nullptr; shape = nullptr; strides = nullptr; @@ -82,6 +88,39 @@ infiniopStatus_t aclnnTensorDescriptor::destroyTensor() { return STATUS_SUCCESS; } +infiniopStatus_t +aclnnTensorDescriptor::inferOriginInfiniOpTensorDescriptor(infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t *ori_ptr) { + auto shape = y->shape; + auto strides = y->strides; + auto ndim = y->ndim; + + std::vector indices(ndim); + for (uint64_t i = 0; i < ndim; ++i) { + indices[i] = i; + } + + std::sort(indices.begin(), indices.end(), [&](uint64_t a, uint64_t b) { + return strides[a] > strides[b]; + }); + + auto oriShape = new std::vector(ndim); + auto oriStrides = new std::vector(ndim); + for (uint64_t i = 0; i < ndim; ++i) { + (*oriShape)[i] = shape[indices[i]]; + (*oriStrides)[i] = strides[indices[i]]; + } + + auto status = infiniopCreateTensorDescriptor( + ori_ptr, + ndim, + (*oriShape).data(), + (*oriStrides).data(), + y->dt); + + return status; +} + aclnnTensorDescriptor::~aclnnTensorDescriptor() { if (this->t) { destroyTensor(); diff --git a/src/devices/ascend/tensor_aclnn.h b/src/devices/ascend/tensor_aclnn.h index 2042fd1c..d8d00858 100644 --- a/src/devices/ascend/tensor_aclnn.h +++ b/src/devices/ascend/tensor_aclnn.h @@ -2,6 +2,7 @@ #define __ACLNN_TENSOR__ #include "./common_ascend.h" +#include "tensor/tensor_descriptor.h" #include "operators.h" #include "tensor.h" #include @@ -27,6 +28,9 @@ struct aclnnTensorDescriptor { infiniopStatus_t fromInfiniOpTensorDescriptor(infiniopTensorDescriptor_t y_desc); infiniopStatus_t createTensor(); infiniopStatus_t destroyTensor(); + infiniopStatus_t + inferOriginInfiniOpTensorDescriptor(infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t *ori_ptr); ~aclnnTensorDescriptor(); char *toString(); diff --git a/src/ops/matmul/ascend/matmul_aclnn.cc b/src/ops/matmul/ascend/matmul_aclnn.cc index 65ad67c8..2d88f7cf 100644 --- a/src/ops/matmul/ascend/matmul_aclnn.cc +++ b/src/ops/matmul/ascend/matmul_aclnn.cc @@ -2,7 +2,7 @@ MatmulAclnnDescriptor::MatmulAclnnDescriptor(Device _device) { device = _device; - device_id = 0; + device_id = 0; executor = nullptr; info = nullptr; cDesc = new aclnnTensorDescriptor(); @@ -22,6 +22,9 @@ infiniopStatus_t aclnnCreateMatmulDescriptor(AscendHandle_t handle, infiniopTensorDescriptor_t b_desc, float beta, int8_t mt) { + if (c_desc->ndim == 3 && (alpha != 1.0 || beta != 0)) { + return STATUS_BAD_PARAM; + } *desc_ptr = new MatmulAclnnDescriptor(handle->device); (*desc_ptr)->device_id = handle->device_id; @@ -57,7 +60,7 @@ infiniopStatus_t aclnnCreateMatmulDescriptor(AscendHandle_t handle, aclTensor *tb = bDesc->t; aclnnStatus ret; - + if (b > 1) { // https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC3alpha003/apiref/aolapi/context/aclnnMatmul.md ret = aclnnMatmulGetWorkspaceSize(ta, @@ -72,8 +75,10 @@ infiniopStatus_t aclnnCreateMatmulDescriptor(AscendHandle_t handle, aclSetAclOpExecutorRepeatable(executor); } else { // Get transA and transB according strides - int64_t transA = aDesc->strides[aDesc->ndim - 1] == 1 ? 0 : 1; - int64_t transB = bDesc->strides[bDesc->ndim - 1] == 1 ? 0 : 1; + // int64_t transA = aDesc->strides[aDesc->ndim - 1] == 1 ? 0 : 1; + // int64_t transB = bDesc->strides[bDesc->ndim - 1] == 1 ? 0 : 1; + int64_t transA = 0; + int64_t transB = 0; // aclnnGemm support C = alpha * A @ B + beta * C // see https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC3alpha003/apiref/aolapi/context/aclnnGemm.md ret = aclnnGemmGetWorkspaceSize(ta, tb, tc, (*desc_ptr)->alpha, (*desc_ptr)->beta, transA, transB, tc,