mirror of
https://github.com/CNugteren/CLBlast.git
synced 2024-07-07 12:23:46 +02:00
Layed the groundwork for cuBLAS comparisons in the clients
This commit is contained in:
parent
c5461d77e5
commit
b24d364743
|
@ -130,17 +130,23 @@ if(TUNERS)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Locates the reference BLAS libraries in case the tests need to be compiled. The "FindclBLAS.cmake"
|
# Locates the reference BLAS libraries in case the tests need to be compiled. The "FindclBLAS.cmake"
|
||||||
# and "FindCBLAS.cmake" are included.
|
# and "FindCBLAS.cmake" are included, "FindCUDA.cmake" is provided by CMake.
|
||||||
if(CLIENTS OR TESTS)
|
if(CLIENTS OR TESTS)
|
||||||
find_package(clBLAS)
|
find_package(clBLAS)
|
||||||
find_package(CBLAS)
|
find_package(CBLAS)
|
||||||
if(NOT CLBLAS_FOUND AND NOT CBLAS_FOUND)
|
find_package(CUDA QUIET) # for cuBLAS
|
||||||
|
if(CUDA_FOUND)
|
||||||
|
message(STATUS "CUDA and cuBLAS found")
|
||||||
|
else()
|
||||||
|
message(STATUS "Could not find cuBLAS as a reference")
|
||||||
|
endif()
|
||||||
|
if(NOT CLBLAS_FOUND AND NOT CBLAS_FOUND AND NOT CUDA_FOUND)
|
||||||
if(TESTS)
|
if(TESTS)
|
||||||
message(STATUS "Could NOT find clBLAS nor a CPU BLAS, disabling the compilation of the tests")
|
message(STATUS "Could NOT find clBLAS nor a CPU BLAS nor cuBLAS, disabling the compilation of the tests")
|
||||||
set(TESTS OFF)
|
set(TESTS OFF)
|
||||||
endif()
|
endif()
|
||||||
if(CLIENTS)
|
if(CLIENTS)
|
||||||
message(STATUS "Could NOT find clBLAS nor a CPU BLAS, head-to-head performance comparison not supported in the clients")
|
message(STATUS "Could NOT find clBLAS nor a CPU BLAS nor cuBLAS, head-to-head performance comparison not supported in the clients")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
@ -320,13 +326,22 @@ if(CLIENTS OR TESTS)
|
||||||
add_definitions(" -DCLBLAST_REF_CBLAS")
|
add_definitions(" -DCLBLAST_REF_CBLAS")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
if(CUDA_FOUND)
|
||||||
|
set(REF_INCLUDES ${REF_INCLUDES} ${CUDA_INCLUDE_DIRS})
|
||||||
|
set(REF_LIBRARIES ${REF_LIBRARIES} ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES})
|
||||||
|
if(MSVC)
|
||||||
|
add_definitions(" /DCLBLAST_REF_CUBLAS")
|
||||||
|
else()
|
||||||
|
add_definitions(" -DCLBLAST_REF_CUBLAS")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# ==================================================================================================
|
# ==================================================================================================
|
||||||
|
|
||||||
# Section for the performance tests (i.e. the client). These compare against optionally a reference
|
# Section for the performance tests (i.e. the client). These compare against optionally a reference
|
||||||
# library, either clBLAS or a CPU BLAS.
|
# library, either clBLAS, a CPU BLAS, or CUDA's cuBLAS.
|
||||||
if(CLIENTS)
|
if(CLIENTS)
|
||||||
|
|
||||||
# Visual Studio requires the sources of non-exported objects/libraries
|
# Visual Studio requires the sources of non-exported objects/libraries
|
||||||
|
@ -372,7 +387,7 @@ endif()
|
||||||
# ==================================================================================================
|
# ==================================================================================================
|
||||||
|
|
||||||
# Section for the correctness tests. Note that these tests require the presence of clBLAS and/or a
|
# Section for the correctness tests. Note that these tests require the presence of clBLAS and/or a
|
||||||
# CPU BLAS library to act as a reference.
|
# CPU BLAS library, and/or cuBLAS to act as a reference.
|
||||||
if(TESTS)
|
if(TESTS)
|
||||||
enable_testing()
|
enable_testing()
|
||||||
|
|
||||||
|
|
|
@ -81,6 +81,7 @@ constexpr auto kArgFraction = "fraction";
|
||||||
// The client-specific arguments in string form
|
// The client-specific arguments in string form
|
||||||
constexpr auto kArgCompareclblas = "clblas";
|
constexpr auto kArgCompareclblas = "clblas";
|
||||||
constexpr auto kArgComparecblas = "cblas";
|
constexpr auto kArgComparecblas = "cblas";
|
||||||
|
constexpr auto kArgComparecublas = "cublas";
|
||||||
constexpr auto kArgStepSize = "step";
|
constexpr auto kArgStepSize = "step";
|
||||||
constexpr auto kArgNumSteps = "num_steps";
|
constexpr auto kArgNumSteps = "num_steps";
|
||||||
constexpr auto kArgNumRuns = "runs";
|
constexpr auto kArgNumRuns = "runs";
|
||||||
|
@ -188,6 +189,7 @@ struct Arguments {
|
||||||
// Client-specific arguments
|
// Client-specific arguments
|
||||||
int compare_clblas = 1;
|
int compare_clblas = 1;
|
||||||
int compare_cblas = 1;
|
int compare_cblas = 1;
|
||||||
|
int compare_cublas = 1;
|
||||||
size_t step = 1;
|
size_t step = 1;
|
||||||
size_t num_steps = 0;
|
size_t num_steps = 0;
|
||||||
size_t num_runs = 10;
|
size_t num_runs = 10;
|
||||||
|
|
|
@ -116,24 +116,38 @@ Tester<T,U>::Tester(const std::vector<std::string> &arguments, const bool silent
|
||||||
tests_failed_{0} {
|
tests_failed_{0} {
|
||||||
options_ = options;
|
options_ = options;
|
||||||
|
|
||||||
|
// Determines which reference is the default
|
||||||
|
auto default_clblas = 0;
|
||||||
|
auto default_cblas = 0;
|
||||||
|
auto default_cublas = 0;
|
||||||
|
#if defined(CLBLAST_REF_CBLAS)
|
||||||
|
default_cblas = 1;
|
||||||
|
#elif defined(CLBLAST_REF_CLBLAS)
|
||||||
|
default_clblas = 1;
|
||||||
|
#elif defined(CLBLAST_REF_CUBLAS)
|
||||||
|
default_cublas = 1;
|
||||||
|
#endif
|
||||||
|
|
||||||
// Determines which reference to test against
|
// Determines which reference to test against
|
||||||
#if defined(CLBLAST_REF_CLBLAS) && defined(CLBLAST_REF_CBLAS)
|
|
||||||
compare_clblas_ = GetArgument(arguments, help_, kArgCompareclblas, 0);
|
|
||||||
compare_cblas_ = GetArgument(arguments, help_, kArgComparecblas, 1);
|
|
||||||
#elif CLBLAST_REF_CLBLAS
|
|
||||||
compare_clblas_ = GetArgument(arguments, help_, kArgCompareclblas, 1);
|
|
||||||
compare_cblas_ = 0;
|
|
||||||
#elif CLBLAST_REF_CBLAS
|
|
||||||
compare_clblas_ = 0;
|
|
||||||
compare_cblas_ = GetArgument(arguments, help_, kArgComparecblas, 1);
|
|
||||||
#else
|
|
||||||
compare_clblas_ = 0;
|
compare_clblas_ = 0;
|
||||||
compare_cblas_ = 0;
|
compare_cblas_ = 0;
|
||||||
|
compare_cublas_ = 0;
|
||||||
|
#if defined(CLBLAST_REF_CBLAS)
|
||||||
|
compare_cblas_ = GetArgument(arguments, help_, kArgComparecblas, default_cblas);
|
||||||
|
#endif
|
||||||
|
#if defined(CLBLAST_REF_CLBLAS)
|
||||||
|
compare_clblas_ = GetArgument(arguments, help_, kArgCompareclblas, default_clblas);
|
||||||
|
#endif
|
||||||
|
#if defined(CLBLAST_REF_CUBLAS)
|
||||||
|
compare_cublas_ = GetArgument(arguments, help_, kArgComparecublas, default_cublas);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Prints the help message (command-line arguments)
|
// Prints the help message (command-line arguments)
|
||||||
if (!silent) { fprintf(stdout, "\n* %s\n", help_.c_str()); }
|
if (!silent) { fprintf(stdout, "\n* %s\n", help_.c_str()); }
|
||||||
|
|
||||||
|
// Support for cuBLAS not available yet
|
||||||
|
if (compare_cublas_) { throw std::runtime_error("Cannot test against cuBLAS; not implemented yet"); }
|
||||||
|
|
||||||
// Can only test against a single reference (not two, not zero)
|
// Can only test against a single reference (not two, not zero)
|
||||||
if (compare_clblas_ && compare_cblas_) {
|
if (compare_clblas_ && compare_cblas_) {
|
||||||
throw std::runtime_error("Cannot test against both clBLAS and CBLAS references; choose one using the -cblas and -clblas arguments");
|
throw std::runtime_error("Cannot test against both clBLAS and CBLAS references; choose one using the -cblas and -clblas arguments");
|
||||||
|
|
|
@ -113,6 +113,7 @@ class Tester {
|
||||||
// Testing against reference implementations
|
// Testing against reference implementations
|
||||||
int compare_cblas_;
|
int compare_cblas_;
|
||||||
int compare_clblas_;
|
int compare_clblas_;
|
||||||
|
int compare_cublas_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
|
|
|
@ -30,13 +30,14 @@ template <typename T, typename U> const int Client<T,U>::kSeed = 42; // fixed se
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
Client<T,U>::Client(const Routine run_routine,
|
Client<T,U>::Client(const Routine run_routine,
|
||||||
const Reference1 run_reference1, const Reference2 run_reference2,
|
const Reference1 run_reference1, const Reference2 run_reference2,
|
||||||
const std::vector<std::string> &options,
|
const Reference3 run_reference3, const std::vector<std::string> &options,
|
||||||
const std::vector<std::string> &buffers_in,
|
const std::vector<std::string> &buffers_in,
|
||||||
const std::vector<std::string> &buffers_out,
|
const std::vector<std::string> &buffers_out,
|
||||||
const GetMetric get_flops, const GetMetric get_bytes):
|
const GetMetric get_flops, const GetMetric get_bytes):
|
||||||
run_routine_(run_routine),
|
run_routine_(run_routine),
|
||||||
run_reference1_(run_reference1),
|
run_reference1_(run_reference1),
|
||||||
run_reference2_(run_reference2),
|
run_reference2_(run_reference2),
|
||||||
|
run_reference3_(run_reference3),
|
||||||
options_(options),
|
options_(options),
|
||||||
buffers_in_(buffers_in),
|
buffers_in_(buffers_in),
|
||||||
buffers_out_(buffers_out),
|
buffers_out_(buffers_out),
|
||||||
|
@ -119,6 +120,11 @@ Arguments<U> Client<T,U>::ParseArguments(int argc, char *argv[], const size_t le
|
||||||
#else
|
#else
|
||||||
args.compare_cblas = 0;
|
args.compare_cblas = 0;
|
||||||
#endif
|
#endif
|
||||||
|
#ifdef CLBLAST_REF_CUBLAS
|
||||||
|
args.compare_cublas = GetArgument(command_line_args, help, kArgComparecublas, 1);
|
||||||
|
#else
|
||||||
|
args.compare_cublas = 0;
|
||||||
|
#endif
|
||||||
args.step = GetArgument(command_line_args, help, kArgStepSize, size_t{1});
|
args.step = GetArgument(command_line_args, help, kArgStepSize, size_t{1});
|
||||||
args.num_steps = GetArgument(command_line_args, help, kArgNumSteps, size_t{0});
|
args.num_steps = GetArgument(command_line_args, help, kArgNumSteps, size_t{0});
|
||||||
args.num_runs = GetArgument(command_line_args, help, kArgNumRuns, size_t{10});
|
args.num_runs = GetArgument(command_line_args, help, kArgNumRuns, size_t{10});
|
||||||
|
@ -133,24 +139,26 @@ Arguments<U> Client<T,U>::ParseArguments(int argc, char *argv[], const size_t le
|
||||||
|
|
||||||
// Comparison against a non-BLAS routine is not supported
|
// Comparison against a non-BLAS routine is not supported
|
||||||
if (level == 4) { // level-4 == level-X
|
if (level == 4) { // level-4 == level-X
|
||||||
if (args.compare_clblas != 0 || args.compare_cblas != 0) {
|
if (args.compare_clblas != 0 || args.compare_cblas != 0 || args.compare_cublas != 0) {
|
||||||
if (!args.silent) {
|
if (!args.silent) {
|
||||||
fprintf(stdout, "* Disabling clBLAS and CPU BLAS comparisons for this non-BLAS routine\n\n");
|
fprintf(stdout, "* Disabling clBLAS/CBLAS/cuBLAS comparisons for this non-BLAS routine\n\n");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
args.compare_clblas = 0;
|
args.compare_clblas = 0;
|
||||||
args.compare_cblas = 0;
|
args.compare_cblas = 0;
|
||||||
|
args.compare_cublas = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Comparison against clBLAS or a CPU BLAS library is not supported in case of half-precision
|
// Comparison against other BLAS libraries is not supported in case of half-precision
|
||||||
if (args.precision == Precision::kHalf) {
|
if (args.precision == Precision::kHalf) {
|
||||||
if (args.compare_clblas != 0 || args.compare_cblas != 0) {
|
if (args.compare_clblas != 0 || args.compare_cblas != 0 || args.compare_cublas != 0) {
|
||||||
if (!args.silent) {
|
if (!args.silent) {
|
||||||
fprintf(stdout, "* Disabling clBLAS and CPU BLAS comparisons for half-precision\n\n");
|
fprintf(stdout, "* Disabling clBLAS/CBLAS/cuBLAS comparisons for half-precision\n\n");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
args.compare_clblas = 0;
|
args.compare_clblas = 0;
|
||||||
args.compare_cblas = 0;
|
args.compare_cblas = 0;
|
||||||
|
args.compare_cublas = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the arguments
|
// Returns the arguments
|
||||||
|
@ -174,6 +182,9 @@ void Client<T,U>::PerformanceTest(Arguments<U> &args, const SetMetric set_sizes)
|
||||||
#ifdef CLBLAST_REF_CLBLAS
|
#ifdef CLBLAST_REF_CLBLAS
|
||||||
if (args.compare_clblas) { clblasSetup(); }
|
if (args.compare_clblas) { clblasSetup(); }
|
||||||
#endif
|
#endif
|
||||||
|
#ifdef CLBLAST_REF_CUBLAS
|
||||||
|
cudaSetDevice(static_cast<int>(args.device_id));
|
||||||
|
#endif
|
||||||
|
|
||||||
// Iterates over all "num_step" values jumping by "step" each time
|
// Iterates over all "num_step" values jumping by "step" each time
|
||||||
auto s = size_t{0};
|
auto s = size_t{0};
|
||||||
|
@ -232,6 +243,16 @@ void Client<T,U>::PerformanceTest(Arguments<U> &args, const SetMetric set_sizes)
|
||||||
HostToDevice(args, buffers, buffers_host, queue, buffers_out_);
|
HostToDevice(args, buffers, buffers_host, queue, buffers_out_);
|
||||||
timings.push_back(std::pair<std::string, double>("CPU BLAS", ms_cblas));
|
timings.push_back(std::pair<std::string, double>("CPU BLAS", ms_cblas));
|
||||||
}
|
}
|
||||||
|
if (args.compare_cublas) {
|
||||||
|
auto buffers_host = BuffersHost<T>();
|
||||||
|
auto buffers_cuda = BuffersCUDA<T>();
|
||||||
|
DeviceToHost(args, buffers, buffers_host, queue, buffers_in_);
|
||||||
|
HostToCUDA(args, buffers_cuda, buffers_host, buffers_in_);
|
||||||
|
auto ms_cublas = TimedExecution(args.num_runs, args, buffers_cuda, queue, run_reference3_, "cuBLAS");
|
||||||
|
CUDAToHost(args, buffers_cuda, buffers_host, buffers_out_);
|
||||||
|
HostToDevice(args, buffers, buffers_host, queue, buffers_out_);
|
||||||
|
timings.push_back(std::pair<std::string, double>("cuBLAS", ms_cublas));
|
||||||
|
}
|
||||||
|
|
||||||
// Prints the performance of the tested libraries
|
// Prints the performance of the tested libraries
|
||||||
PrintTableRow(args, timings);
|
PrintTableRow(args, timings);
|
||||||
|
@ -307,6 +328,7 @@ void Client<T,U>::PrintTableHeader(const Arguments<U>& args) {
|
||||||
fprintf(stdout, " | <-- CLBlast -->");
|
fprintf(stdout, " | <-- CLBlast -->");
|
||||||
if (args.compare_clblas) { fprintf(stdout, " | <-- clBLAS -->"); }
|
if (args.compare_clblas) { fprintf(stdout, " | <-- clBLAS -->"); }
|
||||||
if (args.compare_cblas) { fprintf(stdout, " | <-- CPU BLAS -->"); }
|
if (args.compare_cblas) { fprintf(stdout, " | <-- CPU BLAS -->"); }
|
||||||
|
if (args.compare_cublas) { fprintf(stdout, " | <-- cuBLAS -->"); }
|
||||||
fprintf(stdout, " |\n");
|
fprintf(stdout, " |\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -315,6 +337,7 @@ void Client<T,U>::PrintTableHeader(const Arguments<U>& args) {
|
||||||
fprintf(stdout, "%9s;%9s;%9s", "ms_1", "GFLOPS_1", "GBs_1");
|
fprintf(stdout, "%9s;%9s;%9s", "ms_1", "GFLOPS_1", "GBs_1");
|
||||||
if (args.compare_clblas) { fprintf(stdout, ";%9s;%9s;%9s", "ms_2", "GFLOPS_2", "GBs_2"); }
|
if (args.compare_clblas) { fprintf(stdout, ";%9s;%9s;%9s", "ms_2", "GFLOPS_2", "GBs_2"); }
|
||||||
if (args.compare_cblas) { fprintf(stdout, ";%9s;%9s;%9s", "ms_3", "GFLOPS_3", "GBs_3"); }
|
if (args.compare_cblas) { fprintf(stdout, ";%9s;%9s;%9s", "ms_3", "GFLOPS_3", "GBs_3"); }
|
||||||
|
if (args.compare_cublas) { fprintf(stdout, ";%9s;%9s;%9s", "ms_4", "GFLOPS_4", "GBs_4"); }
|
||||||
fprintf(stdout, "\n");
|
fprintf(stdout, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,7 @@
|
||||||
#ifdef CLBLAST_REF_CLBLAS
|
#ifdef CLBLAST_REF_CLBLAS
|
||||||
#include <clBLAS.h>
|
#include <clBLAS.h>
|
||||||
#endif
|
#endif
|
||||||
|
#include "test/wrapper_cuda.hpp"
|
||||||
#include "clblast.h"
|
#include "clblast.h"
|
||||||
|
|
||||||
namespace clblast {
|
namespace clblast {
|
||||||
|
@ -46,12 +47,13 @@ class Client {
|
||||||
using Routine = std::function<StatusCode(const Arguments<U>&, Buffers<T>&, Queue&)>;
|
using Routine = std::function<StatusCode(const Arguments<U>&, Buffers<T>&, Queue&)>;
|
||||||
using Reference1 = std::function<StatusCode(const Arguments<U>&, Buffers<T>&, Queue&)>;
|
using Reference1 = std::function<StatusCode(const Arguments<U>&, Buffers<T>&, Queue&)>;
|
||||||
using Reference2 = std::function<StatusCode(const Arguments<U>&, BuffersHost<T>&, Queue&)>;
|
using Reference2 = std::function<StatusCode(const Arguments<U>&, BuffersHost<T>&, Queue&)>;
|
||||||
|
using Reference3 = std::function<StatusCode(const Arguments<U>&, BuffersCUDA<T>&, Queue&)>;
|
||||||
using SetMetric = std::function<void(Arguments<U>&)>;
|
using SetMetric = std::function<void(Arguments<U>&)>;
|
||||||
using GetMetric = std::function<size_t(const Arguments<U>&)>;
|
using GetMetric = std::function<size_t(const Arguments<U>&)>;
|
||||||
|
|
||||||
// The constructor
|
// The constructor
|
||||||
Client(const Routine run_routine, const Reference1 run_reference1, const Reference2 run_reference2,
|
Client(const Routine run_routine, const Reference1 run_reference1, const Reference2 run_reference2,
|
||||||
const std::vector<std::string> &options,
|
const Reference3 run_reference3, const std::vector<std::string> &options,
|
||||||
const std::vector<std::string> &buffers_in, const std::vector<std::string> &buffers_out,
|
const std::vector<std::string> &buffers_in, const std::vector<std::string> &buffers_out,
|
||||||
const GetMetric get_flops, const GetMetric get_bytes);
|
const GetMetric get_flops, const GetMetric get_bytes);
|
||||||
|
|
||||||
|
@ -84,6 +86,7 @@ class Client {
|
||||||
const Routine run_routine_;
|
const Routine run_routine_;
|
||||||
const Reference1 run_reference1_;
|
const Reference1 run_reference1_;
|
||||||
const Reference2 run_reference2_;
|
const Reference2 run_reference2_;
|
||||||
|
const Reference3 run_reference3_;
|
||||||
const std::vector<std::string> options_;
|
const std::vector<std::string> options_;
|
||||||
const std::vector<std::string> buffers_in_;
|
const std::vector<std::string> buffers_in_;
|
||||||
const std::vector<std::string> buffers_out_;
|
const std::vector<std::string> buffers_out_;
|
||||||
|
@ -118,9 +121,14 @@ void RunClient(int argc, char *argv[]) {
|
||||||
#else
|
#else
|
||||||
auto reference2 = ReferenceNotAvailable<T,U,BuffersHost<T>>;
|
auto reference2 = ReferenceNotAvailable<T,U,BuffersHost<T>>;
|
||||||
#endif
|
#endif
|
||||||
|
#ifdef CLBLAST_REF_CUBLAS
|
||||||
|
auto reference3 = C::RunReference3; // cuBLAS when available
|
||||||
|
#else
|
||||||
|
auto reference3 = ReferenceNotAvailable<T,U,BuffersCUDA<T>>;
|
||||||
|
#endif
|
||||||
|
|
||||||
// Creates a new client
|
// Creates a new client
|
||||||
auto client = Client<T,U>(C::RunRoutine, reference1, reference2, C::GetOptions(),
|
auto client = Client<T,U>(C::RunRoutine, reference1, reference2, reference3, C::GetOptions(),
|
||||||
C::BuffersIn(), C::BuffersOut(), C::GetFlops, C::GetBytes);
|
C::BuffersIn(), C::BuffersOut(), C::GetFlops, C::GetBytes);
|
||||||
|
|
||||||
// Simple command line argument parser with defaults
|
// Simple command line argument parser with defaults
|
||||||
|
|
111
test/wrapper_cuda.hpp
Normal file
111
test/wrapper_cuda.hpp
Normal file
|
@ -0,0 +1,111 @@
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
|
||||||
|
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
|
||||||
|
// width of 100 characters per line.
|
||||||
|
//
|
||||||
|
// Author(s):
|
||||||
|
// Cedric Nugteren <www.cedricnugteren.nl>
|
||||||
|
//
|
||||||
|
// This file contains all the CUDA related code; used only in case of testing against cuBLAS
|
||||||
|
//
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
|
#ifndef CLBLAST_TEST_WRAPPER_CUDA_H_
|
||||||
|
#define CLBLAST_TEST_WRAPPER_CUDA_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
|
#include "utilities/utilities.hpp"
|
||||||
|
|
||||||
|
#ifdef CLBLAST_REF_CUBLAS
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cublas_v2.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace clblast {
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
|
// Copies data from the CUDA device to the host and frees-up the CUDA memory afterwards
|
||||||
|
#ifdef CLBLAST_REF_CUBLAS
|
||||||
|
template <typename T>
|
||||||
|
void CUDAToHost(const T* buffer_cuda, const std::vector<T> &buffer_host, const size_t size) {
|
||||||
|
cudaMemcpy(
|
||||||
|
std::reinterpret_cast<void*>(buffer_host.data()),
|
||||||
|
std::reinterpret_cast<void*>(buffer_cuda),
|
||||||
|
size*sizeof(T),
|
||||||
|
cudaMemcpyDeviceToHost
|
||||||
|
);
|
||||||
|
cudaFree(buffer_cuda);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
template <typename T> void CUDAToHost(const T*, const std::vector<T>&, const size_t) { }
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Allocates space on the CUDA device and copies in data from the host
|
||||||
|
#ifdef CLBLAST_REF_CUBLAS
|
||||||
|
template <typename T>
|
||||||
|
void HostToCUDA(const T* buffer_cuda, const std::vector<T> &buffer_host, const size_t size) {
|
||||||
|
cudaMalloc(std::reinterpret_cast<void**>&buffer_cuda, size*sizeof(T));
|
||||||
|
cudaMemcpy(
|
||||||
|
std::reinterpret_cast<void*>(buffer_cuda),
|
||||||
|
std::reinterpret_cast<void*>(buffer_host.data()),
|
||||||
|
size*sizeof(T),
|
||||||
|
cudaMemcpyHostToDevice
|
||||||
|
);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
template <typename T> void HostToCUDA(const T*, const std::vector<T>&, const size_t) { }
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct BuffersCUDA {
|
||||||
|
T* x_vec;
|
||||||
|
T* y_vec;
|
||||||
|
T* a_mat;
|
||||||
|
T* b_mat;
|
||||||
|
T* c_mat;
|
||||||
|
T* ap_mat;
|
||||||
|
T* scalar;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
void CUDAToHost(const Arguments<U> &args, BuffersCUDA<T> &buffers, BuffersHost<T> &buffers_host,
|
||||||
|
const std::vector<std::string> &names) {
|
||||||
|
for (auto &name: names) {
|
||||||
|
if (name == kBufVecX) { buffers_host.x_vec = std::vector<T>(args.x_size, static_cast<T>(0)); CUDAToHost(buffers.x_vec, buffers_host.x_vec, args.x_size); }
|
||||||
|
else if (name == kBufVecY) { buffers_host.y_vec = std::vector<T>(args.y_size, static_cast<T>(0)); CUDAToHost(buffers.y_vec, buffers_host.y_vec, args.y_size); }
|
||||||
|
else if (name == kBufMatA) { buffers_host.a_mat = std::vector<T>(args.a_size, static_cast<T>(0)); CUDAToHost(buffers.a_mat, buffers_host.a_mat, args.a_size); }
|
||||||
|
else if (name == kBufMatB) { buffers_host.b_mat = std::vector<T>(args.b_size, static_cast<T>(0)); CUDAToHost(buffers.b_mat, buffers_host.b_mat, args.b_size); }
|
||||||
|
else if (name == kBufMatC) { buffers_host.c_mat = std::vector<T>(args.c_size, static_cast<T>(0)); CUDAToHost(buffers.c_mat, buffers_host.c_mat, args.c_size); }
|
||||||
|
else if (name == kBufMatAP) { buffers_host.ap_mat = std::vector<T>(args.ap_size, static_cast<T>(0)); CUDAToHost(buffers.ap_mat, buffers_host.ap_mat, args.ap_size); }
|
||||||
|
else if (name == kBufScalar) { buffers_host.scalar = std::vector<T>(args.scalar_size, static_cast<T>(0)); CUDAToHost(buffers.scalar, buffers_host.scalar, args.scalar_size); }
|
||||||
|
else { throw std::runtime_error("Invalid buffer name"); }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
void HostToCUDA(const Arguments<U> &args, BuffersCUDA<T> &buffers, BuffersHost<T> &buffers_host,
|
||||||
|
const std::vector<std::string> &names) {
|
||||||
|
for (auto &name: names) {
|
||||||
|
if (name == kBufVecX) { HostToCUDA(buffers.x_vec, buffers_host.x_vec, args.x_size); }
|
||||||
|
else if (name == kBufVecY) { HostToCUDA(buffers.y_vec, buffers_host.y_vec, args.y_size); }
|
||||||
|
else if (name == kBufMatA) { HostToCUDA(buffers.a_mat, buffers_host.a_mat, args.a_size); }
|
||||||
|
else if (name == kBufMatB) { HostToCUDA(buffers.b_mat, buffers_host.b_mat, args.b_size); }
|
||||||
|
else if (name == kBufMatC) { HostToCUDA(buffers.c_mat, buffers_host.c_mat, args.c_size); }
|
||||||
|
else if (name == kBufMatAP) { HostToCUDA(buffers.ap_mat, buffers_host.ap_mat, args.ap_size); }
|
||||||
|
else if (name == kBufScalar) { HostToCUDA(buffers.scalar, buffers_host.scalar, args.scalar_size); }
|
||||||
|
else { throw std::runtime_error("Invalid buffer name"); }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// =================================================================================================
|
||||||
|
} // namespace clblast
|
||||||
|
|
||||||
|
// CLBLAST_TEST_WRAPPER_CUDA_H_
|
||||||
|
#endif
|
Loading…
Reference in a new issue