Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPU implementation of trilinear only supports batch_size == 1 #61

Open
kamo262 opened this issue Feb 24, 2022 · 1 comment
Open

CPU implementation of trilinear only supports batch_size == 1 #61

kamo262 opened this issue Feb 24, 2022 · 1 comment

Comments

@kamo262
Copy link

kamo262 commented Feb 24, 2022

I noticed the CPU implementation of trilinear forward and backward functions only support batch_size == 1. When we use the functions with batch_size > 2, the first sample is only computed.

I have to fix the functions as the following to process multiple samples in a batch.

void TriLinearForwardCpu(const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int channels, const int batch)
{
    const int output_size = height * width;

    for (int batch_index = 0; batch_index < batch; ++batch_index) {
        const int batch_start_index = batch_index * output_size * channels;
        for (int index = 0; index < output_size; ++index)
        {
            float r = image[batch_start_index + index];
            float g = image[batch_start_index + index + width * height];
            float b = image[batch_start_index + index + width * height * 2];

            int r_id = floor(r / binsize);
            int g_id = floor(g / binsize);
            int b_id = floor(b / binsize);

            float r_d = fmod(r,binsize) / binsize;
            float g_d = fmod(g,binsize) / binsize;
            float b_d = fmod(b,binsize) / binsize;

            int id000 = r_id + g_id * dim + b_id * dim * dim;
            int id100 = r_id + 1 + g_id * dim + b_id * dim * dim;
            int id010 = r_id + (g_id + 1) * dim + b_id * dim * dim;
            int id110 = r_id + 1 + (g_id + 1) * dim + b_id * dim * dim;
            int id001 = r_id + g_id * dim + (b_id + 1) * dim * dim;
            int id101 = r_id + 1 + g_id * dim + (b_id + 1) * dim * dim;
            int id011 = r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim;
            int id111 = r_id + 1 + (g_id + 1) * dim + (b_id + 1) * dim * dim;

            float w000 = (1-r_d)*(1-g_d)*(1-b_d);
            float w100 = r_d*(1-g_d)*(1-b_d);
            float w010 = (1-r_d)*g_d*(1-b_d);
            float w110 = r_d*g_d*(1-b_d);
            float w001 = (1-r_d)*(1-g_d)*b_d;
            float w101 = r_d*(1-g_d)*b_d;
            float w011 = (1-r_d)*g_d*b_d;
            float w111 = r_d*g_d*b_d;

            output[batch_start_index + index] =
                w000 * lut[id000] + w100 * lut[id100] + 
                w010 * lut[id010] + w110 * lut[id110] + 
                w001 * lut[id001] + w101 * lut[id101] + 
                w011 * lut[id011] + w111 * lut[id111];

            output[batch_start_index + index + width * height] =
                w000 * lut[id000 + shift] + w100 * lut[id100 + shift] + 
                w010 * lut[id010 + shift] + w110 * lut[id110 + shift] + 
                w001 * lut[id001 + shift] + w101 * lut[id101 + shift] + 
                w011 * lut[id011 + shift] + w111 * lut[id111 + shift];

            output[batch_start_index + index + width * height * 2] =
                w000 * lut[id000 + shift * 2] + w100 * lut[id100 + shift * 2] + 
                w010 * lut[id010 + shift * 2] + w110 * lut[id110 + shift * 2] + 
                w001 * lut[id001 + shift * 2] + w101 * lut[id101 + shift * 2] + 
                w011 * lut[id011 + shift * 2] + w111 * lut[id111 + shift * 2];
        }
    }
}
@HuiZeng
Copy link
Owner

HuiZeng commented Feb 28, 2022

Hi, thanks for sharing this code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants