Eccentric Developments


Bounding Volume Hierarchies

As I showed in the previous post about bringing back the triangles into the path tracer, the performance is extremely poor. The problem is not the triangles themselves, but the huge amount of them that are needed to create good looking objects.

To make some sense of why this happens, let's remember that for each pixel in the image, the render algorithm will test the primary ray against all objects in the scene, and when an intersection is found, another ray is randomly generated and checked for intersections against the scene again, this continues until the recursiveness limit is reached or a light or nothing at all is found.

To illustrate how many intersection tests are done for a single scene, the following is the formula for the worst case scenario:

f(m,n,o) = m * n * o

Where:

  • m is the number of pixels
  • n is the max depth of recursion
  • o is the number of objects

If we plug in the numbers for the torus scene on the current depth configuration (5), it give us f(307200, 5, 580) = 890_880_000, which is the total number of intersection tests per image. That is a lot of work, for a simple scene.

As with any speed improvement for any algorithm or work, what we need to do is decrease the amount of work done, in the case of the path tracer, we need to decrease the number of intersection tests done per ray. To achieve this kind of outcome, rendering engines use specialized data arragements called acceleration structures.

Acceleration structures reduce the time it takes to find a ray-object intersection by discarding as many objects as possible. Think about it like this: why would you test ray-object interections with the objects on the left of the screen (i.e, X < 320 for a 640px wide image) if you know your ray origin and direction are on the right (i.e, X > 320)? makes sense, right?

There are several different acceleration structures: kd-trees, octrees, binary space partitioning and bounding volume hierarchies.

In this post, I'll be writing about bounding volume hierarchies (BVH), which is the acceleration structure I am more familiar with. And in the next article, I'll explain a bit about octrees which I use in a very particular way.

Bounding Boxes

To undersand bounding volume hierarchies, the first thing to explore is the concept of bounding boxes.

A good analogy for bounding boxes are the book shelves that you find in libraries, they not only help keep the books in a tidy arragement, they also makes it much easier to find any book you are looking for.

Bounding boxes delimit the 3D space of a scene, containerizing the primitives in it, allowing faster intersection tests between rays and the scene.

There are several types of bounding boxes, the most common and also easier to understand are the Axis-Aligned Bounding Boxes ar AABB.

AABB are very simple in concept, they are a data structure composed of two points in 3D space that represent a box that encloses one or more primitives; with the dawback that they cannot be rotated and, most of the times, they will not be a tigh fit for the contained objects.

Nevertheless, they are useful as the memory required to store them is very low and the axis-aligment property makes it possible to do very fast ray intersections.

AABBs are used to form hiearchical structures that divide the space occupied by the objects in a 3D scene. This hierarchical structure is known as a Bounding Volume Hierarchy (BVH).

Bounding Volume Hierarchies

A BVH is analogous to a binary search tree, where each node is a bounding box (in our case an AABB) that in turn contains more space subdivisions in form of AABBs.

BVHs bottom out at AABBs that contain 3D primitives and no more space subdivisions, in other words, a leaf. It is at this point where the ray interestion tests happens against actual scene geometry.

To construct this acceleration structure, the algorithm to follow is composed of the following steps:

  1. Calculate the AABB of the geometry in scope.
  2. If the current geometry count is less than a given threshold, usually 4, mark this AABB as a leaf and stop the recursion.
  3. Divide the primitives evenly along two space subdivisions.
  4. Recursively generate the left child of this AABB, using the first half of the geometry and continuing from step 1.
  5. Recursively generate the right child of this AABB, using the second half of the geometry and continuing from step 1.

The algorithm will stop when there are no more subdivisions to calculate.

Step 3 of the algorithm is the most critical and can be implemented different ways. In this entry, I'll show a naive partition that simply divides the scene primitives in halves without regard to any other property of the scene.

BVH Usage

Using the BVH is simple, having a ray that needs to be checked for intersections and a pointer to the root of the BVH, execute these steps:

  1. Test the ray and current AABB for intersection.
  2. If no intersection is found stop the recursion.
  3. If the AABB is a leaf, add the contained primitives to a list to check later.
  4. If the AABB is a node, move to the left child and go to step 1.
  5. If the AABB is a node, move to the right child and go to step 1.

At the end of the BVH traversal, there will be a list of 3D primitives that we check the ray against for intersections.

