はなちるのマイノート

Unityをメインとした技術ブログ。自分らしくまったりやっていきたいと思いますー!

【Unity】ComputeShaderで大津の2値化をしてみる【Q4】

大津の2値化とは?

詳細はこちらの記事がとても参考になりました。
qiita.com

途中経過を省いて結果だけ書いてしまうと、あるt(=0 ~ 255)を境に、グレースケールが小さいものをクラス0 ・大きいものをクラス1とし

w0 = クラス0に含まれる画素数 / 全体の画素数
w1 =  クラス1に含まれる画素数 / 全体の画素数
m0 = クラス0のグレースケールの平均値
m1 = クラス1のグレースケールの平均値

において、

Sb^2 = w0 × w1 × (m0−m1) × (m0-m1)

最大になるtを閾値に用いて2値化します。

コード

ComputeShaderはこちら。

#pragma kernel GrayScale
#pragma kernel FindThreshold
#pragma kernel CreateTexture

Texture2D<float4> Texture;
RWTexture2D<float4> Result;

int Width;
int Height;

RWStructuredBuffer<float> buffer;

float Threshold;

[numthreads(8,8,1)]
void GrayScale (uint3 id : SV_DispatchThreadID)
{
    float gray = 0.2126 * Texture[id.xy].x + 0.7152 * Texture[id.xy].y + 0.0722 * Texture[id.xy].z;

    Result[id.xy] = float4(gray, gray, gray, 1);
}

[numthreads(16,1,1)]
void FindThreshold (uint3 id : SV_DispatchThreadID)
{
    float m0 = 0;   // クラスの平均
    float m1 = 0;  
    float w0 = 0;   // クラスのピクセル数をピクセル総数で割ったもの
    float w1 = 0;
    float sb2 = 0;   // クラス間分散(=sb^2)
    uint i = 0;
    uint j = 0;

    for(i = 0; i < Width; i++){
        for(j = 0; j < Height; j++){
            if(Texture[int2(i,j)].x < (id.x / 255.0)){
               w0 += 1;                        // 一時的にクラスの個数を代入
               m0 += Texture[int2(i,j)].x;     // 一時的にクラスの合計を代入
            }
            else
            {
                w1 += 1;
                m1 += Texture[int2(i,j)].x;
           }
        }
    }

    m0 = (w0 == 0) ? 0 : m0 / w0;   // クラスの平均を算出
    w0 = w0 / (Width * Height);     // 全体における割合の算出
    m1 = (w1 == 0) ? 0 : m1 / w1;
    w1 = w1 / (Width * Height);

    sb2 = w0 * w1 * (m0 - m1) * (m0 - m1);

    buffer[id.x] = sb2;
}

[numthreads(8,8,1)]
void CreateTexture (uint3 id : SV_DispatchThreadID)
{
    float th = Threshold / 255.0;
    Result[id.xy] = (Texture[id.xy] > th) ? float4(1,1,1,1) : float4(0,0,0,1);
}


CPU側のコードこっち。

using UnityEngine;
using UnityEngine.UI;
using System.Linq;

public class OtsuBinarization : MonoBehaviour
{
    [SerializeField] private ComputeShader _computeShader;
    [SerializeField] private Texture2D _tex;
    [SerializeField] private RawImage _renderer;
    private ComputeBuffer _buffer;

    private const int COLOR_SIZE = 256;

    struct ThreadSize
    {
        public uint x;
        public uint y;
        public uint z;

        public ThreadSize(uint x, uint y, uint z)
        {
            this.x = x;
            this.y = y;
            this.z = z;
        }
    }

    private void Start()
    {
        var result = BinarizeOtsu(_tex);
        _renderer.texture = result;
    }

    /// <summary>
    /// グレースケールから大津の2値化を用いたテクスチャを生成
    /// </summary>
    private RenderTexture BinarizeOtsu(Texture2D texture)
    {
        if (!SystemInfo.supportsComputeShaders)
        {
            Debug.LogError("Comppute Shader is not support.");
            return null;
        }

        var target = RGB2Gray(texture);
        var threshold = FindThreshold(target);
        CreateTexture(target, threshold);

        return target;
    }

