const float4 sp[2][2] : register(c8);

sampler sPos : register(s0);
sampler sDir : register(s1);
sampler sMask : register(s3);

struct P_IN
{
	float4 pos : POSITION0;
	float2 tex : TEXCOORD0;
};

struct P_OUT2
{
	float4 pos1 : COLOR0;
	float4 dir1 : COLOR1;
	float4 pos2 : COLOR2;
	float4 dir2 : COLOR3;
};

inline float3 reflect_impl(const float3 i, const float3 n, const float c)
{
	return i - n * c * 2.0;
} 

inline float3 refract_impl(const float3 i, const float3 n, const float c, const float r)
{
	return (i - n * c) * r - n * sqrt(1.0 - (1.0 - c * c) * r * r);
} 

inline P_OUT2 sphere(const P_IN In, const int index)
{
	clip(0.5 - tex2D(sMask, In.tex).g);

	const float4 regp = tex2D(sPos, In.tex);
	const float4 regs = tex2D(sDir, In.tex);
	const float3 pos = regp.rgb;
	const float3 dir = regs.rgb;

	const float3 offset = sp[index][0].xyz - pos;
	const float distance = dot(offset, dir);
	const float3 differ = offset - dir * distance;
	const float para = dot(differ, differ);
	clip(float2(distance, sp[index][0].w * sp[index][0].w - para));

	const float3 hit = pos + dir * (distance - sqrt(sp[index][0].w * sp[index][0].w - para));
	const float3 norm = (hit - sp[index][0].xyz) / sp[index][0].w;
	const float cosine = dot(dir, norm);
	const float3 refl = normalize(reflect_impl(dir, norm, cosine));
	const float3 refr = normalize(refract_impl(dir, norm, cosine, 1.0 / sp[index][1].y));
	const float cosine2 = dot(refr, norm);
	const float refp = (cosine - cosine2 * sp[index][1].y) / (cosine + cosine2 * sp[index][1].y);
	const float refs = (cosine * sp[index][1].y - cosine2) / (cosine * sp[index][1].y + cosine2);

	const P_OUT2 Out =
	{
		float4(hit, regp.a * refp * refp),
		float4(refl, regs.a * refs * refs),
		float4(hit, regp.a * (1.0 - refp * refp)),
		float4(refr, regs.a * (1.0 - refs * refs)),
	};

	return Out;
}

P_OUT2 sphere1(const P_IN In)
{
	return sphere(In, 0);
}

inline P_OUT2 inner(const P_IN In, const int index)
{
	clip(0.5 - tex2D(sMask, In.tex).g);

	const float4 regp = tex2D(sPos, In.tex);
	const float4 regs = tex2D(sDir, In.tex);
	const float3 pos = regp.rgb;
	const float3 dir = regs.rgb;

	const float3 offset = sp[index][0].xyz - pos;
	const float distance = dot(offset, dir);

	const float3 hit = pos + dir * distance * 2.0;
	const float3 norm = -(hit - sp[index][0].xyz) / sp[index][0].w;
	const float cosine = dot(dir, norm);
	const float3 refl = normalize(reflect_impl(dir, norm, cosine));
	const float3 refr = normalize(refract_impl(dir, norm, cosine, sp[index][1].y));
	const float cosine2 = dot(refr, norm);
	const float refp = (cosine * sp[index][1].y - cosine2) / (cosine * sp[index][1].y + cosine2);
	const float refs = (cosine - cosine2 * sp[index][1].y) / (cosine + cosine2 * sp[index][1].y);

	const P_OUT2 Out =
	{
		float4(hit, regp.a * (1.0 - refp * refp)),
		float4(refr, regs.a * (1.0 - refs * refs)),
		float4(hit, regp.a * refp * refp),
		float4(refl, regs.a *  refs * refs)
	};

	return Out;
}

P_OUT2 inner1(const P_IN In)
{
	return inner(In, 0);
}