mirror of
https://github.com/CNugteren/CLBlast.git
synced 2024-07-07 12:23:46 +02:00
Refactored the Python C++ generator script; now confirms to the PEP8 styleguide
This commit is contained in:
parent
b30b26b89e
commit
a2f8350703
|
@ -1,70 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# ==================================================================================================
|
||||
# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
|
||||
# project loosely follows the Google C++ styleguide and uses a max-width of 100 characters per line.
|
||||
#
|
||||
# Author(s):
|
||||
# Cedric Nugteren <www.cedricnugteren.nl>
|
||||
#
|
||||
# This file contains the 'DataType' class, used in the generator script to generate the CLBlast API
|
||||
# interface and implementation.
|
||||
#
|
||||
# ==================================================================================================
|
||||
|
||||
# Short-hands for data-types
|
||||
HLF = "half"
|
||||
FLT = "float"
|
||||
DBL = "double"
|
||||
FLT2 = "float2"
|
||||
DBL2 = "double2"
|
||||
|
||||
HCL = "cl_half"
|
||||
F2CL = "cl_float2"
|
||||
D2CL = "cl_double2"
|
||||
|
||||
# Structure holding data-type and precision information
|
||||
class DataType():
|
||||
def __init__(self, precision_name, name, template, scalars, buffertype):
|
||||
self.precision_name = precision_name
|
||||
self.name = name
|
||||
self.template = template
|
||||
self.alpha_cpp = scalars[0]
|
||||
self.beta_cpp = scalars[1]
|
||||
self.alpha_cl = scalars[2]
|
||||
self.beta_cl = scalars[3]
|
||||
self.buffertype = buffertype
|
||||
|
||||
# Outputs the name of the data-type (alpha/beta), possibly transforming into the right type
|
||||
def UseAlpha(self):
|
||||
if self.alpha_cpp in [FLT2, DBL2]:
|
||||
return self.alpha_cpp+"{alpha.s[0], alpha.s[1]}"
|
||||
return "alpha"
|
||||
def UseBeta(self):
|
||||
if self.beta_cpp in [FLT2, DBL2]:
|
||||
return self.beta_cpp+"{beta.s[0], beta.s[1]}"
|
||||
return "beta"
|
||||
|
||||
# As above, but the transformation is in the opposite direction
|
||||
def UseAlphaCL(self):
|
||||
if self.alpha_cpp in [FLT2, DBL2]:
|
||||
return self.alpha_cl+"{{alpha.real(), alpha.imag()}}"
|
||||
return "alpha"
|
||||
def UseBetaCL(self):
|
||||
if self.beta_cpp in [FLT2, DBL2]:
|
||||
return self.beta_cl+"{{beta.real(), beta.imag()}}"
|
||||
return "beta"
|
||||
|
||||
# Returns the template as used in the correctness/performance tests
|
||||
def TestTemplate(self):
|
||||
if self.buffertype != self.beta_cpp:
|
||||
return "<"+self.buffertype+","+self.beta_cpp+">, "+self.buffertype+", "+self.beta_cpp
|
||||
return "<"+self.buffertype+">, "+self.buffertype+", "+self.beta_cpp
|
||||
|
||||
# Current scalar is complex
|
||||
def IsComplex(self, scalar):
|
||||
return ((scalar == "alpha" and self.alpha_cpp in [FLT2, DBL2]) or
|
||||
(scalar == "beta" and self.beta_cpp in [FLT2, DBL2]))
|
||||
|
||||
|
||||
# ==================================================================================================
|
|
@ -1,14 +1,13 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# ==================================================================================================
|
||||
# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
|
||||
# project loosely follows the Google C++ styleguide and uses a max-width of 100 characters per line.
|
||||
# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
|
||||
# PEP8 Python style guide and uses a max-width of 120 characters per line.
|
||||
#
|
||||
# Author(s):
|
||||
# Cedric Nugteren <www.cedricnugteren.nl>
|
||||
#
|
||||
# This script automatically generates the bodies of the following files, creating the full CLBlast
|
||||
# API interface and implementation (C, C++, and reference BLAS wrappers):
|
||||
# This script automatically generates the bodies of the following files, creating the full CLBlast API interface and
|
||||
# implementation (C, C++, and reference BLAS wrappers):
|
||||
# clblast.h
|
||||
# clblast.cpp
|
||||
# clblast_c.h
|
||||
|
@ -19,45 +18,20 @@
|
|||
# test/correctness/routines/levelX/xYYYY.cpp
|
||||
# test/performance/routines/levelX/xYYYY.cpp
|
||||
# It also produces the API documentation found in doc/clblast.md
|
||||
#
|
||||
# ==================================================================================================
|
||||
|
||||
# System modules
|
||||
|
||||
import sys
|
||||
import os.path
|
||||
import argparse
|
||||
|
||||
# Local files
|
||||
from routine import Routine
|
||||
from datatype import DataType, HLF, FLT, DBL, FLT2, DBL2, HCL, F2CL, D2CL
|
||||
import generator.cpp as cpp
|
||||
import generator.doc as doc
|
||||
from generator.routine import Routine
|
||||
from generator.datatype import H, S, D, C, Z, Sc, Dz, iH, iS, iD, iC, iZ, Css, Zdd, Ccs, Zzd, T, Tc, TU
|
||||
|
||||
# ==================================================================================================
|
||||
|
||||
# Regular data-types
|
||||
H = DataType("H", "H", HLF, [HLF, HLF, HCL, HCL], HLF ) # half (16)
|
||||
S = DataType("S", "S", FLT, [FLT, FLT, FLT, FLT], FLT ) # single (32)
|
||||
D = DataType("D", "D", DBL, [DBL, DBL, DBL, DBL], DBL ) # double (64)
|
||||
C = DataType("C", "C", FLT2, [FLT2, FLT2, F2CL, F2CL], FLT2) # single-complex (3232)
|
||||
Z = DataType("Z", "Z", DBL2, [DBL2, DBL2, D2CL, D2CL], DBL2) # double-complex (6464)
|
||||
|
||||
# Special cases
|
||||
Sc = DataType("C", "Sc", FLT2, [FLT2, FLT2, FLT2, FLT2], FLT2) # As C, but with real output
|
||||
Dz = DataType("Z", "Dz", DBL2, [DBL2, DBL2, DBL2, DBL2], DBL2) # As Z, but with real output
|
||||
iH = DataType("H", "iH", HLF, [HLF, HLF, HLF, HLF], HLF ) # As H, but with integer output
|
||||
iS = DataType("S", "iS", FLT, [FLT, FLT, FLT, FLT], FLT ) # As S, but with integer output
|
||||
iD = DataType("D", "iD", DBL, [DBL, DBL, DBL, DBL], DBL ) # As D, but with integer output
|
||||
iC = DataType("C", "iC", FLT2, [FLT2, FLT2, F2CL, F2CL], FLT2) # As C, but with integer output
|
||||
iZ = DataType("Z", "iZ", DBL2, [DBL2, DBL2, D2CL, D2CL], DBL2) # As Z, but with integer output
|
||||
Css = DataType("C", "C", FLT, [FLT, FLT, FLT, FLT], FLT2) # As C, but with constants from S
|
||||
Zdd = DataType("Z", "Z", DBL, [DBL, DBL, DBL, DBL], DBL2) # As Z, but with constants from D
|
||||
Ccs = DataType("C", "C", FLT2+","+FLT, [FLT2, FLT, F2CL, FLT], FLT2) # As C, but with one constant from S
|
||||
Zzd = DataType("Z", "Z", DBL2+","+DBL, [DBL2, DBL, D2CL, DBL], DBL2) # As Z, but with one constant from D
|
||||
|
||||
# C++ template data-types
|
||||
T = DataType("T", "typename T", "T", ["T", "T", "T", "T"], "T") # regular routine
|
||||
Tc = DataType("Tc", "typename T", "std::complex<T>,T", ["T", "T", "T", "T"], "std::complex<T>") # for herk
|
||||
TU = DataType("TU", "typename T, typename U", "T,U", ["T", "U", "T", "U"], "T") # for her2k
|
||||
|
||||
# ==================================================================================================
|
||||
HEADER_LINES = [96, 73, 97, 22, 29, 41]
|
||||
FOOTER_LINES = [17, 75, 19, 14, 6, 6]
|
||||
|
||||
# Different possibilities for requirements
|
||||
ald_m = "The value of `a_ld` must be at least `m`."
|
||||
|
@ -77,472 +51,162 @@ cld_n = "The value of `c_ld` must be at least `n`."
|
|||
# ==================================================================================================
|
||||
|
||||
# Populates a list of routines
|
||||
routines = [
|
||||
[ # Level 1: vector-vector
|
||||
Routine(False, True, "1", "rotg", T, [S,D], [], [], [], ["sa","sb","sc","ss"], [], "", "Generate givens plane rotation", "", []),
|
||||
Routine(False, True, "1", "rotmg", T, [S,D], [], [], ["sy1"], ["sd1","sd2","sx1","sparam"], [], "", "Generate modified givens plane rotation", "", []),
|
||||
Routine(False, True, "1", "rot", T, [S,D], ["n"], [], [], ["x","y"], ["cos","sin"], "", "Apply givens plane rotation", "", []),
|
||||
Routine(False, True, "1", "rotm", T, [S,D], ["n"], [], [], ["x","y","sparam"], [], "", "Apply modified givens plane rotation", "", []),
|
||||
Routine(True, True, "1", "swap", T, [S,D,C,Z,H], ["n"], [], [], ["x","y"], [], "", "Swap two vectors", "Interchanges _n_ elements of vectors _x_ and _y_.", []),
|
||||
Routine(True, True, "1", "scal", T, [S,D,C,Z,H], ["n"], [], [], ["x"], ["alpha"], "", "Vector scaling", "Multiplies _n_ elements of vector _x_ by a scalar constant _alpha_.", []),
|
||||
Routine(True, True, "1", "copy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [], "", "Vector copy", "Copies the contents of vector _x_ into vector _y_.", []),
|
||||
Routine(True, True, "1", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], ["alpha"], "", "Vector-times-constant plus vector", "Performs the operation _y = alpha * x + y_, in which _x_ and _y_ are vectors and _alpha_ is a scalar constant.", []),
|
||||
Routine(True, True, "1", "dot", T, [S,D,H], ["n"], [], ["x","y"], ["dot"], [], "n", "Dot product of two vectors", "Multiplies _n_ elements of the vectors _x_ and _y_ element-wise and accumulates the results. The sum is stored in the _dot_ buffer.", []),
|
||||
Routine(True, True, "1", "dotu", T, [C,Z], ["n"], [], ["x","y"], ["dot"], [], "n", "Dot product of two complex vectors", "See the regular xDOT routine.", []),
|
||||
Routine(True, True, "1", "dotc", T, [C,Z], ["n"], [], ["x","y"], ["dot"], [], "n", "Dot product of two complex vectors, one conjugated", "See the regular xDOT routine.", []),
|
||||
Routine(True, True, "1", "nrm2", T, [S,D,Sc,Dz,H], ["n"], [], ["x"], ["nrm2"], [], "2*n", "Euclidian norm of a vector", "Accumulates the square of _n_ elements in the _x_ vector and takes the square root. The resulting L2 norm is stored in the _nrm2_ buffer.", []),
|
||||
Routine(True, True, "1", "asum", T, [S,D,Sc,Dz,H], ["n"], [], ["x"], ["asum"], [], "n", "Absolute sum of values in a vector", "Accumulates the absolute value of _n_ elements in the _x_ vector. The results are stored in the _asum_ buffer.", []),
|
||||
Routine(True, False, "1", "sum", T, [S,D,Sc,Dz,H], ["n"], [], ["x"], ["sum"], [], "n", "Sum of values in a vector (non-BLAS function)", "Accumulates the values of _n_ elements in the _x_ vector. The results are stored in the _sum_ buffer. This routine is the non-absolute version of the xASUM BLAS routine.", []),
|
||||
Routine(True, True, "1", "amax", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imax"], [], "2*n", "Index of absolute maximum value in a vector", "Finds the index of the maximum of the absolute values in the _x_ vector. The resulting integer index is stored in the _imax_ buffer.", []),
|
||||
Routine(True, False, "1", "max", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imax"], [], "2*n", "Index of maximum value in a vector (non-BLAS function)", "Finds the index of the maximum of the values in the _x_ vector. The resulting integer index is stored in the _imax_ buffer. This routine is the non-absolute version of the IxAMAX BLAS routine.", []),
|
||||
Routine(True, False, "1", "min", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imin"], [], "2*n", "Index of minimum value in a vector (non-BLAS function)", "Finds the index of the minimum of the values in the _x_ vector. The resulting integer index is stored in the _imin_ buffer. This routine is the non-absolute minimum version of the IxAMAX BLAS routine.", []),
|
||||
ROUTINES = [
|
||||
[ # Level 1: vector-vector
|
||||
Routine(False, True, "1", "rotg", T, [S,D], [], [], [], ["sa","sb","sc","ss"], [], "", "Generate givens plane rotation", "", []),
|
||||
Routine(False, True, "1", "rotmg", T, [S,D], [], [], ["sy1"], ["sd1","sd2","sx1","sparam"], [], "", "Generate modified givens plane rotation", "", []),
|
||||
Routine(False, True, "1", "rot", T, [S,D], ["n"], [], [], ["x","y"], ["cos","sin"], "", "Apply givens plane rotation", "", []),
|
||||
Routine(False, True, "1", "rotm", T, [S,D], ["n"], [], [], ["x","y","sparam"], [], "", "Apply modified givens plane rotation", "", []),
|
||||
Routine(True, True, "1", "swap", T, [S,D,C,Z,H], ["n"], [], [], ["x","y"], [], "", "Swap two vectors", "Interchanges _n_ elements of vectors _x_ and _y_.", []),
|
||||
Routine(True, True, "1", "scal", T, [S,D,C,Z,H], ["n"], [], [], ["x"], ["alpha"], "", "Vector scaling", "Multiplies _n_ elements of vector _x_ by a scalar constant _alpha_.", []),
|
||||
Routine(True, True, "1", "copy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], [], "", "Vector copy", "Copies the contents of vector _x_ into vector _y_.", []),
|
||||
Routine(True, True, "1", "axpy", T, [S,D,C,Z,H], ["n"], [], ["x"], ["y"], ["alpha"], "", "Vector-times-constant plus vector", "Performs the operation _y = alpha * x + y_, in which _x_ and _y_ are vectors and _alpha_ is a scalar constant.", []),
|
||||
Routine(True, True, "1", "dot", T, [S,D,H], ["n"], [], ["x","y"], ["dot"], [], "n", "Dot product of two vectors", "Multiplies _n_ elements of the vectors _x_ and _y_ element-wise and accumulates the results. The sum is stored in the _dot_ buffer.", []),
|
||||
Routine(True, True, "1", "dotu", T, [C,Z], ["n"], [], ["x","y"], ["dot"], [], "n", "Dot product of two complex vectors", "See the regular xDOT routine.", []),
|
||||
Routine(True, True, "1", "dotc", T, [C,Z], ["n"], [], ["x","y"], ["dot"], [], "n", "Dot product of two complex vectors, one conjugated", "See the regular xDOT routine.", []),
|
||||
Routine(True, True, "1", "nrm2", T, [S,D,Sc,Dz,H], ["n"], [], ["x"], ["nrm2"], [], "2*n", "Euclidian norm of a vector", "Accumulates the square of _n_ elements in the _x_ vector and takes the square root. The resulting L2 norm is stored in the _nrm2_ buffer.", []),
|
||||
Routine(True, True, "1", "asum", T, [S,D,Sc,Dz,H], ["n"], [], ["x"], ["asum"], [], "n", "Absolute sum of values in a vector", "Accumulates the absolute value of _n_ elements in the _x_ vector. The results are stored in the _asum_ buffer.", []),
|
||||
Routine(True, False, "1", "sum", T, [S,D,Sc,Dz,H], ["n"], [], ["x"], ["sum"], [], "n", "Sum of values in a vector (non-BLAS function)", "Accumulates the values of _n_ elements in the _x_ vector. The results are stored in the _sum_ buffer. This routine is the non-absolute version of the xASUM BLAS routine.", []),
|
||||
Routine(True, True, "1", "amax", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imax"], [], "2*n", "Index of absolute maximum value in a vector", "Finds the index of the maximum of the absolute values in the _x_ vector. The resulting integer index is stored in the _imax_ buffer.", []),
|
||||
Routine(True, False, "1", "max", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imax"], [], "2*n", "Index of maximum value in a vector (non-BLAS function)", "Finds the index of the maximum of the values in the _x_ vector. The resulting integer index is stored in the _imax_ buffer. This routine is the non-absolute version of the IxAMAX BLAS routine.", []),
|
||||
Routine(True, False, "1", "min", T, [iS,iD,iC,iZ,iH], ["n"], [], ["x"], ["imin"], [], "2*n", "Index of minimum value in a vector (non-BLAS function)", "Finds the index of the minimum of the values in the _x_ vector. The resulting integer index is stored in the _imin_ buffer. This routine is the non-absolute minimum version of the IxAMAX BLAS routine.", []),
|
||||
],
|
||||
[ # Level 2: matrix-vector
|
||||
Routine(True, True, "2a", "gemv", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a","x"], ["y"], ["alpha","beta"], "", "General matrix-vector multiplication", "Performs the operation _y = alpha * A * x + beta * y_, in which _x_ is an input vector, _y_ is an input and output vector, _A_ is an input matrix, and _alpha_ and _beta_ are scalars. The matrix _A_ can optionally be transposed before performing the operation.", [ald_m]),
|
||||
Routine(True, True, "2a", "gbmv", T, [S,D,C,Z,H], ["m","n","kl","ku"], ["layout","a_transpose"], ["a","x"], ["y"], ["alpha","beta"], "", "General banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is banded instead.", [ald_kl_ku_one]),
|
||||
Routine(True, True, "2a", "hemv", T, [C,Z], ["n"], ["layout","triangle"], ["a","x"], ["y"], ["alpha","beta"], "", "Hermitian matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is an Hermitian matrix instead.", [ald_n]),
|
||||
Routine(True, True, "2a", "hbmv", T, [C,Z], ["n","k"], ["layout","triangle"], ["a","x"], ["y"], ["alpha","beta"], "", "Hermitian banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is an Hermitian banded matrix instead.", [ald_k_one]),
|
||||
Routine(True, True, "2a", "hpmv", T, [C,Z], ["n"], ["layout","triangle"], ["ap","x"], ["y"], ["alpha","beta"], "", "Hermitian packed matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is an Hermitian packed matrix instead and represented as _AP_.", []),
|
||||
Routine(True, True, "2a", "symv", T, [S,D,H], ["n"], ["layout","triangle"], ["a","x"], ["y"], ["alpha","beta"], "", "Symmetric matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is symmetric instead.", [ald_n]),
|
||||
Routine(True, True, "2a", "sbmv", T, [S,D,H], ["n","k"], ["layout","triangle"], ["a","x"], ["y"], ["alpha","beta"], "", "Symmetric banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is symmetric and banded instead.", [ald_k_one]),
|
||||
Routine(True, True, "2a", "spmv", T, [S,D,H], ["n"], ["layout","triangle"], ["ap","x"], ["y"], ["alpha","beta"], "", "Symmetric packed matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is a symmetric packed matrix instead and represented as _AP_.", []),
|
||||
Routine(True, True, "2a", "trmv", T, [S,D,C,Z,H], ["n"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [], "n", "Triangular matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is triangular instead.", [ald_n]),
|
||||
Routine(True, True, "2a", "tbmv", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [], "n", "Triangular banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is triangular and banded instead.", [ald_k_one]),
|
||||
Routine(True, True, "2a", "tpmv", T, [S,D,C,Z,H], ["n"], ["layout","triangle","a_transpose","diagonal"], ["ap"], ["x"], [], "n", "Triangular packed matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is a triangular packed matrix instead and repreented as _AP_.", []),
|
||||
Routine(False, True, "2a", "trsv", T, [S,D,C,Z], ["n"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [], "", "Solves a triangular system of equations", "", []),
|
||||
Routine(False, True, "2a", "tbsv", T, [S,D,C,Z], ["n","k"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [], "", "Solves a banded triangular system of equations", "", [ald_k_one]),
|
||||
Routine(False, True, "2a", "tpsv", T, [S,D,C,Z], ["n"], ["layout","triangle","a_transpose","diagonal"], ["ap"], ["x"], [], "", "Solves a packed triangular system of equations", "", []),
|
||||
[ # Level 2: matrix-vector
|
||||
Routine(True, True, "2a", "gemv", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a","x"], ["y"], ["alpha","beta"], "", "General matrix-vector multiplication", "Performs the operation _y = alpha * A * x + beta * y_, in which _x_ is an input vector, _y_ is an input and output vector, _A_ is an input matrix, and _alpha_ and _beta_ are scalars. The matrix _A_ can optionally be transposed before performing the operation.", [ald_m]),
|
||||
Routine(True, True, "2a", "gbmv", T, [S,D,C,Z,H], ["m","n","kl","ku"], ["layout","a_transpose"], ["a","x"], ["y"], ["alpha","beta"], "", "General banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is banded instead.", [ald_kl_ku_one]),
|
||||
Routine(True, True, "2a", "hemv", T, [C,Z], ["n"], ["layout","triangle"], ["a","x"], ["y"], ["alpha","beta"], "", "Hermitian matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is an Hermitian matrix instead.", [ald_n]),
|
||||
Routine(True, True, "2a", "hbmv", T, [C,Z], ["n","k"], ["layout","triangle"], ["a","x"], ["y"], ["alpha","beta"], "", "Hermitian banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is an Hermitian banded matrix instead.", [ald_k_one]),
|
||||
Routine(True, True, "2a", "hpmv", T, [C,Z], ["n"], ["layout","triangle"], ["ap","x"], ["y"], ["alpha","beta"], "", "Hermitian packed matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is an Hermitian packed matrix instead and represented as _AP_.", []),
|
||||
Routine(True, True, "2a", "symv", T, [S,D,H], ["n"], ["layout","triangle"], ["a","x"], ["y"], ["alpha","beta"], "", "Symmetric matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is symmetric instead.", [ald_n]),
|
||||
Routine(True, True, "2a", "sbmv", T, [S,D,H], ["n","k"], ["layout","triangle"], ["a","x"], ["y"], ["alpha","beta"], "", "Symmetric banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is symmetric and banded instead.", [ald_k_one]),
|
||||
Routine(True, True, "2a", "spmv", T, [S,D,H], ["n"], ["layout","triangle"], ["ap","x"], ["y"], ["alpha","beta"], "", "Symmetric packed matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is a symmetric packed matrix instead and represented as _AP_.", []),
|
||||
Routine(True, True, "2a", "trmv", T, [S,D,C,Z,H], ["n"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [], "n", "Triangular matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is triangular instead.", [ald_n]),
|
||||
Routine(True, True, "2a", "tbmv", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [], "n", "Triangular banded matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is triangular and banded instead.", [ald_k_one]),
|
||||
Routine(True, True, "2a", "tpmv", T, [S,D,C,Z,H], ["n"], ["layout","triangle","a_transpose","diagonal"], ["ap"], ["x"], [], "n", "Triangular packed matrix-vector multiplication", "Same operation as xGEMV, but matrix _A_ is a triangular packed matrix instead and repreented as _AP_.", []),
|
||||
Routine(False, True, "2a", "trsv", T, [S,D,C,Z], ["n"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [], "", "Solves a triangular system of equations", "", []),
|
||||
Routine(False, True, "2a", "tbsv", T, [S,D,C,Z], ["n","k"], ["layout","triangle","a_transpose","diagonal"], ["a"], ["x"], [], "", "Solves a banded triangular system of equations", "", [ald_k_one]),
|
||||
Routine(False, True, "2a", "tpsv", T, [S,D,C,Z], ["n"], ["layout","triangle","a_transpose","diagonal"], ["ap"], ["x"], [], "", "Solves a packed triangular system of equations", "", []),
|
||||
# Level 2: matrix update
|
||||
Routine(True, True, "2b", "ger", T, [S,D,H], ["m","n"], ["layout"], ["x","y"], ["a"], ["alpha"], "", "General rank-1 matrix update", "Performs the operation _A = alpha * x * y^T + A_, in which _x_ is an input vector, _y^T_ is the transpose of the input vector _y_, _A_ is the matrix to be updated, and _alpha_ is a scalar value.", [ald_m]),
|
||||
Routine(True, True, "2b", "geru", T, [C,Z], ["m","n"], ["layout"], ["x","y"], ["a"], ["alpha"], "", "General rank-1 complex matrix update", "Same operation as xGER, but with complex data-types.", [ald_m]),
|
||||
Routine(True, True, "2b", "gerc", T, [C,Z], ["m","n"], ["layout"], ["x","y"], ["a"], ["alpha"], "", "General rank-1 complex conjugated matrix update", "Same operation as xGERU, but the update is done based on the complex conjugate of the input vectors.", [ald_m]),
|
||||
Routine(True, True, "2b", "her", Tc, [Css,Zdd], ["n"], ["layout","triangle"], ["x"], ["a"], ["alpha"], "", "Hermitian rank-1 matrix update", "Performs the operation _A = alpha * x * x^T + A_, in which x is an input vector, x^T is the transpose of this vector, _A_ is the triangular Hermetian matrix to be updated, and alpha is a scalar value.", [ald_n]),
|
||||
Routine(True, True, "2b", "hpr", Tc, [Css,Zdd], ["n"], ["layout","triangle"], ["x"], ["ap"], ["alpha"], "", "Hermitian packed rank-1 matrix update", "Same operation as xHER, but matrix _A_ is an Hermitian packed matrix instead and represented as _AP_.", []),
|
||||
Routine(True, True, "2b", "her2", T, [C,Z], ["n"], ["layout","triangle"], ["x","y"], ["a"], ["alpha"], "", "Hermitian rank-2 matrix update", "Performs the operation _A = alpha * x * y^T + conj(alpha) * y * x^T + A_, in which _x_ is an input vector and _x^T_ its transpose, _y_ is an input vector and _y^T_ its transpose, _A_ is the triangular Hermetian matrix to be updated, _alpha_ is a scalar value and _conj(alpha)_ its complex conjugate.", [ald_n]),
|
||||
Routine(True, True, "2b", "hpr2", T, [C,Z], ["n"], ["layout","triangle"], ["x","y"], ["ap"], ["alpha"], "", "Hermitian packed rank-2 matrix update", "Same operation as xHER2, but matrix _A_ is an Hermitian packed matrix instead and represented as _AP_.", []),
|
||||
Routine(True, True, "2b", "syr", T, [S,D,H], ["n"], ["layout","triangle"], ["x"], ["a"], ["alpha"], "", "Symmetric rank-1 matrix update", "Same operation as xHER, but matrix A is a symmetric matrix instead.", [ald_n]),
|
||||
Routine(True, True, "2b", "spr", T, [S,D,H], ["n"], ["layout","triangle"], ["x"], ["ap"], ["alpha"], "", "Symmetric packed rank-1 matrix update", "Same operation as xSPR, but matrix _A_ is a symmetric packed matrix instead and represented as _AP_.", []),
|
||||
Routine(True, True, "2b", "syr2", T, [S,D,H], ["n"], ["layout","triangle"], ["x","y"], ["a"], ["alpha"], "", "Symmetric rank-2 matrix update", "Same operation as xHER2, but matrix _A_ is a symmetric matrix instead.", [ald_n]),
|
||||
Routine(True, True, "2b", "spr2", T, [S,D,H], ["n"], ["layout","triangle"], ["x","y"], ["ap"], ["alpha"], "", "Symmetric packed rank-2 matrix update", "Same operation as xSPR2, but matrix _A_ is a symmetric packed matrix instead and represented as _AP_.", []),
|
||||
Routine(True, True, "2b", "ger", T, [S,D,H], ["m","n"], ["layout"], ["x","y"], ["a"], ["alpha"], "", "General rank-1 matrix update", "Performs the operation _A = alpha * x * y^T + A_, in which _x_ is an input vector, _y^T_ is the transpose of the input vector _y_, _A_ is the matrix to be updated, and _alpha_ is a scalar value.", [ald_m]),
|
||||
Routine(True, True, "2b", "geru", T, [C,Z], ["m","n"], ["layout"], ["x","y"], ["a"], ["alpha"], "", "General rank-1 complex matrix update", "Same operation as xGER, but with complex data-types.", [ald_m]),
|
||||
Routine(True, True, "2b", "gerc", T, [C,Z], ["m","n"], ["layout"], ["x","y"], ["a"], ["alpha"], "", "General rank-1 complex conjugated matrix update", "Same operation as xGERU, but the update is done based on the complex conjugate of the input vectors.", [ald_m]),
|
||||
Routine(True, True, "2b", "her", Tc, [Css,Zdd], ["n"], ["layout","triangle"], ["x"], ["a"], ["alpha"], "", "Hermitian rank-1 matrix update", "Performs the operation _A = alpha * x * x^T + A_, in which x is an input vector, x^T is the transpose of this vector, _A_ is the triangular Hermetian matrix to be updated, and alpha is a scalar value.", [ald_n]),
|
||||
Routine(True, True, "2b", "hpr", Tc, [Css,Zdd], ["n"], ["layout","triangle"], ["x"], ["ap"], ["alpha"], "", "Hermitian packed rank-1 matrix update", "Same operation as xHER, but matrix _A_ is an Hermitian packed matrix instead and represented as _AP_.", []),
|
||||
Routine(True, True, "2b", "her2", T, [C,Z], ["n"], ["layout","triangle"], ["x","y"], ["a"], ["alpha"], "", "Hermitian rank-2 matrix update", "Performs the operation _A = alpha * x * y^T + conj(alpha) * y * x^T + A_, in which _x_ is an input vector and _x^T_ its transpose, _y_ is an input vector and _y^T_ its transpose, _A_ is the triangular Hermetian matrix to be updated, _alpha_ is a scalar value and _conj(alpha)_ its complex conjugate.", [ald_n]),
|
||||
Routine(True, True, "2b", "hpr2", T, [C,Z], ["n"], ["layout","triangle"], ["x","y"], ["ap"], ["alpha"], "", "Hermitian packed rank-2 matrix update", "Same operation as xHER2, but matrix _A_ is an Hermitian packed matrix instead and represented as _AP_.", []),
|
||||
Routine(True, True, "2b", "syr", T, [S,D,H], ["n"], ["layout","triangle"], ["x"], ["a"], ["alpha"], "", "Symmetric rank-1 matrix update", "Same operation as xHER, but matrix A is a symmetric matrix instead.", [ald_n]),
|
||||
Routine(True, True, "2b", "spr", T, [S,D,H], ["n"], ["layout","triangle"], ["x"], ["ap"], ["alpha"], "", "Symmetric packed rank-1 matrix update", "Same operation as xSPR, but matrix _A_ is a symmetric packed matrix instead and represented as _AP_.", []),
|
||||
Routine(True, True, "2b", "syr2", T, [S,D,H], ["n"], ["layout","triangle"], ["x","y"], ["a"], ["alpha"], "", "Symmetric rank-2 matrix update", "Same operation as xHER2, but matrix _A_ is a symmetric matrix instead.", [ald_n]),
|
||||
Routine(True, True, "2b", "spr2", T, [S,D,H], ["n"], ["layout","triangle"], ["x","y"], ["ap"], ["alpha"], "", "Symmetric packed rank-2 matrix update", "Same operation as xSPR2, but matrix _A_ is a symmetric packed matrix instead and represented as _AP_.", []),
|
||||
],
|
||||
[ # Level 3: matrix-matrix
|
||||
Routine(True, True, "3", "gemm", T, [S,D,C,Z,H], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], ["alpha","beta"], "", "General matrix-matrix multiplication", "Performs the matrix product _C = alpha * A * B + beta * C_, in which _A_ (_m_ by _k_) and _B_ (_k_ by _n_) are two general rectangular input matrices, _C_ (_m_ by _n_) is the matrix to be updated, and _alpha_ and _beta_ are scalar values. The matrices _A_ and/or _B_ can optionally be transposed before performing the operation.", [ald_transa_m_k, bld_transb_k_n, cld_m]),
|
||||
Routine(True, True, "3", "symm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle"], ["a","b"], ["c"], ["alpha","beta"], "", "Symmetric matrix-matrix multiplication", "Same operation as xGEMM, but _A_ is symmetric instead. In case of `side == kLeft`, _A_ is a symmetric _m_ by _m_ matrix and _C = alpha * A * B + beta * C_ is performed. Otherwise, in case of `side == kRight`, _A_ is a symmtric _n_ by _n_ matrix and _C = alpha * B * A + beta * C_ is performed.", [ald_side_m_n, bld_m, cld_m]),
|
||||
Routine(True, True, "3", "hemm", T, [C,Z], ["m","n"], ["layout","side","triangle"], ["a","b"], ["c"], ["alpha","beta"], "", "Hermitian matrix-matrix multiplication", "Same operation as xSYMM, but _A_ is an Hermitian matrix instead.", [ald_side_m_n, bld_m, cld_m]),
|
||||
Routine(True, True, "3", "syrk", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","a_transpose"], ["a"], ["c"], ["alpha","beta"], "", "Rank-K update of a symmetric matrix", "Performs the matrix product _C = alpha * A * A^T + beta * C_ or _C = alpha * A^T * A + beta * C_, in which _A_ is a general matrix and _A^T_ is its transpose, _C_ (_n_ by _n_) is the symmetric matrix to be updated, and _alpha_ and _beta_ are scalar values.", [ald_trans_n_k, cld_m]),
|
||||
Routine(True, True, "3", "herk", Tc, [Css,Zdd], ["n","k"], ["layout","triangle","a_transpose"], ["a"], ["c"], ["alpha","beta"], "", "Rank-K update of a hermitian matrix", "Same operation as xSYRK, but _C_ is an Hermitian matrix instead.", [ald_trans_n_k, cld_m]),
|
||||
Routine(True, True, "3", "syr2k", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], ["alpha","beta"], "", "Rank-2K update of a symmetric matrix", "Performs the matrix product _C = alpha * A * B^T + alpha * B * A^T + beta * C_ or _C = alpha * A^T * B + alpha * B^T * A + beta * C_, in which _A_ and _B_ are general matrices and _A^T_ and _B^T_ are their transposed versions, _C_ (_n_ by _n_) is the symmetric matrix to be updated, and _alpha_ and _beta_ are scalar values.", [ald_trans_n_k, bld_trans_n_k, cld_n]),
|
||||
Routine(True, True, "3", "her2k", TU, [Ccs,Zzd], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], ["alpha","beta"], "", "Rank-2K update of a hermitian matrix", "Same operation as xSYR2K, but _C_ is an Hermitian matrix instead.", [ald_trans_n_k, bld_trans_n_k, cld_n]),
|
||||
Routine(True, True, "3", "trmm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], ["alpha"], "", "Triangular matrix-matrix multiplication", "Performs the matrix product _B = alpha * A * B_ or _B = alpha * B * A_, in which _A_ is a unit or non-unit triangular matrix, _B_ (_m_ by _n_) is the general matrix to be updated, and _alpha_ is a scalar value.", [ald_side_m_n, bld_m]),
|
||||
Routine(False, True, "3", "trsm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], ["alpha"], "", "Solves a triangular system of equations", "", []),
|
||||
[ # Level 3: matrix-matrix
|
||||
Routine(True, True, "3", "gemm", T, [S,D,C,Z,H], ["m","n","k"], ["layout","a_transpose","b_transpose"], ["a","b"], ["c"], ["alpha","beta"], "", "General matrix-matrix multiplication", "Performs the matrix product _C = alpha * A * B + beta * C_, in which _A_ (_m_ by _k_) and _B_ (_k_ by _n_) are two general rectangular input matrices, _C_ (_m_ by _n_) is the matrix to be updated, and _alpha_ and _beta_ are scalar values. The matrices _A_ and/or _B_ can optionally be transposed before performing the operation.", [ald_transa_m_k, bld_transb_k_n, cld_m]),
|
||||
Routine(True, True, "3", "symm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle"], ["a","b"], ["c"], ["alpha","beta"], "", "Symmetric matrix-matrix multiplication", "Same operation as xGEMM, but _A_ is symmetric instead. In case of `side == kLeft`, _A_ is a symmetric _m_ by _m_ matrix and _C = alpha * A * B + beta * C_ is performed. Otherwise, in case of `side == kRight`, _A_ is a symmtric _n_ by _n_ matrix and _C = alpha * B * A + beta * C_ is performed.", [ald_side_m_n, bld_m, cld_m]),
|
||||
Routine(True, True, "3", "hemm", T, [C,Z], ["m","n"], ["layout","side","triangle"], ["a","b"], ["c"], ["alpha","beta"], "", "Hermitian matrix-matrix multiplication", "Same operation as xSYMM, but _A_ is an Hermitian matrix instead.", [ald_side_m_n, bld_m, cld_m]),
|
||||
Routine(True, True, "3", "syrk", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","a_transpose"], ["a"], ["c"], ["alpha","beta"], "", "Rank-K update of a symmetric matrix", "Performs the matrix product _C = alpha * A * A^T + beta * C_ or _C = alpha * A^T * A + beta * C_, in which _A_ is a general matrix and _A^T_ is its transpose, _C_ (_n_ by _n_) is the symmetric matrix to be updated, and _alpha_ and _beta_ are scalar values.", [ald_trans_n_k, cld_m]),
|
||||
Routine(True, True, "3", "herk", Tc, [Css,Zdd], ["n","k"], ["layout","triangle","a_transpose"], ["a"], ["c"], ["alpha","beta"], "", "Rank-K update of a hermitian matrix", "Same operation as xSYRK, but _C_ is an Hermitian matrix instead.", [ald_trans_n_k, cld_m]),
|
||||
Routine(True, True, "3", "syr2k", T, [S,D,C,Z,H], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], ["alpha","beta"], "", "Rank-2K update of a symmetric matrix", "Performs the matrix product _C = alpha * A * B^T + alpha * B * A^T + beta * C_ or _C = alpha * A^T * B + alpha * B^T * A + beta * C_, in which _A_ and _B_ are general matrices and _A^T_ and _B^T_ are their transposed versions, _C_ (_n_ by _n_) is the symmetric matrix to be updated, and _alpha_ and _beta_ are scalar values.", [ald_trans_n_k, bld_trans_n_k, cld_n]),
|
||||
Routine(True, True, "3", "her2k", TU, [Ccs,Zzd], ["n","k"], ["layout","triangle","ab_transpose"], ["a","b"], ["c"], ["alpha","beta"], "", "Rank-2K update of a hermitian matrix", "Same operation as xSYR2K, but _C_ is an Hermitian matrix instead.", [ald_trans_n_k, bld_trans_n_k, cld_n]),
|
||||
Routine(True, True, "3", "trmm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], ["alpha"], "", "Triangular matrix-matrix multiplication", "Performs the matrix product _B = alpha * A * B_ or _B = alpha * B * A_, in which _A_ is a unit or non-unit triangular matrix, _B_ (_m_ by _n_) is the general matrix to be updated, and _alpha_ is a scalar value.", [ald_side_m_n, bld_m]),
|
||||
Routine(False, True, "3", "trsm", T, [S,D,C,Z,H], ["m","n"], ["layout","side","triangle","a_transpose","diagonal"], ["a"], ["b"], ["alpha"], "", "Solves a triangular system of equations", "", []),
|
||||
],
|
||||
[ # Level X: extra routines (not part of BLAS)
|
||||
Routine(True, True, "x", "omatcopy", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a"], ["b"], ["alpha"], "", "Scaling and out-place transpose/copy (non-BLAS function)", "Performs scaling and out-of-place transposition/copying of matrices according to _B = alpha*op(A)_, in which _A_ is an input matrix (_m_ rows by _n_ columns), _B_ an output matrix, and _alpha_ a scalar value. The operation _op_ can be a normal matrix copy, a transposition or a conjugate transposition.", [ald_m, bld_n]),
|
||||
[ # Level X: extra routines (not part of BLAS)
|
||||
Routine(True, True, "x", "omatcopy", T, [S,D,C,Z,H], ["m","n"], ["layout","a_transpose"], ["a"], ["b"], ["alpha"], "", "Scaling and out-place transpose/copy (non-BLAS function)", "Performs scaling and out-of-place transposition/copying of matrices according to _B = alpha*op(A)_, in which _A_ is an input matrix (_m_ rows by _n_ columns), _B_ an output matrix, and _alpha_ a scalar value. The operation _op_ can be a normal matrix copy, a transposition or a conjugate transposition.", [ald_m, bld_n]),
|
||||
]]
|
||||
|
||||
# ==================================================================================================
|
||||
# Translates an option name to a CLBlast data-type
|
||||
def PrecisionToFullName(x):
|
||||
return {
|
||||
'H': "Half",
|
||||
'S': "Single",
|
||||
'D': "Double",
|
||||
'C': "ComplexSingle",
|
||||
'Z': "ComplexDouble",
|
||||
}[x]
|
||||
|
||||
# ==================================================================================================
|
||||
def main(argv):
|
||||
|
||||
# Separators for the BLAS levels
|
||||
separators = ["""
|
||||
// =================================================================================================
|
||||
// BLAS level-1 (vector-vector) routines
|
||||
// =================================================================================================""",
|
||||
"""
|
||||
// =================================================================================================
|
||||
// BLAS level-2 (matrix-vector) routines
|
||||
// =================================================================================================""",
|
||||
"""
|
||||
// =================================================================================================
|
||||
// BLAS level-3 (matrix-matrix) routines
|
||||
// =================================================================================================""",
|
||||
"""
|
||||
// =================================================================================================
|
||||
// Extra non-BLAS routines (level-X)
|
||||
// ================================================================================================="""]
|
||||
# Parses the command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("clblast_root", help="Root of the CLBlast sources")
|
||||
parser.add_argument("-v", "--verbose", action="store_true", help="Increase verbosity of the script")
|
||||
cl_args = parser.parse_args(argv)
|
||||
library_root = cl_args.clblast_root
|
||||
|
||||
# Names of the level sub-folders
|
||||
levelnames = ["1", "2", "3", "x"]
|
||||
# Sets all the files the output
|
||||
files = [
|
||||
library_root + "/include/clblast.h",
|
||||
library_root + "/src/clblast.cpp",
|
||||
library_root + "/include/clblast_c.h",
|
||||
library_root + "/src/clblast_c.cpp",
|
||||
library_root + "/test/wrapper_clblas.hpp",
|
||||
library_root + "/test/wrapper_cblas.hpp",
|
||||
]
|
||||
|
||||
# Main header/footer for source files
|
||||
header = """
|
||||
// =================================================================================================
|
||||
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
|
||||
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
|
||||
// width of 100 characters per line.
|
||||
//
|
||||
// Author(s):
|
||||
// Cedric Nugteren <www.cedricnugteren.nl>
|
||||
//
|
||||
// =================================================================================================
|
||||
"""
|
||||
footer = """
|
||||
// =================================================================================================
|
||||
"""
|
||||
# Checks whether the command-line arguments are valid; exists otherwise
|
||||
for f in files:
|
||||
if not os.path.isfile(f):
|
||||
print("[ERROR] The path '" + library_root + "' does not point to the root of the CLBlast library")
|
||||
sys.exit()
|
||||
|
||||
# ==================================================================================================
|
||||
# Iterates over all regular files to output
|
||||
for i in range(0, len(files)):
|
||||
|
||||
# The C++ API header (.h)
|
||||
def clblast_h(routines):
|
||||
result = ""
|
||||
for routine in routines:
|
||||
result += "\n// "+routine.description+": "+routine.ShortNames()+"\n"
|
||||
result += routine.RoutineHeaderCPP(12, " = nullptr")+";\n"
|
||||
return result
|
||||
# Stores the header and the footer of the original file
|
||||
with open(files[i]) as f:
|
||||
original = f.readlines()
|
||||
file_header = original[:HEADER_LINES[i]]
|
||||
file_footer = original[-FOOTER_LINES[i]:]
|
||||
|
||||
# The C++ API implementation (.cpp)
|
||||
def clblast_cc(routines):
|
||||
result = ""
|
||||
for routine in routines:
|
||||
indent1 = " "*(20 + routine.Length())
|
||||
result += "\n// "+routine.description+": "+routine.ShortNames()+"\n"
|
||||
if routine.implemented:
|
||||
result += routine.RoutineHeaderCPP(12, "")+" {\n"
|
||||
result += " auto queue_cpp = Queue(*queue);\n"
|
||||
result += " auto routine = X"+routine.name+"<"+routine.template.template+">(queue_cpp, event);\n"
|
||||
result += " auto status = routine.SetUp();\n"
|
||||
result += " if (status != StatusCode::kSuccess) { return status; }\n"
|
||||
result += " return routine.Do"+routine.name.capitalize()+"("
|
||||
result += (",\n"+indent1).join([a for a in routine.ArgumentsCladuc(routine.template, indent1)])
|
||||
result += ");\n"
|
||||
else:
|
||||
result += routine.RoutineHeaderTypeCPP(12)+" {\n"
|
||||
result += " return StatusCode::kNotImplemented;\n"
|
||||
result += "}\n"
|
||||
for flavour in routine.flavours:
|
||||
indent2 = " "*(34 + routine.Length() + len(flavour.template))
|
||||
result += "template StatusCode PUBLIC_API "+routine.name.capitalize()+"<"+flavour.template+">("
|
||||
result += (",\n"+indent2).join([a for a in routine.ArgumentsType(flavour)])
|
||||
result += ",\n"+indent2+"cl_command_queue*, cl_event*);\n"
|
||||
return result
|
||||
# Re-writes the body of the file
|
||||
with open(files[i], "w") as f:
|
||||
body = ""
|
||||
levels = [1, 2, 3] if (i == 4 or i == 5) else [1, 2, 3, 4]
|
||||
for level in levels:
|
||||
body += cpp.LEVEL_SEPARATORS[level - 1] + "\n"
|
||||
for routine in ROUTINES[level - 1]:
|
||||
if i == 0:
|
||||
body += cpp.clblast_h(routine)
|
||||
if i == 1:
|
||||
body += cpp.clblast_cc(routine)
|
||||
if i == 2:
|
||||
body += cpp.clblast_c_h(routine)
|
||||
if i == 3:
|
||||
body += cpp.clblast_c_cc(routine)
|
||||
if i == 4:
|
||||
body += cpp.wrapper_clblas(routine)
|
||||
if i == 5:
|
||||
body += cpp.wrapper_cblas(routine)
|
||||
f.write("".join(file_header))
|
||||
f.write(body)
|
||||
f.write("".join(file_footer))
|
||||
|
||||
# ==================================================================================================
|
||||
# Outputs all the test implementations
|
||||
for level in [1, 2, 3, 4]:
|
||||
for routine in ROUTINES[level - 1]:
|
||||
if routine.has_tests:
|
||||
level_string = cpp.LEVEL_NAMES[level - 1]
|
||||
routine_suffix = "level" + level_string + "/x" + routine.name + ".cpp"
|
||||
|
||||
# The C API header (.h)
|
||||
def clblast_c_h(routines):
|
||||
result = ""
|
||||
for routine in routines:
|
||||
result += "\n// "+routine.description+": "+routine.ShortNames()+"\n"
|
||||
for flavour in routine.flavours:
|
||||
result += routine.RoutineHeaderC(flavour, 31, " PUBLIC_API")+";\n"
|
||||
return result
|
||||
# Correctness tests
|
||||
filename = library_root + "/test/correctness/routines/" + routine_suffix
|
||||
with open(filename, "w") as f:
|
||||
f.write(cpp.HEADER + "\n")
|
||||
f.write(cpp.correctness_test(routine, level_string))
|
||||
f.write(cpp.FOOTER)
|
||||
|
||||
# The C API implementation (.cpp)
|
||||
def clblast_c_cc(routines):
|
||||
result = ""
|
||||
for routine in routines:
|
||||
result += "\n// "+routine.name.upper()+"\n"
|
||||
for flavour in routine.flavours:
|
||||
template = "<"+flavour.template+">" if routine.NoScalars() else ""
|
||||
indent = " "*(26 + routine.Length() + len(template))
|
||||
result += routine.RoutineHeaderC(flavour, 20, "")+" {\n"
|
||||
result += " auto status = clblast::"+routine.name.capitalize()+template+"("
|
||||
result += (",\n"+indent).join([a for a in routine.ArgumentsCast(flavour, indent)])
|
||||
result += ",\n"+indent+"queue, event);"
|
||||
result += "\n return static_cast<StatusCode>(status);\n}\n"
|
||||
return result
|
||||
# Performance tests
|
||||
filename = library_root + "/test/performance/routines/" + routine_suffix
|
||||
with open(filename, "w") as f:
|
||||
f.write(cpp.HEADER + "\n")
|
||||
f.write(cpp.performance_test(routine, level_string))
|
||||
f.write(cpp.FOOTER)
|
||||
|
||||
# ==================================================================================================
|
||||
# Outputs the API documentation
|
||||
filename = cl_args.clblast_root + "/doc/clblast.md"
|
||||
with open(filename, "w") as f:
|
||||
|
||||
# The wrapper to the reference clBLAS routines (for performance/correctness testing)
|
||||
def wrapper_clblas(routines):
|
||||
result = ""
|
||||
for routine in routines:
|
||||
if routine.has_tests:
|
||||
result += "\n// Forwards the clBLAS calls for %s\n" % (routine.ShortNamesTested())
|
||||
if routine.NoScalars():
|
||||
result += routine.RoutineHeaderWrapperCL(routine.template, True, 21)+";\n"
|
||||
for flavour in routine.flavours:
|
||||
result += routine.RoutineHeaderWrapperCL(flavour, False, 21)+" {\n"
|
||||
# Outputs the header
|
||||
doc_header = doc.header()
|
||||
f.write(doc_header)
|
||||
|
||||
# There is a version available in clBLAS
|
||||
if flavour.precision_name in ["S","D","C","Z"]:
|
||||
indent = " "*(17 + routine.Length())
|
||||
arguments = routine.ArgumentsWrapperCL(flavour)
|
||||
if routine.scratch:
|
||||
result += " auto queue = Queue(queues[0]);\n"
|
||||
result += " auto context = queue.GetContext();\n"
|
||||
result += " auto scratch_buffer = Buffer<"+flavour.template+">(context, "+routine.scratch+");\n"
|
||||
arguments += ["scratch_buffer()"]
|
||||
result += " return clblas"+flavour.name+routine.name+"("
|
||||
result += (",\n"+indent).join([a for a in arguments])
|
||||
result += ",\n"+indent+"num_queues, queues, num_wait_events, wait_events, events);"
|
||||
# Generates the documentation for each routine
|
||||
for level in [1, 2, 3, 4]:
|
||||
for routine in ROUTINES[level - 1]:
|
||||
if routine.implemented:
|
||||
doc_routine = doc.generate(routine)
|
||||
f.write(doc_routine)
|
||||
|
||||
# There is no clBLAS available, forward the call to one of the available functions
|
||||
else: # Half-precision
|
||||
indent = " "*(24 + routine.Length())
|
||||
|
||||
# Convert to float (note: also integer buffers are stored as half/float)
|
||||
for buf in routine.inputs + routine.outputs:
|
||||
result += " auto "+buf+"_buffer_bis = HalfToFloatBuffer("+buf+"_buffer, queues[0]);\n"
|
||||
|
||||
# Call the float routine
|
||||
result += " auto status = clblasX"+routine.name+"("
|
||||
result += (",\n"+indent).join([a for a in routine.ArgumentsHalf()])
|
||||
result += ",\n"+indent+"num_queues, queues, num_wait_events, wait_events, events);"
|
||||
result += "\n"
|
||||
|
||||
# Convert back to half
|
||||
for buf in routine.outputs:
|
||||
result += " FloatToHalfBuffer("+buf+"_buffer, "+buf+"_buffer_bis, queues[0]);\n"
|
||||
result += " return status;"
|
||||
|
||||
# Complete
|
||||
result += "\n}\n"
|
||||
return result
|
||||
|
||||
# The wrapper to the reference CBLAS routines (for performance/correctness testing)
|
||||
def wrapper_cblas(routines):
|
||||
result = ""
|
||||
for routine in routines:
|
||||
if routine.has_tests:
|
||||
result += "\n// Forwards the Netlib BLAS calls for %s\n" % (routine.ShortNamesTested())
|
||||
for flavour in routine.flavours:
|
||||
result += routine.RoutineHeaderWrapperC(flavour, False, 12)+" {\n"
|
||||
|
||||
# There is a version available in CBLAS
|
||||
if flavour.precision_name in ["S","D","C","Z"]:
|
||||
indent = " "*(10 + routine.Length())
|
||||
arguments = routine.ArgumentsWrapperC(flavour)
|
||||
|
||||
# Complex scalars
|
||||
for scalar in routine.scalars:
|
||||
if flavour.IsComplex(scalar):
|
||||
result += " const auto "+scalar+"_array = std::vector<"+flavour.buffertype[:-1]+">{"+scalar+".real(), "+scalar+".imag()};\n"
|
||||
|
||||
# Special case for scalar outputs
|
||||
assignment = ""
|
||||
postfix = ""
|
||||
endofline = ""
|
||||
extra_argument = ""
|
||||
for output_buffer in routine.outputs:
|
||||
if output_buffer in routine.ScalarBuffersFirst():
|
||||
if flavour in [C,Z]:
|
||||
postfix += "_sub"
|
||||
indent += " "
|
||||
extra_argument += ",\n"+indent+"reinterpret_cast<return_pointer_"+flavour.buffertype[:-1]+">(&"+output_buffer+"_buffer["+output_buffer+"_offset])"
|
||||
elif output_buffer in routine.IndexBuffers():
|
||||
assignment = "((int*)&"+output_buffer+"_buffer[0])["+output_buffer+"_offset] = "
|
||||
indent += " "*len(assignment)
|
||||
else:
|
||||
assignment = output_buffer+"_buffer["+output_buffer+"_offset]"
|
||||
if (flavour.name in ["Sc","Dz"]):
|
||||
assignment = assignment+".real("
|
||||
endofline += ")"
|
||||
else:
|
||||
assignment = assignment+" = "
|
||||
indent += " "*len(assignment)
|
||||
|
||||
result += " "+assignment+"cblas_"+flavour.name.lower()+routine.name+postfix+"("
|
||||
result += (",\n"+indent).join([a for a in arguments])
|
||||
result += extra_argument+endofline+");\n"
|
||||
|
||||
# There is no CBLAS available, forward the call to one of the available functions
|
||||
else: # Half-precision
|
||||
indent = " "*(9 + routine.Length())
|
||||
|
||||
# Convert to float (note: also integer buffers are stored as half/float)
|
||||
for buf in routine.inputs + routine.outputs:
|
||||
result += " auto "+buf+"_buffer_bis = HalfToFloatBuffer("+buf+"_buffer);\n"
|
||||
|
||||
# Call the float routine
|
||||
result += " cblasX"+routine.name+"("
|
||||
result += (",\n"+indent).join([a for a in routine.ArgumentsHalf()])
|
||||
result += ");\n"
|
||||
|
||||
# Convert back to half
|
||||
for buf in routine.outputs:
|
||||
result += " FloatToHalfBuffer("+buf+"_buffer, "+buf+"_buffer_bis);\n"
|
||||
|
||||
# Complete
|
||||
result += "}\n"
|
||||
return result
|
||||
|
||||
# ==================================================================================================
|
||||
|
||||
# Checks for the number of command-line arguments
|
||||
if len(sys.argv) != 2:
|
||||
print "[ERROR] Usage: generator.py <root_of_clblast>"
|
||||
sys.exit()
|
||||
|
||||
# Parses the command-line arguments
|
||||
path_clblast = sys.argv[1]
|
||||
files = [
|
||||
path_clblast+"/include/clblast.h",
|
||||
path_clblast+"/src/clblast.cpp",
|
||||
path_clblast+"/include/clblast_c.h",
|
||||
path_clblast+"/src/clblast_c.cpp",
|
||||
path_clblast+"/test/wrapper_clblas.hpp",
|
||||
path_clblast+"/test/wrapper_cblas.hpp",
|
||||
]
|
||||
header_lines = [96, 73, 97, 22, 29, 41]
|
||||
footer_lines = [17, 75, 19, 14, 6, 6]
|
||||
|
||||
# Checks whether the command-line arguments are valid; exists otherwise
|
||||
for f in files:
|
||||
if not os.path.isfile(f):
|
||||
print "[ERROR] The path '"+path_clblast+"' does not point to the root of the CLBlast library"
|
||||
sys.exit()
|
||||
|
||||
# ==================================================================================================
|
||||
|
||||
# Iterates over all files to output
|
||||
for i in xrange(0,len(files)):
|
||||
|
||||
# Stores the header and the footer of the original file
|
||||
with open(files[i]) as f:
|
||||
original = f.readlines()
|
||||
file_header = original[:header_lines[i]]
|
||||
file_footer = original[-footer_lines[i]:]
|
||||
|
||||
# Re-writes the body of the file
|
||||
with open(files[i], "w") as f:
|
||||
body = ""
|
||||
levels = [1,2,3] if (i == 4 or i == 5) else [1,2,3,4]
|
||||
for level in levels:
|
||||
body += separators[level-1]+"\n"
|
||||
if i == 0:
|
||||
body += clblast_h(routines[level-1])
|
||||
if i == 1:
|
||||
body += clblast_cc(routines[level-1])
|
||||
if i == 2:
|
||||
body += clblast_c_h(routines[level-1])
|
||||
if i == 3:
|
||||
body += clblast_c_cc(routines[level-1])
|
||||
if i == 4:
|
||||
body += wrapper_clblas(routines[level-1])
|
||||
if i == 5:
|
||||
body += wrapper_cblas(routines[level-1])
|
||||
f.write("".join(file_header))
|
||||
f.write(body)
|
||||
f.write("".join(file_footer))
|
||||
|
||||
# ==================================================================================================
|
||||
|
||||
# Outputs all the correctness-test implementations
|
||||
for level in [1,2,3,4]:
|
||||
for routine in routines[level-1]:
|
||||
if routine.has_tests:
|
||||
filename = path_clblast+"/test/correctness/routines/level"+levelnames[level-1]+"/x"+routine.name+".cpp"
|
||||
with open(filename, "w") as f:
|
||||
body = ""
|
||||
body += "#include \"test/correctness/testblas.hpp\"\n"
|
||||
body += "#include \"test/routines/level"+levelnames[level-1]+"/x"+routine.name+".hpp\"\n\n"
|
||||
body += "// Shortcuts to the clblast namespace\n"
|
||||
body += "using float2 = clblast::float2;\n"
|
||||
body += "using double2 = clblast::double2;\n\n"
|
||||
body += "// Main function (not within the clblast namespace)\n"
|
||||
body += "int main(int argc, char *argv[]) {\n"
|
||||
body += " auto errors = size_t{0};\n"
|
||||
not_first = "false"
|
||||
for flavour in routine.flavours:
|
||||
body += " errors += clblast::RunTests<clblast::TestX"+routine.name+flavour.TestTemplate()
|
||||
body += ">(argc, argv, "+not_first+", \""+flavour.name+routine.name.upper()+"\");\n"
|
||||
not_first = "true"
|
||||
body += " if (errors > 0) { return 1; } else { return 0; }\n"
|
||||
body += "}\n"
|
||||
f.write(header+"\n")
|
||||
f.write(body)
|
||||
f.write(footer)
|
||||
|
||||
# Outputs all the performance-test implementations
|
||||
for level in [1,2,3,4]:
|
||||
for routine in routines[level-1]:
|
||||
if routine.has_tests:
|
||||
filename = path_clblast+"/test/performance/routines/level"+levelnames[level-1]+"/x"+routine.name+".cpp"
|
||||
with open(filename, "w") as f:
|
||||
body = ""
|
||||
body += "#include \"test/performance/client.hpp\"\n"
|
||||
body += "#include \"test/routines/level"+levelnames[level-1]+"/x"+routine.name+".hpp\"\n\n"
|
||||
body += "// Shortcuts to the clblast namespace\n"
|
||||
body += "using float2 = clblast::float2;\n"
|
||||
body += "using double2 = clblast::double2;\n\n"
|
||||
body += "// Main function (not within the clblast namespace)\n"
|
||||
body += "int main(int argc, char *argv[]) {\n"
|
||||
default = PrecisionToFullName(routine.flavours[0].precision_name)
|
||||
body += " switch(clblast::GetPrecision(argc, argv, clblast::Precision::k"+default+")) {\n"
|
||||
for precision in ["H","S","D","C","Z"]:
|
||||
body += " case clblast::Precision::k"+PrecisionToFullName(precision)+":"
|
||||
found = False
|
||||
for flavour in routine.flavours:
|
||||
if flavour.precision_name == precision:
|
||||
body += "\n clblast::RunClient<clblast::TestX"+routine.name+flavour.TestTemplate()
|
||||
body += ">(argc, argv); break;\n"
|
||||
found = True
|
||||
if not found:
|
||||
body += " throw std::runtime_error(\"Unsupported precision mode\");\n"
|
||||
body += " }\n"
|
||||
body += " return 0;\n"
|
||||
body += "}\n"
|
||||
f.write(header+"\n")
|
||||
f.write(body)
|
||||
f.write(footer)
|
||||
|
||||
# ==================================================================================================
|
||||
|
||||
# Outputs the API documentation
|
||||
filename = path_clblast+"/doc/clblast.md"
|
||||
with open(filename, "w") as f:
|
||||
|
||||
# Outputs the header
|
||||
f.write("CLBlast: API reference\n")
|
||||
f.write("================\n")
|
||||
f.write("\n\n")
|
||||
|
||||
# Loops over the routines
|
||||
for level in [1,2,3,4]:
|
||||
for routine in routines[level-1]:
|
||||
if routine.implemented:
|
||||
|
||||
# Routine header
|
||||
f.write("x"+routine.name.upper()+": "+routine.description+"\n")
|
||||
f.write("-------------\n")
|
||||
f.write("\n")
|
||||
f.write(routine.details+"\n")
|
||||
f.write("\n")
|
||||
|
||||
# Routine API
|
||||
f.write("C++ API:\n")
|
||||
f.write("```\n")
|
||||
f.write(routine.RoutineHeaderCPP(12, "")+"\n")
|
||||
f.write("```\n")
|
||||
f.write("\n")
|
||||
f.write("C API:\n")
|
||||
f.write("```\n")
|
||||
for flavour in routine.flavours:
|
||||
f.write(routine.RoutineHeaderC(flavour, 20, "")+"\n")
|
||||
f.write("```\n")
|
||||
f.write("\n")
|
||||
|
||||
# Routine arguments
|
||||
f.write("Arguments to "+routine.name.upper()+":\n")
|
||||
f.write("\n")
|
||||
for argument in routine.ArgumentsDoc():
|
||||
f.write("* "+argument+"\n")
|
||||
f.write("* `cl_command_queue* queue`: Pointer to an OpenCL command queue associated with a context and device to execute the routine on.\n")
|
||||
f.write("* `cl_event* event`: Pointer to an OpenCL event to be able to wait for completion of the routine's OpenCL kernel(s). This is an optional argument.\n")
|
||||
f.write("\n")
|
||||
|
||||
# Routine requirements
|
||||
if len(routine.RequirementsDoc()) > 0:
|
||||
f.write("Requirements for "+routine.name.upper()+":\n")
|
||||
f.write("\n")
|
||||
for requirement in routine.RequirementsDoc():
|
||||
f.write("* "+requirement+"\n")
|
||||
f.write("\n")
|
||||
|
||||
# Routine footer
|
||||
f.write("\n\n")
|
||||
|
||||
|
||||
# ==================================================================================================
|
||||
if __name__ == '__main__':
|
||||
main(sys.argv[1:])
|
||||
|
|
0
scripts/generator/generator/__init__.py
Normal file
0
scripts/generator/generator/__init__.py
Normal file
69
scripts/generator/generator/convert.py
Normal file
69
scripts/generator/generator/convert.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
|
||||
# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
|
||||
# PEP8 Python style guide and uses a max-width of 120 characters per line.
|
||||
#
|
||||
# Author(s):
|
||||
# Cedric Nugteren <www.cedricnugteren.nl>
|
||||
|
||||
|
||||
def precision_to_full_name(x):
|
||||
"""Translates an option name to a CLBlast data-type"""
|
||||
return {
|
||||
'H': "Half",
|
||||
'S': "Single",
|
||||
'D': "Double",
|
||||
'C': "ComplexSingle",
|
||||
'Z': "ComplexDouble",
|
||||
}[x]
|
||||
|
||||
|
||||
def option_to_clblast(x):
|
||||
"""Translates an option name to a CLBlast data-type"""
|
||||
return {
|
||||
'layout': "Layout",
|
||||
'a_transpose': "Transpose",
|
||||
'b_transpose': "Transpose",
|
||||
'ab_transpose': "Transpose",
|
||||
'side': "Side",
|
||||
'triangle': "Triangle",
|
||||
'diagonal': "Diagonal",
|
||||
}[x]
|
||||
|
||||
|
||||
def option_to_clblas(x):
|
||||
"""As above, but for clBLAS data-types"""
|
||||
return {
|
||||
'layout': "clblasOrder",
|
||||
'a_transpose': "clblasTranspose",
|
||||
'b_transpose': "clblasTranspose",
|
||||
'ab_transpose': "clblasTranspose",
|
||||
'side': "clblasSide",
|
||||
'triangle': "clblasUplo",
|
||||
'diagonal': "clblasDiag",
|
||||
}[x]
|
||||
|
||||
|
||||
def option_to_cblas(x):
|
||||
"""As above, but for CBLAS data-types"""
|
||||
return {
|
||||
'layout': "CBLAS_ORDER",
|
||||
'a_transpose': "CBLAS_TRANSPOSE",
|
||||
'b_transpose': "CBLAS_TRANSPOSE",
|
||||
'ab_transpose': "CBLAS_TRANSPOSE",
|
||||
'side': "CBLAS_SIDE",
|
||||
'triangle': "CBLAS_UPLO",
|
||||
'diagonal': "CBLAS_DIAG",
|
||||
}[x]
|
||||
|
||||
|
||||
def option_to_documentation(x):
|
||||
"""Translates an option name to a documentation string"""
|
||||
return {
|
||||
'layout': "Data-layout of the matrices, either `Layout::kRowMajor` (101) for row-major layout or `Layout::kColMajor` (102) for column-major data-layout.",
|
||||
'a_transpose': "Transposing the input matrix A, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.",
|
||||
'b_transpose': "Transposing the input matrix B, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.",
|
||||
'ab_transpose': "Transposing the packed input matrix AP, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.",
|
||||
'side': "The position of the triangular matrix in the operation, either on the `Side::kLeft` (141) or `Side::kRight` (142).",
|
||||
'triangle': "The part of the array of the triangular matrix to be used, either `Triangle::kUpper` (121) or `Triangle::kLower` (122).",
|
||||
'diagonal': "The property of the diagonal matrix, either `Diagonal::kNonUnit` (131) for non-unit values on the diagonal or `Diagonal::kUnit` (132) for unit values on the diagonal.",
|
||||
}[x]
|
257
scripts/generator/generator/cpp.py
Normal file
257
scripts/generator/generator/cpp.py
Normal file
|
@ -0,0 +1,257 @@
|
|||
|
||||
# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
|
||||
# PEP8 Python style guide and uses a max-width of 120 characters per line.
|
||||
#
|
||||
# Author(s):
|
||||
# Cedric Nugteren <www.cedricnugteren.nl>
|
||||
|
||||
import generator.datatype as datatype
|
||||
import generator.convert as convert
|
||||
|
||||
|
||||
NL = "\n"
|
||||
SEPARATOR = "// ================================================================================================="
|
||||
|
||||
# Separators for the BLAS levels
|
||||
LEVEL_SEPARATORS = [
|
||||
NL + SEPARATOR + NL + "// BLAS level-1 (vector-vector) routines" + NL + SEPARATOR,
|
||||
NL + SEPARATOR + NL + "// BLAS level-2 (matrix-vector) routines" + NL + SEPARATOR,
|
||||
NL + SEPARATOR + NL + "// BLAS level-3 (matrix-matrix) routines" + NL + SEPARATOR,
|
||||
NL + SEPARATOR + NL + "// Extra non-BLAS routines (level-X)" + NL + SEPARATOR
|
||||
]
|
||||
|
||||
# Names of the level sub-folders
|
||||
LEVEL_NAMES = ["1", "2", "3", "x"]
|
||||
|
||||
# Main header/footer for source files
|
||||
FOOTER = NL + SEPARATOR + NL
|
||||
HEADER = NL + SEPARATOR + """
|
||||
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
|
||||
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
|
||||
// width of 100 characters per line.
|
||||
//
|
||||
// Author(s):
|
||||
// Cedric Nugteren <www.cedricnugteren.nl>
|
||||
//
|
||||
""" + SEPARATOR + NL
|
||||
|
||||
|
||||
def clblast_h(routine):
|
||||
"""The C++ API header (.h)"""
|
||||
result = NL + "// " + routine.description + ": " + routine.short_names() + NL
|
||||
result += routine.routine_header_cpp(12, " = nullptr") + ";" + NL
|
||||
return result
|
||||
|
||||
|
||||
def clblast_cc(routine):
|
||||
"""The C++ API implementation (.cpp)"""
|
||||
indent1 = " " * (20 + routine.length())
|
||||
result = NL + "// " + routine.description + ": " + routine.short_names() + NL
|
||||
if routine.implemented:
|
||||
result += routine.routine_header_cpp(12, "") + " {" + NL
|
||||
result += " auto queue_cpp = Queue(*queue);" + NL
|
||||
result += " auto routine = X" + routine.name + "<" + routine.template.template + ">(queue_cpp, event);" + NL
|
||||
result += " auto status = routine.SetUp();" + NL
|
||||
result += " if (status != StatusCode::kSuccess) { return status; }" + NL
|
||||
result += " return routine.Do" + routine.name.capitalize() + "("
|
||||
result += ("," + NL + indent1).join([a for a in routine.arguments_clcudaapi()])
|
||||
result += ");" + NL
|
||||
else:
|
||||
result += routine.routine_header_type_cpp(12) + " {" + NL
|
||||
result += " return StatusCode::kNotImplemented;" + NL
|
||||
result += "}" + NL
|
||||
for flavour in routine.flavours:
|
||||
indent2 = " " * (34 + routine.length() + len(flavour.template))
|
||||
result += "template StatusCode PUBLIC_API " + routine.name.capitalize() + "<" + flavour.template + ">("
|
||||
result += ("," + NL + indent2).join([a for a in routine.arguments_type(flavour)])
|
||||
result += "," + NL + indent2 + "cl_command_queue*, cl_event*);" + NL
|
||||
return result
|
||||
|
||||
|
||||
def clblast_c_h(routine):
|
||||
"""The C API header (.h)"""
|
||||
result = NL + "// " + routine.description + ": " + routine.short_names() + NL
|
||||
for flavour in routine.flavours:
|
||||
result += routine.routine_header_c(flavour, 31, " PUBLIC_API") + ";" + NL
|
||||
return result
|
||||
|
||||
|
||||
def clblast_c_cc(routine):
|
||||
"""The C API implementation (.cpp)"""
|
||||
result = NL + "// " + routine.name.upper() + NL
|
||||
for flavour in routine.flavours:
|
||||
template = "<" + flavour.template + ">" if routine.no_scalars() else ""
|
||||
indent = " " * (26 + routine.length() + len(template))
|
||||
result += routine.routine_header_c(flavour, 20, "") + " {" + NL
|
||||
result += " auto status = clblast::" + routine.name.capitalize() + template + "("
|
||||
result += ("," + NL + indent).join([a for a in routine.arguments_cast(flavour, indent)])
|
||||
result += "," + NL + indent + "queue, event);"
|
||||
result += NL + " return static_cast<StatusCode>(status);" + NL + "}" + NL
|
||||
return result
|
||||
|
||||
|
||||
def wrapper_clblas(routine):
|
||||
"""The wrapper to the reference clBLAS routines (for performance/correctness testing)"""
|
||||
result = ""
|
||||
if routine.has_tests:
|
||||
result += NL + "// Forwards the clBLAS calls for %s" % routine.short_names_tested() + NL
|
||||
if routine.no_scalars():
|
||||
result += routine.routine_header_wrapper_clblas(routine.template, True, 21) + ";" + NL
|
||||
for flavour in routine.flavours:
|
||||
result += routine.routine_header_wrapper_clblas(flavour, False, 21) + " {" + NL
|
||||
|
||||
# There is a version available in clBLAS
|
||||
if flavour.precision_name in ["S", "D", "C", "Z"]:
|
||||
indent = " " * (17 + routine.length())
|
||||
arguments = routine.arguments_wrapper_clblas(flavour)
|
||||
if routine.scratch:
|
||||
result += " auto queue = Queue(queues[0]);" + NL
|
||||
result += " auto context = queue.GetContext();" + NL
|
||||
result += " auto scratch_buffer = Buffer<" + flavour.template + ">"
|
||||
result += "(context, " + routine.scratch + ");" + NL
|
||||
arguments += ["scratch_buffer()"]
|
||||
result += " return clblas" + flavour.name + routine.name + "("
|
||||
result += ("," + NL + indent).join([a for a in arguments])
|
||||
result += "," + NL + indent + "num_queues, queues, num_wait_events, wait_events, events);"
|
||||
|
||||
# There is no clBLAS available, forward the call to one of the available functions
|
||||
else: # Half-precision
|
||||
indent = " " * (24 + routine.length())
|
||||
|
||||
# Convert to float (note: also integer buffers are stored as half/float)
|
||||
for buf in routine.inputs + routine.outputs:
|
||||
result += " auto " + buf + "_buffer_bis = HalfToFloatBuffer(" + buf + "_buffer, queues[0]);" + NL
|
||||
|
||||
# Call the float routine
|
||||
result += " auto status = clblasX" + routine.name + "("
|
||||
result += ("," + NL + indent).join([a for a in routine.arguments_half()])
|
||||
result += "," + NL + indent + "num_queues, queues, num_wait_events, wait_events, events);"
|
||||
result += NL
|
||||
|
||||
# Convert back to half
|
||||
for buf in routine.outputs:
|
||||
result += " FloatToHalfBuffer(" + buf + "_buffer, " + buf + "_buffer_bis, queues[0]);" + NL
|
||||
result += " return status;"
|
||||
|
||||
# Complete
|
||||
result += NL + "}" + NL
|
||||
return result
|
||||
|
||||
|
||||
def wrapper_cblas(routine):
|
||||
"""The wrapper to the reference CBLAS routines (for performance/correctness testing)"""
|
||||
result = ""
|
||||
if routine.has_tests:
|
||||
result += NL + "// Forwards the Netlib BLAS calls for %s" % routine.short_names_tested() + NL
|
||||
for flavour in routine.flavours:
|
||||
result += routine.routine_header_wrapper_cblas(flavour, 12) + " {" + NL
|
||||
|
||||
# There is a version available in CBLAS
|
||||
if flavour.precision_name in ["S", "D", "C", "Z"]:
|
||||
indent = " " * (10 + routine.length())
|
||||
arguments = routine.arguments_wrapper_cblas(flavour)
|
||||
|
||||
# Complex scalars
|
||||
for scalar in routine.scalars:
|
||||
if flavour.is_complex(scalar):
|
||||
result += " const auto " + scalar + "_array = std::vector<" + flavour.buffer_type[:-1] + ">"
|
||||
result += "{" + scalar + ".real(), " + scalar + ".imag()};" + NL
|
||||
|
||||
# Special case for scalar outputs
|
||||
assignment = ""
|
||||
postfix = ""
|
||||
end_of_line = ""
|
||||
extra_argument = ""
|
||||
for output_buffer in routine.outputs:
|
||||
if output_buffer in routine.scalar_buffers_first():
|
||||
if flavour in [datatype.C, datatype.Z]:
|
||||
postfix += "_sub"
|
||||
indent += " "
|
||||
extra_argument += "," + NL + indent
|
||||
extra_argument += "reinterpret_cast<return_pointer_" + flavour.buffer_type[:-1] + ">"
|
||||
extra_argument += "(&" + output_buffer + "_buffer[" + output_buffer + "_offset])"
|
||||
elif output_buffer in routine.index_buffers():
|
||||
assignment = "((int*)&" + output_buffer + "_buffer[0])[" + output_buffer + "_offset] = "
|
||||
indent += " " * len(assignment)
|
||||
else:
|
||||
assignment = output_buffer + "_buffer[" + output_buffer + "_offset]"
|
||||
if flavour.name in ["Sc", "Dz"]:
|
||||
assignment += ".real("
|
||||
end_of_line += ")"
|
||||
else:
|
||||
assignment += " = "
|
||||
indent += " " * len(assignment)
|
||||
|
||||
result += " " + assignment + "cblas_" + flavour.name.lower() + routine.name + postfix + "("
|
||||
result += ("," + NL + indent).join([a for a in arguments])
|
||||
result += extra_argument + end_of_line + ");" + NL
|
||||
|
||||
# There is no CBLAS available, forward the call to one of the available functions
|
||||
else: # Half-precision
|
||||
indent = " " * (9 + routine.length())
|
||||
|
||||
# Convert to float (note: also integer buffers are stored as half/float)
|
||||
for buf in routine.inputs + routine.outputs:
|
||||
result += " auto " + buf + "_buffer_bis = HalfToFloatBuffer(" + buf + "_buffer);" + NL
|
||||
|
||||
# Call the float routine
|
||||
result += " cblasX" + routine.name + "("
|
||||
result += ("," + NL + indent).join([a for a in routine.arguments_half()])
|
||||
result += ");" + NL
|
||||
|
||||
# Convert back to half
|
||||
for buf in routine.outputs:
|
||||
result += " FloatToHalfBuffer(" + buf + "_buffer, " + buf + "_buffer_bis);" + NL
|
||||
|
||||
# Complete
|
||||
result += "}" + NL
|
||||
return result
|
||||
|
||||
|
||||
def performance_test(routine, level_string):
|
||||
"""Generates the body of a performance test for a specific routine"""
|
||||
result = ""
|
||||
result += "#include \"test/performance/client.hpp\"" + NL
|
||||
result += "#include \"test/routines/level" + level_string + "/x" + routine.name + ".hpp\"" + NL + NL
|
||||
result += "// Shortcuts to the clblast namespace" + NL
|
||||
result += "using float2 = clblast::float2;" + NL
|
||||
result += "using double2 = clblast::double2;" + NL + NL
|
||||
result += "// Main function (not within the clblast namespace)" + NL
|
||||
result += "int main(int argc, char *argv[]) {" + NL
|
||||
default = convert.precision_to_full_name(routine.flavours[0].precision_name)
|
||||
result += " switch(clblast::GetPrecision(argc, argv, clblast::Precision::k" + default + ")) {" + NL
|
||||
for precision in ["H", "S", "D", "C", "Z"]:
|
||||
result += " case clblast::Precision::k" + convert.precision_to_full_name(precision) + ":"
|
||||
found = False
|
||||
for flavour in routine.flavours:
|
||||
if flavour.precision_name == precision:
|
||||
result += NL + " clblast::RunClient<clblast::TestX" + routine.name + flavour.test_template()
|
||||
result += ">(argc, argv); break;" + NL
|
||||
found = True
|
||||
if not found:
|
||||
result += " throw std::runtime_error(\"Unsupported precision mode\");" + NL
|
||||
result += " }" + NL
|
||||
result += " return 0;" + NL
|
||||
result += "}" + NL
|
||||
return result
|
||||
|
||||
|
||||
def correctness_test(routine, level_string):
|
||||
"""Generates the body of a correctness test for a specific routine"""
|
||||
result = ""
|
||||
result += "#include \"test/correctness/testblas.hpp\"" + NL
|
||||
result += "#include \"test/routines/level" + level_string + "/x" + routine.name + ".hpp\"" + NL + NL
|
||||
result += "// Shortcuts to the clblast namespace" + NL
|
||||
result += "using float2 = clblast::float2;" + NL
|
||||
result += "using double2 = clblast::double2;" + NL + NL
|
||||
result += "// Main function (not within the clblast namespace)" + NL
|
||||
result += "int main(int argc, char *argv[]) {" + NL
|
||||
result += " auto errors = size_t{0};" + NL
|
||||
not_first = "false"
|
||||
for flavour in routine.flavours:
|
||||
result += " errors += clblast::RunTests<clblast::TestX" + routine.name + flavour.test_template()
|
||||
result += ">(argc, argv, " + not_first + ", \"" + flavour.name + routine.name.upper() + "\");" + NL
|
||||
not_first = "true"
|
||||
result += " if (errors > 0) { return 1; } else { return 0; }" + NL
|
||||
result += "}" + NL
|
||||
return result
|
92
scripts/generator/generator/datatype.py
Normal file
92
scripts/generator/generator/datatype.py
Normal file
|
@ -0,0 +1,92 @@
|
|||
|
||||
# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
|
||||
# PEP8 Python style guide and uses a max-width of 120 characters per line.
|
||||
#
|
||||
# Author(s):
|
||||
# Cedric Nugteren <www.cedricnugteren.nl>
|
||||
|
||||
|
||||
# Short-hands for data-types
|
||||
D_HALF = "half"
|
||||
D_FLOAT = "float"
|
||||
D_DOUBLE = "double"
|
||||
D_FLOAT2 = "float2"
|
||||
D_DOUBLE2 = "double2"
|
||||
D_HALF_OPENCL = "cl_half"
|
||||
D_FLOAT2_OPENCL = "cl_float2"
|
||||
D_DOUBLE2_OPENCL = "cl_double2"
|
||||
|
||||
|
||||
class DataType:
|
||||
"""Class holding data-type and precision information"""
|
||||
|
||||
def __init__(self, precision_name, name, template, scalars, buffer_type):
|
||||
self.precision_name = precision_name
|
||||
self.name = name
|
||||
self.template = template
|
||||
self.alpha_cpp = scalars[0]
|
||||
self.beta_cpp = scalars[1]
|
||||
self.alpha_cl = scalars[2]
|
||||
self.beta_cl = scalars[3]
|
||||
self.buffer_type = buffer_type
|
||||
|
||||
def use_alpha(self):
|
||||
"""Outputs the name of the data-type (alpha/beta), possibly transforming into the right type"""
|
||||
if self.alpha_cpp in [D_FLOAT2, D_DOUBLE2]:
|
||||
return self.alpha_cpp + "{alpha.s[0], alpha.s[1]}"
|
||||
return "alpha"
|
||||
|
||||
def use_beta(self):
|
||||
"""As above, but for beta instead of alpha"""
|
||||
if self.beta_cpp in [D_FLOAT2, D_DOUBLE2]:
|
||||
return self.beta_cpp + "{beta.s[0], beta.s[1]}"
|
||||
return "beta"
|
||||
|
||||
def use_alpha_opencl(self):
|
||||
"""As above, but the transformation is in the opposite direction"""
|
||||
if self.alpha_cpp in [D_FLOAT2, D_DOUBLE2]:
|
||||
return self.alpha_cl + "{{alpha.real(), alpha.imag()}}"
|
||||
return "alpha"
|
||||
|
||||
def use_beta_opencl(self):
|
||||
"""As above, but for beta instead of alpha"""
|
||||
if self.beta_cpp in [D_FLOAT2, D_DOUBLE2]:
|
||||
return self.beta_cl + "{{beta.real(), beta.imag()}}"
|
||||
return "beta"
|
||||
|
||||
def test_template(self):
|
||||
"""Returns the template as used in the correctness/performance tests"""
|
||||
if self.buffer_type != self.beta_cpp:
|
||||
return "<" + self.buffer_type + "," + self.beta_cpp + ">, " + self.buffer_type + ", " + self.beta_cpp
|
||||
return "<" + self.buffer_type + ">, " + self.buffer_type + ", " + self.beta_cpp
|
||||
|
||||
def is_complex(self, scalar):
|
||||
"""Current scalar is complex"""
|
||||
return ((scalar == "alpha" and self.alpha_cpp in [D_FLOAT2, D_DOUBLE2]) or
|
||||
(scalar == "beta" and self.beta_cpp in [D_FLOAT2, D_DOUBLE2]))
|
||||
|
||||
|
||||
# Regular data-types
|
||||
H = DataType("H", "H", D_HALF, [D_HALF] * 2 + [D_HALF_OPENCL] * 2, D_HALF) # half (16)
|
||||
S = DataType("S", "S", D_FLOAT, [D_FLOAT] * 4, D_FLOAT) # single (32)
|
||||
D = DataType("D", "D", D_DOUBLE, [D_DOUBLE] * 4, D_DOUBLE) # double (64)
|
||||
C = DataType("C", "C", D_FLOAT2, [D_FLOAT2] * 2 + [D_FLOAT2_OPENCL] * 2, D_FLOAT2) # single-complex (3232)
|
||||
Z = DataType("Z", "Z", D_DOUBLE2, [D_DOUBLE2] * 2 + [D_DOUBLE2_OPENCL] * 2, D_DOUBLE2) # double-complex (6464)
|
||||
|
||||
# Special cases
|
||||
Sc = DataType("C", "Sc", D_FLOAT2, [D_FLOAT2] * 4, D_FLOAT2) # As C, but with real output
|
||||
Dz = DataType("Z", "Dz", D_DOUBLE2, [D_DOUBLE2] * 4, D_DOUBLE2) # As Z, but with real output
|
||||
iH = DataType("H", "iH", D_HALF, [D_HALF] * 4, D_HALF) # As H, but with integer output
|
||||
iS = DataType("S", "iS", D_FLOAT, [D_FLOAT] * 4, D_FLOAT) # As S, but with integer output
|
||||
iD = DataType("D", "iD", D_DOUBLE, [D_DOUBLE] * 4, D_DOUBLE) # As D, but with integer output
|
||||
iC = DataType("C", "iC", D_FLOAT2, [D_FLOAT2] * 2 + [D_FLOAT2_OPENCL] * 2, D_FLOAT2) # As C, but with integer output
|
||||
iZ = DataType("Z", "iZ", D_DOUBLE2, [D_DOUBLE2] * 2 + [D_DOUBLE2_OPENCL] * 2, D_DOUBLE2) # As Z, but with int output
|
||||
Css = DataType("C", "C", D_FLOAT, [D_FLOAT, D_FLOAT, D_FLOAT, D_FLOAT], D_FLOAT2) # As C, but with constants from S
|
||||
Zdd = DataType("Z", "Z", D_DOUBLE, [D_DOUBLE] * 4, D_DOUBLE2) # As Z, but with constants from D
|
||||
Ccs = DataType("C", "C", D_FLOAT2 + "," + D_FLOAT, [D_FLOAT2, D_FLOAT, D_FLOAT2_OPENCL, D_FLOAT], D_FLOAT2) # As C, but with one constant from S
|
||||
Zzd = DataType("Z", "Z", D_DOUBLE2 + "," + D_DOUBLE, [D_DOUBLE2, D_DOUBLE, D_DOUBLE2_OPENCL, D_DOUBLE], D_DOUBLE2) # As Z, but with one constant from D
|
||||
|
||||
# C++ template data-types
|
||||
T = DataType("T", "typename T", "T", ["T", "T", "T", "T"], "T") # regular routine
|
||||
Tc = DataType("Tc", "typename T", "std::complex<T>,T", ["T", "T", "T", "T"], "std::complex<T>") # for herk
|
||||
TU = DataType("TU", "typename T, typename U", "T,U", ["T", "U", "T", "U"], "T") # for her2k
|
57
scripts/generator/generator/doc.py
Normal file
57
scripts/generator/generator/doc.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
|
||||
# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
|
||||
# PEP8 Python style guide and uses a max-width of 120 characters per line.
|
||||
#
|
||||
# Author(s):
|
||||
# Cedric Nugteren <www.cedricnugteren.nl>
|
||||
|
||||
NL = "\n"
|
||||
|
||||
|
||||
def header():
|
||||
"""Generates the header for the API documentation"""
|
||||
result = "CLBlast: API reference" + NL
|
||||
result += "================" + NL + NL + NL
|
||||
return result
|
||||
|
||||
|
||||
def generate(routine):
|
||||
"""Generates the API documentation for a given routine"""
|
||||
result = ""
|
||||
|
||||
# Routine header
|
||||
result += "x" + routine.name.upper() + ": " + routine.description + NL
|
||||
result += "-------------" + NL + NL
|
||||
result += routine.details + NL + NL
|
||||
|
||||
# Routine API
|
||||
result += "C++ API:" + NL
|
||||
result += "```" + NL
|
||||
result += routine.routine_header_cpp(12, "") + NL
|
||||
result += "```" + NL + NL
|
||||
result += "C API:" + NL
|
||||
result += "```" + NL
|
||||
for flavour in routine.flavours:
|
||||
result += routine.routine_header_c(flavour, 20, "") + NL
|
||||
result += "```" + NL + NL
|
||||
|
||||
# Routine arguments
|
||||
result += "Arguments to " + routine.name.upper() + ":" + NL + NL
|
||||
for argument in routine.arguments_doc():
|
||||
result += "* " + argument + NL
|
||||
result += "* `cl_command_queue* queue`: "
|
||||
result += "Pointer to an OpenCL command queue associated with a context and device to execute the routine on." + NL
|
||||
result += "* `cl_event* event`: "
|
||||
result += "Pointer to an OpenCL event to be able to wait for completion of the routine's OpenCL kernel(s). "
|
||||
result += "This is an optional argument." + NL + NL
|
||||
|
||||
# Routine requirements
|
||||
if len(routine.requirements_doc()) > 0:
|
||||
result += "Requirements for " + routine.name.upper() + ":" + NL + NL
|
||||
for requirement in routine.requirements_doc():
|
||||
result += "* " + requirement + NL
|
||||
result += NL
|
||||
|
||||
# Routine footer
|
||||
result += NL + NL
|
||||
return result
|
552
scripts/generator/generator/routine.py
Normal file
552
scripts/generator/generator/routine.py
Normal file
|
@ -0,0 +1,552 @@
|
|||
|
||||
# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
|
||||
# PEP8 Python style guide and uses a max-width of 120 characters per line.
|
||||
#
|
||||
# Author(s):
|
||||
# Cedric Nugteren <www.cedricnugteren.nl>
|
||||
|
||||
from itertools import chain
|
||||
|
||||
import generator.convert as convert
|
||||
|
||||
|
||||
class Routine:
|
||||
"""Class holding routine-specific information (e.g. name, which arguments, which precisions)"""
|
||||
def __init__(self, implemented, has_tests, level, name, template, flavours, sizes, options,
|
||||
inputs, outputs, scalars, scratch, description, details, requirements):
|
||||
self.implemented = implemented
|
||||
self.has_tests = has_tests
|
||||
self.level = level
|
||||
self.name = name
|
||||
self.template = template
|
||||
self.flavours = flavours
|
||||
self.sizes = sizes
|
||||
self.options = options
|
||||
self.inputs = inputs
|
||||
self.outputs = outputs
|
||||
self.scalars = scalars
|
||||
self.scratch = scratch # Scratch buffer (e.g. for xDOT)
|
||||
self.description = description
|
||||
self.details = details
|
||||
self.requirements = requirements
|
||||
|
||||
@staticmethod
|
||||
def scalar_buffers_first():
|
||||
"""List of scalar buffers"""
|
||||
return ["dot", "nrm2", "asum", "sum", "imax", "imin"]
|
||||
|
||||
@staticmethod
|
||||
def scalar_buffers_second():
|
||||
"""List of scalar buffers"""
|
||||
return ["sa", "sb", "sc", "ss", "sd1", "sd2", "sx1", "sy1", "sparam"]
|
||||
|
||||
@staticmethod
|
||||
def other_scalars():
|
||||
"""List of scalars other than alpha and beta"""
|
||||
return ["cos", "sin"]
|
||||
|
||||
@staticmethod
|
||||
def index_buffers():
|
||||
"""List of buffers with unsigned int type"""
|
||||
return ["imax", "imin"]
|
||||
|
||||
@staticmethod
|
||||
def postfix(name):
|
||||
"""Retrieves the postfix for a buffer"""
|
||||
return "inc" if (name in ["x", "y"]) else "ld"
|
||||
|
||||
@staticmethod
|
||||
def buffers_vector():
|
||||
"""Distinguish between vectors and matrices"""
|
||||
return ["x", "y"]
|
||||
|
||||
@staticmethod
|
||||
def buffers_matrix():
|
||||
"""Distinguish between vectors and matrices"""
|
||||
return ["a", "b", "c", "ap"]
|
||||
|
||||
def non_index_inputs(self):
|
||||
"""Lists of input/output buffers not index (integer)"""
|
||||
buffers = self.inputs[:] # make a copy
|
||||
for i in self.index_buffers():
|
||||
if i in buffers:
|
||||
buffers.remove(i)
|
||||
return buffers
|
||||
|
||||
def non_index_outputs(self):
|
||||
"""Lists of input/output buffers not index (integer)"""
|
||||
buffers = self.outputs[:] # make a copy
|
||||
for i in self.index_buffers():
|
||||
if i in buffers:
|
||||
buffers.remove(i)
|
||||
return buffers
|
||||
|
||||
def buffers_without_ld_inc(self):
|
||||
"""List of buffers without 'inc' or 'ld'"""
|
||||
return self.scalar_buffers_first() + self.scalar_buffers_second() + ["ap"]
|
||||
|
||||
def length(self):
|
||||
"""Retrieves the number of characters in the routine's name"""
|
||||
return len(self.name)
|
||||
|
||||
def no_scalars(self):
|
||||
"""Determines whether or not this routine has scalar arguments (alpha/beta)"""
|
||||
return self.scalars == []
|
||||
|
||||
def short_names(self):
|
||||
"""Returns the upper-case names of these routines (all flavours)"""
|
||||
return "/".join([f.name + self.name.upper() for f in self.flavours])
|
||||
|
||||
def short_names_tested(self):
|
||||
"""As above, but excludes some"""
|
||||
names = [f.name + self.name.upper() for f in self.flavours]
|
||||
if "H" + self.name.upper() in names:
|
||||
names.remove("H" + self.name.upper())
|
||||
return "/".join(names)
|
||||
|
||||
def buffers_first(self):
|
||||
"""Determines which buffers go first (between alpha and beta) and which ones go after"""
|
||||
if self.level == "2b":
|
||||
return ["x", "y"]
|
||||
return ["ap", "a", "b", "x"]
|
||||
|
||||
def buffers_second(self):
|
||||
if self.level == "2b":
|
||||
return ["ap", "a", "b", "c"]
|
||||
return ["y", "c"]
|
||||
|
||||
def buffer(self, name):
|
||||
"""Retrieves a variable name for a specific input/output vector/matrix (e.g. 'x')"""
|
||||
if name in self.inputs or name in self.outputs:
|
||||
a = [name + "_buffer"]
|
||||
b = [name + "_offset"]
|
||||
c = [name + "_" + self.postfix(name)] if (name not in self.buffers_without_ld_inc()) else []
|
||||
return [", ".join(a + b + c)]
|
||||
return []
|
||||
|
||||
def buffer_bis(self, name):
|
||||
"""As above but with a '_bis' suffix for the buffer name"""
|
||||
if name in self.inputs or name in self.outputs:
|
||||
a = [name + "_buffer_bis"]
|
||||
b = [name + "_offset"]
|
||||
c = [name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else []
|
||||
return [", ".join(a + b + c)]
|
||||
return []
|
||||
|
||||
def buffer_def(self, name):
|
||||
"""As above but with data-types"""
|
||||
prefix = "const " if name in self.inputs else ""
|
||||
if name in self.inputs or name in self.outputs:
|
||||
a = [prefix + "cl_mem " + name + "_buffer"]
|
||||
b = ["const size_t " + name + "_offset"]
|
||||
c = ["const size_t " + name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else []
|
||||
return [", ".join(a + b + c)]
|
||||
return []
|
||||
|
||||
def buffer_def_wrapper_cl(self, name, flavour):
|
||||
"""As above but with data-types"""
|
||||
prefix = "const " if name in self.inputs else ""
|
||||
if name in self.inputs or name in self.outputs:
|
||||
a = [prefix + "Buffer<" + flavour.buffer_type + ">& " + name + "_buffer"]
|
||||
b = ["const size_t " + name + "_offset"]
|
||||
c = ["const size_t " + name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else []
|
||||
return [", ".join(a + b + c)]
|
||||
return []
|
||||
|
||||
def buffer_def_vector(self, name, flavour):
|
||||
"""As above but as vectors"""
|
||||
prefix = "const " if name in self.inputs else ""
|
||||
if name in self.inputs or name in self.outputs:
|
||||
a = [prefix + "std::vector<" + flavour.buffer_type + ">& " + name + "_buffer"]
|
||||
b = ["const size_t " + name + "_offset"]
|
||||
c = ["const size_t " + name + "_" + self.postfix(name)] if name not in self.buffers_without_ld_inc() else []
|
||||
return [", ".join(a + b + c)]
|
||||
return []
|
||||
|
||||
def buffer_clcudaapi(self, name):
|
||||
"""As above but with CLCudaAPI buffers"""
|
||||
if name in self.inputs or name in self.outputs:
|
||||
buffer_type = "unsigned int" if (name in self.index_buffers()) else self.template.buffer_type
|
||||
a = ["Buffer<" + buffer_type + ">(" + name + "_buffer)"]
|
||||
b = [name + "_offset"]
|
||||
c = [name + "_" + self.postfix(name)] if (name not in self.buffers_without_ld_inc()) else []
|
||||
return [", ".join(a + b + c)]
|
||||
return []
|
||||
|
||||
def buffer_wrapper_clblas(self, name):
|
||||
"""As above but with a static cast for clBLAS wrapper"""
|
||||
if name in self.inputs or name in self.outputs:
|
||||
a = [name + "_buffer()"]
|
||||
b = [name + "_offset"]
|
||||
c = []
|
||||
if name in ["x", "y"]:
|
||||
c = ["static_cast<int>(" + name + "_" + self.postfix(name) + ")"]
|
||||
elif name in ["a", "b", "c"]:
|
||||
c = [name + "_" + self.postfix(name)]
|
||||
return [", ".join(a + b + c)]
|
||||
return []
|
||||
|
||||
def buffer_wrapper_cblas(self, name, flavour):
|
||||
"""As above but with a static cast for CBLAS wrapper"""
|
||||
prefix = "const " if name in self.inputs else ""
|
||||
if name in self.inputs or name in self.outputs:
|
||||
if name == "sy1":
|
||||
a = [name + "_buffer[" + name + "_offset]"]
|
||||
elif flavour.precision_name in ["C", "Z"]:
|
||||
a = ["reinterpret_cast<" + prefix + flavour.buffer_type[:-1] + "*>" +
|
||||
"(&" + name + "_buffer[" + name + "_offset])"]
|
||||
else:
|
||||
a = ["&" + name + "_buffer[" + name + "_offset]"]
|
||||
c = []
|
||||
if name in ["x", "y"]:
|
||||
c = ["static_cast<int>(" + name + "_" + self.postfix(name) + ")"]
|
||||
elif name in ["a", "b", "c"]:
|
||||
c = [name + "_" + self.postfix(name)]
|
||||
return [", ".join(a + c)]
|
||||
return []
|
||||
|
||||
def buffer_type(self, name):
|
||||
"""As above, but only data-types"""
|
||||
prefix = "const " if (name in self.inputs) else ""
|
||||
if (name in self.inputs) or (name in self.outputs):
|
||||
a = [prefix + "cl_mem"]
|
||||
b = ["const size_t"]
|
||||
c = ["const size_t"] if (name not in self.buffers_without_ld_inc()) else []
|
||||
return [", ".join(a + b + c)]
|
||||
return []
|
||||
|
||||
def buffer_doc(self, name):
|
||||
"""Retrieves the documentation of the buffers"""
|
||||
prefix = "const " if (name in self.inputs) else ""
|
||||
inout = "input" if (name in self.inputs) else "output"
|
||||
if (name in self.inputs) or (name in self.outputs):
|
||||
math_name = name.upper() + " matrix" if (name in self.buffers_matrix()) else name + " vector"
|
||||
inc_ld_description = "Leading dimension " if (name in self.buffers_matrix()) else "Stride/increment "
|
||||
a = ["`" + prefix + "cl_mem " + name + "_buffer`: OpenCL buffer to store the " + inout + " " + math_name + "."]
|
||||
b = ["`const size_t " + name + "_offset`: The offset in elements from the start of the " + inout + " " + math_name + "."]
|
||||
if name not in self.buffers_without_ld_inc():
|
||||
c = ["`const size_t " + name + "_" + self.postfix(name) + "`: " +
|
||||
inc_ld_description + "of the " + inout + " " + math_name + ". This value must be greater than 0."]
|
||||
else:
|
||||
c = []
|
||||
return a + b + c
|
||||
return []
|
||||
|
||||
def scalar(self, name):
|
||||
"""Retrieves the name of a scalar (alpha/beta)"""
|
||||
if name in self.scalars:
|
||||
return [name]
|
||||
return []
|
||||
|
||||
def scalar_half_to_float(self, name):
|
||||
"""As above, but converts from float to half"""
|
||||
if name in self.scalars:
|
||||
return ["HalfToFloat(" + name + ")"]
|
||||
return []
|
||||
|
||||
def scalar_use(self, name, flavour):
|
||||
"""Retrieves the use of a scalar (alpha/beta)"""
|
||||
if name in self.scalars:
|
||||
if name == "alpha":
|
||||
return [flavour.use_alpha()]
|
||||
elif name == "beta":
|
||||
return [flavour.use_beta()]
|
||||
return [name]
|
||||
return []
|
||||
|
||||
def scalar_use_wrapper(self, name, flavour):
|
||||
"""As above, but for the clBLAS wrapper"""
|
||||
if name in self.scalars:
|
||||
if name == "alpha":
|
||||
return [flavour.use_alpha_opencl()]
|
||||
elif name == "beta":
|
||||
return [flavour.use_beta_opencl()]
|
||||
return [name]
|
||||
return []
|
||||
|
||||
def scalar_use_wrapper_cblas(self, name, flavour):
|
||||
"""As above, but for the CBLAS wrapper"""
|
||||
if name in self.scalars:
|
||||
if flavour.is_complex(name):
|
||||
return [name + "_array.data()"]
|
||||
return [name]
|
||||
return []
|
||||
|
||||
def scalar_def(self, name, flavour):
|
||||
"""Retrieves the definition of a scalar (alpha/beta)"""
|
||||
if name in self.scalars:
|
||||
if name == "alpha":
|
||||
return ["const " + flavour.alpha_cl + " " + name]
|
||||
return ["const " + flavour.beta_cl + " " + name]
|
||||
return []
|
||||
|
||||
def scalar_def_plain(self, name, flavour):
|
||||
"""As above, but without 'cl_' prefix"""
|
||||
if name in self.scalars:
|
||||
if name == "alpha":
|
||||
return ["const " + flavour.alpha_cpp + " " + name]
|
||||
return ["const " + flavour.beta_cpp + " " + name]
|
||||
return []
|
||||
|
||||
def scalar_type(self, name, flavour):
|
||||
"""Retrieves the type of a scalar (alpha/beta)"""
|
||||
if name in self.scalars:
|
||||
if name == "alpha":
|
||||
return ["const " + flavour.alpha_cpp]
|
||||
return ["const " + flavour.beta_cpp]
|
||||
return []
|
||||
|
||||
def scalar_doc(self, name):
|
||||
"""Retrieves the documentation of a scalar"""
|
||||
if name in self.scalars:
|
||||
if name == "alpha":
|
||||
return ["`const " + self.template.alpha_cpp + " " + name + "`: Input scalar constant."]
|
||||
return ["`const " + self.template.beta_cpp + " " + name + "`: Input scalar constant."]
|
||||
return []
|
||||
|
||||
def sizes_list(self):
|
||||
"""Retrieves a list of comma-separated sizes (m, n, k)"""
|
||||
if self.sizes:
|
||||
return [", ".join([s for s in self.sizes])]
|
||||
return []
|
||||
|
||||
def sizes_def(self):
|
||||
"""Retrieves the definition of the sizes (m,n,k)"""
|
||||
if self.sizes:
|
||||
return [", ".join(["const size_t " + s for s in self.sizes])]
|
||||
return []
|
||||
|
||||
def sizes_type(self):
|
||||
"""Retrieves the types of the sizes (m,n,k)"""
|
||||
if self.sizes:
|
||||
return [", ".join(["const size_t" for s in self.sizes])]
|
||||
return []
|
||||
|
||||
def sizes_doc(self):
|
||||
"""# Retrieves the documentation of the sizes"""
|
||||
if self.sizes:
|
||||
definitions = ["`const size_t " + s + "`: Integer size argument. This value must be positive." for s in self.sizes]
|
||||
return definitions
|
||||
return []
|
||||
|
||||
def options_list(self):
|
||||
"""Retrieves a list of options"""
|
||||
if self.options:
|
||||
return [", ".join(self.options)]
|
||||
return []
|
||||
|
||||
def options_cast(self, indent):
|
||||
"""As above, but now casted to CLBlast data-types"""
|
||||
if self.options:
|
||||
options = ["static_cast<clblast::" + convert.option_to_clblast(o) + ">(" + o + ")" for o in self.options]
|
||||
return [(",\n" + indent).join(options)]
|
||||
return []
|
||||
|
||||
def options_def(self):
|
||||
"""Retrieves the definitions of the options (layout, transpose, side, etc.)"""
|
||||
if self.options:
|
||||
definitions = ["const " + convert.option_to_clblast(o) + " " + o for o in self.options]
|
||||
return [", ".join(definitions)]
|
||||
return []
|
||||
|
||||
def options_def_wrapper_clblas(self):
|
||||
"""As above, but now using clBLAS data-types"""
|
||||
if self.options:
|
||||
definitions = ["const " + convert.option_to_clblas(o) + " " + o for o in self.options]
|
||||
return [", ".join(definitions)]
|
||||
return []
|
||||
|
||||
def options_def_wrapper_cblas(self):
|
||||
"""As above, but now using CBLAS data-types"""
|
||||
if self.options:
|
||||
definitions = ["const " + convert.option_to_cblas(o) + " " + o for o in self.options]
|
||||
return [", ".join(definitions)]
|
||||
return []
|
||||
|
||||
def options_type(self):
|
||||
"""Retrieves the types of the options (layout, transpose, side, etc.)"""
|
||||
if self.options:
|
||||
definitions = ["const " + convert.option_to_clblast(o) for o in self.options]
|
||||
return [", ".join(definitions)]
|
||||
return []
|
||||
|
||||
def options_doc(self):
|
||||
"""Retrieves the documentation of the options"""
|
||||
if self.options:
|
||||
definitions = ["`const " + convert.option_to_clblast(o) + " " + o + "`: " + convert.option_to_documentation(o) for o in self.options]
|
||||
return definitions
|
||||
return []
|
||||
|
||||
def arguments(self):
|
||||
"""Retrieves a combination of all the argument names (no types)"""
|
||||
return (self.options_list() + self.sizes_list() +
|
||||
list(chain(*[self.buffer(b) for b in self.scalar_buffers_first()])) +
|
||||
self.scalar("alpha") +
|
||||
list(chain(*[self.buffer(b) for b in self.buffers_first()])) +
|
||||
self.scalar("beta") +
|
||||
list(chain(*[self.buffer(b) for b in self.buffers_second()])) +
|
||||
list(chain(*[self.buffer(b) for b in self.scalar_buffers_second()])) +
|
||||
list(chain(*[self.scalar(s) for s in self.other_scalars()])))
|
||||
|
||||
def arguments_half(self):
|
||||
"""As above, but with conversions from half to float"""
|
||||
return (self.options_list() + self.sizes_list() +
|
||||
list(chain(*[self.buffer_bis(b) for b in self.scalar_buffers_first()])) +
|
||||
self.scalar_half_to_float("alpha") +
|
||||
list(chain(*[self.buffer_bis(b) for b in self.buffers_first()])) +
|
||||
self.scalar_half_to_float("beta") +
|
||||
list(chain(*[self.buffer_bis(b) for b in self.buffers_second()])) +
|
||||
list(chain(*[self.buffer_bis(b) for b in self.scalar_buffers_second()])) +
|
||||
list(chain(*[self.scalar(s) for s in self.other_scalars()])))
|
||||
|
||||
def arguments_clcudaapi(self):
|
||||
"""Retrieves a combination of all the argument names, with CLCudaAPI casts"""
|
||||
return (self.options_list() + self.sizes_list() +
|
||||
list(chain(*[self.buffer_clcudaapi(b) for b in self.scalar_buffers_first()])) +
|
||||
self.scalar("alpha") +
|
||||
list(chain(*[self.buffer_clcudaapi(b) for b in self.buffers_first()])) +
|
||||
self.scalar("beta") +
|
||||
list(chain(*[self.buffer_clcudaapi(b) for b in self.buffers_second()])) +
|
||||
list(chain(*[self.buffer_clcudaapi(b) for b in self.scalar_buffers_second()])) +
|
||||
list(chain(*[self.scalar(s) for s in self.other_scalars()])))
|
||||
|
||||
def arguments_cast(self, flavour, indent):
|
||||
"""As above, but with CLBlast casts"""
|
||||
return (self.options_cast(indent) + self.sizes_list() +
|
||||
list(chain(*[self.buffer(b) for b in self.scalar_buffers_first()])) +
|
||||
self.scalar_use("alpha", flavour) +
|
||||
list(chain(*[self.buffer(b) for b in self.buffers_first()])) +
|
||||
self.scalar_use("beta", flavour) +
|
||||
list(chain(*[self.buffer(b) for b in self.buffers_second()])) +
|
||||
list(chain(*[self.buffer(b) for b in self.scalar_buffers_second()])) +
|
||||
list(chain(*[self.scalar_use(s, flavour) for s in self.other_scalars()])))
|
||||
|
||||
def arguments_wrapper_clblas(self, flavour):
|
||||
"""As above, but for the clBLAS wrapper"""
|
||||
return (self.options_list() + self.sizes_list() +
|
||||
list(chain(*[self.buffer_wrapper_clblas(b) for b in self.scalar_buffers_first()])) +
|
||||
self.scalar_use_wrapper("alpha", flavour) +
|
||||
list(chain(*[self.buffer_wrapper_clblas(b) for b in self.buffers_first()])) +
|
||||
self.scalar_use_wrapper("beta", flavour) +
|
||||
list(chain(*[self.buffer_wrapper_clblas(b) for b in self.buffers_second()])) +
|
||||
list(chain(*[self.buffer_wrapper_clblas(b) for b in self.scalar_buffers_second()])) +
|
||||
list(chain(*[self.scalar_use_wrapper(s, flavour) for s in self.other_scalars()])))
|
||||
|
||||
def arguments_wrapper_cblas(self, flavour):
|
||||
"""As above, but for the CBLAS wrapper"""
|
||||
return (self.options_list() + self.sizes_list() +
|
||||
self.scalar_use_wrapper_cblas("alpha", flavour) +
|
||||
list(chain(*[self.buffer_wrapper_cblas(b, flavour) for b in self.buffers_first()])) +
|
||||
self.scalar_use_wrapper_cblas("beta", flavour) +
|
||||
list(chain(*[self.buffer_wrapper_cblas(b, flavour) for b in self.buffers_second()])) +
|
||||
list(chain(*[self.buffer_wrapper_cblas(b, flavour) for b in self.scalar_buffers_second()])) +
|
||||
list(chain(*[self.scalar_use_wrapper_cblas(s, flavour) for s in self.other_scalars()])))
|
||||
|
||||
def arguments_def(self, flavour):
|
||||
"""Retrieves a combination of all the argument definitions"""
|
||||
return (self.options_def() + self.sizes_def() +
|
||||
list(chain(*[self.buffer_def(b) for b in self.scalar_buffers_first()])) +
|
||||
self.scalar_def("alpha", flavour) +
|
||||
list(chain(*[self.buffer_def(b) for b in self.buffers_first()])) +
|
||||
self.scalar_def("beta", flavour) +
|
||||
list(chain(*[self.buffer_def(b) for b in self.buffers_second()])) +
|
||||
list(chain(*[self.buffer_def(b) for b in self.scalar_buffers_second()])) +
|
||||
list(chain(*[self.scalar_def(s, flavour) for s in self.other_scalars()])))
|
||||
|
||||
def arguments_def_wrapper_clblas(self, flavour):
|
||||
"""As above, but clBLAS wrapper plain data-types"""
|
||||
return (self.options_def_wrapper_clblas() + self.sizes_def() +
|
||||
list(chain(*[self.buffer_def_wrapper_cl(b, flavour) for b in self.scalar_buffers_first()])) +
|
||||
self.scalar_def_plain("alpha", flavour) +
|
||||
list(chain(*[self.buffer_def_wrapper_cl(b, flavour) for b in self.buffers_first()])) +
|
||||
self.scalar_def_plain("beta", flavour) +
|
||||
list(chain(*[self.buffer_def_wrapper_cl(b, flavour) for b in self.buffers_second()])) +
|
||||
list(chain(*[self.buffer_def_wrapper_cl(b, flavour) for b in self.scalar_buffers_second()])) +
|
||||
list(chain(*[self.scalar_def_plain(s, flavour) for s in self.other_scalars()])))
|
||||
|
||||
def arguments_def_wrapper_cblas(self, flavour):
|
||||
"""As above, but CBLAS wrapper plain data-types"""
|
||||
return (self.options_def_wrapper_cblas() + self.sizes_def() +
|
||||
list(chain(*[self.buffer_def_vector(b, flavour) for b in self.scalar_buffers_first()])) +
|
||||
self.scalar_def_plain("alpha", flavour) +
|
||||
list(chain(*[self.buffer_def_vector(b, flavour) for b in self.buffers_first()])) +
|
||||
self.scalar_def_plain("beta", flavour) +
|
||||
list(chain(*[self.buffer_def_vector(b, flavour) for b in self.buffers_second()])) +
|
||||
list(chain(*[self.buffer_def_vector(b, flavour) for b in self.scalar_buffers_second()])) +
|
||||
list(chain(*[self.scalar_def_plain(s, flavour) for s in self.other_scalars()])))
|
||||
|
||||
def arguments_type(self, flavour):
|
||||
"""Retrieves a combination of all the argument types"""
|
||||
return (self.options_type() + self.sizes_type() +
|
||||
list(chain(*[self.buffer_type(b) for b in self.scalar_buffers_first()])) +
|
||||
self.scalar_type("alpha", flavour) +
|
||||
list(chain(*[self.buffer_type(b) for b in self.buffers_first()])) +
|
||||
self.scalar_type("beta", flavour) +
|
||||
list(chain(*[self.buffer_type(b) for b in self.buffers_second()])) +
|
||||
list(chain(*[self.buffer_type(b) for b in self.scalar_buffers_second()])) +
|
||||
list(chain(*[self.scalar_type(s, flavour) for s in self.other_scalars()])))
|
||||
|
||||
def arguments_doc(self):
|
||||
"""Retrieves a combination of all the argument types"""
|
||||
return (self.options_doc() + self.sizes_doc() +
|
||||
list(chain(*[self.buffer_doc(b) for b in self.scalar_buffers_first()])) +
|
||||
list(chain(*[self.buffer_doc(b) for b in self.scalar_buffers_first()])) +
|
||||
self.scalar_doc("alpha") +
|
||||
list(chain(*[self.buffer_doc(b) for b in self.buffers_first()])) +
|
||||
self.scalar_doc("beta") +
|
||||
list(chain(*[self.buffer_doc(b) for b in self.buffers_second()])) +
|
||||
list(chain(*[self.buffer_doc(b) for b in self.scalar_buffers_second()])) +
|
||||
list(chain(*[self.scalar_doc(s) for s in self.other_scalars()])))
|
||||
|
||||
def requirements_doc(self):
|
||||
"""Retrieves a list of routine requirements for documentation"""
|
||||
return self.requirements
|
||||
|
||||
def routine_header_cpp(self, spaces, default_event):
|
||||
"""Retrieves the C++ templated definition for a routine"""
|
||||
indent = " " * (spaces + self.length())
|
||||
result = "template <" + self.template.name + ">\n"
|
||||
result += "StatusCode " + self.name.capitalize() + "("
|
||||
result += (",\n" + indent).join([a for a in self.arguments_def(self.template)])
|
||||
result += ",\n" + indent + "cl_command_queue* queue, cl_event* event" + default_event + ")"
|
||||
return result
|
||||
|
||||
def routine_header_type_cpp(self, spaces):
|
||||
"""As above, but now without variable names"""
|
||||
indent = " " * (spaces + self.length())
|
||||
result = "template <" + self.template.name + ">\n"
|
||||
result += "StatusCode " + self.name.capitalize() + "("
|
||||
result += (",\n" + indent).join([a for a in self.arguments_type(self.template)])
|
||||
result += ",\n" + indent + "cl_command_queue*, cl_event*)"
|
||||
return result
|
||||
|
||||
def routine_header_c(self, flavour, spaces, extra_qualifier):
|
||||
"""As above, but now for C"""
|
||||
indent = " " * (spaces + self.length())
|
||||
result = "StatusCode" + extra_qualifier + " CLBlast" + flavour.name + self.name + "("
|
||||
result += (",\n" + indent).join([a for a in self.arguments_def(flavour)])
|
||||
result += ",\n" + indent + "cl_command_queue* queue, cl_event* event)"
|
||||
return result
|
||||
|
||||
def routine_header_wrapper_clblas(self, flavour, def_only, spaces):
|
||||
"""As above, but now for the clBLAS wrapper"""
|
||||
template = "<" + flavour.template + ">" if self.no_scalars() and not def_only else ""
|
||||
indent = " " * (spaces + self.length() + len(template))
|
||||
result = ""
|
||||
if self.no_scalars():
|
||||
result += "template <"
|
||||
if def_only:
|
||||
result += flavour.name
|
||||
result += ">\n"
|
||||
result += "clblasStatus clblasX" + self.name + template + "("
|
||||
result += (",\n" + indent).join([a for a in self.arguments_def_wrapper_clblas(flavour)])
|
||||
result += ",\n" + indent + "cl_uint num_queues, cl_command_queue *queues"
|
||||
result += ",\n" + indent + "cl_uint num_wait_events, const cl_event *wait_events, cl_event *events)"
|
||||
return result
|
||||
|
||||
def routine_header_wrapper_cblas(self, flavour, spaces):
|
||||
"""As above, but now for the CBLAS wrapper"""
|
||||
indent = " " * (spaces + self.length())
|
||||
result = "void cblasX" + self.name + "("
|
||||
result += (",\n" + indent).join([a for a in self.arguments_def_wrapper_cblas(flavour)]) + ")"
|
||||
return result
|
|
@ -1,603 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# ==================================================================================================
|
||||
# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
|
||||
# project loosely follows the Google C++ styleguide and uses a max-width of 100 characters per line.
|
||||
#
|
||||
# Author(s):
|
||||
# Cedric Nugteren <www.cedricnugteren.nl>
|
||||
#
|
||||
# This file contains the 'Routine' class, used in the generator script to generate the CLBlast API
|
||||
# interface and implementation.
|
||||
#
|
||||
# ==================================================================================================
|
||||
|
||||
# System modules
|
||||
from itertools import chain
|
||||
|
||||
# Translates an option name to a CLBlast data-type
|
||||
def OptionToCLBlast(x):
|
||||
return {
|
||||
'layout': "Layout",
|
||||
'a_transpose': "Transpose",
|
||||
'b_transpose': "Transpose",
|
||||
'ab_transpose': "Transpose",
|
||||
'side': "Side",
|
||||
'triangle': "Triangle",
|
||||
'diagonal': "Diagonal",
|
||||
}[x]
|
||||
|
||||
# As above, but for clBLAS data-types
|
||||
def OptionToWrapperCL(x):
|
||||
return {
|
||||
'layout': "clblasOrder",
|
||||
'a_transpose': "clblasTranspose",
|
||||
'b_transpose': "clblasTranspose",
|
||||
'ab_transpose': "clblasTranspose",
|
||||
'side': "clblasSide",
|
||||
'triangle': "clblasUplo",
|
||||
'diagonal': "clblasDiag",
|
||||
}[x]
|
||||
|
||||
# As above, but for CBLAS data-types
|
||||
def OptionToWrapperC(x):
|
||||
return {
|
||||
'layout': "CBLAS_ORDER",
|
||||
'a_transpose': "CBLAS_TRANSPOSE",
|
||||
'b_transpose': "CBLAS_TRANSPOSE",
|
||||
'ab_transpose': "CBLAS_TRANSPOSE",
|
||||
'side': "CBLAS_SIDE",
|
||||
'triangle': "CBLAS_UPLO",
|
||||
'diagonal': "CBLAS_DIAG",
|
||||
}[x]
|
||||
|
||||
# Translates an option name to a documentation string
|
||||
def OptionToDoc(x):
|
||||
return {
|
||||
'layout': "Data-layout of the matrices, either `Layout::kRowMajor` (101) for row-major layout or `Layout::kColMajor` (102) for column-major data-layout.",
|
||||
'a_transpose': "Transposing the input matrix A, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.",
|
||||
'b_transpose': "Transposing the input matrix B, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.",
|
||||
'ab_transpose': "Transposing the packed input matrix AP, either `Transpose::kNo` (111), `Transpose::kYes` (112), or `Transpose::kConjugate` (113) for a complex-conjugate transpose.",
|
||||
'side': "The position of the triangular matrix in the operation, either on the `Side::kLeft` (141) or `Side::kRight` (142).",
|
||||
'triangle': "The part of the array of the triangular matrix to be used, either `Triangle::kUpper` (121) or `Triangle::kLower` (122).",
|
||||
'diagonal': "The property of the diagonal matrix, either `Diagonal::kNonUnit` (131) for non-unit values on the diagonal or `Diagonal::kUnit` (132) for unit values on the diagonal.",
|
||||
}[x]
|
||||
|
||||
# ==================================================================================================
|
||||
|
||||
# Class holding routine-specific information (e.g. name, which arguments, which precisions)
|
||||
class Routine():
|
||||
def __init__(self, implemented, has_tests, level, name, template, flavours, sizes, options,
|
||||
inputs, outputs, scalars, scratch, description, details, requirements):
|
||||
self.implemented = implemented
|
||||
self.has_tests = has_tests
|
||||
self.level = level
|
||||
self.name = name
|
||||
self.template = template
|
||||
self.flavours = flavours
|
||||
self.sizes = sizes
|
||||
self.options = options
|
||||
self.inputs = inputs
|
||||
self.outputs = outputs
|
||||
self.scalars = scalars
|
||||
self.scratch = scratch # Scratch buffer (e.g. for xDOT)
|
||||
self.description = description
|
||||
self.details = details
|
||||
self.requirements = requirements
|
||||
|
||||
# List of scalar buffers
|
||||
def ScalarBuffersFirst(self):
|
||||
return ["dot","nrm2","asum","sum","imax","imin"]
|
||||
def ScalarBuffersSecond(self):
|
||||
return ["sa","sb","sc","ss","sd1","sd2","sx1","sy1","sparam"]
|
||||
|
||||
# List of scalars other than alpha and beta
|
||||
def OtherScalars(self):
|
||||
return ["cos","sin"]
|
||||
|
||||
# List of buffers with unsigned int type
|
||||
def IndexBuffers(self):
|
||||
return ["imax","imin"]
|
||||
|
||||
# Lists of input/output buffers not index (integer)
|
||||
def NonIndexInputs(self):
|
||||
buffers = self.inputs[:] # make a copy
|
||||
for i in self.IndexBuffers():
|
||||
if i in buffers: buffers.remove(i)
|
||||
return buffers
|
||||
def NonIndexOutputs(self):
|
||||
buffers = self.outputs[:] # make a copy
|
||||
for i in self.IndexBuffers():
|
||||
if i in buffers: buffers.remove(i)
|
||||
return buffers
|
||||
|
||||
# List of buffers without 'inc' or 'ld'
|
||||
def BuffersWithoutLdInc(self):
|
||||
return self.ScalarBuffersFirst() + self.ScalarBuffersSecond() + ["ap"]
|
||||
|
||||
# Retrieves the number of characters in the routine's name
|
||||
def Length(self):
|
||||
return len(self.name)
|
||||
|
||||
# Retrieves the postfix for a buffer
|
||||
def Postfix(self, name):
|
||||
return "inc" if (name in ["x","y"]) else "ld"
|
||||
|
||||
# Determines whether or not this routine has scalar arguments (alpha/beta)
|
||||
def NoScalars(self):
|
||||
return self.scalars == []
|
||||
|
||||
# Returns the upper-case names of these routines (all flavours)
|
||||
def ShortNames(self):
|
||||
return "/".join([f.name+self.name.upper() for f in self.flavours])
|
||||
|
||||
# As above, but excludes some
|
||||
def ShortNamesTested(self):
|
||||
names = [f.name+self.name.upper() for f in self.flavours]
|
||||
if "H"+self.name.upper() in names: names.remove("H"+self.name.upper())
|
||||
return "/".join(names)
|
||||
|
||||
# Determines which buffers go first (between alpha and beta) and which ones go after
|
||||
def BuffersFirst(self):
|
||||
if self.level == "2b":
|
||||
return ["x","y"]
|
||||
return ["ap","a","b","x"]
|
||||
def BuffersSecond(self):
|
||||
if self.level == "2b":
|
||||
return ["ap","a","b","c"]
|
||||
return ["y","c"]
|
||||
|
||||
# Distinguish between vectors and matrices
|
||||
def BuffersVector(self):
|
||||
return ["x","y"]
|
||||
def BuffersMatrix(self):
|
||||
return ["a","b","c","ap"]
|
||||
|
||||
# ==============================================================================================
|
||||
|
||||
# Retrieves a variable name for a specific input/output vector/matrix (e.g. 'x')
|
||||
def Buffer(self, name):
|
||||
if (name in self.inputs) or (name in self.outputs):
|
||||
a = [name+"_buffer"]
|
||||
b = [name+"_offset"]
|
||||
c = [name+"_"+self.Postfix(name)] if (name not in self.BuffersWithoutLdInc()) else []
|
||||
return [", ".join(a+b+c)]
|
||||
return []
|
||||
|
||||
# As above but with a '_bis' suffix for the buffer name
|
||||
def BufferBis(self, name):
|
||||
#if (name in self.IndexBuffers()):
|
||||
# return self.Buffer(name)
|
||||
if (name in self.inputs) or (name in self.outputs):
|
||||
a = [name+"_buffer_bis"]
|
||||
b = [name+"_offset"]
|
||||
c = [name+"_"+self.Postfix(name)] if (name not in self.BuffersWithoutLdInc()) else []
|
||||
return [", ".join(a+b+c)]
|
||||
return []
|
||||
|
||||
# As above but with data-types
|
||||
def BufferDef(self, name):
|
||||
prefix = "const " if (name in self.inputs) else ""
|
||||
if (name in self.inputs) or (name in self.outputs):
|
||||
a = [prefix+"cl_mem "+name+"_buffer"]
|
||||
b = ["const size_t "+name+"_offset"]
|
||||
c = ["const size_t "+name+"_"+self.Postfix(name)] if (name not in self.BuffersWithoutLdInc()) else []
|
||||
return [", ".join(a+b+c)]
|
||||
return []
|
||||
|
||||
# As above but with data-types
|
||||
def BufferDefWrapperCL(self, name, flavour):
|
||||
prefix = "const " if (name in self.inputs) else ""
|
||||
if (name in self.inputs) or (name in self.outputs):
|
||||
a = [prefix+"Buffer<"+flavour.buffertype+">& "+name+"_buffer"]
|
||||
b = ["const size_t "+name+"_offset"]
|
||||
c = ["const size_t "+name+"_"+self.Postfix(name)] if (name not in self.BuffersWithoutLdInc()) else []
|
||||
return [", ".join(a+b+c)]
|
||||
return []
|
||||
|
||||
# As above but as vectors
|
||||
def BufferDefVector(self, name, flavour):
|
||||
prefix = "const " if (name in self.inputs) else ""
|
||||
if (name in self.inputs) or (name in self.outputs):
|
||||
a = [prefix+"std::vector<"+flavour.buffertype+">& "+name+"_buffer"]
|
||||
b = ["const size_t "+name+"_offset"]
|
||||
c = ["const size_t "+name+"_"+self.Postfix(name)] if (name not in self.BuffersWithoutLdInc()) else []
|
||||
return [", ".join(a+b+c)]
|
||||
return []
|
||||
|
||||
# As above but with Claduc buffers
|
||||
def BufferCladuc(self, name):
|
||||
if (name in self.inputs) or (name in self.outputs):
|
||||
buffertype = "unsigned int" if (name in self.IndexBuffers()) else self.template.buffertype
|
||||
a = ["Buffer<"+buffertype+">("+name+"_buffer)"]
|
||||
b = [name+"_offset"]
|
||||
c = [name+"_"+self.Postfix(name)] if (name not in self.BuffersWithoutLdInc()) else []
|
||||
return [", ".join(a+b+c)]
|
||||
return []
|
||||
|
||||
# As above but with a static cast for clBLAS wrapper
|
||||
def BufferWrapperCL(self, name):
|
||||
if (name in self.inputs) or (name in self.outputs):
|
||||
a = [name+"_buffer()"]
|
||||
b = [name+"_offset"]
|
||||
c = []
|
||||
if (name in ["x","y"]):
|
||||
c = ["static_cast<int>("+name+"_"+self.Postfix(name)+")"]
|
||||
elif (name in ["a","b","c"]):
|
||||
c = [name+"_"+self.Postfix(name)]
|
||||
return [", ".join(a+b+c)]
|
||||
return []
|
||||
|
||||
# As above but with a static cast for CBLAS wrapper
|
||||
def BufferWrapperC(self, name, flavour):
|
||||
prefix = "const " if (name in self.inputs) else ""
|
||||
if (name in self.inputs) or (name in self.outputs):
|
||||
if name == "sy1":
|
||||
a = [name+"_buffer["+name+"_offset]"]
|
||||
elif flavour.precision_name in ["C","Z"]:
|
||||
a = ["reinterpret_cast<"+prefix+flavour.buffertype[:-1]+"*>(&"+name+"_buffer["+name+"_offset])"]
|
||||
else:
|
||||
a = ["&"+name+"_buffer["+name+"_offset]"]
|
||||
c = []
|
||||
if (name in ["x","y"]):
|
||||
c = ["static_cast<int>("+name+"_"+self.Postfix(name)+")"]
|
||||
elif (name in ["a","b","c"]):
|
||||
c = [name+"_"+self.Postfix(name)]
|
||||
return [", ".join(a+c)]
|
||||
return []
|
||||
|
||||
# As above, but only data-types
|
||||
def BufferType(self, name):
|
||||
prefix = "const " if (name in self.inputs) else ""
|
||||
if (name in self.inputs) or (name in self.outputs):
|
||||
a = [prefix+"cl_mem"]
|
||||
b = ["const size_t"]
|
||||
c = ["const size_t"] if (name not in self.BuffersWithoutLdInc()) else []
|
||||
return [", ".join(a+b+c)]
|
||||
return []
|
||||
|
||||
# Retrieves the documentation of the buffers
|
||||
def BufferDoc(self, name):
|
||||
prefix = "const " if (name in self.inputs) else ""
|
||||
inout = "input" if (name in self.inputs) else "output"
|
||||
if (name in self.inputs) or (name in self.outputs):
|
||||
math_name = name.upper()+" matrix" if (name in self.BuffersMatrix()) else name+" vector"
|
||||
incld_description = "Leading dimension " if (name in self.BuffersMatrix()) else "Stride/increment "
|
||||
a = ["`"+prefix+"cl_mem "+name+"_buffer`: OpenCL buffer to store the "+inout+" "+math_name+"."]
|
||||
b = ["`const size_t "+name+"_offset`: The offset in elements from the start of the "+inout+" "+math_name+"."]
|
||||
c = ["`const size_t "+name+"_"+self.Postfix(name)+"`: "+incld_description+"of the "+inout+" "+math_name+". This value must be greater than 0."] if (name not in self.BuffersWithoutLdInc()) else []
|
||||
return a+b+c
|
||||
return []
|
||||
|
||||
# ==============================================================================================
|
||||
|
||||
# Retrieves the name of a scalar (alpha/beta)
|
||||
def Scalar(self, name):
|
||||
if (name in self.scalars):
|
||||
return [name]
|
||||
return []
|
||||
|
||||
# As above, but converts from float to half
|
||||
def ScalarHalfToFloat(self, name):
|
||||
if name in self.scalars:
|
||||
return ["HalfToFloat("+name+")"]
|
||||
return []
|
||||
|
||||
# Retrieves the use of a scalar (alpha/beta)
|
||||
def ScalarUse(self, name, flavour):
|
||||
if name in self.scalars:
|
||||
if name == "alpha":
|
||||
return [flavour.UseAlpha()]
|
||||
elif name == "beta":
|
||||
return [flavour.UseBeta()]
|
||||
return [name]
|
||||
return []
|
||||
|
||||
# As above, but for the clBLAS wrapper
|
||||
def ScalarUseWrapper(self, name, flavour):
|
||||
if name in self.scalars:
|
||||
if name == "alpha":
|
||||
return [flavour.UseAlphaCL()]
|
||||
elif name == "beta":
|
||||
return [flavour.UseBetaCL()]
|
||||
return [name]
|
||||
return []
|
||||
|
||||
# As above, but for the CBLAS wrapper
|
||||
def ScalarUseWrapperC(self, name, flavour):
|
||||
if name in self.scalars:
|
||||
if flavour.IsComplex(name):
|
||||
return [name+"_array.data()"]
|
||||
return [name]
|
||||
return []
|
||||
|
||||
# Retrieves the definition of a scalar (alpha/beta)
|
||||
def ScalarDef(self, name, flavour):
|
||||
if name in self.scalars:
|
||||
if name == "alpha":
|
||||
return ["const "+flavour.alpha_cl+" "+name]
|
||||
return ["const "+flavour.beta_cl+" "+name]
|
||||
return []
|
||||
|
||||
# As above, but without 'cl_' prefix
|
||||
def ScalarDefPlain(self, name, flavour):
|
||||
if name in self.scalars:
|
||||
if name == "alpha":
|
||||
return ["const "+flavour.alpha_cpp+" "+name]
|
||||
return ["const "+flavour.beta_cpp+" "+name]
|
||||
return []
|
||||
|
||||
# Retrieves the type of a scalar (alpha/beta)
|
||||
def ScalarType(self, name, flavour):
|
||||
if name in self.scalars:
|
||||
if name == "alpha":
|
||||
return ["const "+flavour.alpha_cpp]
|
||||
return ["const "+flavour.beta_cpp]
|
||||
return []
|
||||
|
||||
# Retrieves the documentation of a scalar
|
||||
def ScalarDoc(self, name):
|
||||
if name in self.scalars:
|
||||
if name == "alpha":
|
||||
return ["`const "+self.template.alpha_cpp+" "+name+"`: Input scalar constant."]
|
||||
return ["`const "+self.template.beta_cpp+" "+name+"`: Input scalar constant."]
|
||||
return []
|
||||
|
||||
# ==============================================================================================
|
||||
|
||||
# Retrieves a list of comma-separated sizes (m, n, k)
|
||||
def Sizes(self):
|
||||
if self.sizes:
|
||||
return [", ".join([s for s in self.sizes])]
|
||||
return []
|
||||
|
||||
# Retrieves the definition of the sizes (m,n,k)
|
||||
def SizesDef(self):
|
||||
if self.sizes:
|
||||
return [", ".join(["const size_t "+s for s in self.sizes])]
|
||||
return []
|
||||
|
||||
# Retrieves the types of the sizes (m,n,k)
|
||||
def SizesType(self):
|
||||
if self.sizes:
|
||||
return [", ".join(["const size_t" for s in self.sizes])]
|
||||
return []
|
||||
|
||||
# Retrieves the documentation of the sizes
|
||||
def SizesDoc(self):
|
||||
if self.sizes:
|
||||
definitions = ["`const size_t "+s+"`: Integer size argument. This value must be positive." for s in self.sizes]
|
||||
return definitions
|
||||
return []
|
||||
|
||||
# ==============================================================================================
|
||||
|
||||
# Retrieves a list of options
|
||||
def Options(self):
|
||||
if self.options:
|
||||
return [", ".join(self.options)]
|
||||
return []
|
||||
|
||||
# As above, but now casted to CLBlast data-types
|
||||
def OptionsCast(self, indent):
|
||||
if self.options:
|
||||
options = ["static_cast<clblast::"+OptionToCLBlast(o)+">("+o+")" for o in self.options]
|
||||
return [(",\n"+indent).join(options)]
|
||||
return []
|
||||
|
||||
# Retrieves the definitions of the options (layout, transpose, side, etc.)
|
||||
def OptionsDef(self):
|
||||
if self.options:
|
||||
definitions = ["const "+OptionToCLBlast(o)+" "+o for o in self.options]
|
||||
return [", ".join(definitions)]
|
||||
return []
|
||||
|
||||
# As above, but now using clBLAS data-types
|
||||
def OptionsDefWrapperCL(self):
|
||||
if self.options:
|
||||
definitions = ["const "+OptionToWrapperCL(o)+" "+o for o in self.options]
|
||||
return [", ".join(definitions)]
|
||||
return []
|
||||
|
||||
# As above, but now using CBLAS data-types
|
||||
def OptionsDefWrapperC(self):
|
||||
if self.options:
|
||||
definitions = ["const "+OptionToWrapperC(o)+" "+o for o in self.options]
|
||||
return [", ".join(definitions)]
|
||||
return []
|
||||
|
||||
# Retrieves the types of the options (layout, transpose, side, etc.)
|
||||
def OptionsType(self):
|
||||
if self.options:
|
||||
definitions = ["const "+OptionToCLBlast(o) for o in self.options]
|
||||
return [", ".join(definitions)]
|
||||
return []
|
||||
|
||||
# Retrieves the documentation of the options
|
||||
def OptionsDoc(self):
|
||||
if self.options:
|
||||
definitions = ["`const "+OptionToCLBlast(o)+" "+o+"`: "+OptionToDoc(o) for o in self.options]
|
||||
return definitions
|
||||
return []
|
||||
|
||||
# ==============================================================================================
|
||||
|
||||
# Retrieves a combination of all the argument names (no types)
|
||||
def Arguments(self):
|
||||
return (self.Options() + self.Sizes() +
|
||||
list(chain(*[self.Buffer(b) for b in self.ScalarBuffersFirst()])) +
|
||||
self.Scalar("alpha") +
|
||||
list(chain(*[self.Buffer(b) for b in self.BuffersFirst()])) +
|
||||
self.Scalar("beta") +
|
||||
list(chain(*[self.Buffer(b) for b in self.BuffersSecond()])) +
|
||||
list(chain(*[self.Buffer(b) for b in self.ScalarBuffersSecond()])) +
|
||||
list(chain(*[self.Scalar(s) for s in self.OtherScalars()])))
|
||||
|
||||
# As above, but with conversions from half to float
|
||||
def ArgumentsHalf(self):
|
||||
return (self.Options() + self.Sizes() +
|
||||
list(chain(*[self.BufferBis(b) for b in self.ScalarBuffersFirst()])) +
|
||||
self.ScalarHalfToFloat("alpha") +
|
||||
list(chain(*[self.BufferBis(b) for b in self.BuffersFirst()])) +
|
||||
self.ScalarHalfToFloat("beta") +
|
||||
list(chain(*[self.BufferBis(b) for b in self.BuffersSecond()])) +
|
||||
list(chain(*[self.BufferBis(b) for b in self.ScalarBuffersSecond()])) +
|
||||
list(chain(*[self.Scalar(s) for s in self.OtherScalars()])))
|
||||
|
||||
# Retrieves a combination of all the argument names, with Claduc casts
|
||||
def ArgumentsCladuc(self, flavour, indent):
|
||||
return (self.Options() + self.Sizes() +
|
||||
list(chain(*[self.BufferCladuc(b) for b in self.ScalarBuffersFirst()])) +
|
||||
self.Scalar("alpha") +
|
||||
list(chain(*[self.BufferCladuc(b) for b in self.BuffersFirst()])) +
|
||||
self.Scalar("beta") +
|
||||
list(chain(*[self.BufferCladuc(b) for b in self.BuffersSecond()])) +
|
||||
list(chain(*[self.BufferCladuc(b) for b in self.ScalarBuffersSecond()])) +
|
||||
list(chain(*[self.Scalar(s) for s in self.OtherScalars()])))
|
||||
|
||||
# As above, but with CLBlast casts
|
||||
def ArgumentsCast(self, flavour, indent):
|
||||
return (self.OptionsCast(indent) + self.Sizes() +
|
||||
list(chain(*[self.Buffer(b) for b in self.ScalarBuffersFirst()])) +
|
||||
self.ScalarUse("alpha", flavour) +
|
||||
list(chain(*[self.Buffer(b) for b in self.BuffersFirst()])) +
|
||||
self.ScalarUse("beta", flavour) +
|
||||
list(chain(*[self.Buffer(b) for b in self.BuffersSecond()])) +
|
||||
list(chain(*[self.Buffer(b) for b in self.ScalarBuffersSecond()])) +
|
||||
list(chain(*[self.ScalarUse(s, flavour) for s in self.OtherScalars()])))
|
||||
|
||||
# As above, but for the clBLAS wrapper
|
||||
def ArgumentsWrapperCL(self, flavour):
|
||||
return (self.Options() + self.Sizes() +
|
||||
list(chain(*[self.BufferWrapperCL(b) for b in self.ScalarBuffersFirst()])) +
|
||||
self.ScalarUseWrapper("alpha", flavour) +
|
||||
list(chain(*[self.BufferWrapperCL(b) for b in self.BuffersFirst()])) +
|
||||
self.ScalarUseWrapper("beta", flavour) +
|
||||
list(chain(*[self.BufferWrapperCL(b) for b in self.BuffersSecond()])) +
|
||||
list(chain(*[self.BufferWrapperCL(b) for b in self.ScalarBuffersSecond()])) +
|
||||
list(chain(*[self.ScalarUseWrapper(s, flavour) for s in self.OtherScalars()])))
|
||||
|
||||
# As above, but for the CBLAS wrapper
|
||||
def ArgumentsWrapperC(self, flavour):
|
||||
return (self.Options() + self.Sizes() +
|
||||
self.ScalarUseWrapperC("alpha", flavour) +
|
||||
list(chain(*[self.BufferWrapperC(b, flavour) for b in self.BuffersFirst()])) +
|
||||
self.ScalarUseWrapperC("beta", flavour) +
|
||||
list(chain(*[self.BufferWrapperC(b, flavour) for b in self.BuffersSecond()])) +
|
||||
list(chain(*[self.BufferWrapperC(b, flavour) for b in self.ScalarBuffersSecond()])) +
|
||||
list(chain(*[self.ScalarUseWrapperC(s, flavour) for s in self.OtherScalars()])))
|
||||
|
||||
# Retrieves a combination of all the argument definitions
|
||||
def ArgumentsDef(self, flavour):
|
||||
return (self.OptionsDef() + self.SizesDef() +
|
||||
list(chain(*[self.BufferDef(b) for b in self.ScalarBuffersFirst()])) +
|
||||
self.ScalarDef("alpha", flavour) +
|
||||
list(chain(*[self.BufferDef(b) for b in self.BuffersFirst()])) +
|
||||
self.ScalarDef("beta", flavour) +
|
||||
list(chain(*[self.BufferDef(b) for b in self.BuffersSecond()])) +
|
||||
list(chain(*[self.BufferDef(b) for b in self.ScalarBuffersSecond()])) +
|
||||
list(chain(*[self.ScalarDef(s, flavour) for s in self.OtherScalars()])))
|
||||
|
||||
# As above, but clBLAS wrapper plain datatypes
|
||||
def ArgumentsDefWrapperCL(self, flavour):
|
||||
return (self.OptionsDefWrapperCL() + self.SizesDef() +
|
||||
list(chain(*[self.BufferDefWrapperCL(b, flavour) for b in self.ScalarBuffersFirst()])) +
|
||||
self.ScalarDefPlain("alpha", flavour) +
|
||||
list(chain(*[self.BufferDefWrapperCL(b, flavour) for b in self.BuffersFirst()])) +
|
||||
self.ScalarDefPlain("beta", flavour) +
|
||||
list(chain(*[self.BufferDefWrapperCL(b, flavour) for b in self.BuffersSecond()])) +
|
||||
list(chain(*[self.BufferDefWrapperCL(b, flavour) for b in self.ScalarBuffersSecond()])) +
|
||||
list(chain(*[self.ScalarDefPlain(s, flavour) for s in self.OtherScalars()])))
|
||||
|
||||
# As above, but CBLAS wrapper plain datatypes
|
||||
def ArgumentsDefWrapperC(self, flavour):
|
||||
return (self.OptionsDefWrapperC() + self.SizesDef() +
|
||||
list(chain(*[self.BufferDefVector(b, flavour) for b in self.ScalarBuffersFirst()])) +
|
||||
self.ScalarDefPlain("alpha", flavour) +
|
||||
list(chain(*[self.BufferDefVector(b, flavour) for b in self.BuffersFirst()])) +
|
||||
self.ScalarDefPlain("beta", flavour) +
|
||||
list(chain(*[self.BufferDefVector(b, flavour) for b in self.BuffersSecond()])) +
|
||||
list(chain(*[self.BufferDefVector(b, flavour) for b in self.ScalarBuffersSecond()])) +
|
||||
list(chain(*[self.ScalarDefPlain(s, flavour) for s in self.OtherScalars()])))
|
||||
|
||||
# Retrieves a combination of all the argument types
|
||||
def ArgumentsType(self, flavour):
|
||||
return (self.OptionsType() + self.SizesType() +
|
||||
list(chain(*[self.BufferType(b) for b in self.ScalarBuffersFirst()])) +
|
||||
self.ScalarType("alpha", flavour) +
|
||||
list(chain(*[self.BufferType(b) for b in self.BuffersFirst()])) +
|
||||
self.ScalarType("beta", flavour) +
|
||||
list(chain(*[self.BufferType(b) for b in self.BuffersSecond()])) +
|
||||
list(chain(*[self.BufferType(b) for b in self.ScalarBuffersSecond()])) +
|
||||
list(chain(*[self.ScalarType(s, flavour) for s in self.OtherScalars()])))
|
||||
|
||||
# Retrieves a combination of all the argument types
|
||||
def ArgumentsDoc(self):
|
||||
return (self.OptionsDoc() + self.SizesDoc() +
|
||||
list(chain(*[self.BufferDoc(b) for b in self.ScalarBuffersFirst()])) +
|
||||
list(chain(*[self.BufferDoc(b) for b in self.ScalarBuffersFirst()])) +
|
||||
self.ScalarDoc("alpha") +
|
||||
list(chain(*[self.BufferDoc(b) for b in self.BuffersFirst()])) +
|
||||
self.ScalarDoc("beta") +
|
||||
list(chain(*[self.BufferDoc(b) for b in self.BuffersSecond()])) +
|
||||
list(chain(*[self.BufferDoc(b) for b in self.ScalarBuffersSecond()])) +
|
||||
list(chain(*[self.ScalarDoc(s) for s in self.OtherScalars()])))
|
||||
|
||||
# ==============================================================================================
|
||||
|
||||
# Retrieves a list of routine requirements for documentation
|
||||
def RequirementsDoc(self):
|
||||
return self.requirements
|
||||
|
||||
# ==============================================================================================
|
||||
|
||||
# Retrieves the C++ templated definition for a routine
|
||||
def RoutineHeaderCPP(self, spaces, default_event):
|
||||
indent = " "*(spaces + self.Length())
|
||||
result = "template <"+self.template.name+">\n"
|
||||
result += "StatusCode "+self.name.capitalize()+"("
|
||||
result += (",\n"+indent).join([a for a in self.ArgumentsDef(self.template)])
|
||||
result += ",\n"+indent+"cl_command_queue* queue, cl_event* event"+default_event+")"
|
||||
return result
|
||||
|
||||
# As above, but now without variable names
|
||||
def RoutineHeaderTypeCPP(self, spaces):
|
||||
indent = " "*(spaces + self.Length())
|
||||
result = "template <"+self.template.name+">\n"
|
||||
result += "StatusCode "+self.name.capitalize()+"("
|
||||
result += (",\n"+indent).join([a for a in self.ArgumentsType(self.template)])
|
||||
result += ",\n"+indent+"cl_command_queue*, cl_event*)"
|
||||
return result
|
||||
|
||||
# As above, but now for C
|
||||
def RoutineHeaderC(self, flavour, spaces, extra_qualifier):
|
||||
indent = " "*(spaces + self.Length())
|
||||
result = "StatusCode"+extra_qualifier+" CLBlast"+flavour.name+self.name+"("
|
||||
result += (",\n"+indent).join([a for a in self.ArgumentsDef(flavour)])
|
||||
result += ",\n"+indent+"cl_command_queue* queue, cl_event* event)"
|
||||
return result
|
||||
|
||||
# As above, but now for the clBLAS wrapper
|
||||
def RoutineHeaderWrapperCL(self, flavour, def_only, spaces):
|
||||
template = "<"+flavour.template+">" if self.NoScalars() and not def_only else ""
|
||||
indent = " "*(spaces + self.Length() + len(template))
|
||||
result = ""
|
||||
if self.NoScalars():
|
||||
result += "template <"
|
||||
if def_only:
|
||||
result += flavour.name
|
||||
result += ">\n"
|
||||
result += "clblasStatus clblasX"+self.name+template+"("
|
||||
result += (",\n"+indent).join([a for a in self.ArgumentsDefWrapperCL(flavour)])
|
||||
result += ",\n"+indent+"cl_uint num_queues, cl_command_queue *queues"
|
||||
result += ",\n"+indent+"cl_uint num_wait_events, const cl_event *wait_events, cl_event *events)"
|
||||
return result
|
||||
|
||||
# As above, but now for the CBLAS wrapper
|
||||
def RoutineHeaderWrapperC(self, flavour, def_only, spaces):
|
||||
indent = " "*(spaces + self.Length())
|
||||
result = "void cblasX"+self.name+"("
|
||||
result += (",\n"+indent).join([a for a in self.ArgumentsDefWrapperC(flavour)])+")"
|
||||
return result
|
||||
|
||||
# ==================================================================================================
|
Loading…
Reference in a new issue