mirror of
https://github.com/CNugteren/CLBlast.git
synced 2024-07-07 12:23:46 +02:00
Added half-precision tests for the CBLAS reference through conversion to single-precison
This commit is contained in:
parent
4612ff3552
commit
b487d4dd44
|
@ -259,12 +259,14 @@ def wrapper_cblas(routines):
|
|||
if routine.has_tests:
|
||||
result += "\n// Forwards the Netlib BLAS calls for %s\n" % (routine.ShortNamesTested())
|
||||
for flavour in routine.flavours:
|
||||
indent = " "*(10 + routine.Length())
|
||||
result += routine.RoutineHeaderWrapperC(flavour, False, 12)+" {\n"
|
||||
|
||||
# There is a version available in CBLAS
|
||||
if flavour.precision_name in ["S","D","C","Z"]:
|
||||
indent = " "*(10 + routine.Length())
|
||||
arguments = routine.ArgumentsWrapperC(flavour)
|
||||
|
||||
# Double-precision scalars
|
||||
# Complex scalars
|
||||
for scalar in routine.scalars:
|
||||
if flavour.IsComplex(scalar):
|
||||
result += " const auto "+scalar+"_array = std::vector<"+flavour.buffertype[:-1]+">{"+scalar+".real(), "+scalar+".imag()};\n"
|
||||
|
@ -294,10 +296,27 @@ def wrapper_cblas(routines):
|
|||
|
||||
result += " "+assignment+"cblas_"+flavour.name.lower()+routine.name+postfix+"("
|
||||
result += (",\n"+indent).join([a for a in arguments])
|
||||
result += extra_argument+endofline+");"
|
||||
else:
|
||||
result += " return;"
|
||||
result += "\n}\n"
|
||||
result += extra_argument+endofline+");\n"
|
||||
|
||||
# There is no CBLAS available, forward the call to one of the available functions
|
||||
else: # Half-precision
|
||||
indent = " "*(9 + routine.Length())
|
||||
|
||||
# Convert to float (note: also integer buffers are stored as half/float)
|
||||
for buf in routine.inputs + routine.outputs:
|
||||
result += " auto "+buf+"_buffer_bis = HalfToFloatBuffer("+buf+"_buffer);\n"
|
||||
|
||||
# Call the float routine
|
||||
result += " cblasX"+routine.name+"("
|
||||
result += (",\n"+indent).join([a for a in routine.ArgumentsHalf()])
|
||||
result += ");\n"
|
||||
|
||||
# Convert back to half
|
||||
for buf in routine.outputs:
|
||||
result += " FloatToHalfBuffer("+buf+"_buffer, "+buf+"_buffer_bis);\n"
|
||||
|
||||
# Complete
|
||||
result += "}\n"
|
||||
return result
|
||||
|
||||
# ==================================================================================================
|
||||
|
@ -317,7 +336,7 @@ files = [
|
|||
path_clblast+"/test/wrapper_clblas.h",
|
||||
path_clblast+"/test/wrapper_cblas.h",
|
||||
]
|
||||
header_lines = [84, 71, 93, 22, 29, 41]
|
||||
header_lines = [84, 71, 93, 22, 29, 51]
|
||||
footer_lines = [17, 71, 19, 14, 6, 6]
|
||||
|
||||
# Checks whether the command-line arguments are valid; exists otherwise
|
||||
|
@ -376,10 +395,9 @@ for level in [1,2,3]:
|
|||
body += "int main(int argc, char *argv[]) {\n"
|
||||
not_first = "false"
|
||||
for flavour in routine.flavours:
|
||||
if flavour.precision_name in ["S","D","C","Z"]:
|
||||
body += " clblast::RunTests<clblast::TestX"+routine.name+flavour.TestTemplate()
|
||||
body += ">(argc, argv, "+not_first+", \""+flavour.name+routine.name.upper()+"\");\n"
|
||||
not_first = "true"
|
||||
body += " clblast::RunTests<clblast::TestX"+routine.name+flavour.TestTemplate()
|
||||
body += ">(argc, argv, "+not_first+", \""+flavour.name+routine.name.upper()+"\");\n"
|
||||
not_first = "true"
|
||||
body += " return 0;\n"
|
||||
body += "}\n"
|
||||
f.write(header+"\n")
|
||||
|
|
|
@ -99,6 +99,18 @@ class Routine():
|
|||
def IndexBuffers(self):
|
||||
return ["imax","imin"]
|
||||
|
||||
# Lists of input/output buffers not index (integer)
|
||||
def NonIndexInputs(self):
|
||||
buffers = self.inputs[:] # make a copy
|
||||
for i in self.IndexBuffers():
|
||||
if i in buffers: buffers.remove(i)
|
||||
return buffers
|
||||
def NonIndexOutputs(self):
|
||||
buffers = self.outputs[:] # make a copy
|
||||
for i in self.IndexBuffers():
|
||||
if i in buffers: buffers.remove(i)
|
||||
return buffers
|
||||
|
||||
# List of buffers without 'inc' or 'ld'
|
||||
def BuffersWithoutLdInc(self):
|
||||
return self.ScalarBuffersFirst() + self.ScalarBuffersSecond() + ["ap"]
|
||||
|
@ -152,6 +164,17 @@ class Routine():
|
|||
return [", ".join(a+b+c)]
|
||||
return []
|
||||
|
||||
# As above but with a '_bis' suffix for the buffer name
|
||||
def BufferBis(self, name):
|
||||
#if (name in self.IndexBuffers()):
|
||||
# return self.Buffer(name)
|
||||
if (name in self.inputs) or (name in self.outputs):
|
||||
a = [name+"_buffer_bis"]
|
||||
b = [name+"_offset"]
|
||||
c = [name+"_"+self.Postfix(name)] if (name not in self.BuffersWithoutLdInc()) else []
|
||||
return [", ".join(a+b+c)]
|
||||
return []
|
||||
|
||||
# As above but with data-types
|
||||
def BufferDef(self, name):
|
||||
prefix = "const " if (name in self.inputs) else ""
|
||||
|
@ -244,6 +267,12 @@ class Routine():
|
|||
return [name]
|
||||
return []
|
||||
|
||||
# As above, but converts from float to half
|
||||
def ScalarHalfToFloat(self, name):
|
||||
if name in self.scalars:
|
||||
return ["HalfToFloat("+name+")"]
|
||||
return []
|
||||
|
||||
# Retrieves the use of a scalar (alpha/beta)
|
||||
def ScalarUse(self, name, flavour):
|
||||
if name in self.scalars:
|
||||
|
@ -254,7 +283,7 @@ class Routine():
|
|||
return [name]
|
||||
return []
|
||||
|
||||
# Retrieves the use of a scalar (alpha/beta)
|
||||
# As above, but for the clBLAS wrapper
|
||||
def ScalarUseWrapper(self, name, flavour):
|
||||
if name in self.scalars:
|
||||
if name == "alpha":
|
||||
|
@ -264,7 +293,7 @@ class Routine():
|
|||
return [name]
|
||||
return []
|
||||
|
||||
# Retrieves the use of a scalar for CBLAS (alpha/beta)
|
||||
# As above, but for the CBLAS wrapper
|
||||
def ScalarUseWrapperC(self, name, flavour):
|
||||
if name in self.scalars:
|
||||
if flavour.IsComplex(name):
|
||||
|
@ -383,6 +412,28 @@ class Routine():
|
|||
|
||||
# ==============================================================================================
|
||||
|
||||
# Retrieves a combination of all the argument names (no types)
|
||||
def Arguments(self):
|
||||
return (self.Options() + self.Sizes() +
|
||||
list(chain(*[self.Buffer(b) for b in self.ScalarBuffersFirst()])) +
|
||||
self.Scalar("alpha") +
|
||||
list(chain(*[self.Buffer(b) for b in self.BuffersFirst()])) +
|
||||
self.Scalar("beta") +
|
||||
list(chain(*[self.Buffer(b) for b in self.BuffersSecond()])) +
|
||||
list(chain(*[self.Buffer(b) for b in self.ScalarBuffersSecond()])) +
|
||||
list(chain(*[self.Scalar(s) for s in self.OtherScalars()])))
|
||||
|
||||
# As above, but with conversions from half to float
|
||||
def ArgumentsHalf(self):
|
||||
return (self.Options() + self.Sizes() +
|
||||
list(chain(*[self.BufferBis(b) for b in self.ScalarBuffersFirst()])) +
|
||||
self.ScalarHalfToFloat("alpha") +
|
||||
list(chain(*[self.BufferBis(b) for b in self.BuffersFirst()])) +
|
||||
self.ScalarHalfToFloat("beta") +
|
||||
list(chain(*[self.BufferBis(b) for b in self.BuffersSecond()])) +
|
||||
list(chain(*[self.BufferBis(b) for b in self.ScalarBuffersSecond()])) +
|
||||
list(chain(*[self.Scalar(s) for s in self.OtherScalars()])))
|
||||
|
||||
# Retrieves a combination of all the argument names, with Claduc casts
|
||||
def ArgumentsCladuc(self, flavour, indent):
|
||||
return (self.Options() + self.Sizes() +
|
||||
|
@ -394,7 +445,7 @@ class Routine():
|
|||
list(chain(*[self.BufferCladuc(b) for b in self.ScalarBuffersSecond()])) +
|
||||
list(chain(*[self.Scalar(s) for s in self.OtherScalars()])))
|
||||
|
||||
# Retrieves a combination of all the argument names, with CLBlast casts
|
||||
# As above, but with CLBlast casts
|
||||
def ArgumentsCast(self, flavour, indent):
|
||||
return (self.OptionsCast(indent) + self.Sizes() +
|
||||
list(chain(*[self.Buffer(b) for b in self.ScalarBuffersFirst()])) +
|
||||
|
|
|
@ -22,6 +22,7 @@ int main(int argc, char *argv[]) {
|
|||
clblast::RunTests<clblast::TestXamax<double>, double, double>(argc, argv, true, "iDAMAX");
|
||||
clblast::RunTests<clblast::TestXamax<float2>, float2, float2>(argc, argv, true, "iCAMAX");
|
||||
clblast::RunTests<clblast::TestXamax<double2>, double2, double2>(argc, argv, true, "iZAMAX");
|
||||
clblast::RunTests<clblast::TestXamax<half>, half, half>(argc, argv, true, "iHAMAX");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ int main(int argc, char *argv[]) {
|
|||
clblast::RunTests<clblast::TestXasum<double>, double, double>(argc, argv, true, "DASUM");
|
||||
clblast::RunTests<clblast::TestXasum<float2>, float2, float2>(argc, argv, true, "ScASUM");
|
||||
clblast::RunTests<clblast::TestXasum<double2>, double2, double2>(argc, argv, true, "DzASUM");
|
||||
clblast::RunTests<clblast::TestXasum<half>, half, half>(argc, argv, true, "HASUM");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ int main(int argc, char *argv[]) {
|
|||
clblast::RunTests<clblast::TestXaxpy<double>, double, double>(argc, argv, true, "DAXPY");
|
||||
clblast::RunTests<clblast::TestXaxpy<float2>, float2, float2>(argc, argv, true, "CAXPY");
|
||||
clblast::RunTests<clblast::TestXaxpy<double2>, double2, double2>(argc, argv, true, "ZAXPY");
|
||||
clblast::RunTests<clblast::TestXaxpy<half>, half, half>(argc, argv, true, "HAXPY");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ int main(int argc, char *argv[]) {
|
|||
clblast::RunTests<clblast::TestXcopy<double>, double, double>(argc, argv, true, "DCOPY");
|
||||
clblast::RunTests<clblast::TestXcopy<float2>, float2, float2>(argc, argv, true, "CCOPY");
|
||||
clblast::RunTests<clblast::TestXcopy<double2>, double2, double2>(argc, argv, true, "ZCOPY");
|
||||
clblast::RunTests<clblast::TestXcopy<half>, half, half>(argc, argv, true, "HCOPY");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ using double2 = clblast::double2;
|
|||
int main(int argc, char *argv[]) {
|
||||
clblast::RunTests<clblast::TestXdot<float>, float, float>(argc, argv, false, "SDOT");
|
||||
clblast::RunTests<clblast::TestXdot<double>, double, double>(argc, argv, true, "DDOT");
|
||||
clblast::RunTests<clblast::TestXdot<half>, half, half>(argc, argv, true, "HDOT");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ int main(int argc, char *argv[]) {
|
|||
clblast::RunTests<clblast::TestXnrm2<double>, double, double>(argc, argv, true, "DNRM2");
|
||||
clblast::RunTests<clblast::TestXnrm2<float2>, float2, float2>(argc, argv, true, "ScNRM2");
|
||||
clblast::RunTests<clblast::TestXnrm2<double2>, double2, double2>(argc, argv, true, "DzNRM2");
|
||||
clblast::RunTests<clblast::TestXnrm2<half>, half, half>(argc, argv, true, "HNRM2");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ int main(int argc, char *argv[]) {
|
|||
clblast::RunTests<clblast::TestXscal<double>, double, double>(argc, argv, true, "DSCAL");
|
||||
clblast::RunTests<clblast::TestXscal<float2>, float2, float2>(argc, argv, true, "CSCAL");
|
||||
clblast::RunTests<clblast::TestXscal<double2>, double2, double2>(argc, argv, true, "ZSCAL");
|
||||
clblast::RunTests<clblast::TestXscal<half>, half, half>(argc, argv, true, "HSCAL");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ int main(int argc, char *argv[]) {
|
|||
clblast::RunTests<clblast::TestXswap<double>, double, double>(argc, argv, true, "DSWAP");
|
||||
clblast::RunTests<clblast::TestXswap<float2>, float2, float2>(argc, argv, true, "CSWAP");
|
||||
clblast::RunTests<clblast::TestXswap<double2>, double2, double2>(argc, argv, true, "ZSWAP");
|
||||
clblast::RunTests<clblast::TestXswap<half>, half, half>(argc, argv, true, "HSWAP");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ int main(int argc, char *argv[]) {
|
|||
clblast::RunTests<clblast::TestXgbmv<double>, double, double>(argc, argv, true, "DGBMV");
|
||||
clblast::RunTests<clblast::TestXgbmv<float2>, float2, float2>(argc, argv, true, "CGBMV");
|
||||
clblast::RunTests<clblast::TestXgbmv<double2>, double2, double2>(argc, argv, true, "ZGBMV");
|
||||
clblast::RunTests<clblast::TestXgbmv<half>, half, half>(argc, argv, true, "HGBMV");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ int main(int argc, char *argv[]) {
|
|||
clblast::RunTests<clblast::TestXgemv<double>, double, double>(argc, argv, true, "DGEMV");
|
||||
clblast::RunTests<clblast::TestXgemv<float2>, float2, float2>(argc, argv, true, "CGEMV");
|
||||
clblast::RunTests<clblast::TestXgemv<double2>, double2, double2>(argc, argv, true, "ZGEMV");
|
||||
clblast::RunTests<clblast::TestXgemv<half>, half, half>(argc, argv, true, "HGEMV");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ using double2 = clblast::double2;
|
|||
int main(int argc, char *argv[]) {
|
||||
clblast::RunTests<clblast::TestXger<float>, float, float>(argc, argv, false, "SGER");
|
||||
clblast::RunTests<clblast::TestXger<double>, double, double>(argc, argv, true, "DGER");
|
||||
clblast::RunTests<clblast::TestXger<half>, half, half>(argc, argv, true, "HGER");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ using double2 = clblast::double2;
|
|||
int main(int argc, char *argv[]) {
|
||||
clblast::RunTests<clblast::TestXsbmv<float>, float, float>(argc, argv, false, "SSBMV");
|
||||
clblast::RunTests<clblast::TestXsbmv<double>, double, double>(argc, argv, true, "DSBMV");
|
||||
clblast::RunTests<clblast::TestXsbmv<half>, half, half>(argc, argv, true, "HSBMV");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ using double2 = clblast::double2;
|
|||
int main(int argc, char *argv[]) {
|
||||
clblast::RunTests<clblast::TestXspmv<float>, float, float>(argc, argv, false, "SSPMV");
|
||||
clblast::RunTests<clblast::TestXspmv<double>, double, double>(argc, argv, true, "DSPMV");
|
||||
clblast::RunTests<clblast::TestXspmv<half>, half, half>(argc, argv, true, "HSPMV");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ using double2 = clblast::double2;
|
|||
int main(int argc, char *argv[]) {
|
||||
clblast::RunTests<clblast::TestXspr<float>, float, float>(argc, argv, false, "SSPR");
|
||||
clblast::RunTests<clblast::TestXspr<double>, double, double>(argc, argv, true, "DSPR");
|
||||
clblast::RunTests<clblast::TestXspr<half>, half, half>(argc, argv, true, "HSPR");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ using double2 = clblast::double2;
|
|||
int main(int argc, char *argv[]) {
|
||||
clblast::RunTests<clblast::TestXspr2<float>, float, float>(argc, argv, false, "SSPR2");
|
||||
clblast::RunTests<clblast::TestXspr2<double>, double, double>(argc, argv, true, "DSPR2");
|
||||
clblast::RunTests<clblast::TestXspr2<half>, half, half>(argc, argv, true, "HSPR2");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ using double2 = clblast::double2;
|
|||
int main(int argc, char *argv[]) {
|
||||
clblast::RunTests<clblast::TestXsymv<float>, float, float>(argc, argv, false, "SSYMV");
|
||||
clblast::RunTests<clblast::TestXsymv<double>, double, double>(argc, argv, true, "DSYMV");
|
||||
clblast::RunTests<clblast::TestXsymv<half>, half, half>(argc, argv, true, "HSYMV");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ using double2 = clblast::double2;
|
|||
int main(int argc, char *argv[]) {
|
||||
clblast::RunTests<clblast::TestXsyr<float>, float, float>(argc, argv, false, "SSYR");
|
||||
clblast::RunTests<clblast::TestXsyr<double>, double, double>(argc, argv, true, "DSYR");
|
||||
clblast::RunTests<clblast::TestXsyr<half>, half, half>(argc, argv, true, "HSYR");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ using double2 = clblast::double2;
|
|||
int main(int argc, char *argv[]) {
|
||||
clblast::RunTests<clblast::TestXsyr2<float>, float, float>(argc, argv, false, "SSYR2");
|
||||
clblast::RunTests<clblast::TestXsyr2<double>, double, double>(argc, argv, true, "DSYR2");
|
||||
clblast::RunTests<clblast::TestXsyr2<half>, half, half>(argc, argv, true, "HSYR2");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ int main(int argc, char *argv[]) {
|
|||
clblast::RunTests<clblast::TestXtbmv<double>, double, double>(argc, argv, true, "DTBMV");
|
||||
clblast::RunTests<clblast::TestXtbmv<float2>, float2, float2>(argc, argv, true, "CTBMV");
|
||||
clblast::RunTests<clblast::TestXtbmv<double2>, double2, double2>(argc, argv, true, "ZTBMV");
|
||||
clblast::RunTests<clblast::TestXtbmv<half>, half, half>(argc, argv, true, "HTBMV");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ int main(int argc, char *argv[]) {
|
|||
clblast::RunTests<clblast::TestXtpmv<double>, double, double>(argc, argv, true, "DTPMV");
|
||||
clblast::RunTests<clblast::TestXtpmv<float2>, float2, float2>(argc, argv, true, "CTPMV");
|
||||
clblast::RunTests<clblast::TestXtpmv<double2>, double2, double2>(argc, argv, true, "ZTPMV");
|
||||
clblast::RunTests<clblast::TestXtpmv<half>, half, half>(argc, argv, true, "HTPMV");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ int main(int argc, char *argv[]) {
|
|||
clblast::RunTests<clblast::TestXtrmv<double>, double, double>(argc, argv, true, "DTRMV");
|
||||
clblast::RunTests<clblast::TestXtrmv<float2>, float2, float2>(argc, argv, true, "CTRMV");
|
||||
clblast::RunTests<clblast::TestXtrmv<double2>, double2, double2>(argc, argv, true, "ZTRMV");
|
||||
clblast::RunTests<clblast::TestXtrmv<half>, half, half>(argc, argv, true, "HTRMV");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ int main(int argc, char *argv[]) {
|
|||
clblast::RunTests<clblast::TestXgemm<double>, double, double>(argc, argv, true, "DGEMM");
|
||||
clblast::RunTests<clblast::TestXgemm<float2>, float2, float2>(argc, argv, true, "CGEMM");
|
||||
clblast::RunTests<clblast::TestXgemm<double2>, double2, double2>(argc, argv, true, "ZGEMM");
|
||||
clblast::RunTests<clblast::TestXgemm<half>, half, half>(argc, argv, true, "HGEMM");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ int main(int argc, char *argv[]) {
|
|||
clblast::RunTests<clblast::TestXsymm<double>, double, double>(argc, argv, true, "DSYMM");
|
||||
clblast::RunTests<clblast::TestXsymm<float2>, float2, float2>(argc, argv, true, "CSYMM");
|
||||
clblast::RunTests<clblast::TestXsymm<double2>, double2, double2>(argc, argv, true, "ZSYMM");
|
||||
clblast::RunTests<clblast::TestXsymm<half>, half, half>(argc, argv, true, "HSYMM");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ int main(int argc, char *argv[]) {
|
|||
clblast::RunTests<clblast::TestXsyr2k<double>, double, double>(argc, argv, true, "DSYR2K");
|
||||
clblast::RunTests<clblast::TestXsyr2k<float2>, float2, float2>(argc, argv, true, "CSYR2K");
|
||||
clblast::RunTests<clblast::TestXsyr2k<double2>, double2, double2>(argc, argv, true, "ZSYR2K");
|
||||
clblast::RunTests<clblast::TestXsyr2k<half>, half, half>(argc, argv, true, "HSYR2K");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ int main(int argc, char *argv[]) {
|
|||
clblast::RunTests<clblast::TestXsyrk<double>, double, double>(argc, argv, true, "DSYRK");
|
||||
clblast::RunTests<clblast::TestXsyrk<float2>, float2, float2>(argc, argv, true, "CSYRK");
|
||||
clblast::RunTests<clblast::TestXsyrk<double2>, double2, double2>(argc, argv, true, "ZSYRK");
|
||||
clblast::RunTests<clblast::TestXsyrk<half>, half, half>(argc, argv, true, "HSYRK");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ int main(int argc, char *argv[]) {
|
|||
clblast::RunTests<clblast::TestXtrmm<double>, double, double>(argc, argv, true, "DTRMM");
|
||||
clblast::RunTests<clblast::TestXtrmm<float2>, float2, float2>(argc, argv, true, "CTRMM");
|
||||
clblast::RunTests<clblast::TestXtrmm<double2>, double2, double2>(argc, argv, true, "ZTRMM");
|
||||
clblast::RunTests<clblast::TestXtrmm<half>, half, half>(argc, argv, true, "HTRMM");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ int main(int argc, char *argv[]) {
|
|||
clblast::RunTests<clblast::TestXtrsm<double>, double, double>(argc, argv, true, "DTRSM");
|
||||
clblast::RunTests<clblast::TestXtrsm<float2>, float2, float2>(argc, argv, true, "CTRSM");
|
||||
clblast::RunTests<clblast::TestXtrsm<double2>, double2, double2>(argc, argv, true, "ZTRSM");
|
||||
clblast::RunTests<clblast::TestXtrsm<half>, half, half>(argc, argv, true, "HTRSM");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ namespace clblast {
|
|||
// =================================================================================================
|
||||
|
||||
// The transpose-options to test with (data-type dependent)
|
||||
template <> const std::vector<Transpose> TestBlas<half,half>::kTransposes = {Transpose::kNo, Transpose::kYes};
|
||||
template <> const std::vector<Transpose> TestBlas<float,float>::kTransposes = {Transpose::kNo, Transpose::kYes};
|
||||
template <> const std::vector<Transpose> TestBlas<double,double>::kTransposes = {Transpose::kNo, Transpose::kYes};
|
||||
template <> const std::vector<Transpose> TestBlas<float2,float2>::kTransposes = {Transpose::kNo, Transpose::kYes, Transpose::kConjugate};
|
||||
|
@ -147,10 +148,8 @@ void TestBlas<T,U>::TestRegular(std::vector<Arguments<U>> &test_vector, const st
|
|||
if (verbose_) {
|
||||
if (get_id2_(args) == 1) { fprintf(stdout, "\n Error at index %zu: ", id1); }
|
||||
else { fprintf(stdout, "\n Error at %zu,%zu: ", id1, id2); }
|
||||
std::cout << result1[index];
|
||||
fprintf(stdout, " (reference) versus ");
|
||||
std::cout << result2[index];
|
||||
fprintf(stdout, " (CLBlast)");
|
||||
fprintf(stdout, " %s (reference) versus ", ToString(result1[index]).c_str());
|
||||
fprintf(stdout, " %s (CLBlast)", ToString(result2[index]).c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -222,6 +221,7 @@ void TestBlas<T,U>::TestInvalid(std::vector<Arguments<U>> &test_vector, const st
|
|||
// =================================================================================================
|
||||
|
||||
// Compiles the templated class
|
||||
template class TestBlas<half, half>;
|
||||
template class TestBlas<float, float>;
|
||||
template class TestBlas<double, double>;
|
||||
template class TestBlas<float2, float2>;
|
||||
|
|
|
@ -351,11 +351,11 @@ bool TestSimilarity(const T val1, const T val2) {
|
|||
}
|
||||
}
|
||||
|
||||
// Compiles the default case for non-complex data-types
|
||||
// Compiles the default case for standard data-types
|
||||
template bool TestSimilarity<float>(const float, const float);
|
||||
template bool TestSimilarity<double>(const double, const double);
|
||||
|
||||
// Specialisations for complex data-types
|
||||
// Specialisations for non-standard data-types
|
||||
template <>
|
||||
bool TestSimilarity(const float2 val1, const float2 val2) {
|
||||
auto real = TestSimilarity(val1.real(), val2.real());
|
||||
|
@ -368,6 +368,10 @@ bool TestSimilarity(const double2 val1, const double2 val2) {
|
|||
auto imag = TestSimilarity(val1.imag(), val2.imag());
|
||||
return (real && imag);
|
||||
}
|
||||
template <>
|
||||
bool TestSimilarity(const half val1, const half val2) {
|
||||
return TestSimilarity(HalfToFloat(val1), HalfToFloat(val2));
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
|
@ -389,10 +393,15 @@ template <> const std::vector<double2> GetExampleScalars(const bool full_test) {
|
|||
if (full_test) { return {{0.0, 0.0}, {1.0, 1.3}, {2.42, 3.14}}; }
|
||||
else { return {{2.42, 3.14}}; }
|
||||
}
|
||||
template <> const std::vector<half> GetExampleScalars(const bool full_test) {
|
||||
if (full_test) { return {FloatToHalf(0.0f), FloatToHalf(1.0f), FloatToHalf(3.14f)}; }
|
||||
else { return {FloatToHalf(3.14f)}; }
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
// Compiles the templated class
|
||||
template class Tester<half, half>;
|
||||
template class Tester<float, float>;
|
||||
template class Tester<double, double>;
|
||||
template class Tester<float2, float2>;
|
||||
|
|
|
@ -31,6 +31,16 @@ CBLAS_UPLO convertToCBLAS(const Triangle v) { return (v == Triangle::kUpper) ? C
|
|||
CBLAS_DIAG convertToCBLAS(const Diagonal v) { return (v == Diagonal::kUnit) ? CblasUnit : CblasNonUnit; }
|
||||
CBLAS_SIDE convertToCBLAS(const Side v) { return (v == Side::kLeft) ? CblasLeft : CblasRight; }
|
||||
|
||||
// Conversions from and to half-precision
|
||||
std::vector<float> HalfToFloatBuffer(const std::vector<half>& source) {
|
||||
auto result = std::vector<float>(source.size());
|
||||
for (auto i = size_t(0); i < source.size(); ++i) { result[i] = HalfToFloat(source[i]); }
|
||||
return result;
|
||||
}
|
||||
void FloatToHalfBuffer(std::vector<half>& result, const std::vector<float>& source) {
|
||||
for (auto i = size_t(0); i < source.size(); ++i) { result[i] = FloatToHalf(source[i]); }
|
||||
}
|
||||
|
||||
// OpenBLAS is not fully Netlib CBLAS compatible
|
||||
#ifdef OPENBLAS_VERSION
|
||||
using return_pointer_float = openblas_complex_float*;
|
||||
|
@ -164,7 +174,13 @@ void cblasXswap(const size_t n,
|
|||
void cblasXswap(const size_t n,
|
||||
std::vector<half>& x_buffer, const size_t x_offset, const size_t x_inc,
|
||||
std::vector<half>& y_buffer, const size_t y_offset, const size_t y_inc) {
|
||||
return;
|
||||
auto x_buffer_bis = HalfToFloatBuffer(x_buffer);
|
||||
auto y_buffer_bis = HalfToFloatBuffer(y_buffer);
|
||||
cblasXswap(n,
|
||||
x_buffer_bis, x_offset, x_inc,
|
||||
y_buffer_bis, y_offset, y_inc);
|
||||
FloatToHalfBuffer(x_buffer, x_buffer_bis);
|
||||
FloatToHalfBuffer(y_buffer, y_buffer_bis);
|
||||
}
|
||||
|
||||
// Forwards the Netlib BLAS calls for SSCAL/DSCAL/CSCAL/ZSCAL
|
||||
|
@ -201,7 +217,11 @@ void cblasXscal(const size_t n,
|
|||
void cblasXscal(const size_t n,
|
||||
const half alpha,
|
||||
std::vector<half>& x_buffer, const size_t x_offset, const size_t x_inc) {
|
||||
return;
|
||||
auto x_buffer_bis = HalfToFloatBuffer(x_buffer);
|
||||
cblasXscal(n,
|
||||
HalfToFloat(alpha),
|
||||
x_buffer_bis, x_offset, x_inc);
|
||||
FloatToHalfBuffer(x_buffer, x_buffer_bis);
|
||||
}
|
||||
|
||||
// Forwards the Netlib BLAS calls for SCOPY/DCOPY/CCOPY/ZCOPY
|
||||
|
@ -236,7 +256,12 @@ void cblasXcopy(const size_t n,
|
|||
void cblasXcopy(const size_t n,
|
||||
const std::vector<half>& x_buffer, const size_t x_offset, const size_t x_inc,
|
||||
std::vector<half>& y_buffer, const size_t y_offset, const size_t y_inc) {
|
||||
return;
|
||||
auto x_buffer_bis = HalfToFloatBuffer(x_buffer);
|
||||
auto y_buffer_bis = HalfToFloatBuffer(y_buffer);
|
||||
cblasXcopy(n,
|
||||
x_buffer_bis, x_offset, x_inc,
|
||||
y_buffer_bis, y_offset, y_inc);
|
||||
FloatToHalfBuffer(y_buffer, y_buffer_bis);
|
||||
}
|
||||
|
||||
// Forwards the Netlib BLAS calls for SAXPY/DAXPY/CAXPY/ZAXPY
|
||||
|
@ -282,7 +307,13 @@ void cblasXaxpy(const size_t n,
|
|||
const half alpha,
|
||||
const std::vector<half>& x_buffer, const size_t x_offset, const size_t x_inc,
|
||||
std::vector<half>& y_buffer, const size_t y_offset, const size_t y_inc) {
|
||||
return;
|
||||
auto x_buffer_bis = HalfToFloatBuffer(x_buffer);
|
||||
auto y_buffer_bis = HalfToFloatBuffer(y_buffer);
|
||||
cblasXaxpy(n,
|
||||
HalfToFloat(alpha),
|
||||
x_buffer_bis, x_offset, x_inc,
|
||||
y_buffer_bis, y_offset, y_inc);
|
||||
FloatToHalfBuffer(y_buffer, y_buffer_bis);
|
||||
}
|
||||
|
||||
// Forwards the Netlib BLAS calls for SDOT/DDOT
|
||||
|
@ -306,7 +337,14 @@ void cblasXdot(const size_t n,
|
|||
std::vector<half>& dot_buffer, const size_t dot_offset,
|
||||
const std::vector<half>& x_buffer, const size_t x_offset, const size_t x_inc,
|
||||
const std::vector<half>& y_buffer, const size_t y_offset, const size_t y_inc) {
|
||||
return;
|
||||
auto x_buffer_bis = HalfToFloatBuffer(x_buffer);
|
||||
auto y_buffer_bis = HalfToFloatBuffer(y_buffer);
|
||||
auto dot_buffer_bis = HalfToFloatBuffer(dot_buffer);
|
||||
cblasXdot(n,
|
||||
dot_buffer_bis, dot_offset,
|
||||
x_buffer_bis, x_offset, x_inc,
|
||||
y_buffer_bis, y_offset, y_inc);
|
||||
FloatToHalfBuffer(dot_buffer, dot_buffer_bis);
|
||||
}
|
||||
|
||||
// Forwards the Netlib BLAS calls for CDOTU/ZDOTU
|
||||
|
@ -377,7 +415,12 @@ void cblasXnrm2(const size_t n,
|
|||
void cblasXnrm2(const size_t n,
|
||||
std::vector<half>& nrm2_buffer, const size_t nrm2_offset,
|
||||
const std::vector<half>& x_buffer, const size_t x_offset, const size_t x_inc) {
|
||||
return;
|
||||
auto x_buffer_bis = HalfToFloatBuffer(x_buffer);
|
||||
auto nrm2_buffer_bis = HalfToFloatBuffer(nrm2_buffer);
|
||||
cblasXnrm2(n,
|
||||
nrm2_buffer_bis, nrm2_offset,
|
||||
x_buffer_bis, x_offset, x_inc);
|
||||
FloatToHalfBuffer(nrm2_buffer, nrm2_buffer_bis);
|
||||
}
|
||||
|
||||
// Forwards the Netlib BLAS calls for SASUM/DASUM/ScASUM/DzASUM
|
||||
|
@ -408,7 +451,12 @@ void cblasXasum(const size_t n,
|
|||
void cblasXasum(const size_t n,
|
||||
std::vector<half>& asum_buffer, const size_t asum_offset,
|
||||
const std::vector<half>& x_buffer, const size_t x_offset, const size_t x_inc) {
|
||||
return;
|
||||
auto x_buffer_bis = HalfToFloatBuffer(x_buffer);
|
||||
auto asum_buffer_bis = HalfToFloatBuffer(asum_buffer);
|
||||
cblasXasum(n,
|
||||
asum_buffer_bis, asum_offset,
|
||||
x_buffer_bis, x_offset, x_inc);
|
||||
FloatToHalfBuffer(asum_buffer, asum_buffer_bis);
|
||||
}
|
||||
|
||||
// Forwards the Netlib BLAS calls for iSAMAX/iDAMAX/iCAMAX/iZAMAX/iHAMAX
|
||||
|
@ -439,7 +487,12 @@ void cblasXamax(const size_t n,
|
|||
void cblasXamax(const size_t n,
|
||||
std::vector<half>& imax_buffer, const size_t imax_offset,
|
||||
const std::vector<half>& x_buffer, const size_t x_offset, const size_t x_inc) {
|
||||
return;
|
||||
auto x_buffer_bis = HalfToFloatBuffer(x_buffer);
|
||||
auto imax_buffer_bis = HalfToFloatBuffer(imax_buffer);
|
||||
cblasXamax(n,
|
||||
imax_buffer_bis, imax_offset,
|
||||
x_buffer_bis, x_offset, x_inc);
|
||||
FloatToHalfBuffer(imax_buffer, imax_buffer_bis);
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
|
@ -518,7 +571,17 @@ void cblasXgemv(const CBLAS_ORDER layout, const CBLAS_TRANSPOSE a_transpose,
|
|||
const std::vector<half>& x_buffer, const size_t x_offset, const size_t x_inc,
|
||||
const half beta,
|
||||
std::vector<half>& y_buffer, const size_t y_offset, const size_t y_inc) {
|
||||
return;
|
||||
auto a_buffer_bis = HalfToFloatBuffer(a_buffer);
|
||||
auto x_buffer_bis = HalfToFloatBuffer(x_buffer);
|
||||
auto y_buffer_bis = HalfToFloatBuffer(y_buffer);
|
||||
cblasXgemv(layout, a_transpose,
|
||||
m, n,
|
||||
HalfToFloat(alpha),
|
||||
a_buffer_bis, a_offset, a_ld,
|
||||
x_buffer_bis, x_offset, x_inc,
|
||||
HalfToFloat(beta),
|
||||
y_buffer_bis, y_offset, y_inc);
|
||||
FloatToHalfBuffer(y_buffer, y_buffer_bis);
|
||||
}
|
||||
|
||||
// Forwards the Netlib BLAS calls for SGBMV/DGBMV/CGBMV/ZGBMV
|
||||
|
@ -593,7 +656,17 @@ void cblasXgbmv(const CBLAS_ORDER layout, const CBLAS_TRANSPOSE a_transpose,
|
|||
const std::vector<half>& x_buffer, const size_t x_offset, const size_t x_inc,
|
||||
const half beta,
|
||||
std::vector<half>& y_buffer, const size_t y_offset, const size_t y_inc) {
|
||||
return;
|
||||
auto a_buffer_bis = HalfToFloatBuffer(a_buffer);
|
||||
auto x_buffer_bis = HalfToFloatBuffer(x_buffer);
|
||||
auto y_buffer_bis = HalfToFloatBuffer(y_buffer);
|
||||
cblasXgbmv(layout, a_transpose,
|
||||
m, n, kl, ku,
|
||||
HalfToFloat(alpha),
|
||||
a_buffer_bis, a_offset, a_ld,
|
||||
x_buffer_bis, x_offset, x_inc,
|
||||
HalfToFloat(beta),
|
||||
y_buffer_bis, y_offset, y_inc);
|
||||
FloatToHalfBuffer(y_buffer, y_buffer_bis);
|
||||
}
|
||||
|
||||
// Forwards the Netlib BLAS calls for CHEMV/ZHEMV
|
||||
|
@ -742,7 +815,17 @@ void cblasXsymv(const CBLAS_ORDER layout, const CBLAS_UPLO triangle,
|
|||
const std::vector<half>& x_buffer, const size_t x_offset, const size_t x_inc,
|
||||
const half beta,
|
||||
std::vector<half>& y_buffer, const size_t y_offset, const size_t y_inc) {
|
||||
return;
|
||||
auto a_buffer_bis = HalfToFloatBuffer(a_buffer);
|
||||
auto x_buffer_bis = HalfToFloatBuffer(x_buffer);
|
||||
auto y_buffer_bis = HalfToFloatBuffer(y_buffer);
|
||||
cblasXsymv(layout, triangle,
|
||||
n,
|
||||
HalfToFloat(alpha),
|
||||
a_buffer_bis, a_offset, a_ld,
|
||||
x_buffer_bis, x_offset, x_inc,
|
||||
HalfToFloat(beta),
|
||||
y_buffer_bis, y_offset, y_inc);
|
||||
FloatToHalfBuffer(y_buffer, y_buffer_bis);
|
||||
}
|
||||
|
||||
// Forwards the Netlib BLAS calls for SSBMV/DSBMV
|
||||
|
@ -783,7 +866,17 @@ void cblasXsbmv(const CBLAS_ORDER layout, const CBLAS_UPLO triangle,
|
|||
const std::vector<half>& x_buffer, const size_t x_offset, const size_t x_inc,
|
||||
const half beta,
|
||||
std::vector<half>& y_buffer, const size_t y_offset, const size_t y_inc) {
|
||||
return;
|
||||
auto a_buffer_bis = HalfToFloatBuffer(a_buffer);
|
||||
auto x_buffer_bis = HalfToFloatBuffer(x_buffer);
|
||||
auto y_buffer_bis = HalfToFloatBuffer(y_buffer);
|
||||
cblasXsbmv(layout, triangle,
|
||||
n, k,
|
||||
HalfToFloat(alpha),
|
||||
a_buffer_bis, a_offset, a_ld,
|
||||
x_buffer_bis, x_offset, x_inc,
|
||||
HalfToFloat(beta),
|
||||
y_buffer_bis, y_offset, y_inc);
|
||||
FloatToHalfBuffer(y_buffer, y_buffer_bis);
|
||||
}
|
||||
|
||||
// Forwards the Netlib BLAS calls for SSPMV/DSPMV
|
||||
|
@ -824,7 +917,17 @@ void cblasXspmv(const CBLAS_ORDER layout, const CBLAS_UPLO triangle,
|
|||
const std::vector<half>& x_buffer, const size_t x_offset, const size_t x_inc,
|
||||
const half beta,
|
||||
std::vector<half>& y_buffer, const size_t y_offset, const size_t y_inc) {
|
||||
return;
|
||||
auto ap_buffer_bis = HalfToFloatBuffer(ap_buffer);
|
||||
auto x_buffer_bis = HalfToFloatBuffer(x_buffer);
|
||||
auto y_buffer_bis = HalfToFloatBuffer(y_buffer);
|
||||
cblasXspmv(layout, triangle,
|
||||
n,
|
||||
HalfToFloat(alpha),
|
||||
ap_buffer_bis, ap_offset,
|
||||
x_buffer_bis, x_offset, x_inc,
|
||||
HalfToFloat(beta),
|
||||
y_buffer_bis, y_offset, y_inc);
|
||||
FloatToHalfBuffer(y_buffer, y_buffer_bis);
|
||||
}
|
||||
|
||||
// Forwards the Netlib BLAS calls for STRMV/DTRMV/CTRMV/ZTRMV
|
||||
|
@ -868,7 +971,13 @@ void cblasXtrmv(const CBLAS_ORDER layout, const CBLAS_UPLO triangle, const CBLAS
|
|||
const size_t n,
|
||||
const std::vector<half>& a_buffer, const size_t a_offset, const size_t a_ld,
|
||||
std::vector<half>& x_buffer, const size_t x_offset, const size_t x_inc) {
|
||||
return;
|
||||
auto a_buffer_bis = HalfToFloatBuffer(a_buffer);
|
||||
auto x_buffer_bis = HalfToFloatBuffer(x_buffer);
|
||||
cblasXtrmv(layout, triangle, a_transpose, diagonal,
|
||||
n,
|
||||
a_buffer_bis, a_offset, a_ld,
|
||||
x_buffer_bis, x_offset, x_inc);
|
||||
FloatToHalfBuffer(x_buffer, x_buffer_bis);
|
||||
}
|
||||
|
||||
// Forwards the Netlib BLAS calls for STBMV/DTBMV/CTBMV/ZTBMV
|
||||
|
@ -912,7 +1021,13 @@ void cblasXtbmv(const CBLAS_ORDER layout, const CBLAS_UPLO triangle, const CBLAS
|
|||
const size_t n, const size_t k,
|
||||
const std::vector<half>& a_buffer, const size_t a_offset, const size_t a_ld,
|
||||
std::vector<half>& x_buffer, const size_t x_offset, const size_t x_inc) {
|
||||
return;
|
||||
auto a_buffer_bis = HalfToFloatBuffer(a_buffer);
|
||||
auto x_buffer_bis = HalfToFloatBuffer(x_buffer);
|
||||
cblasXtbmv(layout, triangle, a_transpose, diagonal,
|
||||
n, k,
|
||||
a_buffer_bis, a_offset, a_ld,
|
||||
x_buffer_bis, x_offset, x_inc);
|
||||
FloatToHalfBuffer(x_buffer, x_buffer_bis);
|
||||
}
|
||||
|
||||
// Forwards the Netlib BLAS calls for STPMV/DTPMV/CTPMV/ZTPMV
|
||||
|
@ -956,7 +1071,13 @@ void cblasXtpmv(const CBLAS_ORDER layout, const CBLAS_UPLO triangle, const CBLAS
|
|||
const size_t n,
|
||||
const std::vector<half>& ap_buffer, const size_t ap_offset,
|
||||
std::vector<half>& x_buffer, const size_t x_offset, const size_t x_inc) {
|
||||
return;
|
||||
auto ap_buffer_bis = HalfToFloatBuffer(ap_buffer);
|
||||
auto x_buffer_bis = HalfToFloatBuffer(x_buffer);
|
||||
cblasXtpmv(layout, triangle, a_transpose, diagonal,
|
||||
n,
|
||||
ap_buffer_bis, ap_offset,
|
||||
x_buffer_bis, x_offset, x_inc);
|
||||
FloatToHalfBuffer(x_buffer, x_buffer_bis);
|
||||
}
|
||||
|
||||
// Forwards the Netlib BLAS calls for STRSV/DTRSV/CTRSV/ZTRSV
|
||||
|
@ -1106,7 +1227,16 @@ void cblasXger(const CBLAS_ORDER layout,
|
|||
const std::vector<half>& x_buffer, const size_t x_offset, const size_t x_inc,
|
||||
const std::vector<half>& y_buffer, const size_t y_offset, const size_t y_inc,
|
||||
std::vector<half>& a_buffer, const size_t a_offset, const size_t a_ld) {
|
||||
return;
|
||||
auto x_buffer_bis = HalfToFloatBuffer(x_buffer);
|
||||
auto y_buffer_bis = HalfToFloatBuffer(y_buffer);
|
||||
auto a_buffer_bis = HalfToFloatBuffer(a_buffer);
|
||||
cblasXger(layout,
|
||||
m, n,
|
||||
HalfToFloat(alpha),
|
||||
x_buffer_bis, x_offset, x_inc,
|
||||
y_buffer_bis, y_offset, y_inc,
|
||||
a_buffer_bis, a_offset, a_ld);
|
||||
FloatToHalfBuffer(a_buffer, a_buffer_bis);
|
||||
}
|
||||
|
||||
// Forwards the Netlib BLAS calls for CGERU/ZGERU
|
||||
|
@ -1305,7 +1435,14 @@ void cblasXsyr(const CBLAS_ORDER layout, const CBLAS_UPLO triangle,
|
|||
const half alpha,
|
||||
const std::vector<half>& x_buffer, const size_t x_offset, const size_t x_inc,
|
||||
std::vector<half>& a_buffer, const size_t a_offset, const size_t a_ld) {
|
||||
return;
|
||||
auto x_buffer_bis = HalfToFloatBuffer(x_buffer);
|
||||
auto a_buffer_bis = HalfToFloatBuffer(a_buffer);
|
||||
cblasXsyr(layout, triangle,
|
||||
n,
|
||||
HalfToFloat(alpha),
|
||||
x_buffer_bis, x_offset, x_inc,
|
||||
a_buffer_bis, a_offset, a_ld);
|
||||
FloatToHalfBuffer(a_buffer, a_buffer_bis);
|
||||
}
|
||||
|
||||
// Forwards the Netlib BLAS calls for SSPR/DSPR
|
||||
|
@ -1336,7 +1473,14 @@ void cblasXspr(const CBLAS_ORDER layout, const CBLAS_UPLO triangle,
|
|||
const half alpha,
|
||||
const std::vector<half>& x_buffer, const size_t x_offset, const size_t x_inc,
|
||||
std::vector<half>& ap_buffer, const size_t ap_offset) {
|
||||
return;
|
||||
auto x_buffer_bis = HalfToFloatBuffer(x_buffer);
|
||||
auto ap_buffer_bis = HalfToFloatBuffer(ap_buffer);
|
||||
cblasXspr(layout, triangle,
|
||||
n,
|
||||
HalfToFloat(alpha),
|
||||
x_buffer_bis, x_offset, x_inc,
|
||||
ap_buffer_bis, ap_offset);
|
||||
FloatToHalfBuffer(ap_buffer, ap_buffer_bis);
|
||||
}
|
||||
|
||||
// Forwards the Netlib BLAS calls for SSYR2/DSYR2
|
||||
|
@ -1372,7 +1516,16 @@ void cblasXsyr2(const CBLAS_ORDER layout, const CBLAS_UPLO triangle,
|
|||
const std::vector<half>& x_buffer, const size_t x_offset, const size_t x_inc,
|
||||
const std::vector<half>& y_buffer, const size_t y_offset, const size_t y_inc,
|
||||
std::vector<half>& a_buffer, const size_t a_offset, const size_t a_ld) {
|
||||
return;
|
||||
auto x_buffer_bis = HalfToFloatBuffer(x_buffer);
|
||||
auto y_buffer_bis = HalfToFloatBuffer(y_buffer);
|
||||
auto a_buffer_bis = HalfToFloatBuffer(a_buffer);
|
||||
cblasXsyr2(layout, triangle,
|
||||
n,
|
||||
HalfToFloat(alpha),
|
||||
x_buffer_bis, x_offset, x_inc,
|
||||
y_buffer_bis, y_offset, y_inc,
|
||||
a_buffer_bis, a_offset, a_ld);
|
||||
FloatToHalfBuffer(a_buffer, a_buffer_bis);
|
||||
}
|
||||
|
||||
// Forwards the Netlib BLAS calls for SSPR2/DSPR2
|
||||
|
@ -1408,7 +1561,16 @@ void cblasXspr2(const CBLAS_ORDER layout, const CBLAS_UPLO triangle,
|
|||
const std::vector<half>& x_buffer, const size_t x_offset, const size_t x_inc,
|
||||
const std::vector<half>& y_buffer, const size_t y_offset, const size_t y_inc,
|
||||
std::vector<half>& ap_buffer, const size_t ap_offset) {
|
||||
return;
|
||||
auto x_buffer_bis = HalfToFloatBuffer(x_buffer);
|
||||
auto y_buffer_bis = HalfToFloatBuffer(y_buffer);
|
||||
auto ap_buffer_bis = HalfToFloatBuffer(ap_buffer);
|
||||
cblasXspr2(layout, triangle,
|
||||
n,
|
||||
HalfToFloat(alpha),
|
||||
x_buffer_bis, x_offset, x_inc,
|
||||
y_buffer_bis, y_offset, y_inc,
|
||||
ap_buffer_bis, ap_offset);
|
||||
FloatToHalfBuffer(ap_buffer, ap_buffer_bis);
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
|
@ -1487,7 +1649,17 @@ void cblasXgemm(const CBLAS_ORDER layout, const CBLAS_TRANSPOSE a_transpose, con
|
|||
const std::vector<half>& b_buffer, const size_t b_offset, const size_t b_ld,
|
||||
const half beta,
|
||||
std::vector<half>& c_buffer, const size_t c_offset, const size_t c_ld) {
|
||||
return;
|
||||
auto a_buffer_bis = HalfToFloatBuffer(a_buffer);
|
||||
auto b_buffer_bis = HalfToFloatBuffer(b_buffer);
|
||||
auto c_buffer_bis = HalfToFloatBuffer(c_buffer);
|
||||
cblasXgemm(layout, a_transpose, b_transpose,
|
||||
m, n, k,
|
||||
HalfToFloat(alpha),
|
||||
a_buffer_bis, a_offset, a_ld,
|
||||
b_buffer_bis, b_offset, b_ld,
|
||||
HalfToFloat(beta),
|
||||
c_buffer_bis, c_offset, c_ld);
|
||||
FloatToHalfBuffer(c_buffer, c_buffer_bis);
|
||||
}
|
||||
|
||||
// Forwards the Netlib BLAS calls for SSYMM/DSYMM/CSYMM/ZSYMM
|
||||
|
@ -1562,7 +1734,17 @@ void cblasXsymm(const CBLAS_ORDER layout, const CBLAS_SIDE side, const CBLAS_UPL
|
|||
const std::vector<half>& b_buffer, const size_t b_offset, const size_t b_ld,
|
||||
const half beta,
|
||||
std::vector<half>& c_buffer, const size_t c_offset, const size_t c_ld) {
|
||||
return;
|
||||
auto a_buffer_bis = HalfToFloatBuffer(a_buffer);
|
||||
auto b_buffer_bis = HalfToFloatBuffer(b_buffer);
|
||||
auto c_buffer_bis = HalfToFloatBuffer(c_buffer);
|
||||
cblasXsymm(layout, side, triangle,
|
||||
m, n,
|
||||
HalfToFloat(alpha),
|
||||
a_buffer_bis, a_offset, a_ld,
|
||||
b_buffer_bis, b_offset, b_ld,
|
||||
HalfToFloat(beta),
|
||||
c_buffer_bis, c_offset, c_ld);
|
||||
FloatToHalfBuffer(c_buffer, c_buffer_bis);
|
||||
}
|
||||
|
||||
// Forwards the Netlib BLAS calls for CHEMM/ZHEMM
|
||||
|
@ -1664,7 +1846,15 @@ void cblasXsyrk(const CBLAS_ORDER layout, const CBLAS_UPLO triangle, const CBLAS
|
|||
const std::vector<half>& a_buffer, const size_t a_offset, const size_t a_ld,
|
||||
const half beta,
|
||||
std::vector<half>& c_buffer, const size_t c_offset, const size_t c_ld) {
|
||||
return;
|
||||
auto a_buffer_bis = HalfToFloatBuffer(a_buffer);
|
||||
auto c_buffer_bis = HalfToFloatBuffer(c_buffer);
|
||||
cblasXsyrk(layout, triangle, a_transpose,
|
||||
n, k,
|
||||
HalfToFloat(alpha),
|
||||
a_buffer_bis, a_offset, a_ld,
|
||||
HalfToFloat(beta),
|
||||
c_buffer_bis, c_offset, c_ld);
|
||||
FloatToHalfBuffer(c_buffer, c_buffer_bis);
|
||||
}
|
||||
|
||||
// Forwards the Netlib BLAS calls for CHERK/ZHERK
|
||||
|
@ -1767,7 +1957,17 @@ void cblasXsyr2k(const CBLAS_ORDER layout, const CBLAS_UPLO triangle, const CBLA
|
|||
const std::vector<half>& b_buffer, const size_t b_offset, const size_t b_ld,
|
||||
const half beta,
|
||||
std::vector<half>& c_buffer, const size_t c_offset, const size_t c_ld) {
|
||||
return;
|
||||
auto a_buffer_bis = HalfToFloatBuffer(a_buffer);
|
||||
auto b_buffer_bis = HalfToFloatBuffer(b_buffer);
|
||||
auto c_buffer_bis = HalfToFloatBuffer(c_buffer);
|
||||
cblasXsyr2k(layout, triangle, ab_transpose,
|
||||
n, k,
|
||||
HalfToFloat(alpha),
|
||||
a_buffer_bis, a_offset, a_ld,
|
||||
b_buffer_bis, b_offset, b_ld,
|
||||
HalfToFloat(beta),
|
||||
c_buffer_bis, c_offset, c_ld);
|
||||
FloatToHalfBuffer(c_buffer, c_buffer_bis);
|
||||
}
|
||||
|
||||
// Forwards the Netlib BLAS calls for CHER2K/ZHER2K
|
||||
|
@ -1856,7 +2056,14 @@ void cblasXtrmm(const CBLAS_ORDER layout, const CBLAS_SIDE side, const CBLAS_UPL
|
|||
const half alpha,
|
||||
const std::vector<half>& a_buffer, const size_t a_offset, const size_t a_ld,
|
||||
std::vector<half>& b_buffer, const size_t b_offset, const size_t b_ld) {
|
||||
return;
|
||||
auto a_buffer_bis = HalfToFloatBuffer(a_buffer);
|
||||
auto b_buffer_bis = HalfToFloatBuffer(b_buffer);
|
||||
cblasXtrmm(layout, side, triangle, a_transpose, diagonal,
|
||||
m, n,
|
||||
HalfToFloat(alpha),
|
||||
a_buffer_bis, a_offset, a_ld,
|
||||
b_buffer_bis, b_offset, b_ld);
|
||||
FloatToHalfBuffer(b_buffer, b_buffer_bis);
|
||||
}
|
||||
|
||||
// Forwards the Netlib BLAS calls for STRSM/DTRSM/CTRSM/ZTRSM
|
||||
|
@ -1911,7 +2118,14 @@ void cblasXtrsm(const CBLAS_ORDER layout, const CBLAS_SIDE side, const CBLAS_UPL
|
|||
const half alpha,
|
||||
const std::vector<half>& a_buffer, const size_t a_offset, const size_t a_ld,
|
||||
std::vector<half>& b_buffer, const size_t b_offset, const size_t b_ld) {
|
||||
return;
|
||||
auto a_buffer_bis = HalfToFloatBuffer(a_buffer);
|
||||
auto b_buffer_bis = HalfToFloatBuffer(b_buffer);
|
||||
cblasXtrsm(layout, side, triangle, a_transpose, diagonal,
|
||||
m, n,
|
||||
HalfToFloat(alpha),
|
||||
a_buffer_bis, a_offset, a_ld,
|
||||
b_buffer_bis, b_offset, b_ld);
|
||||
FloatToHalfBuffer(b_buffer, b_buffer_bis);
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
|
|
Loading…
Reference in a new issue