#pragma kernel CSMain FOG_ELLIPSOIDS ANISOTROPY POINT_LIGHTS DIR_LIGHT DIR_LIGHT_SHADOWS /*FOG_BOMB*/ /*ATTENUATION_LEGACY*/ FLIP_SHADOWS

float3 _FroxelResolution;
RWTexture3D<half4> _VolumeInject;
float4 _FrustumRays[4];
float4 _CameraPos;
float4 _FrustumRaysLight[4];
float4 _CameraPosLight;
float _Density;
float _Intensity;
float _Anisotropy;
Texture2D _Noise;
SamplerState sampler_Noise;
float4 _FogParams;
float _NoiseFogAmount;
float _NoiseFogScale;
float _WindSpeed;
float3 _WindDir;
float _Time;
Texture2D _LightTextureB0;
SamplerState sampler_LightTextureB0;
float _NearOverFarClip;
float3 _AmbientLight;

#ifdef FOG_BOMB
float _FogBombRadius;
float3 _FogBombPos;
#endif

#ifdef DIR_LIGHT
float3 _DirLightColor;
float3 _DirLightDir;
#ifdef DIR_LIGHT_SHADOWS
float _DirLightShadows;
float _ESMExponentDirLight;
float _DirLightOffset;
float _VSMBias;
float _DirBias;

struct ShadowParams
{
	float4x4 worldToShadow[4];
	float4 shadowSplitSpheres[4];
	float4 shadowSplitSqRadii;
};
RWStructuredBuffer<ShadowParams> _ShadowParams;
Texture2D _DirectionalShadowmap;
SamplerState sampler_DirectionalShadowmap;
#endif
#endif

#ifdef POINT_LIGHTS
struct PointLight
{
	float3 pos;
	float range;
	float3 color;
	float padding;
};
StructuredBuffer<PointLight> _PointLights;
float _PointLightsCount;
#endif

#ifdef FOG_ELLIPSOIDS
struct FogEllipsoid
{
	float3 pos;
	float radius;
	float3 axis;
	float stretch;
	float density;
	float noiseAmount;
	float noiseSpeed;
	float noiseScale;
	float feather;
	float blend;
	float padding1;
	float padding2;
};
StructuredBuffer<FogEllipsoid> _FogEllipsoids;
float _FogEllipsoidsCount;
#endif

float hash( float n ) { return frac(sin(n)*753.5453123); }
float noisep(float3 x)
{
    float3 p = floor(x);
    float3 f = frac(x);
    f = f*f*(3.0-2.0*f);
	
    float n = p.x + p.y*157.0 + 113.0*p.z;
    return lerp(lerp(lerp( hash(n+  0.0), hash(n+  1.0),f.x),
                   lerp( hash(n+157.0), hash(n+158.0),f.x),f.y),
               lerp(lerp( hash(n+113.0), hash(n+114.0),f.x),
                   lerp( hash(n+270.0), hash(n+271.0),f.x),f.y),f.z);
}

float noise(float3 x)
{
	float3 p = floor(x);
	float3 f = frac(x);
	f = f * f * (3.0 - 2.0 * f);
	float2 uv = (p.xy + float2(37.0,17.0) * p.z) + f.xy;
	float2 rg = _Noise.SampleLevel(sampler_Noise, (uv + 0.5) / 128.0, 0).yx;
	return -1.0 + 2.0 * lerp(rg.x, rg.y, f.z);
}

float ScrollNoise(float3 pos, float speed, float scale, float3 dir, float amount, float bias = 0.0, float mult = 1.0)
{
	float time = _Time * speed;
	float noiseScale = scale;
	float3 noiseScroll = dir * time;
	float3 q = pos - noiseScroll;
	q *= scale;
	float f = 0;
	f = 0.5 * noisep(q);
	// scroll the next octave in the opposite direction to get some morphing instead of just scrolling
	q += noiseScroll * scale;
	q = q * 2.01;
	f += 0.25 * noisep(q);

	f += bias;
	f *= mult;

	f = max(f, 0.0);
	return lerp(1.0, f, amount);
}

#if ATTENUATION_LEGACY

float Attenuation(float distNorm)
{
	return 1.0 / (1.0 + 25.0 * distNorm);
}

float AttenuationToZero(float distNorm)
{
	float att = Attenuation(distNorm);

	// Replicating unity light attenuation - pulled to 0 at range
	// if (distNorm > 0.8 * 0.8)
	// 		att *= 1 - (distNorm - 0.8 * 0.8) / (1 - 0.8 * 0.8);
	// Same, simplified
	float oneDistNorm = 1.0 - distNorm;
	att *= lerp(1.0, oneDistNorm * 2.78, step(0.64, distNorm));

	att *= step(0.0, oneDistNorm);

	return att;
}

#else

float Attenuation(float distSqr)
{
	float d = sqrt(distSqr);
	float kDefaultPointLightRadius = 0.25;
	return 1.0 / pow(1.0 + d / kDefaultPointLightRadius, 2);
}

float AttenuationToZero(float distSqr)
{
	// attenuation = 1 / (1 + distance_to_light / light_radius)^2
	//             = 1 / (1 + 2*(d/r) + (d/r)^2)
	// For more details see: https://imdoingitwrong.wordpress.com/2011/01/31/light-attenuation/
	float d = sqrt(distSqr);
	float kDefaultPointLightRadius = 0.25;
	float atten = 1.0 / pow(1.0 + d / kDefaultPointLightRadius, 2);
	float kCutoff = 1.0 / pow(1.0 + 1.0 / kDefaultPointLightRadius, 2); // cutoff equal to attenuation at distance 1.0

	// Force attenuation to fall towards zero at distance 1.0
	atten = (atten - kCutoff) / (1.f - kCutoff);
	if (d >= 1.f)
		atten = 0.f;

	return atten;
}

#endif

#ifdef FOG_ELLIPSOIDS
void FogEllipsoids(float3 pos, inout float density)
{
	for (int i = 0; i < _FogEllipsoidsCount; i++)
	{
		float3 dir = _FogEllipsoids[i].pos - pos;
		float3 axis = _FogEllipsoids[i].axis;
		float3 dirAlongAxis = dot(dir, axis) * axis;

		float scrollNoise = ScrollNoise(dir, _FogEllipsoids[i].noiseSpeed, _FogEllipsoids[i].noiseScale, axis, _FogEllipsoids[i].noiseAmount);

		dir = dir + dirAlongAxis * _FogEllipsoids[i].stretch;
		float distsq = dot(dir, dir);
		float radius = _FogEllipsoids[i].radius;
		float feather = _FogEllipsoids[i].feather;
		// float feather = 0.3;
		feather = (1.0 - smoothstep (radius * feather, radius, distsq));

		float contribution = scrollNoise * feather * _FogEllipsoids[i].density;
		density = lerp(density + contribution, density * contribution, _FogEllipsoids[i].blend);
	}
}
#endif

#ifdef FOG_BOMB
float Pulse(float c, float w, float x)
{
	return smoothstep(c - w, c, x) - smoothstep(c, c + w, x);
}
#endif

float Density(float3 pos)
{
	float fog = _FogParams.x;

	fog += max(exp(_FogParams.y*(-pos.y + _FogParams.z)) * _FogParams.w, 0.0);

	float3 warp = pos;

	#ifdef FOG_BOMB
	if (_FogBombRadius > 0)
	{
		float3 posToBomb = _FogBombPos - pos;
		float distToBomb = length(posToBomb);
		fog *= smoothstep (_FogBombRadius * 0.9, _FogBombRadius * 1.1, distToBomb);
		fog *= 1.0 + 0.5 * Pulse(_FogBombRadius * 1.35, 0.7, distToBomb);
		warp += (1 - smoothstep(_FogBombRadius, _FogBombRadius * 1.4, distToBomb)) * posToBomb * 0.3;
	}
	#endif

	fog *= ScrollNoise(warp, _WindSpeed, _NoiseFogScale, _WindDir, _NoiseFogAmount, -0.3, 8.0);

	#ifdef FOG_ELLIPSOIDS
	FogEllipsoids(pos, fog);
	#endif

	return max(fog * _Density, 0.0);
}

float3 FrustumRay(float2 uv, float4 frustumRays[4])
{
	float3 ray0 = lerp(frustumRays[0].xyz, frustumRays[1].xyz, uv.x);
	float3 ray1 = lerp(frustumRays[3].xyz, frustumRays[2].xyz, uv.x);
	return lerp(ray0, ray1, uv.y);
}

#ifdef ANISOTROPY
float anisotropy(float costheta)
{
	float g = _Anisotropy;
	float gsq = g*g;
	float denom = 1 + gsq - 2.0 * g * costheta;
	denom = denom * denom * denom;
	denom = sqrt(max(0, denom));
	return (1 - gsq) / denom;
}
#endif

