diff --git a/src/platform/graphics/graphics-device.js b/src/platform/graphics/graphics-device.js index a47edb79942..283dc5d5727 100644 --- a/src/platform/graphics/graphics-device.js +++ b/src/platform/graphics/graphics-device.js @@ -8,14 +8,17 @@ import { Tracing } from '../../core/tracing.js'; import { Color } from '../../core/math/color.js'; import { TRACEID_TEXTURES } from '../../core/constants.js'; import { + BUFFER_STATIC, CULLFACE_BACK, CLEARFLAG_COLOR, CLEARFLAG_DEPTH, + INDEXFORMAT_UINT16, PRIMITIVE_POINTS, PRIMITIVE_TRIFAN, SEMANTIC_POSITION, TYPE_FLOAT32, PIXELFORMAT_111110F, PIXELFORMAT_RGBA16F, PIXELFORMAT_RGBA32F, DISPLAYFORMAT_LDR, semanticToLocation } from './constants.js'; import { BlendState } from './blend-state.js'; import { DepthState } from './depth-state.js'; +import { IndexBuffer } from './index-buffer.js'; import { ScopeSpace } from './scope-space.js'; import { VertexBuffer } from './vertex-buffer.js'; import { VertexFormat } from './vertex-format.js'; @@ -27,7 +30,6 @@ import { DebugGraphics } from './debug-graphics.js'; * @import { DEVICETYPE_WEBGL2, DEVICETYPE_WEBGPU } from './constants.js' * @import { DynamicBuffers } from './dynamic-buffers.js' * @import { GpuProfiler } from './gpu-profiler.js' - * @import { IndexBuffer } from './index-buffer.js' * @import { RenderTarget } from './render-target.js' * @import { Shader } from './shader.js' * @import { Texture } from './texture.js' @@ -420,6 +422,15 @@ class GraphicsDevice extends EventHandler { */ quadVertexBuffer; + /** + * An index buffer for drawing a quad as an indexed triangle list. + * Contains 6 indices: [0, 1, 2, 2, 1, 3] forming two triangles. + * + * @type {IndexBuffer} + * @ignore + */ + quadIndexBuffer; + /** * An object representing current blend state * @@ -596,6 +607,10 @@ class GraphicsDevice extends EventHandler { this.quadVertexBuffer = new VertexBuffer(this, vertexFormat, 4, { data: positions }); + + // create quad index buffer for indexed triangle list (two triangles forming a quad) + const indices = new Uint16Array([0, 1, 2, 2, 1, 3]); + this.quadIndexBuffer = new IndexBuffer(this, INDEXFORMAT_UINT16, 6, BUFFER_STATIC, indices.buffer); } /** @@ -630,6 +645,9 @@ class GraphicsDevice extends EventHandler { this.quadVertexBuffer?.destroy(); this.quadVertexBuffer = null; + this.quadIndexBuffer?.destroy(); + this.quadIndexBuffer = null; + this.dynamicBuffers?.destroy(); this.dynamicBuffers = null; diff --git a/src/scene/graphics/quad-render.js b/src/scene/graphics/quad-render.js index 1ce683d5d3d..192ace60893 100644 --- a/src/scene/graphics/quad-render.js +++ b/src/scene/graphics/quad-render.js @@ -1,7 +1,7 @@ import { Debug, DebugHelper } from '../../core/debug.js'; import { Vec4 } from '../../core/math/vec4.js'; import { BindGroup, DynamicBindGroup } from '../../platform/graphics/bind-group.js'; -import { BINDGROUP_MESH, BINDGROUP_MESH_UB, BINDGROUP_VIEW, PRIMITIVE_TRISTRIP } from '../../platform/graphics/constants.js'; +import { BINDGROUP_MESH, BINDGROUP_MESH_UB, BINDGROUP_VIEW, PRIMITIVE_TRIANGLES } from '../../platform/graphics/constants.js'; import { DebugGraphics } from '../../platform/graphics/debug-graphics.js'; import { ShaderProcessorOptions } from '../../platform/graphics/shader-processor-options.js'; import { UniformBuffer } from '../../platform/graphics/uniform-buffer.js'; @@ -12,11 +12,10 @@ import { ShaderUtils } from '../shader-lib/shader-utils.js'; */ const _quadPrimitive = { - type: PRIMITIVE_TRISTRIP, + type: PRIMITIVE_TRIANGLES, base: 0, - baseVertex: 0, - count: 4, - indexed: false + count: 6, + indexed: true }; const _tempViewport = new Vec4(); @@ -119,8 +118,12 @@ class QuadRender { * not changed if not provided. * @param {Vec4} [scissor] - The scissor rectangle of the quad, in pixels. Used only if the * viewport is provided. + * @param {number} [numInstances] - Number of instances to draw. When provided, renders + * multiple quads using instanced drawing. Each instance can use the instance index + * (`gl_InstanceID` in GLSL, `pcInstanceIndex` in WGSL) to fetch per-quad data from + * a texture or buffer, allowing each quad to be parameterized independently. */ - render(viewport, scissor) { + render(viewport, scissor, numInstances) { const device = this.shader.device; DebugGraphics.pushGpuMarker(device, 'QuadRender'); @@ -163,7 +166,7 @@ class QuadRender { } } - device.draw(_quadPrimitive); + device.draw(_quadPrimitive, device.quadIndexBuffer, numInstances); // restore if changed if (viewport) { diff --git a/src/scene/gsplat-unified/gsplat-info.js b/src/scene/gsplat-unified/gsplat-info.js index 6df6b6f454f..0078c82193a 100644 --- a/src/scene/gsplat-unified/gsplat-info.js +++ b/src/scene/gsplat-unified/gsplat-info.js @@ -3,10 +3,9 @@ import { Mat4 } from '../../core/math/mat4.js'; import { Vec2 } from '../../core/math/vec2.js'; import { Vec4 } from '../../core/math/vec4.js'; import { BoundingBox } from '../../core/shape/bounding-box.js'; -import { PIXELFORMAT_R32U, FILTER_NEAREST, ADDRESS_CLAMP_TO_EDGE } from '../../platform/graphics/constants.js'; +import { PIXELFORMAT_R32U, PIXELFORMAT_RGBA32U, FILTER_NEAREST, ADDRESS_CLAMP_TO_EDGE } from '../../platform/graphics/constants.js'; import { Texture } from '../../platform/graphics/texture.js'; import { TextureUtils } from '../../platform/graphics/texture-utils.js'; -import { GSplatIntervalTexture } from './gsplat-interval-texture.js'; /** * @import { GraphicsDevice } from "../../platform/graphics/graphics-device.js"; @@ -18,11 +17,11 @@ import { GSplatIntervalTexture } from './gsplat-interval-texture.js'; * @import { ScopeId } from '../../platform/graphics/scope-id.js'; */ -/** @type {Vec2[]} */ -const vecs = []; - const tmpSize = new Vec2(); +// Reusable buffer for sub-draw data (only grows, never shrinks) +let subDrawDataArray = new Uint32Array(0); + /** * Represents a snapshot of gsplat state for rendering. This class captures all necessary data * at a point in time and should not hold references back to the source placement. All required @@ -83,11 +82,20 @@ class GSplatInfo { aabb = new BoundingBox(); /** - * Manager for the intervals texture generation + * Small RGBA32U texture storing per-sub-draw data for instanced interval rendering. + * Each texel: R = rowStart | (numRows << 16), G = colStart, B = colEnd, A = sourceBase. + * Null when intervals are not used (non-LOD or full range). * - * @type {GSplatIntervalTexture|null} + * @type {Texture|null} */ - intervalTexture = null; + subDrawTexture = null; + + /** + * Number of sub-draw instances for instanced interval rendering. + * + * @type {number} + */ + subDrawCount = 0; /** * Small R32U texture mapping octree node index to sequential local bounds index. @@ -195,7 +203,9 @@ class GSplatInfo { destroy() { this.intervals.length = 0; - this.intervalTexture?.destroy(); + this.subDrawTexture?.destroy(); + this.subDrawTexture = null; + this.subDrawCount = 0; this.nodeToLocalBoundsTexture?.destroy(); this.nodeToLocalBoundsTexture = null; } @@ -206,11 +216,17 @@ class GSplatInfo { this.padding = textureSize * count - activeSplats; Debug.assert(this.padding >= 0); this.viewport.set(0, start, textureSize, count); + + // Build sub-draw data for instanced interval rendering + if (this.intervals.length > 0) { + this.updateSubDraws(textureSize); + } } /** - * Updates the flattened intervals array and GPU texture from placement intervals. - * Also updates the nodeToLocalBounds texture if octree nodes are available. + * Updates the flattened intervals array from placement intervals. Intervals are sorted and + * stored as half-open pairs [start, end). Called once from the constructor; sub-draw data + * is built later in setLines when the work buffer texture width is known. * * @param {Map} intervals - Map of node index to inclusive [x, y] intervals. */ @@ -223,52 +239,21 @@ class GSplatInfo { // If placement has intervals defined if (intervals.size > 0) { - // fast path: if the intervals cover the full range, intervals are not needed - // also collect references to inclusive Vec2 intervals into a reusable array for sorting + // Write half-open intervals and count total splats let totalCount = 0; - let used = 0; + let k = 0; + this.intervals.length = intervals.size * 2; for (const interval of intervals.values()) { + this.intervals[k++] = interval.x; + this.intervals[k++] = interval.y + 1; totalCount += (interval.y - interval.x + 1); - vecs[used++] = interval; } - // not full range - if (totalCount !== this.numSplats) { - - // finalize temp array length for sorting/merging - vecs.length = used; - - // sort by start - vecs.sort((a, b) => (a.x - b.x)); - - // pre-size to the upper bound - this.intervals.length = used * 2; - - // write merged intervals directly to this.intervals - let k = 0; - let currentStart = vecs[0].x; - let currentEnd = vecs[0].y; - for (let i = 1; i < used; i++) { - const p = vecs[i]; - if (p.x === currentEnd + 1) { // adjacent, extend current interval - currentEnd = p.y; - } else { // write half-open pair - this.intervals[k++] = currentStart; - this.intervals[k++] = currentEnd + 1; - currentStart = p.x; - currentEnd = p.y; - } - } - // write final half-open pair - this.intervals[k++] = currentStart; - this.intervals[k++] = currentEnd + 1; - - // trim to actual merged length - this.intervals.length = k; - - // update GPU texture and active splats count - this.intervalTexture = new GSplatIntervalTexture(this.device); - this.activeSplats = this.intervalTexture.update(this.intervals, totalCount); + // If intervals cover the full range, they're not needed + if (totalCount === this.numSplats) { + this.intervals.length = 0; + } else { + this.activeSplats = totalCount; } // Update nodeToLocalBounds mapping for GPU culling @@ -276,9 +261,6 @@ class GSplatInfo { this.placementIntervals = intervals; this.updateNodeToLocalBounds(intervals, this.octreeNodes.length); } - - // clear temp array - vecs.length = 0; } else { // Non-octree: single bounds entry this.numBoundsEntries = 1; @@ -293,6 +275,104 @@ class GSplatInfo { } } + /** + * Builds the sub-draw data texture from the current intervals. Each interval is split at + * row boundaries of the work buffer texture to produce axis-aligned rectangles. The result + * is a small RGBA32U texture where each texel stores the parameters for one instanced quad. + * Called once from setLines when the work buffer texture width is known. + * + * @param {number} textureWidth - The work buffer texture width. + */ + updateSubDraws(textureWidth) { + + const numIntervals = this.intervals.length / 2; + + // Split intervals at row boundaries. Each interval produces at most 3 sub-draws: + // partial first row, full middle rows, partial last row. + // Reuse module-scope buffer, growing if needed (4 uints per sub-draw, 3 sub-draws per interval max). + const maxSubDraws = numIntervals * 3; + const requiredSize = maxSubDraws * 4; + if (subDrawDataArray.length < requiredSize) { + subDrawDataArray = new Uint32Array(requiredSize); + } + const subDrawData = subDrawDataArray; + const intervals = this.intervals; + let subDrawCount = 0; + let targetOffset = 0; // running target index across all intervals + + for (let i = 0; i < numIntervals; i++) { + let sourceBase = intervals[i * 2]; + const size = intervals[i * 2 + 1] - sourceBase; + + let remaining = size; + let row = (targetOffset / textureWidth) | 0; + const col = targetOffset % textureWidth; + + // Partial first row (if not starting at column 0) + if (col > 0) { + const count = Math.min(remaining, textureWidth - col); + const idx = subDrawCount * 4; + subDrawData[idx] = row | (1 << 16); // rowStart | (numRows << 16) + subDrawData[idx + 1] = col; // colStart + subDrawData[idx + 2] = col + count; // colEnd + subDrawData[idx + 3] = sourceBase; // sourceBase + subDrawCount++; + sourceBase += count; + remaining -= count; + row++; + } + + // Full middle rows + const fullRows = (remaining / textureWidth) | 0; + if (fullRows > 0) { + const idx = subDrawCount * 4; + subDrawData[idx] = row | (fullRows << 16); // rowStart | (numRows << 16) + subDrawData[idx + 1] = 0; // colStart + subDrawData[idx + 2] = textureWidth; // colEnd + subDrawData[idx + 3] = sourceBase; // sourceBase + subDrawCount++; + sourceBase += fullRows * textureWidth; + remaining -= fullRows * textureWidth; + row += fullRows; + } + + // Partial last row + if (remaining > 0) { + const idx = subDrawCount * 4; + subDrawData[idx] = row | (1 << 16); // rowStart | (numRows << 16) + subDrawData[idx + 1] = 0; // colStart + subDrawData[idx + 2] = remaining; // colEnd + subDrawData[idx + 3] = sourceBase; // sourceBase + subDrawCount++; + } + + targetOffset += size; + } + + this.subDrawCount = subDrawCount; + + // Calculate 2D texture dimensions to stay within device limits + const { x: texWidth, y: texHeight } = TextureUtils.calcTextureSize(subDrawCount, tmpSize); + + // Create the sub-draw data texture + this.subDrawTexture = new Texture(this.device, { + name: 'subDrawData', + width: texWidth, + height: texHeight, + format: PIXELFORMAT_RGBA32U, + mipmaps: false, + minFilter: FILTER_NEAREST, + magFilter: FILTER_NEAREST, + addressU: ADDRESS_CLAMP_TO_EDGE, + addressV: ADDRESS_CLAMP_TO_EDGE + }); + + // Upload sub-draw data + const texData = this.subDrawTexture.lock(); + texData.set(subDrawData.subarray(0, subDrawCount * 4)); + this.subDrawTexture.unlock(); + } + update() { const worldMatrix = this.node.getWorldTransform(); const worldMatrixChanged = !this.previousWorldTransform.equals(worldMatrix); diff --git a/src/scene/gsplat-unified/gsplat-interval-texture.js b/src/scene/gsplat-unified/gsplat-interval-texture.js deleted file mode 100644 index bcac539d95b..00000000000 --- a/src/scene/gsplat-unified/gsplat-interval-texture.js +++ /dev/null @@ -1,175 +0,0 @@ -import { Texture } from '../../platform/graphics/texture.js'; -import { - ADDRESS_CLAMP_TO_EDGE, FILTER_NEAREST, PIXELFORMAT_R32U, PIXELFORMAT_RG32U, CULLFACE_NONE, - SEMANTIC_POSITION -} from '../../platform/graphics/constants.js'; -import { RenderTarget } from '../../platform/graphics/render-target.js'; -import { drawQuadWithShader } from '../graphics/quad-render-utils.js'; -import { BlendState } from '../../platform/graphics/blend-state.js'; -import { DepthState } from '../../platform/graphics/depth-state.js'; -import { ShaderUtils } from '../shader-lib/shader-utils.js'; -import gsplatIntervalTextureGLSL from '../shader-lib/glsl/chunks/gsplat/frag/gsplatIntervalTexture.js'; -import gsplatIntervalTextureWGSL from '../shader-lib/wgsl/chunks/gsplat/frag/gsplatIntervalTexture.js'; - -/** - * @import { GraphicsDevice } from '../../platform/graphics/graphics-device.js' - * @import { Shader } from '../../platform/graphics/shader.js' - */ - -/** - * Manages the intervals texture generation for GSplat LOD system using GPU acceleration. A list of - * intervals is provided to the update method, and the texture is generated on the GPU. The texture - * is then used to map target indices to source splat indices. - * - * @ignore - */ -class GSplatIntervalTexture { - /** @type {GraphicsDevice} */ - device; - - /** - * Texture that maps target indices to source splat indices based on intervals - * - * @type {Texture|null} - */ - texture = null; - - /** - * Render target for the intervals texture - * - * @type {RenderTarget|null} - */ - rt = null; - - /** - * Texture that stores interval data (start + accumulated sum pairs) for GPU processing - * - * @type {Texture|null} - */ - intervalsDataTexture = null; - - /** - * Shader for generating intervals texture on GPU - * - * @type {Shader|null} - */ - shader = null; - - /** - * @param {GraphicsDevice} device - The graphics device - */ - constructor(device) { - this.device = device; - } - - destroy() { - this.texture?.destroy(); - this.texture = null; - this.rt?.destroy(); - this.rt = null; - this.intervalsDataTexture?.destroy(); - this.intervalsDataTexture = null; - this.shader = null; - } - - /** - * Creates shader for GPU-based intervals texture generation - */ - getShader() { - if (!this.shader) { - this.shader = ShaderUtils.createShader(this.device, { - uniqueName: 'GSplatIntervalsShader', - attributes: { aPosition: SEMANTIC_POSITION }, - vertexChunk: 'quadVS', - fragmentGLSL: gsplatIntervalTextureGLSL, - fragmentWGSL: gsplatIntervalTextureWGSL, - fragmentOutputTypes: ['uint'] - }); - } - - return this.shader; - } - - /** - * Creates a texture with specified parameters - */ - createTexture(name, format, width, height) { - return new Texture(this.device, { - name: name, - width: width, - height: height, - format: format, - cubemap: false, - mipmaps: false, - minFilter: FILTER_NEAREST, - magFilter: FILTER_NEAREST, - addressU: ADDRESS_CLAMP_TO_EDGE, - addressV: ADDRESS_CLAMP_TO_EDGE - }); - } - - /** - * Updates the intervals texture based on provided intervals array - * - * @param {number[]} intervals - Array of intervals (start, end pairs) - * @param {number} totalIntervalSplats - Total number of splats referenced by the intervals - * @returns {number} The number of active splats - */ - update(intervals, totalIntervalSplats) { - - // Calculate texture dimensions for output intervals texture - const maxTextureSize = this.device.maxTextureSize; - let textureWidth = Math.ceil(Math.sqrt(totalIntervalSplats)); - textureWidth = Math.min(textureWidth, maxTextureSize); - const textureHeight = Math.ceil(totalIntervalSplats / textureWidth); - - // Create main intervals texture - this.texture = this.createTexture('intervalsTexture', PIXELFORMAT_R32U, textureWidth, textureHeight); - - this.rt = new RenderTarget({ - colorBuffer: this.texture, - depth: false - }); - - // Prepare intervals data with CPU prefix sum - const numIntervals = intervals.length / 2; - const dataTextureSize = Math.ceil(Math.sqrt(numIntervals)); - - // Create intervals data texture - this.intervalsDataTexture = this.createTexture('intervalsData', PIXELFORMAT_RG32U, dataTextureSize, dataTextureSize); - - // Compute intervals data with accumulated sums on CPU - // TODO: consider doing this using compute shader on WebGPU - const intervalsData = this.intervalsDataTexture.lock(); - let runningSum = 0; - - for (let i = 0; i < numIntervals; i++) { - const start = intervals[i * 2]; - const end = intervals[i * 2 + 1]; - const intervalSize = end - start; - runningSum += intervalSize; - - intervalsData[i * 2] = start; // R: interval start - intervalsData[i * 2 + 1] = runningSum; // G: accumulated sum - } - - this.intervalsDataTexture.unlock(); - - // Generate intervals texture on GPU - const scope = this.device.scope; - scope.resolve('uIntervalsTexture').setValue(this.intervalsDataTexture); - scope.resolve('uNumIntervals').setValue(numIntervals); - scope.resolve('uTextureWidth').setValue(textureWidth); - scope.resolve('uActiveSplats').setValue(totalIntervalSplats); - - this.device.setCullMode(CULLFACE_NONE); - this.device.setBlendState(BlendState.NOBLEND); - this.device.setDepthState(DepthState.NODEPTH); - - drawQuadWithShader(this.device, this.rt, this.getShader()); - - return totalIntervalSplats; - } -} - -export { GSplatIntervalTexture }; diff --git a/src/scene/gsplat-unified/gsplat-work-buffer-render-pass.js b/src/scene/gsplat-unified/gsplat-work-buffer-render-pass.js index bd745dfa9dd..20d6f9fd7e9 100644 --- a/src/scene/gsplat-unified/gsplat-work-buffer-render-pass.js +++ b/src/scene/gsplat-unified/gsplat-work-buffer-render-pass.js @@ -131,7 +131,8 @@ class GSplatWorkBufferRenderPass extends RenderPass { const scope = device.scope; Debug.assert(resource); - const { activeSplats, lineStart, viewport, intervalTexture } = splatInfo; + const { activeSplats, lineStart, viewport, subDrawTexture, subDrawCount } = splatInfo; + const useIntervals = subDrawTexture !== null && subDrawCount > 0; // Get work buffer modifier (live from placement, not a snapshot copy) const workBufferModifier = splatInfo.getWorkBufferModifier?.() ?? null; @@ -142,7 +143,7 @@ class GSplatWorkBufferRenderPass extends RenderPass { // quad renderer and material are cached in the resource const workBufferRenderInfo = resource.getWorkBufferRenderInfo( - intervalTexture !== null, + useIntervals, this.colorOnly, workBufferModifier, formatHash, @@ -153,11 +154,6 @@ class GSplatWorkBufferRenderPass extends RenderPass { // Assign material properties to scope workBufferRenderInfo.material.setParameters(device); - if (intervalTexture) { - // Set LOD intervals texture for remapping of indices - scope.resolve('uIntervalsTexture').setValue(intervalTexture.texture); - } - scope.resolve('uActiveSplats').setValue(activeSplats); scope.resolve('uStartLine').setValue(lineStart); scope.resolve('uViewportWidth').setValue(viewport.z); @@ -216,8 +212,17 @@ class GSplatWorkBufferRenderPass extends RenderPass { } } - // Render the quad - QuadRender handles all the complex setup internally - workBufferRenderInfo.quadRender.render(viewport); + if (useIntervals) { + // Instanced draw path: one quad per interval row-segment + scope.resolve('uSubDrawData').setValue(subDrawTexture); + scope.resolve('uLineCount').setValue(splatInfo.lineCount); + scope.resolve('uTextureWidth').setValue(viewport.z); + + workBufferRenderInfo.quadRender.render(viewport, undefined, subDrawCount); + } else { + // Standard single-quad draw path + workBufferRenderInfo.quadRender.render(viewport); + } } destroy() { diff --git a/src/scene/gsplat-unified/gsplat-work-buffer.js b/src/scene/gsplat-unified/gsplat-work-buffer.js index 6924937455f..155b2f25ca5 100644 --- a/src/scene/gsplat-unified/gsplat-work-buffer.js +++ b/src/scene/gsplat-unified/gsplat-work-buffer.js @@ -15,6 +15,8 @@ import { QuadRender } from '../graphics/quad-render.js'; import { ShaderUtils } from '../shader-lib/shader-utils.js'; import glslGsplatCopyToWorkBufferPS from '../shader-lib/glsl/chunks/gsplat/frag/gsplatCopyToWorkbuffer.js'; import wgslGsplatCopyToWorkBufferPS from '../shader-lib/wgsl/chunks/gsplat/frag/gsplatCopyToWorkbuffer.js'; +import glslGsplatCopyInstancedQuadVS from '../shader-lib/glsl/chunks/gsplat/vert/gsplatCopyInstancedQuad.js'; +import wgslGsplatCopyInstancedQuadVS from '../shader-lib/wgsl/chunks/gsplat/vert/gsplatCopyInstancedQuad.js'; import { GSplatNodeCullRenderPass } from './gsplat-node-cull-render-pass.js'; import { GSplatWorkBufferRenderPass } from './gsplat-work-buffer-render-pass.js'; import { GSplatStreams } from '../gsplat/gsplat-streams.js'; @@ -53,7 +55,6 @@ class WorkBufferRenderInfo { * @param {GSplatFormat} format - The work buffer format descriptor. */ constructor(device, key, material, colorOnly, format) { - this.device = device; this.material = material; const clonedDefines = new Map(material.defines); @@ -96,18 +97,30 @@ class WorkBufferRenderInfo { fragmentOutputTypes.push(info.returnType); } - const shader = ShaderUtils.createShader(device, { + // Use instanced vertex shader for LOD path, fullscreen quad for non-LOD + const useInstanced = clonedDefines.has('GSPLAT_LOD'); + + const shaderOptions = { uniqueName: `SplatCopyToWorkBuffer:${key}`, attributes: { vertex_position: SEMANTIC_POSITION }, vertexDefines: clonedDefines, fragmentDefines: clonedDefines, - vertexChunk: 'fullscreenQuadVS', fragmentGLSL: glslGsplatCopyToWorkBufferPS, fragmentWGSL: wgslGsplatCopyToWorkBufferPS, fragmentIncludes: fragmentIncludes, fragmentOutputTypes: fragmentOutputTypes - }); + }; + + if (useInstanced) { + // Instanced LOD path: custom vertex shader that positions quads per instance + shaderOptions.vertexGLSL = glslGsplatCopyInstancedQuadVS; + shaderOptions.vertexWGSL = wgslGsplatCopyInstancedQuadVS; + } else { + // Standard fullscreen quad path + shaderOptions.vertexChunk = 'fullscreenQuadVS'; + } + const shader = ShaderUtils.createShader(device, shaderOptions); this.quadRender = new QuadRender(shader); } diff --git a/src/scene/shader-lib/glsl/chunks/gsplat/frag/gsplatCopyToWorkbuffer.js b/src/scene/shader-lib/glsl/chunks/gsplat/frag/gsplatCopyToWorkbuffer.js index e90cbeea4b1..11dd819c94b 100644 --- a/src/scene/shader-lib/glsl/chunks/gsplat/frag/gsplatCopyToWorkbuffer.js +++ b/src/scene/shader-lib/glsl/chunks/gsplat/frag/gsplatCopyToWorkbuffer.js @@ -18,8 +18,8 @@ uniform int uStartLine; // Start row in destination texture uniform int uViewportWidth; // Width of the destination viewport in pixels #ifdef GSPLAT_LOD - // LOD intervals texture - uniform usampler2D uIntervalsTexture; + // Packed sub-draw params: (sourceBase, colStart, rowWidth, rowStart) + flat varying ivec4 vSubDraw; #endif uniform vec3 uColorMultiply; @@ -66,10 +66,10 @@ void main(void) { } else { #ifdef GSPLAT_LOD - // Use intervals texture to remap target index to source index - int intervalsSize = int(textureSize(uIntervalsTexture, 0).x); - ivec2 intervalUV = ivec2(targetIndex % intervalsSize, targetIndex / intervalsSize); - uint originalIndex = texelFetch(uIntervalsTexture, intervalUV, 0).r; + // Compute source index from packed sub-draw varying: (sourceBase, colStart, rowWidth, rowStart) + int localRow = int(gl_FragCoord.y) - uStartLine - vSubDraw.w; + int localCol = int(gl_FragCoord.x) - vSubDraw.y; + uint originalIndex = uint(vSubDraw.x + localRow * vSubDraw.z + localCol); #else uint originalIndex = uint(targetIndex); #endif diff --git a/src/scene/shader-lib/glsl/chunks/gsplat/frag/gsplatIntervalTexture.js b/src/scene/shader-lib/glsl/chunks/gsplat/frag/gsplatIntervalTexture.js deleted file mode 100644 index 22b24f17728..00000000000 --- a/src/scene/shader-lib/glsl/chunks/gsplat/frag/gsplatIntervalTexture.js +++ /dev/null @@ -1,68 +0,0 @@ -// fragment shader to generate intervals texture for GSplat LOD system -export default /* glsl */` - -precision highp usampler2D; - -// RG32U: (start, accumulatedSum) -uniform usampler2D uIntervalsTexture; -uniform int uNumIntervals; -uniform int uTextureWidth; -uniform int uActiveSplats; - -ivec2 getCoordFromIndex(int index, int textureWidth) { - return ivec2(index % textureWidth, index / textureWidth); -} - -void main() { - ivec2 coord = ivec2(gl_FragCoord.xy); - int targetIndex = coord.y * uTextureWidth + coord.x; - - if (targetIndex >= uActiveSplats) { - gl_FragColor = 0u; - return; - } - - // Binary search through accumulated sums (G channel) - int left = 0; - int right = uNumIntervals - 1; - int intervalIndex = 0; - - while (left <= right) { - int mid = (left + right) / 2; - - int textureWidth = textureSize(uIntervalsTexture, 0).x; - ivec2 intervalCoord = getCoordFromIndex(mid, textureWidth); - uvec2 intervalData = texelFetch(uIntervalsTexture, intervalCoord, 0).rg; - - uint accumulatedSum = intervalData.g; // G channel - - if (uint(targetIndex) < accumulatedSum) { - intervalIndex = mid; - right = mid - 1; - } else { - left = mid + 1; - } - } - - // Get interval data (both start and accumulated sum in one fetch) - int textureWidth = textureSize(uIntervalsTexture, 0).x; - ivec2 intervalCoord = getCoordFromIndex(intervalIndex, textureWidth); - uvec2 intervalData = texelFetch(uIntervalsTexture, intervalCoord, 0).rg; - - uint intervalStart = intervalData.r; // R channel - uint currentAccSum = intervalData.g; // G channel - - // Get previous accumulated sum - uint prevAccSum = 0u; - if (intervalIndex > 0) { - ivec2 prevCoord = getCoordFromIndex(intervalIndex - 1, textureWidth); - prevAccSum = texelFetch(uIntervalsTexture, prevCoord, 0).g; - } - - // Calculate original splat index - uint offsetInInterval = uint(targetIndex) - prevAccSum; - uint originalIndex = intervalStart + offsetInInterval; - - gl_FragColor = originalIndex; -} -`; diff --git a/src/scene/shader-lib/glsl/chunks/gsplat/vert/gsplatCopyInstancedQuad.js b/src/scene/shader-lib/glsl/chunks/gsplat/vert/gsplatCopyInstancedQuad.js new file mode 100644 index 00000000000..05b19d414c0 --- /dev/null +++ b/src/scene/shader-lib/glsl/chunks/gsplat/vert/gsplatCopyInstancedQuad.js @@ -0,0 +1,43 @@ +// vertex shader for instanced LOD quad rendering to work buffer. +// Each instance covers one row-aligned segment of an interval. +// The fragment shader computes originalIndex from flat varyings. +export default /* glsl */` + +attribute vec2 vertex_position; + +// Sub-draw data texture: RGBA32U +// R = rowStart | (numRows << 16) +// G = colStart +// B = colEnd +// A = sourceBase +precision highp usampler2D; +uniform usampler2D uSubDrawData; +uniform int uLineCount; +uniform int uTextureWidth; + +// packed sub-draw params: (sourceBase, colStart, rowWidth, rowStart) +flat varying ivec4 vSubDraw; + +void main(void) { + // Read sub-draw parameters from 2D data texture + int subDrawWidth = textureSize(uSubDrawData, 0).x; + uvec4 data = texelFetch(uSubDrawData, ivec2(gl_InstanceID % subDrawWidth, gl_InstanceID / subDrawWidth), 0); + int rowStart = int(data.r & 0xFFFFu); + int numRows = int(data.r >> 16u); + int colStart = int(data.g); + int colEnd = int(data.b); + int sourceBase = int(data.a); + + // Quad corner from gl_VertexID (0-3 via index buffer [0,1,2, 2,1,3]) + float u = float(gl_VertexID & 1); // 0 or 1 (left or right) + float v = float(gl_VertexID >> 1); // 0 or 1 (bottom or top) + + // Map to NDC within the viewport + vec4 ndc = vec4(colStart, colEnd, rowStart, rowStart + numRows) / vec4(uTextureWidth, uTextureWidth, uLineCount, uLineCount) * 2.0 - 1.0; + + gl_Position = vec4(mix(ndc.x, ndc.y, u), mix(ndc.z, ndc.w, v), 0.5, 1.0); + + // Output packed flat varying: (sourceBase, colStart, rowWidth, rowStart) + vSubDraw = ivec4(sourceBase, colStart, colEnd - colStart, rowStart); +} +`; diff --git a/src/scene/shader-lib/wgsl/chunks/gsplat/frag/gsplatCopyToWorkbuffer.js b/src/scene/shader-lib/wgsl/chunks/gsplat/frag/gsplatCopyToWorkbuffer.js index c0cf98ed452..0640011d222 100644 --- a/src/scene/shader-lib/wgsl/chunks/gsplat/frag/gsplatCopyToWorkbuffer.js +++ b/src/scene/shader-lib/wgsl/chunks/gsplat/frag/gsplatCopyToWorkbuffer.js @@ -24,8 +24,8 @@ uniform uStartLine: i32; // Start row in destination texture uniform uViewportWidth: i32; // Width of the destination viewport in pixels #ifdef GSPLAT_LOD - // LOD intervals texture - var uIntervalsTexture: texture_2d; + // Packed sub-draw params: (sourceBase, colStart, rowWidth, rowStart) + varying @interpolate(flat) vSubDraw: vec4i; #endif uniform uColorMultiply: vec3f; @@ -70,10 +70,10 @@ fn fragmentMain(input: FragmentInput) -> FragmentOutput { } else { #ifdef GSPLAT_LOD - // Use intervals texture to remap target index to source index - let intervalsSize = i32(textureDimensions(uIntervalsTexture, 0).x); - let intervalUV = vec2i(targetIndex % intervalsSize, targetIndex / intervalsSize); - let originalIndex = textureLoad(uIntervalsTexture, intervalUV, 0).r; + // Compute source index from packed sub-draw varying: (sourceBase, colStart, rowWidth, rowStart) + let localRow = i32(input.position.y) - uniform.uStartLine - input.vSubDraw.w; + let localCol = i32(input.position.x) - input.vSubDraw.y; + let originalIndex = u32(input.vSubDraw.x + localRow * input.vSubDraw.z + localCol); #else let originalIndex = targetIndex; #endif diff --git a/src/scene/shader-lib/wgsl/chunks/gsplat/frag/gsplatIntervalTexture.js b/src/scene/shader-lib/wgsl/chunks/gsplat/frag/gsplatIntervalTexture.js deleted file mode 100644 index bda5e3f8c98..00000000000 --- a/src/scene/shader-lib/wgsl/chunks/gsplat/frag/gsplatIntervalTexture.js +++ /dev/null @@ -1,70 +0,0 @@ -// fragment shader to generate intervals texture for GSplat LOD system -export default /* wgsl */` - -// RG32U: (start, accumulatedSum) -var uIntervalsTexture: texture_2d; -uniform uNumIntervals: i32; -uniform uTextureWidth: i32; -uniform uActiveSplats: i32; - -fn getCoordFromIndex(index: i32, textureWidth: i32) -> vec2i { - return vec2i(index % textureWidth, index / textureWidth); -} - -@fragment -fn fragmentMain(input: FragmentInput) -> FragmentOutput { - var output: FragmentOutput; - - let coord = vec2i(i32(input.position.x), i32(input.position.y)); - let targetIndex = coord.y * uniform.uTextureWidth + coord.x; - - if (targetIndex >= uniform.uActiveSplats) { - output.color = 0u; - return output; - } - - // Binary search through accumulated sums (G channel) - var left = 0i; - var right = uniform.uNumIntervals - 1; - var intervalIndex = 0i; - - while (left <= right) { - let mid = (left + right) / 2; - - let textureWidth = i32(textureDimensions(uIntervalsTexture, 0).x); - let intervalCoord = getCoordFromIndex(mid, textureWidth); - let intervalData = textureLoad(uIntervalsTexture, intervalCoord, 0).rg; - - let accumulatedSum = intervalData.g; // G channel - - if (u32(targetIndex) < accumulatedSum) { - intervalIndex = mid; - right = mid - 1; - } else { - left = mid + 1; - } - } - - // Get interval data (both start and accumulated sum in one fetch) - let textureWidth = i32(textureDimensions(uIntervalsTexture, 0).x); - let intervalCoord = getCoordFromIndex(intervalIndex, textureWidth); - let intervalData = textureLoad(uIntervalsTexture, intervalCoord, 0).rg; - - let intervalStart = intervalData.r; // R channel - let currentAccSum = intervalData.g; // G channel - - // Get previous accumulated sum - var prevAccSum = 0u; - if (intervalIndex > 0) { - let prevCoord = getCoordFromIndex(intervalIndex - 1, textureWidth); - prevAccSum = textureLoad(uIntervalsTexture, prevCoord, 0).g; - } - - // Calculate original splat index - let offsetInInterval = u32(targetIndex) - prevAccSum; - let originalIndex = intervalStart + offsetInInterval; - - output.color = originalIndex; - return output; -} -`; diff --git a/src/scene/shader-lib/wgsl/chunks/gsplat/vert/gsplatCopyInstancedQuad.js b/src/scene/shader-lib/wgsl/chunks/gsplat/vert/gsplatCopyInstancedQuad.js new file mode 100644 index 00000000000..12b224d33e4 --- /dev/null +++ b/src/scene/shader-lib/wgsl/chunks/gsplat/vert/gsplatCopyInstancedQuad.js @@ -0,0 +1,50 @@ +// vertex shader for instanced LOD quad rendering to work buffer. +// Each instance covers one row-aligned segment of an interval. +// The fragment shader computes originalIndex from flat varyings. +export default /* wgsl */` + +attribute vertex_position: vec2f; + +// Sub-draw data texture: RGBA32U +// R = rowStart | (numRows << 16) +// G = colStart +// B = colEnd +// A = sourceBase +var uSubDrawData: texture_2d; +uniform uLineCount: i32; +uniform uTextureWidth: i32; + +// packed sub-draw params: (sourceBase, colStart, rowWidth, rowStart) +varying @interpolate(flat) vSubDraw: vec4i; + +@vertex +fn vertexMain(input: VertexInput) -> VertexOutput { + var output: VertexOutput; + + // Read sub-draw parameters from 2D data texture + let subDrawWidth = i32(textureDimensions(uSubDrawData, 0).x); + let instIdx = i32(input.instanceIndex); + let data = textureLoad(uSubDrawData, vec2i(instIdx % subDrawWidth, instIdx / subDrawWidth), 0); + let rowStart = i32(data.r & 0xFFFFu); + let numRows = i32(data.r >> 16u); + let colStart = i32(data.g); + let colEnd = i32(data.b); + let sourceBase = i32(data.a); + + // Quad corner from vertexIndex (0-3 via index buffer [0,1,2, 2,1,3]) + let u = f32(i32(input.vertexIndex) & 1); // 0 or 1 (left or right) + let v = f32(i32(input.vertexIndex) >> 1u); // 0 or 1 (bottom or top) + + // Map to NDC within the viewport + // WebGPU viewport transform inverts Y: y_pixel = viewport.y + viewport.h * (1 - y_ndc) / 2 + // so we negate Y compared to the GLSL version to get correct row positioning + let ndc = vec4f(f32(colStart), f32(colEnd), f32(rowStart), f32(rowStart + numRows)) / vec4f(f32(uniform.uTextureWidth), f32(uniform.uTextureWidth), f32(uniform.uLineCount), f32(uniform.uLineCount)) * 2.0 - 1.0; + + output.position = vec4f(mix(ndc.x, ndc.y, u), mix(-ndc.z, -ndc.w, v), 0.5, 1.0); + + // Output packed flat varying: (sourceBase, colStart, rowWidth, rowStart) + output.vSubDraw = vec4i(sourceBase, colStart, colEnd - colStart, rowStart); + + return output; +} +`;