AO benchをXNAで動かそう

 以前より、Syoyo FujitaさんのAOのコードなどをC#に移植して遊んだりしていました。

 こちらの記事が動機で、情報を見返したのですが、「4k Procedual Gfx Monitor」にて、ほぼそのままGLSL版になっているのを見ました。

 それならば、HLSLでも同じようなことができるはずなので、移植して、XNAでも動かしてみました。
 コードは少しだけシェーダっぽく書き換えてみたりしています。ベタ移植で不自然になるところはそれらしく置き換えています。

 ソースコード(「ao.fx」と「AOGame.cs」)などは以下の通りです。

最初からプロジェクトを作成する方法

  1. Visual C# 2008 Express Editionをインストールする。
  2. XNA Game Studio 3.0をインストールする。
  3. Visual C# 2008からテンプレート「Windows Game (3.0)」で新しいプロジェクトを作る。
  4. テンプレートの「Game1.cs」を「AOGame.cs」に置き換える。
  5. 「ao.fx」のソースファイルをサブプロジェクト「Content」に追加する。
  6. ビルドして実行する。
  7. ベンチマーク機能はないが、球の1つが飛びまわる様子が確認できる。

※ntheta = 4, nphi = 4にしていて、GLSL版より荒くなっています。これより大きくするとシェーダが実行できませんでした。

ao.fx (HLSLコード)

// porting of Ambient Occlusion to HLSL/XNA by XELF 2009
// [http://kioku.sys-k.net/4kgfxmon/]
// porting GLSL by kioku based on syoyo's AS3 Ambient Occlusion
// [http://lucille.atso-net.jp/blog/?p=638]

struct Ray
{
	float3 org;
	float3 dir;
};
struct Sphere
{
	float3 center;
	float radius;
};
struct Plane
{
	float3 p;
	float3 n;
};

struct Intersection
{
    float t;
    float3 p;     // hit point
    float3 n;     // normal
    int hit;
};

Sphere sphere[3];
Plane plane;
float aspectRatio = (800 / 600.0);
static int seed = 0;

void shpere_intersect(Sphere s, Ray ray, inout Intersection isect)
{
    // rs = ray.org - sphere.center
    float3 rs = ray.org - s.center;
    float B = dot(rs, ray.dir);
    float C = dot(rs, rs) - (s.radius * s.radius);
    float D = B * B - C;

    if (D > 0.0)
    {
		float t = -B - sqrt(D);
		if ( (t > 0.0) && (t < isect.t) )
		{
			isect.t = t;
			isect.hit = 1;

			// calculate normal.
			isect.p = ray.org + ray.dir * t;
			isect.n = normalize(isect.p - s.center);
		}
	}
}

void plane_intersect(Plane pl, Ray ray, inout Intersection isect)
{
	// d = -(p . n)
	// t = -(ray.org . n + d) / (ray.dir . n)
	float d = -dot(pl.p, pl.n);
	float v = dot(ray.dir, pl.n);

	if (abs(v) < 1.0e-6) {
		; // the plane is parallel to the ray.
	} else {
		float t = -(dot(ray.org, pl.n) + d) / v;

		if ( (t > 0.0) && (t < isect.t) )
		{
			isect.hit = 1;
			isect.t   = t;
			isect.n   = pl.n;

			isect.p = ray.org + t * ray.dir;
		}
	}
}

void Intersect(Ray r, inout Intersection i)
{
	for (int c = 0; c < 3; c++)
	{
		shpere_intersect(sphere[c], r, i);
	}
	plane_intersect(plane, r, i);
}

static const float3 basisTable[8] = {
	float3(1, 0, 0), float3(1, 0, 0), float3(0, 1, 0), float3(1, 0, 0),
	float3(0, 0, 1), float3(1, 0, 0), float3(0, 1, 0), float3(1, 0, 0),
};

void orthoBasis(out float3x3 basis, float3 n)
{
	basis[2] = n;
	basis[1] = basisTable[dot(abs(n) < 0.6, float3(1, 2, 4))];

	basis[0] = cross(basis[1], basis[2]);
	basis[0] = normalize(basis[0]);

	basis[1] = cross(basis[2], basis[0]);
	basis[1] = normalize(basis[1]);
}

float random()
{
	seed = int(fmod(float(seed)*1364.0+626.0,509.0));
	return float(seed)/509.0;
}

float3 computeAO(inout Intersection isect)
{
	const int ntheta = 4;
	const int nphi   = 4;
	const float eps  = 0.0001;

    // Slightly move ray org towards ray dir to avoid numerical problem.
    float3 p = isect.p + eps * isect.n;

    // Calculate orthogonal basis.
    float3x3 basis;
    orthoBasis(basis, isect.n);

    float occlusion = 0.0;

    for (int j = 0; j < ntheta; j++)
    {
		for (int i = 0; i < nphi; i++)
		{
			// Pick a random ray direction with importance sampling.
			// p = cos(theta) / 3.141592
			float r = random();
			float phi = 2.0 * 3.141592 * random();

			float3 ref;
			float s, c;
			sincos(phi, s, c);
			ref.x = c * sqrt(1.0 - r);
			ref.y = s * sqrt(1.0 - r);
			ref.z = sqrt(r);

			Ray ray;
			ray.org = p;
			// local -> global
			ray.dir = mul(ref, basis);

			Intersection occIsect;
			occIsect.hit = 0;
			occIsect.t = 1.0e30;
			occIsect.n = occIsect.p = float3(0, 0, 0);
			Intersect(ray, occIsect);
			occlusion += (occIsect.hit != 0);
		}
	}

	// [0.0, 1.0]
	occlusion = (float(ntheta * nphi) - occlusion) / float(ntheta * nphi);
	return occlusion.xxx;
}

