diff --git a/src/kernel_preprocessor.cpp b/src/kernel_preprocessor.cpp index 999e80e2..d6966836 100644 --- a/src/kernel_preprocessor.cpp +++ b/src/kernel_preprocessor.cpp @@ -9,18 +9,238 @@ // // This file implements the OpenCL kernel preprocessor (see the header for more information). // +// Restrictions: +// - Use comments only single-line "//" style, not "/*" and "*/" +// - Don't use strings with characters parsed (e.g. '//', '}', '#ifdef') +// - Supports conditionals: #if #ifdef #ifndef #else #elif #endif +// - ...with the operators: == +// - "#pragma unroll" requires next loop in the form "for (int w = 0; w < 4; w += 1) {" +// The above also requires the spaces in that exact form +// // ================================================================================================= +#include +#include +#include +#include #include -#include #include "kernel_preprocessor.hpp" namespace clblast { // ================================================================================================= +void FindReplace(std::string &subject, const std::string &search, const std::string &replace) +{ + auto pos = size_t{0}; + while ((pos = subject.find(search, pos)) != std::string::npos) { + subject.replace(pos, search.length(), replace); + pos += replace.length(); + } +} + +bool EvaluateCondition(std::string condition, + const std::unordered_map &defines) { + + // Replace macros in the string + for (const auto &define : defines) { + FindReplace(condition, define.first, std::to_string(define.second)); + } + + // Process the equality sign + const auto equal_pos = condition.find(" == "); + if (equal_pos != std::string::npos) { + const auto left = condition.substr(0, equal_pos); + const auto right = condition.substr(equal_pos + 4); + return (left == right); + } + return false; // unknown error +} + +// ================================================================================================= + +// First pass: detect defines and comments +std::vector PreprocessDefinesAndComments(const std::string& source, + std::unordered_map& defines) { + auto lines = std::vector(); + + // Parse the input string into a vector of lines + auto disabled = false; + auto source_stream = std::stringstream(source); + auto line = std::string{""}; + while (std::getline(source_stream, line)) { + + // Decide whether or not to remain in 'disabled' mode + if (line.find("#endif") != std::string::npos || + line.find("#elif") != std::string::npos) { + disabled = false; + } + + // Not in a disabled-block + if (!disabled) { + + // Skip empty lines + if (line == "") { continue; } + + // Single line comments + const auto comment_pos = line.find("//"); + if (comment_pos != std::string::npos) { + if (comment_pos == 0) { continue; } + line.erase(comment_pos); + } + + // Detect #define macros + const auto define_pos = line.find("#define "); + if (define_pos != std::string::npos) { + const auto define = line.substr(define_pos + 8); // length of "#define " + const auto value_pos = define.find(" "); + const auto value = define.substr(value_pos + 1); + const auto name = define.substr(0, value_pos); + defines.emplace(name, std::stoi(value)); + //continue; + } + + // Detect #ifndef blocks + const auto ifndef_pos = line.find("#ifndef "); + if (ifndef_pos != std::string::npos) { + const auto define = line.substr(ifndef_pos + 8); // length of "#ifndef " + if (defines.find(define) != defines.end()) { disabled = true; } + continue; + } + + // Detect #ifdef blocks + const auto ifdef_pos = line.find("#ifdef "); + if (ifdef_pos != std::string::npos) { + const auto define = line.substr(ifdef_pos + 7); // length of "#ifdef " + if (defines.find(define) == defines.end()) { disabled = true; } + continue; + } + + // Detect #if blocks + const auto if_pos = line.find("#if "); + if (if_pos != std::string::npos) { + const auto condition = line.substr(if_pos + 4); // length of "#if " + if (!EvaluateCondition(condition, defines)) { disabled = true; } + continue; + } + + // Detect #elif blocks + const auto elif_pos = line.find("#elif "); + if (elif_pos != std::string::npos) { + const auto condition = line.substr(elif_pos + 6); // length of "#elif " + if (!EvaluateCondition(condition, defines)) { disabled = true; } + continue; + } + + // Discard #endif statements + if (line.find("#endif") != std::string::npos) { + continue; + } + + lines.push_back(line); + } + } + return lines; +} + +// ================================================================================================= + +// Second pass: unroll loops +std::vector PreprocessUnrollLoops(const std::vector& source_lines, + const std::unordered_map& defines) { + auto lines = std::vector(); + + auto brackets = 0; + auto unroll_next_loop = false; + + for (auto line_id = size_t{0}; line_id < source_lines.size(); ++line_id) { + const auto line = source_lines[line_id]; + + // Detect #pragma unroll directives + if (line.find("#pragma unroll") != std::string::npos) { + unroll_next_loop = true; + continue; + } + + // Brackets + brackets += std::count(line.begin(), line.end(), '{'); + brackets -= std::count(line.begin(), line.end(), '}'); + + // Loop unrolling assuming it to be in the form "for (int w = 0; w < 4; w += 1) {" + if (unroll_next_loop) { + unroll_next_loop = false; + + // Parses loop structure + const auto for_pos = line.find("for ("); + if (for_pos == std::string::npos) { throw Error("Mis-formatted for-loop #0"); } + const auto remainder = line.substr(for_pos + 5); // length of "for (" + const auto line_split = split(remainder, ' '); + if (line_split.size() != 11) { throw Error("Mis-formatted for-loop #1"); } + + // Retrieves loop information (and checks for assumptions) + const auto variable_type = line_split[0]; + const auto variable_name = line_split[1]; + if (variable_name != line_split[4]) { throw Error("Mis-formatted for-loop #2"); } + if (variable_name != line_split[7]) { throw Error("Mis-formatted for-loop #3"); } + auto loop_start_string = line_split[3]; + auto loop_end_string = line_split[6]; + remove_character(loop_start_string, ';'); + remove_character(loop_end_string, ';'); + + // Parses loop information + const auto loop_start = std::stoi(loop_start_string); + if (defines.count(loop_end_string) == 1) { + loop_end_string = ToString(defines.at(loop_end_string)); + } + const auto loop_end = std::stoi(loop_end_string); + auto indent = std::string{""}; + for (auto i = size_t{0}; i < for_pos; ++i) { indent += " "; } + + // Start of the loop + line_id++; + const auto loop_num_brackets = brackets; + const auto line_id_start = line_id; + for (auto loop_iter = loop_start; loop_iter < loop_end; ++loop_iter) { + line_id = line_id_start; + brackets = loop_num_brackets; + lines.emplace_back(indent + "{"); + + // Body of the loop + lines.emplace_back(indent + " " + variable_type + " " + variable_name + " = " + ToString(loop_iter) + ";"); + while (brackets >= loop_num_brackets) { + const auto loop_line = source_lines[line_id]; + brackets += std::count(loop_line.begin(), loop_line.end(), '{'); + brackets -= std::count(loop_line.begin(), loop_line.end(), '}'); + lines.emplace_back(loop_line); + line_id++; + } + line_id--; + } + } + else { + lines.emplace_back(line); + } + } + return lines; +} + +// ================================================================================================= + std::string PreprocessKernelSource(const std::string& kernel_source) { - const auto processed_kernel = kernel_source; + + // Retrieves the defines and removes comments from the source lines + auto defines = std::unordered_map(); + auto lines = PreprocessDefinesAndComments(kernel_source, defines); + + // Unrolls loops (single level each call) + lines = PreprocessUnrollLoops(lines, defines); + lines = PreprocessUnrollLoops(lines, defines); + + // Gather the results + auto processed_kernel = std::string{""}; + for (const auto& line : lines) { + processed_kernel += line + "\n"; + } return processed_kernel; } diff --git a/src/kernel_preprocessor.hpp b/src/kernel_preprocessor.hpp index 6a9a8a38..81873b4f 100644 --- a/src/kernel_preprocessor.hpp +++ b/src/kernel_preprocessor.hpp @@ -17,7 +17,6 @@ #define CLBLAST_KERNEL_PREPROCESSOR_H_ #include -#include #include "utilities/utilities.hpp" diff --git a/src/kernels/level1/xaxpy.opencl b/src/kernels/level1/xaxpy.opencl index d30d4e55..3a574ec2 100644 --- a/src/kernels/level1/xaxpy.opencl +++ b/src/kernels/level1/xaxpy.opencl @@ -29,8 +29,7 @@ void Xaxpy(const int n, const real_arg arg_alpha, const real alpha = GetRealArg(arg_alpha); // Loops over the work that needs to be done (allows for an arbitrary number of threads) - #pragma unroll - for (int id = get_global_id(0); id