Using a BVH this way is similar to a binary search tree, in the sense that, on a well balanced BVH, the expected running time is O(log(n)). Which make the rendering algorithm much faster.

Now Show Me The Code!

The following JavaScript code implements all the previous path tracing methods plus one called initializeBVH which does the following:

  1. Implement the function to create an AABB out of a single triangle.
  2. Initialize a function to merge two AABBs into a larger one.
  3. Recursively build the BVH using a naive implementation.
  4. Export a traceBvh function that will receive a ray and test it agains the BVH, returning the closest triangle found, if any.
function initializeBVH(args) {
  const { scene, triangleIntersect } = args;
  const getComponents = (triangle, idx) => [triangle.pt0[idx], triangle.pt1[idx], triangle.pt2[idx]];
  const getTriangleAABB = (triangle) => ({
    min: [
      Math.min(...getComponents(triangle, 0)),
      Math.min(...getComponents(triangle, 1)),
      Math.min(...getComponents(triangle, 2)),
    ],
    max: [
      Math.max(...getComponents(triangle, 0)),
      Math.max(...getComponents(triangle, 1)),
      Math.max(...getComponents(triangle, 2)),
    ],
    triangles: [triangle],
  });
  const joinAABBs = (a, b) => ({
    min: [Math.min(a.min[0], b.min[0]), Math.min(a.min[1], b.min[1]), Math.min(a.min[2], b.min[2])],
    max: [Math.max(a.max[0], b.max[0]), Math.max(a.max[1], b.max[1]), Math.max(a.max[2], b.max[2])],
    triangles: [...a.triangles, ...b.triangles],
  });

  // Naive implementation
  // Using a recursive function, divide the triangles array in halves and create BVH nodes for them
  // Stop when there are 4 triangles or less
  // return the root of the BVH

  let totalTriangles = 0;

  const recursiveNaiveBuild = (triangles) => {
    if (!triangles || !triangles.length) return null;

    const aabb = triangles.map(getTriangleAABB).reduce(joinAABBs);

    if (aabb.triangles.length <= 4) {
      totalTriangles += aabb.triangles.length;
      return {
        aabb,
        left: null,
        right: null,
      };
    }

    const mid = Math.floor(triangles.length / 2);
    const left = recursiveNaiveBuild(triangles.slice(0, mid));
    const right = recursiveNaiveBuild(triangles.slice(mid, triangles.length));

    return {
      aabb,
      left,
      right,
    };
  };

  const bvh = recursiveNaiveBuild(scene);

  const intersectBvh = (ray, bvh) => {
    if (!bvh) return false;

    let dxi = 1.0 / ray.direction[0];
    let dyi = 1.0 / ray.direction[1];
    let dzi = 1.0 / ray.direction[2];

    let sx = bvh.aabb.min;
    let rsx = bvh.aabb.max;
    if (dxi < 0.0) {
      sx = bvh.aabb.max;
      rsx = bvh.aabb.min;
    }
    let sy = bvh.aabb.min;
    let rsy = bvh.aabb.max;
    if (dyi < 0.0) {
      sy = bvh.aabb.max;
      rsy = bvh.aabb.min;
    }
    let sz = bvh.aabb.min;
    let rsz = bvh.aabb.max;
    if (dzi < 0.0) {
      sz = bvh.aabb.max;
      rsz = bvh.aabb.min;
    }

    let tmin = (sx[0] - ray.origin[0]) * dxi;
    let tymax = (rsy[1] - ray.origin[1]) * dyi;
    if (tmin > tymax) {
      return false;
    }

    let tmax = (rsx[0] - ray.origin[0]) * dxi;
    let tymin = (sy[1] - ray.origin[1]) * dyi;
    if (tymin > tmax) {
      return false;
    }

    tmin = tymin > tmin ? tymin : tmin;
    tmax = tymax < tmax ? tymax : tmax;
    let tzmin = (sz[2] - ray.origin[2]) * dzi;
    let tzmax = (rsz[2] - ray.origin[2]) * dzi;
    return !(tmin > tzmax || tzmin > tmax);
  };

  // recursively find the triangles in the leaves of the BVH that intersect the given ray

  const recTraverseBvh = (ray, acc, node) => {
    // Check if there is a hit
    if (!intersectBvh(ray, node)) return;

    // Check if this is a leaf
    if (!node.left && !node.right) {
      acc.push(...node.aabb.triangles);
      return;
    }

    recTraverseBvh(ray, acc, node.left);
    recTraverseBvh(ray, acc, node.right);
  };

  // To use with the regular rendering pipeline, return a trace function that works like the regular trace one
  const traceBvh = (ray) => {
    const trianglesFound = [];
    recTraverseBvh(ray, trianglesFound, bvh);
    return trianglesFound
      .map((obj) => ({ ...triangleIntersect(ray, obj), obj }))
      .filter(({ hit }) => hit)
      .reduce((acc, intersection) => (intersection.distance < acc.distance ? intersection : acc), {
        hit: false,
        distance: Number.MAX_VALUE,
      });
  };

  return {
    traceBvh,
  };
}

function createScene(args) {
  const {
    vector: { sub, unit, cross },
    memory: { allocStaticFloat32Array, set },
  } = args;

  function rotatePoint([x, y, z], [ax, ay, az]) {
    let sinax = Math.sin(ax);
    let cosax = Math.cos(ax);

    [y, z] = [y * cosax - z * sinax, y * sinax + z * cosax];

    let sinay = Math.sin(ay);
    let cosay = Math.cos(ay);

    [x, z] = [x * cosay + z * sinay, -x * sinay + z * cosay];

    let sinaz = Math.sin(az);
    let cosaz = Math.cos(az);

    [x, y] = [x * cosaz - y * sinaz, x * sinaz + y * cosaz];

    return [x, y, z];
  }

  function translatePoint(point, [x, y, z]) {
    return [point[0] + x, point[1] + y, point[2] + z];
  }

  function createTorus(radius = 1.0, tubeRadius = 0.3, center, rotation, props) {
    const triangles = [];
    const rings = [];
    const ringSegments = 12;
    const tubeSegments = 24;
    const sourceRing = [];

    const ringSegmentRad = (2 * Math.PI) / ringSegments;
    for (let i = 0; i < ringSegments; i++) {
      sourceRing.push(rotatePoint([tubeRadius, 0, 0], [0, 0, ringSegmentRad * i]));
    }

    const tubeSegmentRad = (2 * Math.PI) / tubeSegments;
    for (let i = 0; i < tubeSegments; i++) {
      const ring = structuredClone(sourceRing)
        .map((pt) => translatePoint(pt, [radius, 0, 0]))
        .map((pt) => rotatePoint(pt, [0, tubeSegmentRad * i, 0]))
        .map((pt) => translatePoint(pt, center))
        .map((pt) => rotatePoint(pt, rotation));
      rings.push(ring);
    }

    for (let i = 0; i < tubeSegments; i++) {
      const ni = (i + 1) % tubeSegments;
      for (let j = 0; j < ringSegments; j++) {
        let pt0 = rings[i][j];
        let pt1 = rings[ni][j];
        let pt2 = rings[i][(j + 1) % ringSegments];
        triangles.push({
          pt0,
          pt1,
          pt2,
          ...props,
        });
        pt0 = rings[i][(j + 1) % ringSegments];
        pt1 = rings[ni][j];
        pt2 = rings[ni][(j + 1) % ringSegments];
        triangles.push({
          pt0,
          pt1,
          pt2,
          ...props,
        });
      }
    }

    return triangles;
  }

  function createPlane(width, height, center, rotation, props) {
    const triangles = [];
    let pt0 = [-width / 2, height / 2, 0];
    let pt1 = [width / 2, height / 2, 0];
    let pt2 = [-width / 2, -height / 2, 0];

    pt0 = translatePoint(rotatePoint(pt0, rotation), center);
    pt1 = translatePoint(rotatePoint(pt1, rotation), center);
    pt2 = translatePoint(rotatePoint(pt2, rotation), center);
    triangles.push({
      pt0,
      pt1,
      pt2,
      ...props,
    });

    pt0 = [width / 2, height / 2, 0];
    pt1 = [width / 2, -height / 2, 0];
    pt2 = [-width / 2, -height / 2, 0];
    pt0 = translatePoint(rotatePoint(pt0, rotation), center);
    pt1 = translatePoint(rotatePoint(pt1, rotation), center);
    pt2 = translatePoint(rotatePoint(pt2, rotation), center);
    triangles.push({
      pt0,
      pt1,
      pt2,
      ...props,
    });

    return triangles;
  }

  const scene1 = [
    ...createTorus(4, 2, [0, 0, 0], [Math.PI / 2, 0, 0], { color: [0, 1, 0], isLight: false }),
    ...createPlane(20, 20, [0, 10, 0], [Math.PI / 2, 0, 0], { color: [3, 3, 3], isLight: true }),
    ...createPlane(1000, 1000, [0, -5, 0], [Math.PI / 2, 0, 0], { color: [1, 1, 1], isLight: false }),
  ];

  const scene = scene1.map((obj, id) => {
    const edge0 = sub(obj.pt1, obj.pt0);
    const edge1 = sub(obj.pt2, obj.pt0);
    obj.normal = unit(cross(edge0, edge1));

    obj.id = id;
    return obj;
  });

  return {
    scene,
  };
}