    /// <summary>
    /// RGBからグレースケールを算出
    /// </summary>
    private RenderTexture RGB2Gray(Texture2D texture)
    {
        // RenderTextureの初期化
        var target = new RenderTexture(_tex.width, _tex.height, 0, RenderTextureFormat.ARGB32);
        target.enableRandomWrite = true;
        target.Create();

        // GrayScaleのカーネルインデックスを取得
        var kernelIndex = _computeShader.FindKernel("GrayScale");

        // 一つのグループの中に何個のスレッドがあるか
        ThreadSize threadSize = new ThreadSize();
        _computeShader.GetKernelThreadGroupSizes(kernelIndex, out threadSize.x, out threadSize.y, out threadSize.z);

        // GPUにデータをコピーする
        _computeShader.SetTexture(kernelIndex, "Texture", texture);
        _computeShader.SetTexture(kernelIndex, "Result", target);

        // GPUの処理を実行する
        _computeShader.Dispatch(kernelIndex, _tex.width / (int)threadSize.x, _tex.height / (int)threadSize.y, (int)threadSize.z);

        return target;
    }

    /// <summary>
    /// グレースケールから大津の2値化を用いて閾値を求める
    /// </summary>
    private float FindThreshold(RenderTexture gray)
    {
        // FindThresholdのカーネルインデックスを取得
        var kernelIndex = _computeShader.FindKernel("FindThreshold");

        // float[255](大津2値化のsb^2)を受け取るための準備
        _buffer = new ComputeBuffer(COLOR_SIZE, sizeof(float));

        // 一つのグループの中に何個のスレッドがあるか
        ThreadSize threadSize = new ThreadSize();
        _computeShader.GetKernelThreadGroupSizes(kernelIndex, out threadSize.x, out threadSize.y, out threadSize.z);

        // GPUにデータをコピーする
        _computeShader.SetBuffer(kernelIndex, "buffer", _buffer);
        _computeShader.SetTexture(kernelIndex, "Texture", gray);
        _computeShader.SetInt("Width", _tex.width);
        _computeShader.SetInt("Height", _tex.height);

        // GPUの処理を実行する
        _computeShader.Dispatch(kernelIndex, COLOR_SIZE / (int)threadSize.x, (int)threadSize.y, (int)threadSize.z);

        // 結果を取得
        var result = new float[COLOR_SIZE];
        _buffer.GetData(result);

        // 最大値のインデックスが閾値
        return result.Select((p, i) => new { Sb2 = p, Index = i })
            .OrderByDescending(p => p.Sb2)
            .First()
            .Index;
    }

    /// <summary>
    /// 求めた閾値を使って画像を2値化する
    /// </summary>
    private void CreateTexture(RenderTexture texture, float threshold)
    {
        // CreateTextureのカーネルインデックスを取得
        var kernelIndex = _computeShader.FindKernel("CreateTexture");

        // 一つのグループの中に何個のスレッドがあるか
        ThreadSize threadSize = new ThreadSize();
        _computeShader.GetKernelThreadGroupSizes(kernelIndex, out threadSize.x, out threadSize.y, out threadSize.z);

        // GPUにデータをコピーする
        _computeShader.SetTexture(kernelIndex, "Texture", texture);
        _computeShader.SetTexture(kernelIndex, "Result", texture);
        _computeShader.SetFloat("Threshold", threshold);

        // GPUの処理を実行する
        _computeShader.Dispatch(kernelIndex, _tex.width / (int)threadSize.x, _tex.height / (int)threadSize.y, (int)threadSize.z);
    }

    private void OnDestroy()
    {
        _buffer.Release();
        _buffer = null;
    }
}

解説

今回は大きく3つのパートに分けて処理をしました。(※ここの反省を最後にかきました)

  1. テクスチャ2Dからグレースケールを求める処理
  1. グレースケールからクラス間分散を求める処理
  1. 求めた閾値tから画像を生成する処理

f:id:hanaaaaaachiru:20200204160621p:plain


これらの処理がそれぞれRGB2GrayFindThresholdCreateTextureに対応しています。

さいごに

今回は分かりやすいかなと思い大きく3つのパートに分けていましたが、今思うとCPU->GPU, GPU -> CPUへのデータの転送回数が増え無駄な処理時間が増えてしまっていると思います。

また実際の処理はGPU側のブラックボックスにした方がおそらく良いですよね。

ただクラス間分散を求める際に他のスレッドの計算結果が必要になるので一筋縄ではいかない?ような。

もし改善することができたらまた記事を書きたいと思います。

ではまた。