mirror of
https://github.com/CNugteren/CLBlast.git
synced 2024-07-02 20:36:58 +02:00
Fixed half-precision tests for im2col and col2im
This commit is contained in:
parent
4215bbe62a
commit
469c346a8e
|
@ -197,6 +197,26 @@ StatusCode RunReference(const Arguments<T> &args, BuffersHost<T> &buffers_host)
|
||||||
return StatusCode::kSuccess;
|
return StatusCode::kSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Half-precision version calling the above reference implementation after conversions
|
||||||
|
template <>
|
||||||
|
StatusCode RunReference<half>(const Arguments<half> &args, BuffersHost<half> &buffers_host) {
|
||||||
|
auto a_buffer2 = HalfToFloatBuffer(buffers_host.a_mat);
|
||||||
|
auto b_buffer2 = HalfToFloatBuffer(buffers_host.b_mat);
|
||||||
|
auto dummy = std::vector<float>(0);
|
||||||
|
auto buffers2 = BuffersHost<float>{dummy, dummy, a_buffer2, b_buffer2, dummy, dummy, dummy};
|
||||||
|
auto args2 = Arguments<float>();
|
||||||
|
args2.a_size = args.a_size; args2.b_size = args.b_size;
|
||||||
|
args2.channels = args.channels; args2.height = args.height; args2.width = args.width;
|
||||||
|
args2.kernel_h = args.kernel_h; args2.kernel_w = args.kernel_w;
|
||||||
|
args2.pad_h = args.pad_h; args2.pad_w = args.pad_w;
|
||||||
|
args2.stride_h = args.stride_h; args2.stride_w = args.stride_w;
|
||||||
|
args2.dilation_h = args.dilation_h; args2.dilation_w = args.dilation_w;
|
||||||
|
args2.a_offset = args.a_offset; args2.b_offset = args.b_offset;
|
||||||
|
auto status = RunReference(args2, buffers2);
|
||||||
|
FloatToHalfBuffer(buffers_host.a_mat, buffers2.a_mat);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
} // namespace clblast
|
} // namespace clblast
|
||||||
|
|
||||||
|
|
|
@ -188,6 +188,26 @@ StatusCode RunReference(const Arguments<T> &args, BuffersHost<T> &buffers_host)
|
||||||
return StatusCode::kSuccess;
|
return StatusCode::kSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Half-precision version calling the above reference implementation after conversions
|
||||||
|
template <>
|
||||||
|
StatusCode RunReference<half>(const Arguments<half> &args, BuffersHost<half> &buffers_host) {
|
||||||
|
auto a_buffer2 = HalfToFloatBuffer(buffers_host.a_mat);
|
||||||
|
auto b_buffer2 = HalfToFloatBuffer(buffers_host.b_mat);
|
||||||
|
auto dummy = std::vector<float>(0);
|
||||||
|
auto buffers2 = BuffersHost<float>{dummy, dummy, a_buffer2, b_buffer2, dummy, dummy, dummy};
|
||||||
|
auto args2 = Arguments<float>();
|
||||||
|
args2.a_size = args.a_size; args2.b_size = args.b_size;
|
||||||
|
args2.channels = args.channels; args2.height = args.height; args2.width = args.width;
|
||||||
|
args2.kernel_h = args.kernel_h; args2.kernel_w = args.kernel_w;
|
||||||
|
args2.pad_h = args.pad_h; args2.pad_w = args.pad_w;
|
||||||
|
args2.stride_h = args.stride_h; args2.stride_w = args.stride_w;
|
||||||
|
args2.dilation_h = args.dilation_h; args2.dilation_w = args.dilation_w;
|
||||||
|
args2.a_offset = args.a_offset; args2.b_offset = args.b_offset;
|
||||||
|
auto status = RunReference(args2, buffers2);
|
||||||
|
FloatToHalfBuffer(buffers_host.b_mat, buffers2.b_mat);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
// =================================================================================================
|
// =================================================================================================
|
||||||
} // namespace clblast
|
} // namespace clblast
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue