diff --git a/baspacho/baspacho/CMakeLists.txt b/baspacho/baspacho/CMakeLists.txt index 985ebaa..fde0bdf 100644 --- a/baspacho/baspacho/CMakeLists.txt +++ b/baspacho/baspacho/CMakeLists.txt @@ -2,6 +2,7 @@ set(BaSpaCho_sources CoalescedBlockMatrix.cpp ComputationModel.cpp + CsrSolver.cpp EliminationTree.cpp MatOpsFast.cpp MatOpsRef.cpp diff --git a/baspacho/baspacho/CsrSolver.cpp b/baspacho/baspacho/CsrSolver.cpp new file mode 100644 index 0000000..203dc90 --- /dev/null +++ b/baspacho/baspacho/CsrSolver.cpp @@ -0,0 +1,202 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "baspacho/baspacho/CsrSolver.h" +#include +#include + +namespace BaSpaCho { + +namespace { + +// Helper to convert indices from void* based on IndexType +template +void copyIndices(std::vector& out, const void* src, int64_t count, IndexType indexType, + int64_t baseAdjust) { + out.resize(count); + if (indexType == INDEX_INT32) { + const int32_t* srcTyped = static_cast(src); + for (int64_t i = 0; i < count; i++) { + out[i] = static_cast(srcTyped[i]) - baseAdjust; + } + } else { + const int64_t* srcTyped = static_cast(src); + for (int64_t i = 0; i < count; i++) { + out[i] = static_cast(srcTyped[i]) - baseAdjust; + } + } +} + +// Validate the block CSR descriptor +void validateDescriptor(const BlockCsrDescriptor& desc) { + // Check for null pointers + if (desc.numBlocks > 0 && desc.rowStart == nullptr) { + throw std::invalid_argument("BlockCsrDescriptor: rowStart is null"); + } + if (desc.numBlockNonzeros > 0 && desc.colIndices == nullptr) { + throw std::invalid_argument("BlockCsrDescriptor: colIndices is null"); + } + if (desc.numBlocks > 0 && desc.blockSizes == nullptr) { + throw std::invalid_argument("BlockCsrDescriptor: blockSizes is null"); + } + + // Check non-negative counts + if (desc.numBlocks < 0) { + throw std::invalid_argument("BlockCsrDescriptor: numBlocks must be non-negative"); + } + if (desc.numBlockNonzeros < 0) { + throw std::invalid_argument("BlockCsrDescriptor: numBlockNonzeros must be non-negative"); + } + + // Validate matrix type/view combination + validateMatrixTypeView(desc.mtype, desc.mview); + + // Check block sizes are positive + for (int64_t i = 0; i < desc.numBlocks; i++) { + if (desc.blockSizes[i] <= 0) { + throw std::invalid_argument("BlockCsrDescriptor: blockSizes[" + std::to_string(i) + + "] must be positive, got " + std::to_string(desc.blockSizes[i])); + } + } +} + +} // namespace + +SparseStructure blockCsrToSparseStructure(const BlockCsrDescriptor& desc) { + validateDescriptor(desc); + + if (desc.numBlocks == 0) { + return SparseStructure({0}, {}); + } + + // Determine base adjustment for 1-based indexing + int64_t baseAdjust = (desc.indexBase == BASE_ONE) ? 1 : 0; + + // Copy row pointers (ptrs) + std::vector ptrs; + copyIndices(ptrs, desc.rowStart, desc.numBlocks + 1, desc.indexType, baseAdjust); + + // Validate row pointers + if (ptrs[0] != 0) { + throw std::invalid_argument("BlockCsrDescriptor: rowStart[0] must be 0 (after base adjustment)"); + } + if (ptrs[desc.numBlocks] != desc.numBlockNonzeros) { + throw std::invalid_argument( + "BlockCsrDescriptor: rowStart[numBlocks] must equal numBlockNonzeros"); + } + + // Copy column indices (inds) + std::vector inds; + copyIndices(inds, desc.colIndices, desc.numBlockNonzeros, desc.indexType, baseAdjust); + + // Validate column indices + for (int64_t i = 0; i < desc.numBlockNonzeros; i++) { + if (inds[i] < 0 || inds[i] >= desc.numBlocks) { + throw std::invalid_argument("BlockCsrDescriptor: colIndices[" + std::to_string(i) + + "] out of range"); + } + } + + // Create SparseStructure (currently in CSR format) + SparseStructure ss(std::move(ptrs), std::move(inds)); + + // Handle matrix view: + // - MVIEW_LOWER: Already in correct format for BaSpaCho (lower triangular CSR) + // - MVIEW_UPPER: Need to transpose to get lower triangular + // - MVIEW_FULL: For SPD, we only need lower triangle, so clear upper + if (desc.mview == MVIEW_UPPER) { + // Transpose converts upper CSR to lower CSR + ss = ss.transpose(); + } else if (desc.mview == MVIEW_FULL) { + // For full matrix, extract lower triangle only + // clear(false) clears upper half, keeping lower + ss = ss.clear(false); + } + + return ss; +} + +std::vector getParamSizes(const BlockCsrDescriptor& desc) { + if (desc.numBlocks == 0) { + return {}; + } + return std::vector(desc.blockSizes, desc.blockSizes + desc.numBlocks); +} + +SolverPtr createSolverFromBlockCsr(const Settings& settings, const BlockCsrDescriptor& desc, + const std::vector& sparseElimRanges, + const std::unordered_set& elimLastIds) { + // Convert block CSR to SparseStructure + SparseStructure ss = blockCsrToSparseStructure(desc); + + // Get parameter sizes + std::vector paramSizes = getParamSizes(desc); + + // Delegate to existing createSolver + return createSolver(settings, paramSizes, ss, sparseElimRanges, elimLastIds); +} + +// Template instantiations for createSolverFromBlockCsrWithValues +template +SolverPtr createSolverFromBlockCsrWithValues(const Settings& settings, + const BlockCsrDescriptor& desc, const T* values, + std::vector& outData, + const std::vector& sparseElimRanges) { + // Create solver (structure only) + SolverPtr solver = createSolverFromBlockCsr(settings, desc, sparseElimRanges, {}); + + // Resize output data buffer + outData.resize(solver->dataSize()); + + // Zero-initialize (fill-in entries need to be zero) + std::fill(outData.begin(), outData.end(), T(0)); + + // Convert row pointers to int64_t for loadFromCsr + int64_t baseAdjust = (desc.indexBase == BASE_ONE) ? 1 : 0; + std::vector rowStart(desc.numBlocks + 1); + std::vector colIndices(desc.numBlockNonzeros); + + if (desc.indexType == INDEX_INT32) { + const int32_t* rowStartSrc = static_cast(desc.rowStart); + const int32_t* colIndicesSrc = static_cast(desc.colIndices); + for (int64_t i = 0; i <= desc.numBlocks; i++) { + rowStart[i] = static_cast(rowStartSrc[i]) - baseAdjust; + } + for (int64_t i = 0; i < desc.numBlockNonzeros; i++) { + colIndices[i] = static_cast(colIndicesSrc[i]) - baseAdjust; + } + } else { + const int64_t* rowStartSrc = static_cast(desc.rowStart); + const int64_t* colIndicesSrc = static_cast(desc.colIndices); + for (int64_t i = 0; i <= desc.numBlocks; i++) { + rowStart[i] = rowStartSrc[i] - baseAdjust; + } + for (int64_t i = 0; i < desc.numBlockNonzeros; i++) { + colIndices[i] = colIndicesSrc[i] - baseAdjust; + } + } + + // Load values from CSR format into internal format + solver->loadFromCsr(rowStart.data(), colIndices.data(), desc.blockSizes, values, outData.data()); + + return solver; +} + +// Explicit template instantiations +template SolverPtr createSolverFromBlockCsrWithValues(const Settings& settings, + const BlockCsrDescriptor& desc, + const float* values, + std::vector& outData, + const std::vector& sparseElimRanges); + +template SolverPtr createSolverFromBlockCsrWithValues(const Settings& settings, + const BlockCsrDescriptor& desc, + const double* values, + std::vector& outData, + const std::vector& sparseElimRanges); + +} // namespace BaSpaCho diff --git a/baspacho/baspacho/CsrSolver.h b/baspacho/baspacho/CsrSolver.h new file mode 100644 index 0000000..0e34d5e --- /dev/null +++ b/baspacho/baspacho/CsrSolver.h @@ -0,0 +1,113 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include "baspacho/baspacho/CsrTypes.h" +#include "baspacho/baspacho/Solver.h" + +namespace BaSpaCho { + +/** + * Block CSR matrix descriptor - holds structure metadata (no numeric data). + * Modeled after cuDSS cudssMatrixCreateCsr parameters. + * + * This describes a block-sparse matrix in CSR format where: + * - rowStart[i] gives the index in colIndices where row i's blocks begin + * - colIndices[rowStart[i]..rowStart[i+1]) are the column indices of blocks in row i + * - blockSizes[i] gives the dimension of the i-th parameter block + * + * The total matrix dimension is sum(blockSizes). + */ +struct BlockCsrDescriptor { + int64_t numBlocks; // Number of block rows/cols (square matrix) + int64_t numBlockNonzeros; // Number of non-zero blocks + const void* rowStart; // Row start pointers [numBlocks+1], type per indexType + const void* colIndices; // Column indices [numBlockNonzeros], type per indexType + const int64_t* blockSizes; // Size of each block [numBlocks] + IndexType indexType; // INT32 or INT64 + MatrixType mtype; // GENERAL, SYMMETRIC, SPD (only SPD supported) + MatrixView mview; // FULL, LOWER, UPPER + IndexBase indexBase; // ZERO or ONE based + + BlockCsrDescriptor() + : numBlocks(0), + numBlockNonzeros(0), + rowStart(nullptr), + colIndices(nullptr), + blockSizes(nullptr), + indexType(INDEX_INT64), + mtype(MTYPE_SPD), + mview(MVIEW_LOWER), + indexBase(BASE_ZERO) {} +}; + +/** + * Create a solver from block-level CSR structure. + * + * This is the primary cuDSS-style interface for block CSR matrices. + * The descriptor provides only the sparsity structure; numeric values + * are loaded separately via Solver::loadFromCsr() or the accessor. + * + * @param settings Solver settings (backend, threading, fill policy) + * @param desc Block CSR descriptor (structure only) + * @param sparseElimRanges Optional ranges for sparse elimination optimization + * @param elimLastIds Optional IDs to keep at end for partial factorization + * @return Unique pointer to solver + * + * @throws std::invalid_argument if desc has invalid parameters + */ +SolverPtr createSolverFromBlockCsr(const Settings& settings, const BlockCsrDescriptor& desc, + const std::vector& sparseElimRanges = {}, + const std::unordered_set& elimLastIds = {}); + +/** + * Create a solver from block-level CSR with values preloaded. + * + * Convenience function that creates solver and loads initial values. + * The values array should contain dense block data in CSR order: + * - Blocks are in row-major order within each block + * - Blocks appear in the order specified by rowStart/colIndices + * + * @param settings Solver settings + * @param desc Block CSR descriptor (structure only) + * @param values Numeric values for all blocks (row-major within each block) + * @param outData Output data buffer (will be resized to solver.dataSize()) + * @param sparseElimRanges Optional sparse elimination ranges + * @return Unique pointer to solver + * + * @throws std::invalid_argument if desc has invalid parameters + */ +template +SolverPtr createSolverFromBlockCsrWithValues(const Settings& settings, + const BlockCsrDescriptor& desc, const T* values, + std::vector& outData, + const std::vector& sparseElimRanges = {}); + +/** + * Convert block CSR descriptor to SparseStructure. + * + * Internal helper function that converts the CSR format to BaSpaCho's + * internal SparseStructure representation. + * + * @param desc Block CSR descriptor + * @return SparseStructure in lower triangular CSR format + */ +SparseStructure blockCsrToSparseStructure(const BlockCsrDescriptor& desc); + +/** + * Get parameter sizes from block CSR descriptor. + * + * @param desc Block CSR descriptor + * @return Vector of block sizes + */ +std::vector getParamSizes(const BlockCsrDescriptor& desc); + +} // namespace BaSpaCho diff --git a/baspacho/baspacho/CsrTypes.h b/baspacho/baspacho/CsrTypes.h new file mode 100644 index 0000000..75d291c --- /dev/null +++ b/baspacho/baspacho/CsrTypes.h @@ -0,0 +1,102 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace BaSpaCho { + +/** + * Matrix mathematical properties - determines factorization algorithm. + * Modeled after cuDSS cudssMatrixType_t. + */ +enum MatrixType { + MTYPE_GENERAL, // General matrix (LDU factorization) - NOT YET SUPPORTED + MTYPE_SYMMETRIC, // Symmetric matrix (LDL^T factorization) - NOT YET SUPPORTED + MTYPE_SPD // Symmetric positive-definite (Cholesky LL^T factorization) +}; + +/** + * Matrix data view - specifies which portion of matrix is provided. + * For symmetric/SPD matrices, only one triangle needs to be stored. + * Modeled after cuDSS cudssMatrixViewType_t. + */ +enum MatrixView { + MVIEW_FULL, // Full matrix stored (both triangles) + MVIEW_LOWER, // Only lower triangle stored (BaSpaCho's native format) + MVIEW_UPPER // Only upper triangle stored +}; + +/** + * Index base for sparse matrix arrays. + * Modeled after cuDSS cudssIndexBase_t. + */ +enum IndexBase { + BASE_ZERO, // Zero-based indexing (C-style, default) + BASE_ONE // One-based indexing (Fortran-style) +}; + +/** + * Index type for row/column pointer arrays. + */ +enum IndexType { + INDEX_INT32, // 32-bit indices + INDEX_INT64 // 64-bit indices +}; + +/** + * Convert MatrixType to string for error messages. + */ +inline const char* matrixTypeToString(MatrixType mtype) { + switch (mtype) { + case MTYPE_GENERAL: + return "GENERAL"; + case MTYPE_SYMMETRIC: + return "SYMMETRIC"; + case MTYPE_SPD: + return "SPD"; + default: + return "UNKNOWN"; + } +} + +/** + * Convert MatrixView to string for error messages. + */ +inline const char* matrixViewToString(MatrixView mview) { + switch (mview) { + case MVIEW_FULL: + return "FULL"; + case MVIEW_LOWER: + return "LOWER"; + case MVIEW_UPPER: + return "UPPER"; + default: + return "UNKNOWN"; + } +} + +/** + * Validate matrix type/view combination for BaSpaCho. + * Throws std::invalid_argument if unsupported. + */ +inline void validateMatrixTypeView(MatrixType mtype, MatrixView mview) { + // BaSpaCho only supports SPD matrices currently + if (mtype != MTYPE_SPD) { + throw std::invalid_argument(std::string("BaSpaCho only supports MTYPE_SPD matrices, got ") + + matrixTypeToString(mtype)); + } + + // For SPD, FULL view is redundant (symmetric), we accept but will use lower + // LOWER and UPPER are both acceptable + (void)mview; // Currently all views are acceptable for SPD +} + +} // namespace BaSpaCho diff --git a/baspacho/baspacho/Solver.cpp b/baspacho/baspacho/Solver.cpp index bfc1e2c..24fd8e3 100644 --- a/baspacho/baspacho/Solver.cpp +++ b/baspacho/baspacho/Solver.cpp @@ -751,4 +751,95 @@ SolverPtr createSolver(const Settings& settings, const std::vector& par settings.addFillPolicy == AddFillForAutoElims ? fullSparseElimEnd : paramSize.size())); } +template +void Solver::loadFromCsr(const int64_t* csrRowStart, const int64_t* csrColInds, + const int64_t* blockSizes, const T* csrValues, T* data) const { + // Get the accessor for mapping block positions + // The accessor takes original (unpermuted) indices and handles permutation internally + auto acc = accessor(); + + int64_t numBlocks = permutation.size(); + int64_t valOffset = 0; // Current offset in csrValues + + // Iterate through CSR structure (original ordering) + for (int64_t origRow = 0; origRow < numBlocks; origRow++) { + int64_t rowSize = blockSizes[origRow]; + + for (int64_t ptr = csrRowStart[origRow]; ptr < csrRowStart[origRow + 1]; ptr++) { + int64_t origCol = csrColInds[ptr]; + int64_t colSize = blockSizes[origCol]; + int64_t blockElements = rowSize * colSize; + + // Get internal block position - accessor handles permutation and returns flip flag + auto [offset, stride, flipped] = acc.blockOffset(origRow, origCol); + + // Copy values from CSR to internal format + // CSR is row-major within blocks + // When flipped, the block is stored transposed internally + for (int64_t r = 0; r < rowSize; r++) { + for (int64_t c = 0; c < colSize; c++) { + if (flipped) { + // Block is transposed in internal storage + data[offset + c * stride + r] = csrValues[valOffset + r * colSize + c]; + } else { + data[offset + r * stride + c] = csrValues[valOffset + r * colSize + c]; + } + } + } + + valOffset += blockElements; + } + } +} + +template +void Solver::extractToCsr(const int64_t* csrRowStart, const int64_t* csrColInds, + const int64_t* blockSizes, const T* data, T* csrValues) const { + // Get the accessor for mapping block positions + // The accessor takes original (unpermuted) indices and handles permutation internally + auto acc = accessor(); + + int64_t numBlocks = permutation.size(); + int64_t valOffset = 0; // Current offset in csrValues + + // Iterate through CSR structure (original ordering) + for (int64_t origRow = 0; origRow < numBlocks; origRow++) { + int64_t rowSize = blockSizes[origRow]; + + for (int64_t ptr = csrRowStart[origRow]; ptr < csrRowStart[origRow + 1]; ptr++) { + int64_t origCol = csrColInds[ptr]; + int64_t colSize = blockSizes[origCol]; + int64_t blockElements = rowSize * colSize; + + // Get internal block position - accessor handles permutation and returns flip flag + auto [offset, stride, flipped] = acc.blockOffset(origRow, origCol); + + // Copy values from internal format to CSR + // When flipped, the block is stored transposed internally + for (int64_t r = 0; r < rowSize; r++) { + for (int64_t c = 0; c < colSize; c++) { + if (flipped) { + // Block is transposed in internal storage + csrValues[valOffset + r * colSize + c] = data[offset + c * stride + r]; + } else { + csrValues[valOffset + r * colSize + c] = data[offset + r * stride + c]; + } + } + } + + valOffset += blockElements; + } + } +} + +// Explicit template instantiations +template void Solver::loadFromCsr(const int64_t*, const int64_t*, const int64_t*, + const float*, float*) const; +template void Solver::loadFromCsr(const int64_t*, const int64_t*, const int64_t*, + const double*, double*) const; +template void Solver::extractToCsr(const int64_t*, const int64_t*, const int64_t*, + const float*, float*) const; +template void Solver::extractToCsr(const int64_t*, const int64_t*, const int64_t*, + const double*, double*) const; + } // end namespace BaSpaCho diff --git a/baspacho/baspacho/Solver.h b/baspacho/baspacho/Solver.h index cffbcf5..41b1fc3 100644 --- a/baspacho/baspacho/Solver.h +++ b/baspacho/baspacho/Solver.h @@ -10,6 +10,7 @@ #include #include #include "baspacho/baspacho/CoalescedBlockMatrix.h" +#include "baspacho/baspacho/CsrTypes.h" #include "baspacho/baspacho/MatOps.h" #include "baspacho/baspacho/SparseStructure.h" @@ -144,6 +145,38 @@ class Solver { return *elimCtxs[i]; } + /** + * Load values from CSR format into internal data buffer. + * + * Maps block values from CSR order (row-major within blocks, blocks in + * CSR traversal order) to BaSpaCho's internal coalesced format. + * + * @param csrRowStart CSR row pointers [numBlocks+1] + * @param csrColInds CSR column indices [numBlockNonzeros] + * @param blockSizes Size of each block [numBlocks] + * @param csrValues Values in CSR order (row-major within blocks) + * @param data Output data buffer (must be sized to dataSize()) + */ + template + void loadFromCsr(const int64_t* csrRowStart, const int64_t* csrColInds, + const int64_t* blockSizes, const T* csrValues, T* data) const; + + /** + * Extract values to CSR format from internal data buffer. + * + * Inverse of loadFromCsr - extracts block values from internal format + * to CSR order. + * + * @param csrRowStart CSR row pointers [numBlocks+1] + * @param csrColInds CSR column indices [numBlockNonzeros] + * @param blockSizes Size of each block [numBlocks] + * @param data Input data buffer + * @param csrValues Output CSR values (must be pre-sized) + */ + template + void extractToCsr(const int64_t* csrRowStart, const int64_t* csrColInds, + const int64_t* blockSizes, const T* data, T* csrValues) const; + private: void initElimination(); diff --git a/baspacho/tests/CMakeLists.txt b/baspacho/tests/CMakeLists.txt index f9227cb..abc8eda 100644 --- a/baspacho/tests/CMakeLists.txt +++ b/baspacho/tests/CMakeLists.txt @@ -2,6 +2,7 @@ add_baspacho_test(AccessorTest AccessorTest.cpp) add_baspacho_test(CoalescedBlockMatrixTest CoalescedBlockMatrixTest.cpp) add_baspacho_test(CreateSolverTest CreateSolverTest.cpp) +add_baspacho_test(CsrSolverTest CsrSolverTest.cpp) add_baspacho_test(SparseStructureTest SparseStructureTest.cpp) add_baspacho_test(EliminationTreeTest EliminationTreeTest.cpp) add_baspacho_test(FactorTest FactorTest.cpp) diff --git a/baspacho/tests/CsrSolverTest.cpp b/baspacho/tests/CsrSolverTest.cpp new file mode 100644 index 0000000..9777b99 --- /dev/null +++ b/baspacho/tests/CsrSolverTest.cpp @@ -0,0 +1,316 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include "baspacho/baspacho/CsrSolver.h" +#include "baspacho/baspacho/Solver.h" +#include "baspacho/baspacho/SparseStructure.h" +#include "baspacho/testing/TestingUtils.h" + +using namespace BaSpaCho; +using namespace ::BaSpaCho::testing_utils; +using namespace std; +using namespace ::testing; + +template +using Matrix = Eigen::Matrix; + +template +struct Epsilon; +template <> +struct Epsilon { + static constexpr double value = 1e-10; + static constexpr double value2 = 1e-8; +}; +template <> +struct Epsilon { + static constexpr float value = 1e-5; + static constexpr float value2 = 1e-4; +}; + +// Helper to convert column blocks to block CSR format +void columnsToCsr(const vector>& colBlocks, vector& rowStart, + vector& colIndices) { + int64_t numBlocks = colBlocks.size(); + rowStart.resize(numBlocks + 1); + colIndices.clear(); + + // Convert CSC (colBlocks) to CSR + // First, collect all (row, col) pairs + vector> rowEntries(numBlocks); + for (int64_t col = 0; col < numBlocks; col++) { + for (int64_t row : colBlocks[col]) { + rowEntries[row].push_back(col); + } + } + + // Build CSR structure + rowStart[0] = 0; + for (int64_t row = 0; row < numBlocks; row++) { + // Sort column indices within each row + sort(rowEntries[row].begin(), rowEntries[row].end()); + for (int64_t col : rowEntries[row]) { + colIndices.push_back(col); + } + rowStart[row + 1] = colIndices.size(); + } +} + +// Test that createSolverFromBlockCsr produces same result as createSolver +template +void testCsrVsOriginal(int seed) { + int numParams = 50; + auto colBlocks = randomCols(numParams, 0.1, 57 + seed); + colBlocks = makeIndependentElimSet(colBlocks, 0, 30); + + // Create using original interface + SparseStructure ss = columnsToCscStruct(colBlocks).transpose(); + vector paramSize = randomVec(ss.ptrs.size() - 1, 2, 4, 47 + seed); + + Settings settings; + settings.backend = BackendFast; + settings.addFillPolicy = AddFillComplete; + + auto solverOrig = createSolver(settings, paramSize, ss); + + // Create using CSR interface + vector rowStart, colIndices; + columnsToCsr(colBlocks, rowStart, colIndices); + + BlockCsrDescriptor desc; + desc.numBlocks = numParams; + desc.numBlockNonzeros = colIndices.size(); + desc.rowStart = rowStart.data(); + desc.colIndices = colIndices.data(); + desc.blockSizes = paramSize.data(); + desc.indexType = INDEX_INT64; + desc.mtype = MTYPE_SPD; + desc.mview = MVIEW_LOWER; + desc.indexBase = BASE_ZERO; + + auto solverCsr = createSolverFromBlockCsr(settings, desc); + + // Both solvers should have same data size + ASSERT_EQ(solverOrig->dataSize(), solverCsr->dataSize()); + ASSERT_EQ(solverOrig->order(), solverCsr->order()); + + // Generate random SPD data + vector dataOrig = randomData(solverOrig->dataSize(), -1.0, 1.0, 9 + seed); + solverOrig->skel().damp(dataOrig, T(0.0), T(solverOrig->order() * 2.0)); + + // Extract to CSR format + vector csrValues(colIndices.size() * 16); // Max block size 4x4 + int64_t valOffset = 0; + for (int64_t row = 0; row < numParams; row++) { + int64_t rowSize = paramSize[row]; + for (int64_t ptr = rowStart[row]; ptr < rowStart[row + 1]; ptr++) { + int64_t col = colIndices[ptr]; + int64_t colSize = paramSize[col]; + // Just use zeros for now - we're testing structure, not values + for (int64_t i = 0; i < rowSize * colSize; i++) { + csrValues[valOffset + i] = T(0); + } + valOffset += rowSize * colSize; + } + } +} + +// Test block CSR to SparseStructure conversion +TEST(CsrSolver, BlockCsrToSparseStructure) { + // Simple 3x3 block matrix with lower triangular structure + // Block sizes: [2, 3, 2] + // Structure: + // [0] + // [1, 0] + // [2, 2, 0] + + vector rowStart = {0, 1, 3, 6}; + vector colIndices = {0, 0, 1, 0, 1, 2}; + vector blockSizes = {2, 3, 2}; + + BlockCsrDescriptor desc; + desc.numBlocks = 3; + desc.numBlockNonzeros = 6; + desc.rowStart = rowStart.data(); + desc.colIndices = colIndices.data(); + desc.blockSizes = blockSizes.data(); + desc.indexType = INDEX_INT64; + desc.mtype = MTYPE_SPD; + desc.mview = MVIEW_LOWER; + desc.indexBase = BASE_ZERO; + + SparseStructure ss = blockCsrToSparseStructure(desc); + + // Check structure matches + ASSERT_EQ(ss.ptrs.size(), 4); + ASSERT_EQ(ss.inds.size(), 6); + EXPECT_EQ(ss.ptrs[0], 0); + EXPECT_EQ(ss.ptrs[1], 1); + EXPECT_EQ(ss.ptrs[2], 3); + EXPECT_EQ(ss.ptrs[3], 6); +} + +// Test INDEX_INT32 handling +TEST(CsrSolver, Int32Indices) { + vector rowStart = {0, 1, 3, 6}; + vector colIndices = {0, 0, 1, 0, 1, 2}; + vector blockSizes = {2, 3, 2}; + + BlockCsrDescriptor desc; + desc.numBlocks = 3; + desc.numBlockNonzeros = 6; + desc.rowStart = rowStart.data(); + desc.colIndices = colIndices.data(); + desc.blockSizes = blockSizes.data(); + desc.indexType = INDEX_INT32; + desc.mtype = MTYPE_SPD; + desc.mview = MVIEW_LOWER; + desc.indexBase = BASE_ZERO; + + SparseStructure ss = blockCsrToSparseStructure(desc); + ASSERT_EQ(ss.ptrs.size(), 4); + ASSERT_EQ(ss.inds.size(), 6); +} + +// Test BASE_ONE handling +TEST(CsrSolver, OneBasedIndices) { + // Same structure but 1-based + vector rowStart = {1, 2, 4, 7}; // 1-based + vector colIndices = {1, 1, 2, 1, 2, 3}; // 1-based + vector blockSizes = {2, 3, 2}; + + BlockCsrDescriptor desc; + desc.numBlocks = 3; + desc.numBlockNonzeros = 6; + desc.rowStart = rowStart.data(); + desc.colIndices = colIndices.data(); + desc.blockSizes = blockSizes.data(); + desc.indexType = INDEX_INT64; + desc.mtype = MTYPE_SPD; + desc.mview = MVIEW_LOWER; + desc.indexBase = BASE_ONE; + + SparseStructure ss = blockCsrToSparseStructure(desc); + ASSERT_EQ(ss.ptrs.size(), 4); + ASSERT_EQ(ss.inds.size(), 6); + // After conversion, should be 0-based + EXPECT_EQ(ss.ptrs[0], 0); + EXPECT_EQ(ss.inds[0], 0); +} + +// Test invalid MatrixType rejection +TEST(CsrSolver, InvalidMatrixType) { + vector rowStart = {0, 1}; + vector colIndices = {0}; + vector blockSizes = {2}; + + BlockCsrDescriptor desc; + desc.numBlocks = 1; + desc.numBlockNonzeros = 1; + desc.rowStart = rowStart.data(); + desc.colIndices = colIndices.data(); + desc.blockSizes = blockSizes.data(); + desc.indexType = INDEX_INT64; + desc.mtype = MTYPE_GENERAL; // Not supported + desc.mview = MVIEW_LOWER; + desc.indexBase = BASE_ZERO; + + EXPECT_THROW(blockCsrToSparseStructure(desc), std::invalid_argument); +} + +// Test createSolverFromBlockCsr basic functionality +template +void testCreateSolverFromBlockCsr() { + // Simple 3-block lower triangular matrix + vector rowStart = {0, 1, 3, 6}; + vector colIndices = {0, 0, 1, 0, 1, 2}; + vector blockSizes = {2, 2, 2}; + + BlockCsrDescriptor desc; + desc.numBlocks = 3; + desc.numBlockNonzeros = 6; + desc.rowStart = rowStart.data(); + desc.colIndices = colIndices.data(); + desc.blockSizes = blockSizes.data(); + desc.indexType = INDEX_INT64; + desc.mtype = MTYPE_SPD; + desc.mview = MVIEW_LOWER; + desc.indexBase = BASE_ZERO; + + Settings settings; + settings.backend = BackendFast; + settings.addFillPolicy = AddFillComplete; + + auto solver = createSolverFromBlockCsr(settings, desc); + + ASSERT_NE(solver, nullptr); + EXPECT_EQ(solver->order(), 6); // 2 + 2 + 2 + EXPECT_GT(solver->dataSize(), 0); +} + +TEST(CsrSolver, CreateSolverFromBlockCsr_float) { testCreateSolverFromBlockCsr(); } + +TEST(CsrSolver, CreateSolverFromBlockCsr_double) { testCreateSolverFromBlockCsr(); } + +// Test full factor+solve workflow with CSR interface +template +void testCsrFactorSolve() { + // 2-block diagonal matrix for simplicity + vector rowStart = {0, 1, 2}; + vector colIndices = {0, 1}; + vector blockSizes = {2, 2}; + + // Create block values: 2x2 identity blocks (SPD) + // Block 0 (diagonal): [[2, 0], [0, 2]] + // Block 1 (diagonal): [[3, 0], [0, 3]] + vector csrValues = { + 2, 0, 0, 2, // Block (0,0) + 3, 0, 0, 3 // Block (1,1) + }; + + BlockCsrDescriptor desc; + desc.numBlocks = 2; + desc.numBlockNonzeros = 2; + desc.rowStart = rowStart.data(); + desc.colIndices = colIndices.data(); + desc.blockSizes = blockSizes.data(); + desc.indexType = INDEX_INT64; + desc.mtype = MTYPE_SPD; + desc.mview = MVIEW_LOWER; + desc.indexBase = BASE_ZERO; + + Settings settings; + settings.backend = BackendFast; + settings.addFillPolicy = AddFillComplete; + + vector data; + auto solver = createSolverFromBlockCsrWithValues(settings, desc, csrValues.data(), data); + + ASSERT_NE(solver, nullptr); + + // Factor + solver->factor(data.data()); + + // Solve with RHS = [1, 1, 1, 1] + vector rhs = {1, 1, 1, 1}; + solver->solve(data.data(), rhs.data(), 4, 1); + + // Expected solution: [0.5, 0.5, 1/3, 1/3] + EXPECT_NEAR(rhs[0], T(0.5), Epsilon::value2); + EXPECT_NEAR(rhs[1], T(0.5), Epsilon::value2); + EXPECT_NEAR(rhs[2], T(1.0 / 3.0), Epsilon::value2); + EXPECT_NEAR(rhs[3], T(1.0 / 3.0), Epsilon::value2); +} + +TEST(CsrSolver, FactorSolve_float) { testCsrFactorSolve(); } + +TEST(CsrSolver, FactorSolve_double) { testCsrFactorSolve(); }