#pragma kernel Bilinear32x32
#pragma kernel Bilinear16x16
#pragma kernel Bilinear16x16FromShortArray
#pragma kernel Bilinear8x8

Texture2D<float> src;
RWTexture2D<float> dest;

RWStructuredBuffer<int> srcArray;

cbuffer params
{
    int srcWidth;
    int srcHeight;
    int destWidth;
    int destHeight;
};

float sharpness;

void ConvertFromShortArray(uint2 id)
{
    if(destWidth <= id.x || destHeight <= id.y)
        return;
    
    int rx = id.x;
    int ry = id.y;
    
    int i0 = ry * srcWidth + rx;
    float c0 = (1 - i0 % 2) * (srcArray[i0 / 2] & 0x7FFF) + (i0 % 2) * (srcArray[i0 / 2] >> 16 & 0x7FFF);
    c0 = c0 * 0.001 + step(-1.0, -c0) * 30.0;
    dest[id.xy] = c0;
}

void BilinearInterpolationFromShortArray(uint2 id)
{
    if(destWidth <= id.x || destHeight <= id.y)
        return;
    
    float2 sampleUnit = float2(float(srcWidth) / destWidth, float(srcHeight) / destHeight);
    float sx = sampleUnit.x * id.x;
    float sy = sampleUnit.y * id.y;
    int rx = int(sx);
    int ry = int(sy);
    
    int i0 = ry * srcWidth + rx;
    int i1 = ry * srcWidth + min(srcWidth - 1, rx + 1);
    int i2 = min(srcHeight - 1, ry + 1) * srcWidth + rx;
    int i3 = min(srcHeight - 1, ry + 1) * srcWidth + min(srcWidth - 1, rx + 1);
    float c0 = (1 - i0 % 2) * (srcArray[i0 / 2] & 0x7FFF) + (i0 % 2) * (srcArray[i0 / 2] >> 16 & 0x7FFF);
    float c1 = (1 - i1 % 2) * (srcArray[i1 / 2] & 0x7FFF) + (i1 % 2) * (srcArray[i1 / 2] >> 16 & 0x7FFF);
    float c2 = (1 - i2 % 2) * (srcArray[i2 / 2] & 0x7FFF) + (i2 % 2) * (srcArray[i2 / 2] >> 16 & 0x7FFF);
    float c3 = (1 - i3 % 2) * (srcArray[i3 / 2] & 0x7FFF) + (i3 % 2) * (srcArray[i3 / 2] >> 16 & 0x7FFF);

    c0 = c0 * 0.001 + step(-1.0, -c0) * 30.0;
    c1 = c1 * 0.001 + step(-1.0, -c1) * 30.0;
    c2 = c2 * 0.001 + step(-1.0, -c2) * 30.0;
    c3 = c3 * 0.001 + step(-1.0, -c3) * 30.0;
    float val = (rx + 1.0 - sx) * ((ry + 1.0 - sy) * c0 + (sy - ry) * c2) + (sx - rx) * ((ry + 1 - sy) * c1 + (sy - ry) * c3);

    dest[id.xy] = val;
}

void BilinearInterpolation(uint2 id)
{
    uint2 srcSize;
    src.GetDimensions(srcSize.x, srcSize.y);
    
    uint2 destSize;
    dest.GetDimensions(destSize.x, destSize.y);
    
    if(destSize.x <= id.x || destSize.y <= id.y)
        return;
    
    float2 sampleUnit = float2(float(srcSize.x) / destSize.x, float(srcSize.y) / destSize.y);
    float sx = sampleUnit.x * id.x;
    float sy = sampleUnit.y * id.y;
    int rx = int(sx);
    int ry = int(sy);
    
    float c0 = src[uint2(rx, ry)];
    float c1 = src[uint2(min(srcSize.x - 1, rx + 1), ry)];
    float c2 = src[uint2(rx, min(srcSize.y - 1, ry + 1))];
    float c3 = src[uint2(min(srcSize.x - 1, rx + 1), min(srcSize.y - 1, ry + 1))];
    dest[id.xy] = (rx + 1.0 - sx) * ((ry + 1.0 - sy) * c0 + (sy - ry) * c2) + (sx - rx) * ((ry + 1 - sy) * c1 + (sy - ry) * c3);
}

[numthreads(32, 32, 1)]
void Bilinear32x32(uint3 id : SV_DispatchThreadID)
{
    BilinearInterpolation(id.xy);
}

[numthreads(16, 16, 1)]
void Bilinear16x16(uint3 id : SV_DispatchThreadID)
{
    BilinearInterpolation(id.xy);
}

[numthreads(16, 16, 1)]
void Bilinear16x16FromShortArray(uint3 id : SV_DispatchThreadID)
{
    ConvertFromShortArray(id.xy);
}

[numthreads(8, 8, 1)]
void Bilinear8x8(uint3 id : SV_DispatchThreadID)
{
    BilinearInterpolation(id.xy);
}