From b487d4dd44179293c9e08ddf2ce3ed902fa749c8 Mon Sep 17 00:00:00 2001 From: Cedric Nugteren Date: Thu, 26 May 2016 13:15:27 +0200 Subject: [PATCH] Added half-precision tests for the CBLAS reference through conversion to single-precison --- scripts/generator/generator.py | 40 ++- scripts/generator/routine.py | 57 ++++- test/correctness/routines/level1/xamax.cc | 1 + test/correctness/routines/level1/xasum.cc | 1 + test/correctness/routines/level1/xaxpy.cc | 1 + test/correctness/routines/level1/xcopy.cc | 1 + test/correctness/routines/level1/xdot.cc | 1 + test/correctness/routines/level1/xnrm2.cc | 1 + test/correctness/routines/level1/xscal.cc | 1 + test/correctness/routines/level1/xswap.cc | 1 + test/correctness/routines/level2/xgbmv.cc | 1 + test/correctness/routines/level2/xgemv.cc | 1 + test/correctness/routines/level2/xger.cc | 1 + test/correctness/routines/level2/xsbmv.cc | 1 + test/correctness/routines/level2/xspmv.cc | 1 + test/correctness/routines/level2/xspr.cc | 1 + test/correctness/routines/level2/xspr2.cc | 1 + test/correctness/routines/level2/xsymv.cc | 1 + test/correctness/routines/level2/xsyr.cc | 1 + test/correctness/routines/level2/xsyr2.cc | 1 + test/correctness/routines/level2/xtbmv.cc | 1 + test/correctness/routines/level2/xtpmv.cc | 1 + test/correctness/routines/level2/xtrmv.cc | 1 + test/correctness/routines/level3/xgemm.cc | 1 + test/correctness/routines/level3/xsymm.cc | 1 + test/correctness/routines/level3/xsyr2k.cc | 1 + test/correctness/routines/level3/xsyrk.cc | 1 + test/correctness/routines/level3/xtrmm.cc | 1 + test/correctness/routines/level3/xtrsm.cc | 1 + test/correctness/testblas.cc | 8 +- test/correctness/tester.cc | 13 +- test/wrapper_cblas.h | 268 ++++++++++++++++++--- 32 files changed, 366 insertions(+), 47 deletions(-) diff --git a/scripts/generator/generator.py b/scripts/generator/generator.py index d78c3201..3d07c5a3 100644 --- a/scripts/generator/generator.py +++ b/scripts/generator/generator.py @@ -259,12 +259,14 @@ def wrapper_cblas(routines): if routine.has_tests: result += "\n// Forwards the Netlib BLAS calls for %s\n" % (routine.ShortNamesTested()) for flavour in routine.flavours: - indent = " "*(10 + routine.Length()) result += routine.RoutineHeaderWrapperC(flavour, False, 12)+" {\n" + + # There is a version available in CBLAS if flavour.precision_name in ["S","D","C","Z"]: + indent = " "*(10 + routine.Length()) arguments = routine.ArgumentsWrapperC(flavour) - # Double-precision scalars + # Complex scalars for scalar in routine.scalars: if flavour.IsComplex(scalar): result += " const auto "+scalar+"_array = std::vector<"+flavour.buffertype[:-1]+">{"+scalar+".real(), "+scalar+".imag()};\n" @@ -294,10 +296,27 @@ def wrapper_cblas(routines): result += " "+assignment+"cblas_"+flavour.name.lower()+routine.name+postfix+"(" result += (",\n"+indent).join([a for a in arguments]) - result += extra_argument+endofline+");" - else: - result += " return;" - result += "\n}\n" + result += extra_argument+endofline+");\n" + + # There is no CBLAS available, forward the call to one of the available functions + else: # Half-precision + indent = " "*(9 + routine.Length()) + + # Convert to float (note: also integer buffers are stored as half/float) + for buf in routine.inputs + routine.outputs: + result += " auto "+buf+"_buffer_bis = HalfToFloatBuffer("+buf+"_buffer);\n" + + # Call the float routine + result += " cblasX"+routine.name+"(" + result += (",\n"+indent).join([a for a in routine.ArgumentsHalf()]) + result += ");\n" + + # Convert back to half + for buf in routine.outputs: + result += " FloatToHalfBuffer("+buf+"_buffer, "+buf+"_buffer_bis);\n" + + # Complete + result += "}\n" return result # ================================================================================================== @@ -317,7 +336,7 @@ files = [ path_clblast+"/test/wrapper_clblas.h", path_clblast+"/test/wrapper_cblas.h", ] -header_lines = [84, 71, 93, 22, 29, 41] +header_lines = [84, 71, 93, 22, 29, 51] footer_lines = [17, 71, 19, 14, 6, 6] # Checks whether the command-line arguments are valid; exists otherwise @@ -376,10 +395,9 @@ for level in [1,2,3]: body += "int main(int argc, char *argv[]) {\n" not_first = "false" for flavour in routine.flavours: - if flavour.precision_name in ["S","D","C","Z"]: - body += " clblast::RunTests, double, double>(argc, argv, true, "iDAMAX"); clblast::RunTests, float2, float2>(argc, argv, true, "iCAMAX"); clblast::RunTests, double2, double2>(argc, argv, true, "iZAMAX"); + clblast::RunTests, half, half>(argc, argv, true, "iHAMAX"); return 0; } diff --git a/test/correctness/routines/level1/xasum.cc b/test/correctness/routines/level1/xasum.cc index 5ec20596..d3b036c7 100644 --- a/test/correctness/routines/level1/xasum.cc +++ b/test/correctness/routines/level1/xasum.cc @@ -22,6 +22,7 @@ int main(int argc, char *argv[]) { clblast::RunTests, double, double>(argc, argv, true, "DASUM"); clblast::RunTests, float2, float2>(argc, argv, true, "ScASUM"); clblast::RunTests, double2, double2>(argc, argv, true, "DzASUM"); + clblast::RunTests, half, half>(argc, argv, true, "HASUM"); return 0; } diff --git a/test/correctness/routines/level1/xaxpy.cc b/test/correctness/routines/level1/xaxpy.cc index 746e0001..04f4c128 100644 --- a/test/correctness/routines/level1/xaxpy.cc +++ b/test/correctness/routines/level1/xaxpy.cc @@ -22,6 +22,7 @@ int main(int argc, char *argv[]) { clblast::RunTests, double, double>(argc, argv, true, "DAXPY"); clblast::RunTests, float2, float2>(argc, argv, true, "CAXPY"); clblast::RunTests, double2, double2>(argc, argv, true, "ZAXPY"); + clblast::RunTests, half, half>(argc, argv, true, "HAXPY"); return 0; } diff --git a/test/correctness/routines/level1/xcopy.cc b/test/correctness/routines/level1/xcopy.cc index 3e16ffc6..316c6982 100644 --- a/test/correctness/routines/level1/xcopy.cc +++ b/test/correctness/routines/level1/xcopy.cc @@ -22,6 +22,7 @@ int main(int argc, char *argv[]) { clblast::RunTests, double, double>(argc, argv, true, "DCOPY"); clblast::RunTests, float2, float2>(argc, argv, true, "CCOPY"); clblast::RunTests, double2, double2>(argc, argv, true, "ZCOPY"); + clblast::RunTests, half, half>(argc, argv, true, "HCOPY"); return 0; } diff --git a/test/correctness/routines/level1/xdot.cc b/test/correctness/routines/level1/xdot.cc index 5ea105e0..72dc9d5e 100644 --- a/test/correctness/routines/level1/xdot.cc +++ b/test/correctness/routines/level1/xdot.cc @@ -20,6 +20,7 @@ using double2 = clblast::double2; int main(int argc, char *argv[]) { clblast::RunTests, float, float>(argc, argv, false, "SDOT"); clblast::RunTests, double, double>(argc, argv, true, "DDOT"); + clblast::RunTests, half, half>(argc, argv, true, "HDOT"); return 0; } diff --git a/test/correctness/routines/level1/xnrm2.cc b/test/correctness/routines/level1/xnrm2.cc index 97fb0ad6..0fe8dc33 100644 --- a/test/correctness/routines/level1/xnrm2.cc +++ b/test/correctness/routines/level1/xnrm2.cc @@ -22,6 +22,7 @@ int main(int argc, char *argv[]) { clblast::RunTests, double, double>(argc, argv, true, "DNRM2"); clblast::RunTests, float2, float2>(argc, argv, true, "ScNRM2"); clblast::RunTests, double2, double2>(argc, argv, true, "DzNRM2"); + clblast::RunTests, half, half>(argc, argv, true, "HNRM2"); return 0; } diff --git a/test/correctness/routines/level1/xscal.cc b/test/correctness/routines/level1/xscal.cc index 4d138fad..9146e5ce 100644 --- a/test/correctness/routines/level1/xscal.cc +++ b/test/correctness/routines/level1/xscal.cc @@ -22,6 +22,7 @@ int main(int argc, char *argv[]) { clblast::RunTests, double, double>(argc, argv, true, "DSCAL"); clblast::RunTests, float2, float2>(argc, argv, true, "CSCAL"); clblast::RunTests, double2, double2>(argc, argv, true, "ZSCAL"); + clblast::RunTests, half, half>(argc, argv, true, "HSCAL"); return 0; } diff --git a/test/correctness/routines/level1/xswap.cc b/test/correctness/routines/level1/xswap.cc index 38f110f7..636a5b0f 100644 --- a/test/correctness/routines/level1/xswap.cc +++ b/test/correctness/routines/level1/xswap.cc @@ -22,6 +22,7 @@ int main(int argc, char *argv[]) { clblast::RunTests, double, double>(argc, argv, true, "DSWAP"); clblast::RunTests, float2, float2>(argc, argv, true, "CSWAP"); clblast::RunTests, double2, double2>(argc, argv, true, "ZSWAP"); + clblast::RunTests, half, half>(argc, argv, true, "HSWAP"); return 0; } diff --git a/test/correctness/routines/level2/xgbmv.cc b/test/correctness/routines/level2/xgbmv.cc index b28c5978..528a3325 100644 --- a/test/correctness/routines/level2/xgbmv.cc +++ b/test/correctness/routines/level2/xgbmv.cc @@ -22,6 +22,7 @@ int main(int argc, char *argv[]) { clblast::RunTests, double, double>(argc, argv, true, "DGBMV"); clblast::RunTests, float2, float2>(argc, argv, true, "CGBMV"); clblast::RunTests, double2, double2>(argc, argv, true, "ZGBMV"); + clblast::RunTests, half, half>(argc, argv, true, "HGBMV"); return 0; } diff --git a/test/correctness/routines/level2/xgemv.cc b/test/correctness/routines/level2/xgemv.cc index 14eb74d1..fc1cf3eb 100644 --- a/test/correctness/routines/level2/xgemv.cc +++ b/test/correctness/routines/level2/xgemv.cc @@ -22,6 +22,7 @@ int main(int argc, char *argv[]) { clblast::RunTests, double, double>(argc, argv, true, "DGEMV"); clblast::RunTests, float2, float2>(argc, argv, true, "CGEMV"); clblast::RunTests, double2, double2>(argc, argv, true, "ZGEMV"); + clblast::RunTests, half, half>(argc, argv, true, "HGEMV"); return 0; } diff --git a/test/correctness/routines/level2/xger.cc b/test/correctness/routines/level2/xger.cc index c37a5c41..c3c33ae6 100644 --- a/test/correctness/routines/level2/xger.cc +++ b/test/correctness/routines/level2/xger.cc @@ -20,6 +20,7 @@ using double2 = clblast::double2; int main(int argc, char *argv[]) { clblast::RunTests, float, float>(argc, argv, false, "SGER"); clblast::RunTests, double, double>(argc, argv, true, "DGER"); + clblast::RunTests, half, half>(argc, argv, true, "HGER"); return 0; } diff --git a/test/correctness/routines/level2/xsbmv.cc b/test/correctness/routines/level2/xsbmv.cc index 212e2c3a..c2effcc2 100644 --- a/test/correctness/routines/level2/xsbmv.cc +++ b/test/correctness/routines/level2/xsbmv.cc @@ -20,6 +20,7 @@ using double2 = clblast::double2; int main(int argc, char *argv[]) { clblast::RunTests, float, float>(argc, argv, false, "SSBMV"); clblast::RunTests, double, double>(argc, argv, true, "DSBMV"); + clblast::RunTests, half, half>(argc, argv, true, "HSBMV"); return 0; } diff --git a/test/correctness/routines/level2/xspmv.cc b/test/correctness/routines/level2/xspmv.cc index dc833024..4142636d 100644 --- a/test/correctness/routines/level2/xspmv.cc +++ b/test/correctness/routines/level2/xspmv.cc @@ -20,6 +20,7 @@ using double2 = clblast::double2; int main(int argc, char *argv[]) { clblast::RunTests, float, float>(argc, argv, false, "SSPMV"); clblast::RunTests, double, double>(argc, argv, true, "DSPMV"); + clblast::RunTests, half, half>(argc, argv, true, "HSPMV"); return 0; } diff --git a/test/correctness/routines/level2/xspr.cc b/test/correctness/routines/level2/xspr.cc index a0104dd4..c068b448 100644 --- a/test/correctness/routines/level2/xspr.cc +++ b/test/correctness/routines/level2/xspr.cc @@ -20,6 +20,7 @@ using double2 = clblast::double2; int main(int argc, char *argv[]) { clblast::RunTests, float, float>(argc, argv, false, "SSPR"); clblast::RunTests, double, double>(argc, argv, true, "DSPR"); + clblast::RunTests, half, half>(argc, argv, true, "HSPR"); return 0; } diff --git a/test/correctness/routines/level2/xspr2.cc b/test/correctness/routines/level2/xspr2.cc index 5fe5827f..904870d5 100644 --- a/test/correctness/routines/level2/xspr2.cc +++ b/test/correctness/routines/level2/xspr2.cc @@ -20,6 +20,7 @@ using double2 = clblast::double2; int main(int argc, char *argv[]) { clblast::RunTests, float, float>(argc, argv, false, "SSPR2"); clblast::RunTests, double, double>(argc, argv, true, "DSPR2"); + clblast::RunTests, half, half>(argc, argv, true, "HSPR2"); return 0; } diff --git a/test/correctness/routines/level2/xsymv.cc b/test/correctness/routines/level2/xsymv.cc index 6224739f..eb9b6eb7 100644 --- a/test/correctness/routines/level2/xsymv.cc +++ b/test/correctness/routines/level2/xsymv.cc @@ -20,6 +20,7 @@ using double2 = clblast::double2; int main(int argc, char *argv[]) { clblast::RunTests, float, float>(argc, argv, false, "SSYMV"); clblast::RunTests, double, double>(argc, argv, true, "DSYMV"); + clblast::RunTests, half, half>(argc, argv, true, "HSYMV"); return 0; } diff --git a/test/correctness/routines/level2/xsyr.cc b/test/correctness/routines/level2/xsyr.cc index a47b918f..eccf95e0 100644 --- a/test/correctness/routines/level2/xsyr.cc +++ b/test/correctness/routines/level2/xsyr.cc @@ -20,6 +20,7 @@ using double2 = clblast::double2; int main(int argc, char *argv[]) { clblast::RunTests, float, float>(argc, argv, false, "SSYR"); clblast::RunTests, double, double>(argc, argv, true, "DSYR"); + clblast::RunTests, half, half>(argc, argv, true, "HSYR"); return 0; } diff --git a/test/correctness/routines/level2/xsyr2.cc b/test/correctness/routines/level2/xsyr2.cc index 1743632c..46c939d2 100644 --- a/test/correctness/routines/level2/xsyr2.cc +++ b/test/correctness/routines/level2/xsyr2.cc @@ -20,6 +20,7 @@ using double2 = clblast::double2; int main(int argc, char *argv[]) { clblast::RunTests, float, float>(argc, argv, false, "SSYR2"); clblast::RunTests, double, double>(argc, argv, true, "DSYR2"); + clblast::RunTests, half, half>(argc, argv, true, "HSYR2"); return 0; } diff --git a/test/correctness/routines/level2/xtbmv.cc b/test/correctness/routines/level2/xtbmv.cc index d3bbbade..252abdc4 100644 --- a/test/correctness/routines/level2/xtbmv.cc +++ b/test/correctness/routines/level2/xtbmv.cc @@ -22,6 +22,7 @@ int main(int argc, char *argv[]) { clblast::RunTests, double, double>(argc, argv, true, "DTBMV"); clblast::RunTests, float2, float2>(argc, argv, true, "CTBMV"); clblast::RunTests, double2, double2>(argc, argv, true, "ZTBMV"); + clblast::RunTests, half, half>(argc, argv, true, "HTBMV"); return 0; } diff --git a/test/correctness/routines/level2/xtpmv.cc b/test/correctness/routines/level2/xtpmv.cc index 95489a65..b8776faa 100644 --- a/test/correctness/routines/level2/xtpmv.cc +++ b/test/correctness/routines/level2/xtpmv.cc @@ -22,6 +22,7 @@ int main(int argc, char *argv[]) { clblast::RunTests, double, double>(argc, argv, true, "DTPMV"); clblast::RunTests, float2, float2>(argc, argv, true, "CTPMV"); clblast::RunTests, double2, double2>(argc, argv, true, "ZTPMV"); + clblast::RunTests, half, half>(argc, argv, true, "HTPMV"); return 0; } diff --git a/test/correctness/routines/level2/xtrmv.cc b/test/correctness/routines/level2/xtrmv.cc index ca50af88..256fe900 100644 --- a/test/correctness/routines/level2/xtrmv.cc +++ b/test/correctness/routines/level2/xtrmv.cc @@ -22,6 +22,7 @@ int main(int argc, char *argv[]) { clblast::RunTests, double, double>(argc, argv, true, "DTRMV"); clblast::RunTests, float2, float2>(argc, argv, true, "CTRMV"); clblast::RunTests, double2, double2>(argc, argv, true, "ZTRMV"); + clblast::RunTests, half, half>(argc, argv, true, "HTRMV"); return 0; } diff --git a/test/correctness/routines/level3/xgemm.cc b/test/correctness/routines/level3/xgemm.cc index 632724ed..f8c8a891 100644 --- a/test/correctness/routines/level3/xgemm.cc +++ b/test/correctness/routines/level3/xgemm.cc @@ -22,6 +22,7 @@ int main(int argc, char *argv[]) { clblast::RunTests, double, double>(argc, argv, true, "DGEMM"); clblast::RunTests, float2, float2>(argc, argv, true, "CGEMM"); clblast::RunTests, double2, double2>(argc, argv, true, "ZGEMM"); + clblast::RunTests, half, half>(argc, argv, true, "HGEMM"); return 0; } diff --git a/test/correctness/routines/level3/xsymm.cc b/test/correctness/routines/level3/xsymm.cc index 046fca16..c29f03dd 100644 --- a/test/correctness/routines/level3/xsymm.cc +++ b/test/correctness/routines/level3/xsymm.cc @@ -22,6 +22,7 @@ int main(int argc, char *argv[]) { clblast::RunTests, double, double>(argc, argv, true, "DSYMM"); clblast::RunTests, float2, float2>(argc, argv, true, "CSYMM"); clblast::RunTests, double2, double2>(argc, argv, true, "ZSYMM"); + clblast::RunTests, half, half>(argc, argv, true, "HSYMM"); return 0; } diff --git a/test/correctness/routines/level3/xsyr2k.cc b/test/correctness/routines/level3/xsyr2k.cc index db2b83d9..9f9c87d8 100644 --- a/test/correctness/routines/level3/xsyr2k.cc +++ b/test/correctness/routines/level3/xsyr2k.cc @@ -22,6 +22,7 @@ int main(int argc, char *argv[]) { clblast::RunTests, double, double>(argc, argv, true, "DSYR2K"); clblast::RunTests, float2, float2>(argc, argv, true, "CSYR2K"); clblast::RunTests, double2, double2>(argc, argv, true, "ZSYR2K"); + clblast::RunTests, half, half>(argc, argv, true, "HSYR2K"); return 0; } diff --git a/test/correctness/routines/level3/xsyrk.cc b/test/correctness/routines/level3/xsyrk.cc index 3dad3535..12343074 100644 --- a/test/correctness/routines/level3/xsyrk.cc +++ b/test/correctness/routines/level3/xsyrk.cc @@ -22,6 +22,7 @@ int main(int argc, char *argv[]) { clblast::RunTests, double, double>(argc, argv, true, "DSYRK"); clblast::RunTests, float2, float2>(argc, argv, true, "CSYRK"); clblast::RunTests, double2, double2>(argc, argv, true, "ZSYRK"); + clblast::RunTests, half, half>(argc, argv, true, "HSYRK"); return 0; } diff --git a/test/correctness/routines/level3/xtrmm.cc b/test/correctness/routines/level3/xtrmm.cc index 2d843e3e..aca73f0d 100644 --- a/test/correctness/routines/level3/xtrmm.cc +++ b/test/correctness/routines/level3/xtrmm.cc @@ -22,6 +22,7 @@ int main(int argc, char *argv[]) { clblast::RunTests, double, double>(argc, argv, true, "DTRMM"); clblast::RunTests, float2, float2>(argc, argv, true, "CTRMM"); clblast::RunTests, double2, double2>(argc, argv, true, "ZTRMM"); + clblast::RunTests, half, half>(argc, argv, true, "HTRMM"); return 0; } diff --git a/test/correctness/routines/level3/xtrsm.cc b/test/correctness/routines/level3/xtrsm.cc index b5f5045e..b050269a 100644 --- a/test/correctness/routines/level3/xtrsm.cc +++ b/test/correctness/routines/level3/xtrsm.cc @@ -22,6 +22,7 @@ int main(int argc, char *argv[]) { clblast::RunTests, double, double>(argc, argv, true, "DTRSM"); clblast::RunTests, float2, float2>(argc, argv, true, "CTRSM"); clblast::RunTests, double2, double2>(argc, argv, true, "ZTRSM"); + clblast::RunTests, half, half>(argc, argv, true, "HTRSM"); return 0; } diff --git a/test/correctness/testblas.cc b/test/correctness/testblas.cc index e70c0361..cbf8b0a0 100644 --- a/test/correctness/testblas.cc +++ b/test/correctness/testblas.cc @@ -20,6 +20,7 @@ namespace clblast { // ================================================================================================= // The transpose-options to test with (data-type dependent) +template <> const std::vector TestBlas::kTransposes = {Transpose::kNo, Transpose::kYes}; template <> const std::vector TestBlas::kTransposes = {Transpose::kNo, Transpose::kYes}; template <> const std::vector TestBlas::kTransposes = {Transpose::kNo, Transpose::kYes}; template <> const std::vector TestBlas::kTransposes = {Transpose::kNo, Transpose::kYes, Transpose::kConjugate}; @@ -147,10 +148,8 @@ void TestBlas::TestRegular(std::vector> &test_vector, const st if (verbose_) { if (get_id2_(args) == 1) { fprintf(stdout, "\n Error at index %zu: ", id1); } else { fprintf(stdout, "\n Error at %zu,%zu: ", id1, id2); } - std::cout << result1[index]; - fprintf(stdout, " (reference) versus "); - std::cout << result2[index]; - fprintf(stdout, " (CLBlast)"); + fprintf(stdout, " %s (reference) versus ", ToString(result1[index]).c_str()); + fprintf(stdout, " %s (CLBlast)", ToString(result2[index]).c_str()); } } } @@ -222,6 +221,7 @@ void TestBlas::TestInvalid(std::vector> &test_vector, const st // ================================================================================================= // Compiles the templated class +template class TestBlas; template class TestBlas; template class TestBlas; template class TestBlas; diff --git a/test/correctness/tester.cc b/test/correctness/tester.cc index 85ae7091..5b603585 100644 --- a/test/correctness/tester.cc +++ b/test/correctness/tester.cc @@ -351,11 +351,11 @@ bool TestSimilarity(const T val1, const T val2) { } } -// Compiles the default case for non-complex data-types +// Compiles the default case for standard data-types template bool TestSimilarity(const float, const float); template bool TestSimilarity(const double, const double); -// Specialisations for complex data-types +// Specialisations for non-standard data-types template <> bool TestSimilarity(const float2 val1, const float2 val2) { auto real = TestSimilarity(val1.real(), val2.real()); @@ -368,6 +368,10 @@ bool TestSimilarity(const double2 val1, const double2 val2) { auto imag = TestSimilarity(val1.imag(), val2.imag()); return (real && imag); } +template <> +bool TestSimilarity(const half val1, const half val2) { + return TestSimilarity(HalfToFloat(val1), HalfToFloat(val2)); +} // ================================================================================================= @@ -389,10 +393,15 @@ template <> const std::vector GetExampleScalars(const bool full_test) { if (full_test) { return {{0.0, 0.0}, {1.0, 1.3}, {2.42, 3.14}}; } else { return {{2.42, 3.14}}; } } +template <> const std::vector GetExampleScalars(const bool full_test) { + if (full_test) { return {FloatToHalf(0.0f), FloatToHalf(1.0f), FloatToHalf(3.14f)}; } + else { return {FloatToHalf(3.14f)}; } +} // ================================================================================================= // Compiles the templated class +template class Tester; template class Tester; template class Tester; template class Tester; diff --git a/test/wrapper_cblas.h b/test/wrapper_cblas.h index 2fcab4d0..06ce6269 100644 --- a/test/wrapper_cblas.h +++ b/test/wrapper_cblas.h @@ -31,6 +31,16 @@ CBLAS_UPLO convertToCBLAS(const Triangle v) { return (v == Triangle::kUpper) ? C CBLAS_DIAG convertToCBLAS(const Diagonal v) { return (v == Diagonal::kUnit) ? CblasUnit : CblasNonUnit; } CBLAS_SIDE convertToCBLAS(const Side v) { return (v == Side::kLeft) ? CblasLeft : CblasRight; } +// Conversions from and to half-precision +std::vector HalfToFloatBuffer(const std::vector& source) { + auto result = std::vector(source.size()); + for (auto i = size_t(0); i < source.size(); ++i) { result[i] = HalfToFloat(source[i]); } + return result; +} +void FloatToHalfBuffer(std::vector& result, const std::vector& source) { + for (auto i = size_t(0); i < source.size(); ++i) { result[i] = FloatToHalf(source[i]); } +} + // OpenBLAS is not fully Netlib CBLAS compatible #ifdef OPENBLAS_VERSION using return_pointer_float = openblas_complex_float*; @@ -164,7 +174,13 @@ void cblasXswap(const size_t n, void cblasXswap(const size_t n, std::vector& x_buffer, const size_t x_offset, const size_t x_inc, std::vector& y_buffer, const size_t y_offset, const size_t y_inc) { - return; + auto x_buffer_bis = HalfToFloatBuffer(x_buffer); + auto y_buffer_bis = HalfToFloatBuffer(y_buffer); + cblasXswap(n, + x_buffer_bis, x_offset, x_inc, + y_buffer_bis, y_offset, y_inc); + FloatToHalfBuffer(x_buffer, x_buffer_bis); + FloatToHalfBuffer(y_buffer, y_buffer_bis); } // Forwards the Netlib BLAS calls for SSCAL/DSCAL/CSCAL/ZSCAL @@ -201,7 +217,11 @@ void cblasXscal(const size_t n, void cblasXscal(const size_t n, const half alpha, std::vector& x_buffer, const size_t x_offset, const size_t x_inc) { - return; + auto x_buffer_bis = HalfToFloatBuffer(x_buffer); + cblasXscal(n, + HalfToFloat(alpha), + x_buffer_bis, x_offset, x_inc); + FloatToHalfBuffer(x_buffer, x_buffer_bis); } // Forwards the Netlib BLAS calls for SCOPY/DCOPY/CCOPY/ZCOPY @@ -236,7 +256,12 @@ void cblasXcopy(const size_t n, void cblasXcopy(const size_t n, const std::vector& x_buffer, const size_t x_offset, const size_t x_inc, std::vector& y_buffer, const size_t y_offset, const size_t y_inc) { - return; + auto x_buffer_bis = HalfToFloatBuffer(x_buffer); + auto y_buffer_bis = HalfToFloatBuffer(y_buffer); + cblasXcopy(n, + x_buffer_bis, x_offset, x_inc, + y_buffer_bis, y_offset, y_inc); + FloatToHalfBuffer(y_buffer, y_buffer_bis); } // Forwards the Netlib BLAS calls for SAXPY/DAXPY/CAXPY/ZAXPY @@ -282,7 +307,13 @@ void cblasXaxpy(const size_t n, const half alpha, const std::vector& x_buffer, const size_t x_offset, const size_t x_inc, std::vector& y_buffer, const size_t y_offset, const size_t y_inc) { - return; + auto x_buffer_bis = HalfToFloatBuffer(x_buffer); + auto y_buffer_bis = HalfToFloatBuffer(y_buffer); + cblasXaxpy(n, + HalfToFloat(alpha), + x_buffer_bis, x_offset, x_inc, + y_buffer_bis, y_offset, y_inc); + FloatToHalfBuffer(y_buffer, y_buffer_bis); } // Forwards the Netlib BLAS calls for SDOT/DDOT @@ -306,7 +337,14 @@ void cblasXdot(const size_t n, std::vector& dot_buffer, const size_t dot_offset, const std::vector& x_buffer, const size_t x_offset, const size_t x_inc, const std::vector& y_buffer, const size_t y_offset, const size_t y_inc) { - return; + auto x_buffer_bis = HalfToFloatBuffer(x_buffer); + auto y_buffer_bis = HalfToFloatBuffer(y_buffer); + auto dot_buffer_bis = HalfToFloatBuffer(dot_buffer); + cblasXdot(n, + dot_buffer_bis, dot_offset, + x_buffer_bis, x_offset, x_inc, + y_buffer_bis, y_offset, y_inc); + FloatToHalfBuffer(dot_buffer, dot_buffer_bis); } // Forwards the Netlib BLAS calls for CDOTU/ZDOTU @@ -377,7 +415,12 @@ void cblasXnrm2(const size_t n, void cblasXnrm2(const size_t n, std::vector& nrm2_buffer, const size_t nrm2_offset, const std::vector& x_buffer, const size_t x_offset, const size_t x_inc) { - return; + auto x_buffer_bis = HalfToFloatBuffer(x_buffer); + auto nrm2_buffer_bis = HalfToFloatBuffer(nrm2_buffer); + cblasXnrm2(n, + nrm2_buffer_bis, nrm2_offset, + x_buffer_bis, x_offset, x_inc); + FloatToHalfBuffer(nrm2_buffer, nrm2_buffer_bis); } // Forwards the Netlib BLAS calls for SASUM/DASUM/ScASUM/DzASUM @@ -408,7 +451,12 @@ void cblasXasum(const size_t n, void cblasXasum(const size_t n, std::vector& asum_buffer, const size_t asum_offset, const std::vector& x_buffer, const size_t x_offset, const size_t x_inc) { - return; + auto x_buffer_bis = HalfToFloatBuffer(x_buffer); + auto asum_buffer_bis = HalfToFloatBuffer(asum_buffer); + cblasXasum(n, + asum_buffer_bis, asum_offset, + x_buffer_bis, x_offset, x_inc); + FloatToHalfBuffer(asum_buffer, asum_buffer_bis); } // Forwards the Netlib BLAS calls for iSAMAX/iDAMAX/iCAMAX/iZAMAX/iHAMAX @@ -439,7 +487,12 @@ void cblasXamax(const size_t n, void cblasXamax(const size_t n, std::vector& imax_buffer, const size_t imax_offset, const std::vector& x_buffer, const size_t x_offset, const size_t x_inc) { - return; + auto x_buffer_bis = HalfToFloatBuffer(x_buffer); + auto imax_buffer_bis = HalfToFloatBuffer(imax_buffer); + cblasXamax(n, + imax_buffer_bis, imax_offset, + x_buffer_bis, x_offset, x_inc); + FloatToHalfBuffer(imax_buffer, imax_buffer_bis); } // ================================================================================================= @@ -518,7 +571,17 @@ void cblasXgemv(const CBLAS_ORDER layout, const CBLAS_TRANSPOSE a_transpose, const std::vector& x_buffer, const size_t x_offset, const size_t x_inc, const half beta, std::vector& y_buffer, const size_t y_offset, const size_t y_inc) { - return; + auto a_buffer_bis = HalfToFloatBuffer(a_buffer); + auto x_buffer_bis = HalfToFloatBuffer(x_buffer); + auto y_buffer_bis = HalfToFloatBuffer(y_buffer); + cblasXgemv(layout, a_transpose, + m, n, + HalfToFloat(alpha), + a_buffer_bis, a_offset, a_ld, + x_buffer_bis, x_offset, x_inc, + HalfToFloat(beta), + y_buffer_bis, y_offset, y_inc); + FloatToHalfBuffer(y_buffer, y_buffer_bis); } // Forwards the Netlib BLAS calls for SGBMV/DGBMV/CGBMV/ZGBMV @@ -593,7 +656,17 @@ void cblasXgbmv(const CBLAS_ORDER layout, const CBLAS_TRANSPOSE a_transpose, const std::vector& x_buffer, const size_t x_offset, const size_t x_inc, const half beta, std::vector& y_buffer, const size_t y_offset, const size_t y_inc) { - return; + auto a_buffer_bis = HalfToFloatBuffer(a_buffer); + auto x_buffer_bis = HalfToFloatBuffer(x_buffer); + auto y_buffer_bis = HalfToFloatBuffer(y_buffer); + cblasXgbmv(layout, a_transpose, + m, n, kl, ku, + HalfToFloat(alpha), + a_buffer_bis, a_offset, a_ld, + x_buffer_bis, x_offset, x_inc, + HalfToFloat(beta), + y_buffer_bis, y_offset, y_inc); + FloatToHalfBuffer(y_buffer, y_buffer_bis); } // Forwards the Netlib BLAS calls for CHEMV/ZHEMV @@ -742,7 +815,17 @@ void cblasXsymv(const CBLAS_ORDER layout, const CBLAS_UPLO triangle, const std::vector& x_buffer, const size_t x_offset, const size_t x_inc, const half beta, std::vector& y_buffer, const size_t y_offset, const size_t y_inc) { - return; + auto a_buffer_bis = HalfToFloatBuffer(a_buffer); + auto x_buffer_bis = HalfToFloatBuffer(x_buffer); + auto y_buffer_bis = HalfToFloatBuffer(y_buffer); + cblasXsymv(layout, triangle, + n, + HalfToFloat(alpha), + a_buffer_bis, a_offset, a_ld, + x_buffer_bis, x_offset, x_inc, + HalfToFloat(beta), + y_buffer_bis, y_offset, y_inc); + FloatToHalfBuffer(y_buffer, y_buffer_bis); } // Forwards the Netlib BLAS calls for SSBMV/DSBMV @@ -783,7 +866,17 @@ void cblasXsbmv(const CBLAS_ORDER layout, const CBLAS_UPLO triangle, const std::vector& x_buffer, const size_t x_offset, const size_t x_inc, const half beta, std::vector& y_buffer, const size_t y_offset, const size_t y_inc) { - return; + auto a_buffer_bis = HalfToFloatBuffer(a_buffer); + auto x_buffer_bis = HalfToFloatBuffer(x_buffer); + auto y_buffer_bis = HalfToFloatBuffer(y_buffer); + cblasXsbmv(layout, triangle, + n, k, + HalfToFloat(alpha), + a_buffer_bis, a_offset, a_ld, + x_buffer_bis, x_offset, x_inc, + HalfToFloat(beta), + y_buffer_bis, y_offset, y_inc); + FloatToHalfBuffer(y_buffer, y_buffer_bis); } // Forwards the Netlib BLAS calls for SSPMV/DSPMV @@ -824,7 +917,17 @@ void cblasXspmv(const CBLAS_ORDER layout, const CBLAS_UPLO triangle, const std::vector& x_buffer, const size_t x_offset, const size_t x_inc, const half beta, std::vector& y_buffer, const size_t y_offset, const size_t y_inc) { - return; + auto ap_buffer_bis = HalfToFloatBuffer(ap_buffer); + auto x_buffer_bis = HalfToFloatBuffer(x_buffer); + auto y_buffer_bis = HalfToFloatBuffer(y_buffer); + cblasXspmv(layout, triangle, + n, + HalfToFloat(alpha), + ap_buffer_bis, ap_offset, + x_buffer_bis, x_offset, x_inc, + HalfToFloat(beta), + y_buffer_bis, y_offset, y_inc); + FloatToHalfBuffer(y_buffer, y_buffer_bis); } // Forwards the Netlib BLAS calls for STRMV/DTRMV/CTRMV/ZTRMV @@ -868,7 +971,13 @@ void cblasXtrmv(const CBLAS_ORDER layout, const CBLAS_UPLO triangle, const CBLAS const size_t n, const std::vector& a_buffer, const size_t a_offset, const size_t a_ld, std::vector& x_buffer, const size_t x_offset, const size_t x_inc) { - return; + auto a_buffer_bis = HalfToFloatBuffer(a_buffer); + auto x_buffer_bis = HalfToFloatBuffer(x_buffer); + cblasXtrmv(layout, triangle, a_transpose, diagonal, + n, + a_buffer_bis, a_offset, a_ld, + x_buffer_bis, x_offset, x_inc); + FloatToHalfBuffer(x_buffer, x_buffer_bis); } // Forwards the Netlib BLAS calls for STBMV/DTBMV/CTBMV/ZTBMV @@ -912,7 +1021,13 @@ void cblasXtbmv(const CBLAS_ORDER layout, const CBLAS_UPLO triangle, const CBLAS const size_t n, const size_t k, const std::vector& a_buffer, const size_t a_offset, const size_t a_ld, std::vector& x_buffer, const size_t x_offset, const size_t x_inc) { - return; + auto a_buffer_bis = HalfToFloatBuffer(a_buffer); + auto x_buffer_bis = HalfToFloatBuffer(x_buffer); + cblasXtbmv(layout, triangle, a_transpose, diagonal, + n, k, + a_buffer_bis, a_offset, a_ld, + x_buffer_bis, x_offset, x_inc); + FloatToHalfBuffer(x_buffer, x_buffer_bis); } // Forwards the Netlib BLAS calls for STPMV/DTPMV/CTPMV/ZTPMV @@ -956,7 +1071,13 @@ void cblasXtpmv(const CBLAS_ORDER layout, const CBLAS_UPLO triangle, const CBLAS const size_t n, const std::vector& ap_buffer, const size_t ap_offset, std::vector& x_buffer, const size_t x_offset, const size_t x_inc) { - return; + auto ap_buffer_bis = HalfToFloatBuffer(ap_buffer); + auto x_buffer_bis = HalfToFloatBuffer(x_buffer); + cblasXtpmv(layout, triangle, a_transpose, diagonal, + n, + ap_buffer_bis, ap_offset, + x_buffer_bis, x_offset, x_inc); + FloatToHalfBuffer(x_buffer, x_buffer_bis); } // Forwards the Netlib BLAS calls for STRSV/DTRSV/CTRSV/ZTRSV @@ -1106,7 +1227,16 @@ void cblasXger(const CBLAS_ORDER layout, const std::vector& x_buffer, const size_t x_offset, const size_t x_inc, const std::vector& y_buffer, const size_t y_offset, const size_t y_inc, std::vector& a_buffer, const size_t a_offset, const size_t a_ld) { - return; + auto x_buffer_bis = HalfToFloatBuffer(x_buffer); + auto y_buffer_bis = HalfToFloatBuffer(y_buffer); + auto a_buffer_bis = HalfToFloatBuffer(a_buffer); + cblasXger(layout, + m, n, + HalfToFloat(alpha), + x_buffer_bis, x_offset, x_inc, + y_buffer_bis, y_offset, y_inc, + a_buffer_bis, a_offset, a_ld); + FloatToHalfBuffer(a_buffer, a_buffer_bis); } // Forwards the Netlib BLAS calls for CGERU/ZGERU @@ -1305,7 +1435,14 @@ void cblasXsyr(const CBLAS_ORDER layout, const CBLAS_UPLO triangle, const half alpha, const std::vector& x_buffer, const size_t x_offset, const size_t x_inc, std::vector& a_buffer, const size_t a_offset, const size_t a_ld) { - return; + auto x_buffer_bis = HalfToFloatBuffer(x_buffer); + auto a_buffer_bis = HalfToFloatBuffer(a_buffer); + cblasXsyr(layout, triangle, + n, + HalfToFloat(alpha), + x_buffer_bis, x_offset, x_inc, + a_buffer_bis, a_offset, a_ld); + FloatToHalfBuffer(a_buffer, a_buffer_bis); } // Forwards the Netlib BLAS calls for SSPR/DSPR @@ -1336,7 +1473,14 @@ void cblasXspr(const CBLAS_ORDER layout, const CBLAS_UPLO triangle, const half alpha, const std::vector& x_buffer, const size_t x_offset, const size_t x_inc, std::vector& ap_buffer, const size_t ap_offset) { - return; + auto x_buffer_bis = HalfToFloatBuffer(x_buffer); + auto ap_buffer_bis = HalfToFloatBuffer(ap_buffer); + cblasXspr(layout, triangle, + n, + HalfToFloat(alpha), + x_buffer_bis, x_offset, x_inc, + ap_buffer_bis, ap_offset); + FloatToHalfBuffer(ap_buffer, ap_buffer_bis); } // Forwards the Netlib BLAS calls for SSYR2/DSYR2 @@ -1372,7 +1516,16 @@ void cblasXsyr2(const CBLAS_ORDER layout, const CBLAS_UPLO triangle, const std::vector& x_buffer, const size_t x_offset, const size_t x_inc, const std::vector& y_buffer, const size_t y_offset, const size_t y_inc, std::vector& a_buffer, const size_t a_offset, const size_t a_ld) { - return; + auto x_buffer_bis = HalfToFloatBuffer(x_buffer); + auto y_buffer_bis = HalfToFloatBuffer(y_buffer); + auto a_buffer_bis = HalfToFloatBuffer(a_buffer); + cblasXsyr2(layout, triangle, + n, + HalfToFloat(alpha), + x_buffer_bis, x_offset, x_inc, + y_buffer_bis, y_offset, y_inc, + a_buffer_bis, a_offset, a_ld); + FloatToHalfBuffer(a_buffer, a_buffer_bis); } // Forwards the Netlib BLAS calls for SSPR2/DSPR2 @@ -1408,7 +1561,16 @@ void cblasXspr2(const CBLAS_ORDER layout, const CBLAS_UPLO triangle, const std::vector& x_buffer, const size_t x_offset, const size_t x_inc, const std::vector& y_buffer, const size_t y_offset, const size_t y_inc, std::vector& ap_buffer, const size_t ap_offset) { - return; + auto x_buffer_bis = HalfToFloatBuffer(x_buffer); + auto y_buffer_bis = HalfToFloatBuffer(y_buffer); + auto ap_buffer_bis = HalfToFloatBuffer(ap_buffer); + cblasXspr2(layout, triangle, + n, + HalfToFloat(alpha), + x_buffer_bis, x_offset, x_inc, + y_buffer_bis, y_offset, y_inc, + ap_buffer_bis, ap_offset); + FloatToHalfBuffer(ap_buffer, ap_buffer_bis); } // ================================================================================================= @@ -1487,7 +1649,17 @@ void cblasXgemm(const CBLAS_ORDER layout, const CBLAS_TRANSPOSE a_transpose, con const std::vector& b_buffer, const size_t b_offset, const size_t b_ld, const half beta, std::vector& c_buffer, const size_t c_offset, const size_t c_ld) { - return; + auto a_buffer_bis = HalfToFloatBuffer(a_buffer); + auto b_buffer_bis = HalfToFloatBuffer(b_buffer); + auto c_buffer_bis = HalfToFloatBuffer(c_buffer); + cblasXgemm(layout, a_transpose, b_transpose, + m, n, k, + HalfToFloat(alpha), + a_buffer_bis, a_offset, a_ld, + b_buffer_bis, b_offset, b_ld, + HalfToFloat(beta), + c_buffer_bis, c_offset, c_ld); + FloatToHalfBuffer(c_buffer, c_buffer_bis); } // Forwards the Netlib BLAS calls for SSYMM/DSYMM/CSYMM/ZSYMM @@ -1562,7 +1734,17 @@ void cblasXsymm(const CBLAS_ORDER layout, const CBLAS_SIDE side, const CBLAS_UPL const std::vector& b_buffer, const size_t b_offset, const size_t b_ld, const half beta, std::vector& c_buffer, const size_t c_offset, const size_t c_ld) { - return; + auto a_buffer_bis = HalfToFloatBuffer(a_buffer); + auto b_buffer_bis = HalfToFloatBuffer(b_buffer); + auto c_buffer_bis = HalfToFloatBuffer(c_buffer); + cblasXsymm(layout, side, triangle, + m, n, + HalfToFloat(alpha), + a_buffer_bis, a_offset, a_ld, + b_buffer_bis, b_offset, b_ld, + HalfToFloat(beta), + c_buffer_bis, c_offset, c_ld); + FloatToHalfBuffer(c_buffer, c_buffer_bis); } // Forwards the Netlib BLAS calls for CHEMM/ZHEMM @@ -1664,7 +1846,15 @@ void cblasXsyrk(const CBLAS_ORDER layout, const CBLAS_UPLO triangle, const CBLAS const std::vector& a_buffer, const size_t a_offset, const size_t a_ld, const half beta, std::vector& c_buffer, const size_t c_offset, const size_t c_ld) { - return; + auto a_buffer_bis = HalfToFloatBuffer(a_buffer); + auto c_buffer_bis = HalfToFloatBuffer(c_buffer); + cblasXsyrk(layout, triangle, a_transpose, + n, k, + HalfToFloat(alpha), + a_buffer_bis, a_offset, a_ld, + HalfToFloat(beta), + c_buffer_bis, c_offset, c_ld); + FloatToHalfBuffer(c_buffer, c_buffer_bis); } // Forwards the Netlib BLAS calls for CHERK/ZHERK @@ -1767,7 +1957,17 @@ void cblasXsyr2k(const CBLAS_ORDER layout, const CBLAS_UPLO triangle, const CBLA const std::vector& b_buffer, const size_t b_offset, const size_t b_ld, const half beta, std::vector& c_buffer, const size_t c_offset, const size_t c_ld) { - return; + auto a_buffer_bis = HalfToFloatBuffer(a_buffer); + auto b_buffer_bis = HalfToFloatBuffer(b_buffer); + auto c_buffer_bis = HalfToFloatBuffer(c_buffer); + cblasXsyr2k(layout, triangle, ab_transpose, + n, k, + HalfToFloat(alpha), + a_buffer_bis, a_offset, a_ld, + b_buffer_bis, b_offset, b_ld, + HalfToFloat(beta), + c_buffer_bis, c_offset, c_ld); + FloatToHalfBuffer(c_buffer, c_buffer_bis); } // Forwards the Netlib BLAS calls for CHER2K/ZHER2K @@ -1856,7 +2056,14 @@ void cblasXtrmm(const CBLAS_ORDER layout, const CBLAS_SIDE side, const CBLAS_UPL const half alpha, const std::vector& a_buffer, const size_t a_offset, const size_t a_ld, std::vector& b_buffer, const size_t b_offset, const size_t b_ld) { - return; + auto a_buffer_bis = HalfToFloatBuffer(a_buffer); + auto b_buffer_bis = HalfToFloatBuffer(b_buffer); + cblasXtrmm(layout, side, triangle, a_transpose, diagonal, + m, n, + HalfToFloat(alpha), + a_buffer_bis, a_offset, a_ld, + b_buffer_bis, b_offset, b_ld); + FloatToHalfBuffer(b_buffer, b_buffer_bis); } // Forwards the Netlib BLAS calls for STRSM/DTRSM/CTRSM/ZTRSM @@ -1911,7 +2118,14 @@ void cblasXtrsm(const CBLAS_ORDER layout, const CBLAS_SIDE side, const CBLAS_UPL const half alpha, const std::vector& a_buffer, const size_t a_offset, const size_t a_ld, std::vector& b_buffer, const size_t b_offset, const size_t b_ld) { - return; + auto a_buffer_bis = HalfToFloatBuffer(a_buffer); + auto b_buffer_bis = HalfToFloatBuffer(b_buffer); + cblasXtrsm(layout, side, triangle, a_transpose, diagonal, + m, n, + HalfToFloat(alpha), + a_buffer_bis, a_offset, a_ld, + b_buffer_bis, b_offset, b_ld); + FloatToHalfBuffer(b_buffer, b_buffer_bis); } // =================================================================================================