#if DIR_LIGHT
#if DIR_LIGHT_SHADOWS
float ChebyshevUpperBound(float2 moments, float mean)
{
	// Compute variance
	float variance = moments.y - (moments.x * moments.x);
	variance = max(variance, _VSMBias * mean * mean);

	// Compute probabilistic upper bound
	float d = mean - moments.x;
	float pMax = variance / (variance + (d * d));

	// One-tailed Chebyshev
#ifdef FLIP_SHADOWS
	return (mean >= moments.x ? 1.0f : pMax);
#else
	return (mean <= moments.x ? 1.0f : pMax);
#endif
}

float4 getCascadeWeights_splitSpheres(float3 pos)
{
	float3 fromCenter0 = pos - _ShadowParams[0].shadowSplitSpheres[0].xyz;
	float3 fromCenter1 = pos - _ShadowParams[0].shadowSplitSpheres[1].xyz;
	float3 fromCenter2 = pos - _ShadowParams[0].shadowSplitSpheres[2].xyz;
	float3 fromCenter3 = pos - _ShadowParams[0].shadowSplitSpheres[3].xyz;
	float4 distances2 = float4(dot(fromCenter0,fromCenter0), dot(fromCenter1,fromCenter1), dot(fromCenter2,fromCenter2), dot(fromCenter3,fromCenter3));
	float4 weights = float4(distances2 >= _ShadowParams[0].shadowSplitSqRadii);
	return weights;
}

float4 getShadowCoord(float3 pos, float4 cascadeWeights)
{
	return mul(_ShadowParams[0].worldToShadow[(int)dot(cascadeWeights, float4(1,1,1,1))], float4(pos, 1));
}
#endif

float3 DirectionalLight(float3 pos)
{
	float att = 1;

	#if DIR_LIGHT_SHADOWS
	if (_DirLightShadows > 0.0)
	{
		float4 cascadeWeights = getCascadeWeights_splitSpheres(pos);

		float3 spos = pos + (_DirLightDir * _DirLightOffset);

		float4 samplePos = getShadowCoord(spos, cascadeWeights).xyzw;
		//---
		//att *= _DirectionalShadowmap.SampleLevel(sampler_DirectionalShadowmap, samplePos.xy, 0).r < samplePos.z;
		//---
		float2 shadowmap = _DirectionalShadowmap.SampleLevel(sampler_DirectionalShadowmap, samplePos.xy, 0).xy;
		float shadow = ChebyshevUpperBound(shadowmap.xy, samplePos.z / samplePos.w);

		shadow = saturate(lerp(shadow, 1.0, _DirBias));

		att = shadow;
		//---
		//float depth = exp(-40.0 * samplePos.z);
		//att = saturate(shadowmap.r * depth);
		//---
	}
	#endif

	#if ANISOTROPY
	float3 posToCamera = normalize(_CameraPos.xyz - pos);
	float costheta = dot(posToCamera, _DirLightDir);
	att *= anisotropy(costheta);
	#endif

	return _DirLightColor * att;
}
#endif

#ifdef POINT_LIGHTS
float3 PointLights(float3 pos)
{
	float3 color = 0;
	for (int i = 0; i < _PointLightsCount; i++)
	{
		float3 posToLight = _PointLights[i].pos - pos;
		float distNorm = dot(posToLight, posToLight) * _PointLights[i].range;
		float att = Attenuation(distNorm);

		#if ANISOTROPY
		float3 cameraToPos = normalize(pos - _CameraPos.xyz);
		float costheta = dot(cameraToPos, normalize(posToLight));
		att *= anisotropy(costheta);
		#endif

		color += _PointLights[i].color * att;
	}
	return color;
}
#endif

[numthreads(16,2,16)]
void CSMain (uint3 id : SV_DispatchThreadID)
{
	float3 color = _AmbientLight;
	float2 uv = float2(id.x / (_FroxelResolution.x - 1), id.y / (_FroxelResolution.y - 1));
	float z = id.z / (_FroxelResolution.z - 1);
	z = _NearOverFarClip + z * (1 - _NearOverFarClip);
	float3 pos = FrustumRay(uv, _FrustumRays) * z + _CameraPos.xyz;

	// Directional light
	#ifdef DIR_LIGHT
	color += DirectionalLight(pos);
	#endif

	// Point lights
	#ifdef POINT_LIGHTS
	color += PointLights(pos);
	#endif

	// Density
	float density = Density(pos);

	// Output
	float4 output;
	output.rgb = _Intensity * density * color;
	output.a = density;
	_VolumeInject[id] = output;
}