function createMemoryFunctions(args) {
  const { wasm } = args;
  const totalAvailable = 1024 * 4;
  const memPtr = wasm.alloc(totalAvailable);
  const memView = new DataView(wasm.memory.buffer, memPtr, totalAvailable);
  let staOffset = 0;
  let dynOffset = 2048; // half of totalAvailable

  function get(A, idx) {
    return memView.getFloat32(A + idx * 4 - memPtr, true);
  }
  function set(A, idx, v) {
    memView.setFloat32(A + idx * 4 - memPtr, v, true);
  }
  const allocFloat32Array = (size) => {
    const byteOffset = memPtr + dynOffset;
    dynOffset += size * 4;
    return byteOffset;
  };

  const allocStaticFloat32Array = (size) => {
    const byteOffset = memPtr + staOffset;
    staOffset += size * 4;
    return byteOffset;
  };
  const free = () => (dynOffset = 2048);
  return {
    memory: {
      allocFloat32Array,
      allocStaticFloat32Array,
      free,
      get,
      set,
    },
  };
}

function createAspectRatioFunction() {
  return {
    aspectRatio: (width, height) => {
      let gcd = width;
      let reminder = height;
      while (reminder != 0) {
        const temp = reminder;
        reminder = gcd % reminder;
        gcd = temp;
      }
      return [width / gcd, height / gcd];
    },
  };
}

function createCamera(args) {
  const { width, height, aspectRatio } = args;
  const [w, h] = aspectRatio(width, height);
  return {
    camera: {
      leftTop: [-w, h + 1, -50.0],
      rightTop: [w, h + 1, -50.0],
      leftBottom: [-w, -h + 1, -50.0],
      eye: [0.0, 0.0, -65.0],
    },
  };
}

function createImageGeometry({ width, height }) {
  return {
    imageGeometry: {
      width,
      height,
    },
  };
}

function createVectorFunctions() {
  const sub = (A, B) => [A[0] - B[0], A[1] - B[1], A[2] - B[2]];
  const add = (A, B) => [A[0] + B[0], A[1] + B[1], A[2] + B[2]];
  const mul = (A, B) => [A[0] * B[0], A[1] * B[1], A[2] * B[2]];
  const dot = (A, B) => A[0] * B[0] + A[1] * B[1] + A[2] * B[2];
  // 135ms
  const scale = (A, s) => [A[0] * s, A[1] * s, A[2] * s];
  const norm = (A) => Math.sqrt(dot(A, A));
  const unit = (A) => scale(A, 1.0 / norm(A));
  const abs = (A) => [Math.abs(A[0]), Math.abs(A[1]), Math.abs(A[2])];
  const maxDimension = (A) => {
    if (A[0] > A[1] && A[0] > A[2]) return 0;
    if (A[1] > A[0] && A[1] > A[3]) return 1;
    return 2;
  };
  const permute = (A, i, j, k) => [A[i], A[j], A[k]];
  const cross = (A, B) => {
    const j = A[1] * B[2] - B[1] * A[2];
    const k = A[2] * B[0] - A[0] * B[2];
    const l = A[0] * B[1] - A[1] * B[0];
    return [j, k, l];
  };
  const vector = {
    sub,
    add,
    mul,
    dot,
    scale,
    norm,
    unit,
    abs,
    maxDimension,
    permute,
    cross,
  };

  return {
    vector,
  };
}

function calculatePrimaryRays(args) {
  const {
    camera: { rightTop, leftTop, leftBottom, eye },
    imageGeometry: { width, height },
    vector: { scale, add, sub, unit },
  } = args;
  const vdu = scale(sub(rightTop, leftTop), 1.0 / width);
  const vdv = scale(sub(leftBottom, leftTop), 1.0 / height);
  const primaryRays = [];
  for (let y = 0; y < height; y++) {
    for (let x = 0; x < width; x++) {
      const pixel = y * width + x;
      const origin = eye;
      const direction = unit(sub(add(add(scale(vdu, x), scale(vdv, y)), leftTop), origin));
      primaryRays[pixel] = {
        pixel,
        origin,
        direction,
      };
    }
  }

  return {
    primaryRays,
  };
}

function createRandomDirectionFunction(args) {
  const {
    vector: { dot, norm },
  } = args;

  const randomDirection = (normal) => {
    const p = [0, 0, 0];
    while (true) {
      p[0] = Math.random() - 0.5;
      p[1] = Math.random() - 0.5;
      p[2] = Math.random() - 0.5;
      const n = 1.0 / norm(p);
      p[0] *= n;
      p[1] *= n;
      p[2] *= n;
      if (dot(p, normal) >= 0) {
        return p;
      }
    }
  };
  return {
    randomDirection,
  };
}

function createTriangleIntersectFunction(args) {
  const { sub, abs, maxDimension, permute, scale, add } = args.vector;
  const triangleIntersect = (ray, triangle) => {
    const { pt0, pt1, pt2, normal } = triangle;
    let pt0_t = sub(pt0, ray.origin);
    let pt1_t = sub(pt1, ray.origin);
    let pt2_t = sub(pt2, ray.origin);
    const k = maxDimension(abs(ray.origin));
    const i = (k + 1) % 3;
    const j = (i + 1) % 3;

    const [pdx, pdy, pdz] = permute(ray.direction, i, j, k);
    const sz = 1.0 / pdz;
    const sx = -pdx * sz;
    const sy = -pdy * sz;

    pt0_t = permute(pt0_t, i, j, k);
    pt1_t = permute(pt1_t, i, j, k);
    pt2_t = permute(pt2_t, i, j, k);

    const pt0_t_0 = pt0_t[0] + sx * pt0_t[2];
    const pt0_t_1 = pt0_t[1] + sy * pt0_t[2];
    const pt0_t_2 = pt0_t[2] * sz;
    const pt1_t_0 = pt1_t[0] + sx * pt1_t[2];
    const pt1_t_1 = pt1_t[1] + sy * pt1_t[2];
    const pt1_t_2 = pt1_t[2] * sz;
    const pt2_t_0 = pt2_t[0] + sx * pt2_t[2];
    const pt2_t_1 = pt2_t[1] + sy * pt2_t[2];
    const pt2_t_2 = pt2_t[2] * sz;

    const e0 = pt1_t_0 * pt2_t_1 - pt1_t_1 * pt2_t_0;
    const e1 = pt2_t_0 * pt0_t_1 - pt2_t_1 * pt0_t_0;
    const e2 = pt0_t_0 * pt1_t_1 - pt0_t_1 * pt1_t_0;

    if ((e0 < 0.0 || e1 < 0.0 || e2 < 0.0) && (e0 > 0.0 || e1 > 0.0 || e2 > 0.0)) {
      return { hit: false };
    }

    const det = e0 + e1 + e2;
    if (det == 0.0) {
      return { hit: false };
    }

    const t_scaled = e0 * pt0_t_2 + e1 * pt1_t_2 + e2 * pt2_t_2;

    if (det < 0.0 && t_scaled >= 0.0) {
      return { hit: false };
    }

    if (det > 0.0 && t_scaled <= 0.0) {
      return { hit: false };
    }

    const t = t_scaled / det;

    if (t > 0.007) {
      const point = add(scale(ray.direction, t), ray.origin);
      return {
        hit: true,
        distance: t,
        point,
        normal,
      };
    }

    return {
      hit: false,
    };
  };

  return {
    triangleIntersect,
  };
}

function createTraceFunction(args) {
  const { scene, triangleIntersect } = args;
  const trace = (ray) =>
    scene
      .map((obj) => ({ ...triangleIntersect(ray, obj), obj }))
      .filter(({ hit }) => hit)
      .reduce((acc, intersection) => (intersection.distance < acc.distance ? intersection : acc), {
        hit: false,
        distance: Number.MAX_VALUE,
      });
  return {
    trace,
  };
}

function createTracePrimaryRaysFunction(args) {
  const { trace, traceBvh, primaryRays, width } = args;
  const traceResults = [];
  const tracePrimaryRays = (section, tileSize) => {
    let idx = 0;
    const startPixel = section;
    const endPixel = section + width * tileSize;
    for (let i = startPixel; i < endPixel; i++) {
      const ray = primaryRays[i];
      traceResults[idx++] = traceBvh(ray);
    }
    return traceResults;
  };
  return {
    tracePrimaryRays,
  };
}

function createGenerateBitmapFunction(args) {
  const {
    shading,
    vector: { mul },
  } = args;

  const generateBitmap = (traceResults, section, bbp) => {
    let idx = section * 4 * 3;
    for (const it of traceResults) {
      let pixel = [0, 0, 0];
      if (it.hit) {
        pixel = it.obj.color;
        if (!it.obj.isLight) {
          const intensity = shading(it.point, it.normal, 0);
          pixel = mul(pixel, intensity);
        }
      }
      bbp.setFloat32(idx, pixel[0], true);
      bbp.setFloat32(idx + 4, pixel[1], true);
      bbp.setFloat32(idx + 8, pixel[2], true);
      idx += 12;
    }
  };
  return { generateBitmap };
}

function createRenderFunction(args) {
  const { tracePrimaryRays, generateBitmap, width, height } = args;
  const totalPixels = width * height;
  const render = (tileSize, bitmap, syncBuffer) => {
    const sync = new Uint32Array(syncBuffer);
    const sectionSize = tileSize * width;
    let section = Atomics.add(sync, 0, sectionSize);
    const bpp = new DataView(bitmap, 0, bitmap.length);
    while (section < totalPixels) {
      const traceResults = tracePrimaryRays(section, tileSize);
      generateBitmap(traceResults, section, bpp);
      section = Atomics.add(sync, 0, sectionSize);
    }
  };
  return { render };
}

function pipeline(fns) {
  return (args) => {
    let acc = { ...args };
    for (const fn of fns) {
      const result = fn(acc);
      acc = { ...acc, ...result };
    }
    return acc;
  };
}

function createShadingFunction(args) {
  const {
    vector: { dot },
    trace,
    traceBvh,
    randomDirection,
  } = args;
  const shading = (shadingPoint, pointNormal, depth) => {
    const color = [0, 0, 0];
    if (depth === 5) {
      return color;
    }
    const origin = [];
    origin[0] = pointNormal[0] * 0.1 + shadingPoint[0];
    origin[1] = pointNormal[1] * 0.1 + shadingPoint[1];
    origin[2] = pointNormal[2] * 0.1 + shadingPoint[2];

    const direction = randomDirection(pointNormal);
    const d = dot(pointNormal, direction);
    const ray = { origin, direction };
    const tr = traceBvh(ray);
    if (!tr.hit) {
      return color; //black up to this point;
    }
    color[0] = tr.obj.color[0] * d;
    color[1] = tr.obj.color[0] * d;
    color[2] = tr.obj.color[0] * d;
    if (!tr.obj.isLight) {
      const ncolor = shading(tr.point, tr.normal, depth + 1);
      color[0] *= ncolor[0];
      color[1] *= ncolor[1];
      color[2] *= ncolor[2];
    }
    return color;
  };

  return {
    shading,
  };
}

module.exports = {
  createScene,
  createMemoryFunctions,
  createAspectRatioFunction,
  createCamera,
  createImageGeometry,
  createVectorFunctions,
  calculatePrimaryRays,
  createRandomDirectionFunction,
  createTriangleIntersectFunction,
  initializeBVH,
  createTraceFunction,
  createTracePrimaryRaysFunction,
  createGenerateBitmapFunction,
  createRenderFunction,
  pipeline,
  createShadingFunction,
};

In the worker function, you will notice I made some changes so it no longer has all the rendering functions inside it. Instead, the rendering functions are stringified and passed to the worker as part of the intitialization. During this part, the functions are parsed and the rendering pipeline is created.

function workerFunction() {
  async function loadWasm(wasmFile) {
    const {
      instance: { exports: wasm },
    } = await WebAssembly.instantiate(wasmFile, {});
    return wasm;
  }

  function parseFunctions(functions) {
    const fns = {};
    for (const fn of functions) {
      fns[fn[0]] = new Function(`return ${fn[1]}`)();
    }
    return fns;
  }
  let render = null;

  async function init(wasmFile, width, height, functions) {
    const wasm = await loadWasm(wasmFile);
    const fns = parseFunctions(functions);
    const renderingPipeline = fns.pipeline([
      fns.createMemoryFunctions,
      fns.createVectorFunctions,
      fns.createAspectRatioFunction,
      fns.createScene,
      fns.createCamera,
      fns.createImageGeometry,
      fns.createRandomDirectionFunction,
      fns.calculatePrimaryRays,
      fns.createTriangleIntersectFunction,
      fns.initializeBVH,
      fns.createTraceFunction,
      fns.createShadingFunction,
      fns.createTracePrimaryRaysFunction,
      fns.createGenerateBitmapFunction,
      fns.createRenderFunction,
    ]);
    render = renderingPipeline({
      wasm,
      width,
      height,
      sceneSelector: 1,
    }).render;
  }

  this.onmessage = async (msg) => {
    const { data } = msg;
    if (data.operation === 'init') {
      const { wasmFile, width, height, functions } = data;
      await init(wasmFile, width, height, functions);
      this.postMessage(0);
    } else {
      render(data.tileSize, data.bitmap, data.syncBuffer);
      this.postMessage(1);
    }
  };
}

module.exports = workerFunction;

All the next code is the reguar path tracing implementation, with some minor changes for readability, but other than that, it is all the same.

class WorkersPool {
  #blobUrl;
  poolSize;
  maxPoolSize;
  #pool = [];
  results = [];
  constructor(poolSize, maxPoolSize, workerFunction) {
    const workerFunctionString = `(${workerFunction.toString()})()`;
    const blob = new Blob([workerFunctionString], {
      type: 'application/javascript',
    });
    this.#blobUrl = URL.createObjectURL(blob);
    this.poolSize = poolSize;
    this.maxPoolSize = maxPoolSize;
  }

  init(initPayload) {
    this.#pool = [];
    this.poolSize = Math.max(1, this.poolSize);
    return new Promise((resolve) => {
      let workersDone = 0;
      for (let i = 0; i < this.maxPoolSize; i++) {
        const worker = new Worker(this.#blobUrl);
        this.#pool.push(worker);
        worker.onmessage = () => {
          if (++workersDone === this.maxPoolSize) {
            resolve();
          }
        };
        worker.postMessage(initPayload);
      }
    });
  }

  resize(newSize) {
    this.poolSize = Math.max(1, newSize);
  }

  async process(payload) {
    return new Promise((resolve) => {
      let currentJob = 0;
      for (let i = 0; i < this.poolSize; i++) {
        let wrk = this.#pool[i];
        wrk.onmessage = async (msg) => {
          if (++currentJob === this.poolSize) {
            resolve();
          }
        };
        wrk.postMessage(payload);
      }
    });
  }
}

module.exports = WorkersPool;

Finally, this is the main entry point for the path tracer, hit run to see it in action!

const workerFunction = require('./worker-function.js');
const WorkersPool = require('./workers-pool.js');
const pathTracerFunctions = require('./path-tracer-functions.js');
const workersPool = new WorkersPool(1, 1 /* navigator.hardwareConcurrency */, workerFunction);

async function renderAndPresent(canvas, frameCount, framesAcc, syncBuffer, sharedBuffer, finalBitmap) {
  const ctx = canvas.getContext('2d');
  const width = canvas.width;
  const renderStart = performance.now();
  const bitmap = new Float32Array(sharedBuffer);
  const sync = new Uint32Array(syncBuffer);
  const tileSize = 8;

  sync[0] = 0;
  await workersPool.process({ tileSize, bitmap: sharedBuffer, syncBuffer });

  for (let i = 0; i < bitmap.length; i += 3) {
    const r = bitmap[i] + (framesAcc[i] || 0);
    const g = bitmap[i + 1] + (framesAcc[i + 1] || 0);
    const b = bitmap[i + 2] + (framesAcc[i + 2] || 0);
    framesAcc[i] = r;
    framesAcc[i + 1] = g;
    framesAcc[i + 2] = b;
    finalBitmap[i / 3] =
      (255 << 24) |
      ((Math.min(b / frameCount, 1) * 255) << 16) |
      ((Math.min(g / frameCount, 1) * 255) << 8) |
      (Math.min(r / frameCount, 1) * 255);
  }
  const imageData = new ImageData(new Uint8ClampedArray(finalBitmap.buffer), width);
  ctx.putImageData(imageData, 0, 0);
  const elapsed = Math.floor(performance.now() - renderStart);
  const elapsedMs = `${elapsed}ms|${(Math.round(10000 / elapsed) / 10).toFixed(1)}fps|${workersPool.poolSize}/${
    workersPool.maxPoolSize
  }threads(${window.adjustPool ? '?' : '!'})`;
  ctx.font = '20px monospace';
  ctx.textBaseline = 'top';
  const measure = ctx.measureText(elapsedMs);
  ctx.fillStyle = '#000000';
  ctx.fillRect(0, 0, measure.width, measure.fontBoundingBoxDescent);
  ctx.fillStyle = '#999999';
  ctx.fillText(elapsedMs, 0, 0);
}

window.running = false;
(async () => {
  const canvas = document.getElementById('canvas-1');
  const width = canvas.width;
  const height = canvas.height;
  const wasmFile = await (await fetch('wasm/vector_simd.wasm')).arrayBuffer();
  await workersPool.init({
    operation: 'init',
    wasmFile,
    width,
    height,
    functions: Object.entries(pathTracerFunctions).map(([name, fn]) => [name, fn.toString()]),
  });
  let frameCount = 0;
  const framesAcc = new Array(width * height * 3);
  const sharedBuffer = new SharedArrayBuffer(width * height * 3 * 4);
  const syncBuffer = new SharedArrayBuffer(4);
  const finalBitmap = new Uint32Array(width * height);
  window.running = true;

  /* START: Auto adjust poolSize */
  const renderTimes = {};
  let renderStart = performance.now();
  window.adjustPool = true;
  /* END: Auto adjust poolSize */

  const animation = async () => {
    frameCount++;
    await renderAndPresent(canvas, frameCount, framesAcc, syncBuffer, sharedBuffer, finalBitmap);
    /* START: Auto adjust poolSize */
    renderTimes[workersPool.poolSize] += performance.now();
    if (window.adjustPool && frameCount % 10 === 0) {
      renderTimes[workersPool.poolSize] = performance.now() - renderStart;
      if (workersPool.poolSize < workersPool.maxPoolSize) {
        await workersPool.resize(workersPool.poolSize + 1);
        renderStart = performance.now();
      } else {
        window.adjustPool = false;
        let fastest = Number.MAX_VALUE;
        let poolSize = 0;
        for (const [size, time] of Object.entries(renderTimes)) {
          if (time < fastest) {
            fastest = time;
            poolSize = size;
          }
        }
        await workersPool.resize(poolSize);
      }
    }
    /* END: Auto adjust poolSize */
    window.running && window.requestAnimationFrame(animation);
  };
  window.requestAnimationFrame(animation);
})();

Results & Summary

Here are some numbers comparing the previous tracing implementation (brute force), and the new one using the BVH traversal:

Algorithm Milliseconds per Frame
Brute Force 90300
BVH 2500

It is quite obvious there is a massive speedup of around 36x, even though it is not using multithreading nor SIMD vectorization.

The current naïve implementation works well because the the vast majority of triangles are added to the scene in an organized way and all of them are in the middle of the scene. If this was not the case, the speedup would not be as good, or maybe even worse.

A better algorithm for building the BVH and using a data structure not based on pointers (the tree), might help even more, I will explore those topics in later articles.

Final Note

You have seen me posting rendering speed numbers but be warned that those are only valid for the context of the current article.

I often use different configurations (hardware [not my work laptop, mind you], os, browser) from article to article when running the algorithms, as such, absolute running times cannot be compared as is, it is better to use speedup numbers.

Be sure though, that performance numbers in an article correspond to the same configuration, it would be disingenuous to do otherwise.

Enrique CR - 2024-08-01