はなちるのマイノート

Unityをメインとした技術ブログ。自分らしくまったりやっていきたいと思いますー!

【Unity】ComputeShaderでメディアンフィルタを実装してみる【Q10】

メディアンフィルタとは

以下の画像のように自身と周辺画素の9個の中で画素値が中央値に置き換えます。

f:id:hanaaaaaachiru:20200213230020p:plain

この場合はソートをして15->40->40->78->102->145->180->200->255となり、真ん中にあるやつが中央値になるので102が出力値になるというわけです。

コード

ComputeShaderはこちら。

#pragma kernel Median

RWTexture2D<float4> Result;
Texture2D<float4> Texture;

[numthreads(8,8,1)]
void Median (uint3 id : SV_DispatchThreadID)
{
    float4 upperLeft = Texture[id.xy + int2(-1, 1)];
    float4 up = Texture[id.xy + int2(0, 1)];
    float4 upperRight = Texture[id.xy + int2(1, 1)];
    float4 left = Texture[id.xy + int2(-1, 0)];
    float4 middle = Texture[id.xy];
    float4 right = Texture[id.xy + int2(1, 0)];
    float4 lowerLeft = Texture[id.xy + int2(-1, -1)];
    float4 down = Texture[id.xy + int2(0, -1)];
    float4 lowerRight = Texture[id.xy + int2(1, -1)];
    int i, j;
    float4 tmp;
    float4 rgb2gray = float4(0.2126, 0.7152, 0.0722, 0);

    float4 array[9] = {
        float4(upperLeft.x, upperLeft.y, upperLeft.z, dot(upperLeft, rgb2gray)),
        float4(up.x, up.y, up.z, dot(up, rgb2gray)),
        float4(upperRight.x, upperRight.y, upperRight.z, dot(upperRight, rgb2gray)),
        float4(left.x, left.y, left.z, dot(left, rgb2gray)),
        float4(middle.x, middle.y, middle.z, dot(middle, rgb2gray)),
        float4(right.x, right.y, right.z, dot(right, rgb2gray)),
        float4(lowerLeft.x, lowerLeft.y, lowerLeft.z, dot(lowerLeft, rgb2gray)),
        float4(down.x, down.y, down.z, dot(down, rgb2gray)),
        float4(lowerRight.x, lowerRight.y, lowerRight.z, dot(lowerRight, rgb2gray))
    };

    // バブルソート
    for(i = 0; i < 8; i++){
        for(j = 0; j < 8 - i; j++){
            tmp = array[i + 1];
            array[i + 1] = (array[i].w > array[i+1].w) ? array[i] : array[i + 1];
            array[i] = (array[i].w > array[i+1].w) ? tmp : array[i];
        }
    }

    Result[id.xy] = float4(array[4].x, array[4].y, array[4].z, 1);
}


CPU側のコードはこちら。

using UnityEngine;
using UnityEngine.UI;

public class Median : MonoBehaviour
{
    [SerializeField] private ComputeShader _computeShader;
    [SerializeField] private Texture2D _tex;
    [SerializeField] private RawImage _renderer;

    struct ThreadSize
    {
        public uint x;
        public uint y;
        public uint z;

        public ThreadSize(uint x, uint y, uint z)
        {
            this.x = x;
            this.y = y;
            this.z = z;
        }
    }

    private void Start()
    {
        if (!SystemInfo.supportsComputeShaders)
        {
            Debug.LogError("Comppute Shader is not support.");
            return;
        }

        // RenderTextueの初期化
        var result = new RenderTexture(_tex.width, _tex.height, 0, RenderTextureFormat.ARGB32);
        result.enableRandomWrite = true;
        result.Create();

        // Medianのカーネルインデックス(0)を取得
        var kernelIndex = _computeShader.FindKernel("Median");

        // 一つのグループの中に何個のスレッドがあるか
        ThreadSize threadSize = new ThreadSize();
        _computeShader.GetKernelThreadGroupSizes(kernelIndex, out threadSize.x, out threadSize.y, out threadSize.z);

        // GPUにデータをコピーする
        _computeShader.SetTexture(kernelIndex, "Texture", _tex);
        _computeShader.SetTexture(kernelIndex, "Result", result);

        // GPUの処理を実行する
        _computeShader.Dispatch(kernelIndex, _tex.width / (int)threadSize.x, _tex.height / (int)threadSize.y, (int)threadSize.z);

        // テクスチャを適応
        _renderer.texture = result;
    }
}

さいごに

ソートの部分をバブルソートで行ってしまったので、パフォーマンス的にはよくありませんがたった9個しか対象のピクセルがないのでさほど影響はしない気もします。

またHLSLではif文は遅いらしいので、無理やり三項演算子を使っています。

結構最適化の余地はありますが、こんな実装でも実感ができないほど一瞬で処理が終わるのが本当にすごいところですよね。


あとメディアンフィルタってグレースケールの中央値でおそらくいいんですよね?

ネットで調べてみるとグレースケールの中央値でやっているかたとRGBのそれぞれの中央値でやっている方がいて結構不安になっていたりします。


とりあえず1~10個までを一通り終えたので、いっかい宣伝がてらまとめ記事を書こうと思ってます。

ではまた。