
Subsurface scattering in Vulkan path tracing

Does anyone have any recommendations on how to do subsurface scattering in a ray/path tracer?

I tried just randomizing the index of refraction coefficient, but that didn't really work.

I'm starting with this: subsurface scattering only occurs when opacity is larger than 0 and reflectivity is less than 1. Sound reasonable?

taby said:
I'm starting with this: subsurface scattering only occurs when opacity is larger than 0 and reflectivity is less than 1. Sound reasonable?

That's not correct. You will need to use some BSSDF instead of BRDF for SSS surfaces. It uses different parameters (like density for the material).

It was introduced by Jensen if I'm not mistaken -​ - there was also quite solid chapter in PBRT. If you have a proper understanding of math - it should be straight forward to implement (at least it was for me).

Are you saying that I need to replace the diffuse reflection code with something that does subsurface scattering? This is what I gathered from the paper by Jensen, which I skimmed. Thanks for pointing me toward it.

For subsurface diffuse reflection, I embed the origin just inside the mesh, and randomize the direction. I take step sizes based on the density of the material. Sound better?

Well, I have one half of it working… the light transport works well, and as expected. That is, subsurface scattering scatters the light and the caustics get spread out and diluted.

float trace_path_backward(const int steps, const vec3 origin, const vec3 direction, const float hue, const float eta)
	vec3 o = origin;
	vec3 d = direction;

	const float energy = 1;
	const float caustic_energy = 1;//energy;

	float ret_colour = 0;
	float local_colour = energy;
	float total = 0;

	bool doing_refraction_caustic = false;

	const vec3 mask = hsv2rgb(vec3(hue, 1.0, 1.0));

	for(int i = 0; i < steps; i++)
		const float tmin = 0.001;
		const float tmax = 10000.0;

		traceRayEXT(topLevelAS, gl_RayFlagsOpaqueEXT, 0xff, 0, 0, 0, o, tmin, d, tmax, 0);

		total += mask.r;
		total += mask.g;
		total += mask.b;

			local_colour += caustic_energy*(rayPayload.color.r*mask.r + rayPayload.color.g*mask.g + rayPayload.color.b*mask.b);
			doing_refraction_caustic = false;
			local_colour *= (rayPayload.color.r*mask.r + rayPayload.color.g*mask.g + rayPayload.color.b*mask.b);

		// If hit the sky
		if(rayPayload.dist == -1.0)
			ret_colour += local_colour;

		// If this is simply the final step
		// then don't throw away perfectly 
		// good data
		if(i == steps - 1)
			ret_colour += local_colour;

		// Do light transport
		vec3 hitPos = o + d * rayPayload.dist;

		if(stepAndOutputRNGFloat(prng_state) <= rayPayload.opacity)
			vec3 o_reflect = hitPos + rayPayload.normal * 0.01;
			vec3 d_reflect = reflect(d, rayPayload.normal);

			vec3 o_scatter, d_scatter;

			if(stepAndOutputRNGFloat(prng_state) <= rayPayload.subsurface)
				 o_scatter = hitPos - rayPayload.normal * 0.001 + RandomUnitVector(prng_state)*0.001;// - rayPayload.normal * 0.001 + RandomUnitVector(prng_state)*0.001;//hitPos + rayPayload.normal * 0.01;
				 d_scatter = RandomUnitVector(prng_state);
				 o_scatter = hitPos + rayPayload.normal * 0.01;
				 d_scatter = cosWeightedRandomHemisphereDirection(rayPayload.normal, prng_state);

			o = mix(o_scatter, o_reflect, rayPayload.reflector);
			d = normalize(mix(d_scatter, d_reflect, rayPayload.reflector));
			doing_refraction_caustic = true;

			vec3 o_transparent = vec3(0.0);
			vec3 d_transparent = vec3(0.0);

			// Incoming
			if(dot(d, rayPayload.normal) <= 0.0)
				o_transparent = - rayPayload.normal * 0.01f;
				d_transparent = refract(d, rayPayload.normal, eta);
			else // Outgoing
				vec3 temp_dir = refract(d, -rayPayload.normal, 1.0 / eta);

				if(temp_dir != vec3(0.0))
					o_transparent = + rayPayload.normal * 0.01f;
					d_transparent = mix(temp_dir, RandomUnitVector(prng_state), rayPayload.subsurface);//	temp_dir;
					// Total internal reflection
					o_transparent = - rayPayload.normal * 0.01f;
					d_transparent =  mix(reflect(d, -rayPayload.normal), RandomUnitVector(prng_state), rayPayload.subsurface);//reflect(d, -rayPayload.normal);

			o = o_transparent;
			d = normalize(d_transparent);

	return ret_colour / total;

