Added reference implementation for xCONVGEMM for half-precision

pull/319/head
Cedric Nugteren 2018-09-07 22:04:08 +02:00
parent c788e040f7
commit bbb4523b7c
1 changed files with 22 additions and 0 deletions

View File

@ -214,6 +214,28 @@ StatusCode RunReference(const Arguments<T> &args, BuffersHost<T> &buffers_host)
return StatusCode::kSuccess;
}
// Half-precision version calling the above reference implementation after conversions
template <>
StatusCode RunReference<half>(const Arguments<half> &args, BuffersHost<half> &buffers_host) {
auto a_buffer2 = HalfToFloatBuffer(buffers_host.a_mat);
auto b_buffer2 = HalfToFloatBuffer(buffers_host.b_mat);
auto c_buffer2 = HalfToFloatBuffer(buffers_host.c_mat);
auto dummy = std::vector<float>(0);
auto buffers2 = BuffersHost<float>{dummy, dummy, a_buffer2, b_buffer2, c_buffer2, dummy, dummy};
auto args2 = Arguments<float>();
args2.a_size = args.a_size; args2.b_size = args.b_size; args2.c_size = args.c_size;
args2.channels = args.channels; args2.height = args.height; args2.width = args.width;
args2.kernel_h = args.kernel_h; args2.kernel_w = args.kernel_w;
args2.pad_h = args.pad_h; args2.pad_w = args.pad_w;
args2.stride_h = args.stride_h; args2.stride_w = args.stride_w;
args2.dilation_h = args.dilation_h; args2.dilation_w = args.dilation_w;
args2.num_kernels = args.num_kernels; args2.batch_count = args.batch_count;
args2.a_offset = args.a_offset; args2.b_offset = args.b_offset; args2.c_offset = args.c_offset;
auto status = RunReference(args2, buffers2);
FloatToHalfBuffer(buffers_host.c_mat, buffers2.c_mat);
return status;
}
// =================================================================================================
} // namespace clblast