float4 PS(
	float3 org : TEXCOORD0,
	float3 dir : TEXCOORD1) : COLOR0
{
	//sphere[0].center = float3(-2.0, 0.0, -3.5);
	sphere[0].radius = 0.5;
	sphere[1].center = float3(-0.5, 0.0, -3.0);
	sphere[1].radius = 0.5;
	sphere[2].center = float3(1.0, 0.0, -2.2);
	sphere[2].radius = 0.5;
	plane.p = float3(0,-0.5, 0);
	plane.n = float3(0, 1.0, 0);
	
	Intersection i;
	i.hit = 0;
	i.t = 1.0e30;
	i.n = i.p = float3(0, 0, 0);
		
	Ray r;
	r.org = org;
	r.dir = normalize(dir);
	seed = (int(fmod((10 + dir.x) * (10 + dir.y) * 4525434.0, 65536.0)));
	
	float4 col = float4(0,0,0,0);
	Intersect(r, i);
	if (i.hit != 0)
	{
		col.rgb = computeAO(i);
	}
	
	return col;
}

void VS(
	float4 InPosition : POSITION0,
	out float4 OutPosition : POSITION0,
	out float3 org : TEXCOORD0,
	out float3 dir : TEXCOORD1) {
	OutPosition = InPosition;
	org = float3(0,0,0);
	dir = normalize(-float3(-InPosition.x * aspectRatio, -InPosition.y, 1));
}

technique AO {
	pass AO {
		VertexShader = compile vs_3_0 VS();
		PixelShader = compile ps_3_0 PS();
	}
}

AOGame.cs (C#コード)

using System;
using Microsoft.Xna.Framework;
using Microsoft.Xna.Framework.Graphics;
using Microsoft.Xna.Framework.Input;

namespace AO {

	public class AOGame : Microsoft.Xna.Framework.Game {
		readonly GraphicsDeviceManager graphics;
		SpriteBatch spriteBatch;
		Effect effect;
		QuadFiller quad;
		Random random = new Random();

		public AOGame() {
			graphics = new GraphicsDeviceManager(this);
			Content.RootDirectory = "Content";
		}
		protected override void LoadContent() {
			spriteBatch = new SpriteBatch(GraphicsDevice);

			effect = Content.Load<Effect>("ao");
			quad = new QuadFiller(GraphicsDevice);
		}

		protected override void Update(GameTime gameTime) {
			if (GamePad.GetState(PlayerIndex.One).Buttons.Back == ButtonState.Pressed)
				this.Exit();

			float x = (float)Math.Sin(gameTime.TotalGameTime.TotalSeconds * 0.64f);
			float y = (float)Math.Sin(gameTime.TotalGameTime.TotalSeconds * 0.73f);
			float z = (float)Math.Sin(gameTime.TotalGameTime.TotalSeconds * 0.35f) - 3.5f;
			effect.Parameters["sphere"].Elements[0].StructureMembers["center"].SetValue(new Vector3(x, y, z));

			base.Update(gameTime);
		}

		protected override void Draw(GameTime gameTime) {
			GraphicsDevice.Clear(Color.CornflowerBlue);

			quad.Fill(effect);

			base.Draw(gameTime);
		}
	}


	public class QuadFiller {
		readonly GraphicsDevice graphicsDevice;
		readonly VertexDeclaration declaration;
		readonly VertexBuffer vb;

		public QuadFiller(GraphicsDevice graphicsDevice) {
			this.graphicsDevice = graphicsDevice;
			vb = new VertexBuffer(graphicsDevice, typeof(VertexPositionTexture), 4, BufferUsage.WriteOnly);
			vb.SetData(new VertexPositionTexture[]{
					new VertexPositionTexture(new Vector3(-1, +1, 0), Vector2.Zero),
					new VertexPositionTexture(new Vector3(+1, +1, 0), Vector2.UnitX),
					new VertexPositionTexture(new Vector3(-1, -1, 0), Vector2.UnitY),
					new VertexPositionTexture(new Vector3(+1, -1, 0), Vector2.One),
				});
			declaration = new VertexDeclaration(graphicsDevice, VertexPositionTexture.VertexElements);
		}
		public void Fill(Effect effect) {
			graphicsDevice.RenderState.CullMode = CullMode.None;
			effect.Begin();
			foreach (var pass in effect.CurrentTechnique.Passes) {
				pass.Begin();
				Fill();
				pass.End();
			}
			effect.End();
		}
		public void Fill() {
			graphicsDevice.Vertices[0].SetSource(vb, 0, VertexPositionTexture.SizeInBytes);
			graphicsDevice.VertexDeclaration = declaration;
			graphicsDevice.DrawPrimitives(PrimitiveType.TriangleStrip, 0, 2);
		}

	}

}