I found that I need to keep track of the location of the ray ends, to see if they're inside or outisde the mesh. Any questions or comments?

bool is_inside_mesh(const vec3 location)
	RayPayload r = rayPayload;
	bool is_outside = false;

	const vec3 direction = vec3(0, 1, 0);

	traceRayEXT(topLevelAS, gl_RayFlagsOpaqueEXT, 0xff, 0, 0, 0, location, 0.001, direction, 10000.0, 0);

	if(rayPayload.dist == -1.0 || dot(direction, rayPayload.normal) <= 0.0)
		is_outside = true;

	rayPayload = r;
	return !is_outside;


taby said:
bool is_inside_mesh(const vec3 location)

I've talked about how to do such test robustly and efficiently in the past, as i needed it to voxelize non manifold input.

But it's out of reach for realtime application like this.
In theory a precomputed SDF would work, but lacks the accuracy you need close to the surface.

taby said:
const vec3 direction = vec3(0, 1, 0);

Maybe this can be improved…

I assume your algorithm is something like this:
At a hitpoint, do some random small offset to sample the volume around it.
At the offset, do your inside test and reject if outside, or generate a new random offset to try again.

If so, you could use the surface normal of the initial hitpoint for the inside test. If it's inside, it should find the surface much quicker this way, and you could also set a short maximum ray length (e.g. twice the radius from the random offset), to minimize the searching also in case you're outside.

This should be quite an ‘optimization’ if i make the proper assumptions about your approach.

But it's still terrible.
An efficient solution would to do a random geodesic walk on the surface. If you get a triangle index with the hit, you could traverse adjacent triangles to do this.
But ofc. game models for rendering usually have no adjacency information, and adding it takes a lot of memory.

It's all bad for realtime.
So i would have hoped this works:
You don't do the inside test at all, but accept any offset.
If you trace the ray from the inside, the ray should pass through back faced triangle s to ignore them. (I assume HW RT can be set to do this without generating a redundant hitpoint.)

This should work good enough? Ofc. if your object is thin, random offset behind the normal could exit the model entirely, and then you would hit the backside of a curtain and accumulate error. But i guess it's practical.

Btw, the visual effect i would expect is that edges like this:

would become soft, no longer that hard.
But maybe it fails because your model is composed from boxes, generating surfaces inside the volume defined from solid walls.
Sou you would need to remodel the mesh to remove all interior surfaces, eventually.

Yeah, it's not working entirely the way I want it to, but I think that I'm on the right track.

#version 460
#extension GL_EXT_ray_tracing : require
#extension GL_EXT_nonuniform_qualifier : enable

layout(location = 2) rayPayloadEXT bool shadowed;

struct ray
	vec4 direction;
	vec4 origin;
	bool in_use;

	vec3 normal;
	float depth;

	int child_refract_id;
	int child_reflect_id;
	int child_subsurface_id;
	int parent_id;

	float base_color;
	float accumulated_color;
	float base_opacity;

	float reflection_constant;
	float refraction_constant;
	float sss_constant;
	float sss_density;

	float tint_constant;
	vec3 tint_colour;

	bool external_reflection_ray;
	bool external_refraction_ray;
	bool internal_subsurface_ray;

	int level;

float mod289(float x){return x - floor(x * (1.0 / 289.0)) * 289.0;}
vec4 mod289(vec4 x){return x - floor(x * (1.0 / 289.0)) * 289.0;}
vec4 perm(vec4 x){return mod289(((x * 34.0) + 1.0) * x);}

float noise3(vec3 p, float wavelength){

	p /= wavelength;

    vec3 a = floor(p);
    vec3 d = p - a;
    d = d * d * (3.0 - 2.0 * d);

    vec4 b = a.xxyy + vec4(0.0, 1.0, 0.0, 1.0);
    vec4 k1 = perm(b.xyxy);
    vec4 k2 = perm(k1.xyxy + b.zzww);

    vec4 c = k2 + a.zzzz;
    vec4 k3 = perm(c);
    vec4 k4 = perm(c + 1.0);

    vec4 o1 = fract(k3 * (1.0 / 41.0));
    vec4 o2 = fract(k4 * (1.0 / 41.0));

    vec4 o3 = o2 * d.z + o1 * (1.0 - d.z);
    vec2 o4 = o3.yw * d.x + o3.xz * (1.0 - d.x);

    return o4.y * d.y + o4.x * (1.0 - d.y);

const float TWO_PI = 8.0 * atan(1.0);

vec3 aabb_min = vec3(-2.5, -2.5, -2.5);
vec3 aabb_max = vec3(2.5, 2.5, 2.5);

const float opacity_factor = 0.01;

vec3 fog_colour = vec3(1.0, 1.0, 1.0);

uint prng_state = 0;

layout(binding = 0, set = 0) uniform accelerationStructureEXT topLevelAS;
layout(binding = 1, set = 0, rgba8) uniform image2D color_image;
layout(binding = 2, set = 0) uniform UBO 
	mat4 viewInverse;	
	mat4 projInverse;

	mat4 transformation_matrix;

	vec3 camera_pos;
	int vertexSize;
	bool screenshot_mode;

	uint tri_count;
	uint light_tri_count;

} ubo;

layout(binding = 3, set = 0) buffer Vertices { vec4 v[]; } vertices;
layout(binding = 4, set = 0) buffer Indices { uint i[]; } indices;

struct RayPayload
	vec3 color;
	float dist;
	vec3 normal;
	float reflector;
	float opacity;
	float tint;
	vec3 tint_colour;
	float subsurface;
	float density;

layout(location = 0) rayPayloadEXT RayPayload rayPayload;

// Max. number of recursion is passed via a specialization constant
layout (constant_id = 0) const int MAX_RECURSION = 0;
const int buffer_size = 16;

vec3 rgb2hsv(vec3 c)
    vec4 K = vec4(0.0, -1.0 / 3.0, 2.0 / 3.0, -1.0);
    vec4 p = mix(vec4(, K.wz), vec4(, K.xy), step(c.b, c.g));
    vec4 q = mix(vec4(p.xyw, c.r), vec4(c.r, p.yzx), step(p.x, c.r));

    float d = q.x - min(q.w, q.y);
    float e = 1.0e-10;
    return vec3(abs(q.z + (q.w - q.y) / (6.0 * d + e)), d / (q.x + e), q.x);

vec3 hsv2rgb(vec3 c)
    vec4 K = vec4(1.0, 2.0 / 3.0, 1.0 / 3.0, 3.0);
    vec3 p = abs(fract( + * 6.0 - K.www);
    return c.z * mix(, clamp(p -, 0.0, 1.0), c.y);

float stepAndOutputRNGFloat(inout uint rngState)
  // Condensed version of pcg_output_rxs_m_xs_32_32, with simple conversion to floating-point [0,1].
  rngState  = rngState * 747796405 + 1;
  uint word = ((rngState >> ((rngState >> 28) + 4)) ^ rngState) * 277803737;
  word      = (word >> 22) ^ word;
  return float(word) / 4294967295.0f;

vec3 RandomUnitVector(inout uint state)

float z = stepAndOutputRNGFloat(state) * 2.0f - 1.0f;
    float a = stepAndOutputRNGFloat(state) * TWO_PI;
    float r = sqrt(1.0f - z * z);
    float x = r * cos(a);
    float y = r * sin(a);
    return normalize(vec3(x, y, z));


// I forget where this came from
vec3 cosWeightedRandomHemisphereDirection( const vec3 n, inout uint state )
  	vec2 r = vec2(stepAndOutputRNGFloat(state), stepAndOutputRNGFloat(state));

	vec3  uu = normalize( cross( n, vec3(0.0,1.0,1.0) ) );
	vec3  vv = cross( uu, n );
	float ra = sqrt(r.y);
	float rx = ra*cos(6.2831*r.x); 
	float ry = ra*sin(6.2831*r.x);
	float rz = sqrt( 1.0-r.y );
	vec3  rr = vec3( rx*uu + ry*vv + rz*n );
    return normalize( rr );

bool in_aabb(vec3 pos, vec3 aabb_min, vec3 aabb_max)
	if((pos.x >= aabb_min.x && pos.x <= aabb_max.x) &&
	(pos.y >= aabb_min.y && pos.y <= aabb_max.y) &&
	(pos.z >= aabb_min.z && pos.z <= aabb_max.z))
		return true;
		return false;

bool BBoxIntersect(const vec3 boxMin, const vec3 boxMax, const vec3 origin, const vec3 dir, out float out_t0, out float out_t1)
	vec3 invdir = 1.0 / dir;
	vec3 tbot = invdir * (boxMin - origin);
	vec3 ttop = invdir * (boxMax - origin);
	vec3 tmin = min(ttop, tbot);
	vec3 tmax = max(ttop, tbot);
	vec2 t = max(tmin.xx, tmin.yz);
	float t0 = max(t.x, t.y);
	t = min(tmax.xx, tmax.yz);
	float t1 = min(t.x, t.y);
	out_t0 = t0;
	out_t1 = t1;

	return t1 > max(t0, 0.0);

float trace_path_backward(const int steps, const vec3 origin, const vec3 direction, const float hue, const float eta)
	vec3 o = origin;
	vec3 d = direction;

	const float energy = 1;
	const float caustic_energy = 1;//energy;

	float ret_colour = 0;
	float local_colour = energy;
	float total = 0;

	bool doing_refraction_caustic = false;

	const vec3 mask = hsv2rgb(vec3(hue, 1.0, 1.0));

	for(int i = 0; i < steps; i++)
		const float tmin = 0.001;
		const float tmax = 10000.0;

		traceRayEXT(topLevelAS, gl_RayFlagsOpaqueEXT, 0xff, 0, 0, 0, o, tmin, d, tmax, 0);

		total += mask.r;
		total += mask.g;
		total += mask.b;

			local_colour += caustic_energy*(rayPayload.color.r*mask.r + rayPayload.color.g*mask.g + rayPayload.color.b*mask.b);
			doing_refraction_caustic = false;
			local_colour *= (rayPayload.color.r*mask.r + rayPayload.color.g*mask.g + rayPayload.color.b*mask.b);

		// If hit the sky
		if(rayPayload.dist == -1.0)
			ret_colour += local_colour;

		// If this is simply the final step
		// then don't throw away perfectly 
		// good data
		if(i == steps - 1)
			ret_colour += local_colour;

		vec3 hitPos = o + d * rayPayload.dist;

		if(stepAndOutputRNGFloat(prng_state) <= rayPayload.opacity)
			vec3 o_reflect = hitPos + rayPayload.normal * 0.01;
			vec3 d_reflect = reflect(d, rayPayload.normal);

			vec3 o_scatter, d_scatter;

			if(stepAndOutputRNGFloat(prng_state) <= rayPayload.subsurface)
				 o_scatter = hitPos - rayPayload.normal * 0.001 + RandomUnitVector(prng_state)*0.001;// - rayPayload.normal * 0.001 + RandomUnitVector(prng_state)*0.001;//hitPos + rayPayload.normal * 0.01;
				 d_scatter = RandomUnitVector(prng_state);
				 o_scatter = hitPos + rayPayload.normal * 0.01;
				 d_scatter = cosWeightedRandomHemisphereDirection(rayPayload.normal, prng_state);

			o = mix(o_scatter, o_reflect, rayPayload.reflector);
			d = normalize(mix(d_scatter, d_reflect, rayPayload.reflector));
			doing_refraction_caustic = true;

			vec3 o_transparent = vec3(0.0);
			vec3 d_transparent = vec3(0.0);

			// Incoming
			if(dot(d, rayPayload.normal) <= 0.0)
				o_transparent = - rayPayload.normal * 0.01f;
				d_transparent = refract(d, rayPayload.normal, eta);
			else // Outgoing
				vec3 temp_dir = refract(d, -rayPayload.normal, 1.0 / eta);

				if(temp_dir != vec3(0.0))
					o_transparent = + rayPayload.normal * 0.01f;
					d_transparent = mix(temp_dir, RandomUnitVector(prng_state), rayPayload.subsurface);//	temp_dir;
					// Total internal reflection
					o_transparent = - rayPayload.normal * 0.01f;
					d_transparent =  mix(reflect(d, -rayPayload.normal), RandomUnitVector(prng_state), rayPayload.subsurface);//reflect(d, -rayPayload.normal);

			o = o_transparent;
			d = normalize(d_transparent);

	return ret_colour / total;

float get_radiance_backward(const int samples, const int steps, const vec3 origin, const vec3 seed_direction, const float hue, const float eta)
	RayPayload r = rayPayload;

	float ret_colour = 0.0;

	for(int s = 0; s < samples; s++)
		ret_colour += trace_path_backward(steps, origin, seed_direction, hue, eta);

	rayPayload = r;

	return ret_colour / samples;

float get_omni_radiance_backward(const int samples, const int steps, const vec3 origin, const float hue, const float eta)
	RayPayload r = rayPayload;

	float ret_colour = 0.0;

	for(int s = 0; s < samples; s++)
		ret_colour += trace_path_backward(steps, origin, RandomUnitVector(prng_state), hue, eta);

	rayPayload = r;

	return ret_colour / samples;

void get_fog(inout float dist_color, inout float dist_opacity, vec3 origin, vec3 direction, vec3 hitPos, float hue, float eta)
	vec3 start = origin;
	vec3 end = origin;

	float t0 = 0.0;
	float t1 = 0.0;
	const float target_step_length = 0.1;

	if(in_aabb(origin, aabb_min, aabb_max))
		vec3 backout_pos = origin -*10000.0;

		if(BBoxIntersect(aabb_min, aabb_max, backout_pos,, t0, t1))
			start = origin;//backout_pos +*t0;
			end = backout_pos +*t1;
		if(BBoxIntersect(aabb_min, aabb_max,origin,, t0, t1))
			start = origin +*t0;
			end = origin +*t1;

	if(distance(origin, start) > distance(origin,
		start = + rayPayload.normal * 0.01f;

	if(distance(origin, end) > distance(origin,
		end = + rayPayload.normal * 0.01f;

	const int num_steps = int(floor((distance(start, end) / target_step_length)));

	if(num_steps >= 2)
		const vec3 step = (end - start) / (num_steps - 1);

		vec3 curr_step = start;

		for(int j = 0; j < num_steps; j++, curr_step += step)
			float colour = get_omni_radiance_backward(10, 5, curr_step, hue, eta);

			float noise = noise3(curr_step, 1.0);

	//		noise *= noise3(curr_step, 1.0/2.0);
	//		noise *= noise3(curr_step, 1.0/4.0);
	//		noise *= noise3(curr_step, 1.0/8.0);

			//colour *= noise;

			const float trans = 1.0 - clamp(dist_opacity, 0.0, 1.0);
			dist_color += colour*trans;
			dist_opacity += 0.1*colour*trans;

	const vec3 mask = hsv2rgb(vec3(hue, 1.0, 1.0));
	dist_color *= (fog_colour.r*mask.r + fog_colour.g*mask.g + fog_colour.b*mask.b);

void get_origin_and_direction(inout vec4 origin, inout vec4 direction, float x, float y, float image_width, float image_height)
	const vec2 pixelCenter = vec2(x, y) + vec2(0.5);
	const vec2 inUV = pixelCenter/vec2(image_width, image_height);
	vec2 d = inUV * 2.0 - 1.0;

	origin = ubo.viewInverse * vec4(0,0,0,1);
	vec4 target = ubo.projInverse * vec4(d.x, d.y, 1, 1);
	direction = ubo.viewInverse*vec4(normalize( / target.w), 0);

bool is_inside_mesh(const vec3 location)
	RayPayload r = rayPayload;
	bool is_outside = false;

	const vec3 direction = vec3(0, 1, 0);

	traceRayEXT(topLevelAS, gl_RayFlagsOpaqueEXT, 0xff, 0, 0, 0, location, 0.001, direction, 10000.0, 0);

	if(rayPayload.dist == -1.0 || dot(direction, rayPayload.normal) <= 0.0)
		is_outside = true;

	rayPayload = r;
	return !is_outside;

int do_random_walk_until_exits_mesh(const float density, const int max_walk_length, vec3 pos, inout vec3 final_pos, inout vec3 final_dir, inout float base_opacity)
	final_pos = pos;
	final_dir = RandomUnitVector(prng_state);
	base_opacity = 0.0;

	int i = 1;

	while(is_inside_mesh(pos) && i <= max_walk_length)
		final_dir = RandomUnitVector(prng_state);
		final_pos += final_dir*(1 - density);
		base_opacity += (1.0 - clamp(base_opacity, 0.0, 1.0))*density;


	return i;

float get_ray0(vec4 origin, vec4 direction, const float hue, const float eta)
	// This algorithm stops when the buffer runs out of space,
	// or when the rays miss everything,
	// or when the level is too deep

	ray rays[buffer_size];
	int current_buffer_index = 0;

	uint rayFlags = gl_RayFlagsOpaqueEXT;
	uint cullMask = 0xff;
	float tmin = 0.001;
	float tmax = 10000.0;

	// Step one: make tree of ray segments
	for(int i = 0; i < buffer_size; i++)
		// Initialize buffer
		rays[i].in_use = false;
		rays[i].child_reflect_id = -1;
		rays[i].child_refract_id = -1;
		rays[i].child_subsurface_id = -1;
		rays[i].parent_id = -1;
		rays[i].external_reflection_ray = false;
		rays[i].external_refraction_ray = false;
		rays[i].internal_subsurface_ray = false;
		rays[i].normal = vec3(0, 0, 0);
		rays[i].depth = 0;
		rays[i].tint_constant = 0;
		rays[i].tint_colour = vec3(0, 0, 0);

	rays[0].direction = direction;
	rays[0].origin = origin;
	rays[0].in_use = true;
	rays[0].level = 0;
	rays[0].external_reflection_ray = true;

		int used_count = 0;

		for(int i = 0; i < buffer_size; i++)

					rays[i].base_color = get_radiance_backward(1000, 5, rays[i], rays[i], hue, eta);
					rays[i].base_color = get_radiance_backward(250, 5, rays[i], rays[i], hue, eta);

				traceRayEXT(topLevelAS, rayFlags, cullMask, 0, 0, 0, rays[i], tmin, rays[i], tmax, 0);
				// this particular ray missed everything, or got too deep
				if(rayPayload.dist == -1.0 || rays[i].level >= MAX_RECURSION)
					rays[i].in_use = false;

				vec4 hitPos = rays[i].origin + rays[i].direction * rayPayload.dist;

				rays[i].depth = distance(hitPos, rays[i].origin);
				rays[i].tint_constant = rayPayload.tint;
				rays[i].tint_colour = rayPayload.tint_colour;
				rays[i].normal = rayPayload.normal;				
				rays[i].reflection_constant = rayPayload.reflector;
				rays[i].refraction_constant = rayPayload.opacity;
				rays[i].sss_constant = rayPayload.subsurface;
				rays[i].sss_density = rayPayload.density;

				rays[i].base_opacity = 0.0;
				// Do some fog in an AABB
				//float dist_color = 0.0;
				//float dist_opacity = 0.0;
				//get_fog(dist_color, dist_opacity, rays[i], rays[i],, hue, eta);
				//dist_color = clamp(dist_color, 0.0, 1.0);
				//dist_opacity = clamp(dist_opacity, 0.0, 1.0);
				//rays[i].base_color = mix(rays[i].base_color, dist_color, dist_opacity);

				// entering mesh
				if(dot(rays[i], rayPayload.normal) <= 0.0)
					//generate new ray segments	

					if(current_buffer_index < buffer_size && rays[i].reflection_constant != 0.0)
						rays[current_buffer_index].base_opacity = 1.0;

						rays[i].child_reflect_id = current_buffer_index;
						rays[current_buffer_index] = + rayPayload.normal * 0.01f;
						rays[current_buffer_index] = reflect(rays[i], rayPayload.normal);
						rays[current_buffer_index].in_use = true;
						rays[current_buffer_index].level = rays[i].level + 1;
						rays[current_buffer_index].external_reflection_ray = true;
						rays[current_buffer_index].external_refraction_ray = false;
						rays[current_buffer_index].internal_subsurface_ray = false;
						rays[current_buffer_index].parent_id = i;

						if(current_buffer_index < buffer_size && rays[i].sss_constant != 0.0)
							rays[i].child_subsurface_id = current_buffer_index;
							rays[current_buffer_index] = - rayPayload.normal * 0.01 + RandomUnitVector(prng_state)*0.01;
							rays[current_buffer_index] = RandomUnitVector(prng_state);

							rays[current_buffer_index].base_opacity = 0.0;

							// do random walk here to get base opacity
							vec3 new_pos = vec3(0, 0, 0);
							vec3 new_dir = vec3(0, 0, 0);
							int step_count = do_random_walk_until_exits_mesh(0.01, 10, rays[current_buffer_index], new_pos, new_dir, rays[current_buffer_index].base_opacity);

							//generate new ray segment
							rays[current_buffer_index] = new_pos;
							rays[current_buffer_index] = new_dir;

							rays[current_buffer_index].in_use = true;
							rays[current_buffer_index].level = rays[i].level + 1;
							rays[current_buffer_index].external_reflection_ray = false;
							rays[current_buffer_index].external_refraction_ray = false;
							rays[current_buffer_index].internal_subsurface_ray = true;
							rays[current_buffer_index].parent_id = i;


					if(current_buffer_index < buffer_size && rays[i].refraction_constant != 1.0)
						rays[i].child_refract_id = current_buffer_index;

						//generate new ray segment
						rays[current_buffer_index].base_opacity = 1.0;
						rays[current_buffer_index] = - rayPayload.normal * 0.01f;
						rays[current_buffer_index] = refract(rays[i], rayPayload.normal, eta);
						rays[current_buffer_index].in_use = true;
						rays[current_buffer_index].level = rays[i].level + 1;
						rays[current_buffer_index].external_reflection_ray = false;
						rays[current_buffer_index].external_refraction_ray = false;
						rays[current_buffer_index].internal_subsurface_ray = false;
						rays[current_buffer_index].parent_id = i;
				// exiting mesh
					if(current_buffer_index < buffer_size)
						rays[i].child_refract_id = current_buffer_index;

						vec3 temp_dir = refract(rays[i], -rayPayload.normal, 1.0/eta);

						if(temp_dir != vec3(0.0))
							rays[current_buffer_index].base_opacity = 0.0;
							rays[current_buffer_index] = + rayPayload.normal * 0.01;
							rays[current_buffer_index] = temp_dir;
							rays[current_buffer_index].external_reflection_ray = false;
							rays[current_buffer_index].external_refraction_ray = true;
							rays[current_buffer_index].internal_subsurface_ray = false;
							rays[current_buffer_index].in_use = true;
							rays[current_buffer_index].level = rays[i].level + 1;
							rays[current_buffer_index].parent_id = i;

							// Total internal reflection
							rays[current_buffer_index].base_opacity = 0.0;
							rays[current_buffer_index] = - rayPayload.normal * 0.01f;
							rays[current_buffer_index] = reflect(rays[i], -rayPayload.normal);
							rays[current_buffer_index].in_use = true;
							rays[current_buffer_index].level = rays[i].level + 1;
							rays[current_buffer_index].external_reflection_ray = false;
							rays[current_buffer_index].external_refraction_ray = false;
							rays[current_buffer_index].parent_id = i;


				// The processing of this ray segment is complete
				rays[i].in_use = false;

		if(used_count == 0)

	// Step two: this is the Fresnel reflection-refraction code
	// Start at the tips of the branches, work backwards to the root
	for(int i = current_buffer_index - 1; i >= 0; i--)
		bool pure_refraction = false;
		bool pure_reflection = false;
		bool neither = false;
		bool both = false;

		if(rays[i].child_refract_id != -1 && rays[i].child_reflect_id == -1)
			pure_refraction = true;

		if(rays[i].child_refract_id == -1 && rays[i].child_reflect_id != -1)
			pure_reflection = true;

		if(rays[i].child_refract_id == -1 && rays[i].child_reflect_id == -1)
			neither = true;

		if(rays[i].child_refract_id != -1 && rays[i].child_reflect_id != -1)
			both = true;

		float accum = 0.0;

			accum = rays[i].base_color;
		else if(both)
			// Fake the Fresnel refraction-reflection
			const float ratio = 1.0 - dot(-normalize(rays[i], rays[i].normal);

			float reflect_accum = mix(rays[i].base_color, rays[rays[i].child_reflect_id].accumulated_color, rays[i].reflection_constant);
			float refract_accum = mix(rays[i].base_color, rays[rays[i].child_refract_id].accumulated_color, 1.0 - rays[i].refraction_constant);
			if(rays[i].child_subsurface_id != -1)
				float subsurface_accum = mix(rays[i].base_color, rays[rays[i].child_subsurface_id].accumulated_color, rays[i].sss_constant);
				reflect_accum = mix(reflect_accum, subsurface_accum, 0.0);//rays[i].base_opacity);

			accum = mix(refract_accum, reflect_accum, ratio);
		else if(pure_refraction)
			accum = mix(rays[i].base_color, rays[rays[i].child_refract_id].accumulated_color, 1.0 - rays[i].refraction_constant);	
		else if(pure_reflection)
			accum = mix(rays[i].base_color, rays[rays[i].child_reflect_id].accumulated_color, rays[i].reflection_constant);

		// Do tinting
		const vec3 mask = hsv2rgb(vec3(hue, 1.0, 1.0));
		const float t = rays[i].tint_colour.r*mask.r + rays[i].tint_colour.g*mask.g + rays[i].tint_colour.b*mask.b;
		float x = accum;
		accum = mix(x, t, rays[i].tint_constant);
		accum = min(x, accum);

		rays[i].accumulated_color = accum;

	// Show level depth as grayscale colour
	//float s = 1.0 - float(rays[current_buffer_index - 1].level) / float(MAX_RECURSION);
	// return s;

	// Show buffer fullness as grayscale colour
	//float s = 1.0 - float(current_buffer_index - 1) / float(buffer_size);
	//return s;

	return rays[0].accumulated_color;

vec3 plane_ray_intersection(vec3 planeP, vec3 planeN, vec3 rayP, vec3 rayD)
    float d = dot(planeP, -planeN);
    float t = -(d + dot(rayP, planeN)) / dot(rayD, planeN);
    return rayP + t * rayD;

void main() 
	// Only seed once per pixel
	prng_state = gl_LaunchIDEXT.y*gl_LaunchSizeEXT.x + gl_LaunchIDEXT.x;

	vec4 origin;
	vec4 direction;


	// Do chromatic aberration (good for making rainbows via prisms)
	const int channels = 3;
	const float max_hue = rgb2hsv(vec3(0.0, 0.0, 1.0)).x; // blue
	const float min_hue = rgb2hsv(vec3(1.0, 0.0, 0.0)).x; // red

	const float max_eta = 0.75;
	const float min_eta = 0.75;

	const float hue_diff = max_hue - min_hue;
	const float hue_step_size = hue_diff / (channels - 1);

	const float eta_diff = max_eta - min_eta;
	const float eta_step_size = eta_diff / (channels - 1);

	vec3 color = vec3(0.0);
	vec3 total = vec3(0.0);

	// Do primary ray
	float curr_hue = min_hue;
	float curr_eta = min_eta;

	for(int c = 0; c < channels; c++, curr_hue += hue_step_size, curr_eta += eta_step_size)
		const float f = get_ray0(origin, direction, curr_hue, curr_eta);
		const vec3 mask = hsv2rgb(vec3(curr_hue, 1.0, 1.0));

		total += mask;
		color += f*mask;

	// Do pseudorandom depth of field (DOF) rays
	const float focal_distance = 6.0;
	const vec4 focal_point_bootstrap = origin + normalize(direction) * focal_distance;
	const vec4 focal_point_on_plane = vec4(plane_ray_intersection(, (origin - focal_point_bootstrap).xyz,,, 0.0);
	const int DOF_samples = 0;

	for(int s = 0; s < DOF_samples; s++)
		curr_hue = min_hue;
		curr_eta = min_eta;

		vec3 r = RandomUnitVector(prng_state)*0.01;

		vec4 rand_origin = origin + vec4(r, 0.0);
		vec4 rand_direction = focal_point_on_plane - rand_origin;

		for(int c = 0; c < channels; c++, curr_hue += hue_step_size, curr_eta += eta_step_size)
			const float f = get_ray0(rand_origin, rand_direction, curr_hue, curr_eta);
			const vec3 mask = hsv2rgb(vec3(curr_hue, 1.0, 1.0));

			total += mask;
			color += f*mask;

	color /= total;
	//color = pow(color, vec3(1.0/2.2));

	imageStore(color_image, ivec2(gl_LaunchIDEXT.xy), vec4(color, 1.0));

After trying a handful of different methods of attack, I am finally progressing.

Here's a render showing the colouring from marching through the mesh. Next is to add subsurface scatterin (it's partially implemented).


