Implemented first simple pre-processor: defines parser and loop unrolling based on assumptions
parent
f01bcded1e
commit
69aa3b35ed
|
@ -9,18 +9,238 @@
|
|||
//
|
||||
// This file implements the OpenCL kernel preprocessor (see the header for more information).
|
||||
//
|
||||
// Restrictions:
|
||||
// - Use comments only single-line "//" style, not "/*" and "*/"
|
||||
// - Don't use strings with characters parsed (e.g. '//', '}', '#ifdef')
|
||||
// - Supports conditionals: #if #ifdef #ifndef #else #elif #endif
|
||||
// - ...with the operators: ==
|
||||
// - "#pragma unroll" requires next loop in the form "for (int w = 0; w < 4; w += 1) {"
|
||||
// The above also requires the spaces in that exact form
|
||||
//
|
||||
// =================================================================================================
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <chrono>
|
||||
|
||||
#include "kernel_preprocessor.hpp"
|
||||
|
||||
namespace clblast {
|
||||
// =================================================================================================
|
||||
|
||||
void FindReplace(std::string &subject, const std::string &search, const std::string &replace)
|
||||
{
|
||||
auto pos = size_t{0};
|
||||
while ((pos = subject.find(search, pos)) != std::string::npos) {
|
||||
subject.replace(pos, search.length(), replace);
|
||||
pos += replace.length();
|
||||
}
|
||||
}
|
||||
|
||||
bool EvaluateCondition(std::string condition,
|
||||
const std::unordered_map<std::string, int> &defines) {
|
||||
|
||||
// Replace macros in the string
|
||||
for (const auto &define : defines) {
|
||||
FindReplace(condition, define.first, std::to_string(define.second));
|
||||
}
|
||||
|
||||
// Process the equality sign
|
||||
const auto equal_pos = condition.find(" == ");
|
||||
if (equal_pos != std::string::npos) {
|
||||
const auto left = condition.substr(0, equal_pos);
|
||||
const auto right = condition.substr(equal_pos + 4);
|
||||
return (left == right);
|
||||
}
|
||||
return false; // unknown error
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
// First pass: detect defines and comments
|
||||
std::vector<std::string> PreprocessDefinesAndComments(const std::string& source,
|
||||
std::unordered_map<std::string, int>& defines) {
|
||||
auto lines = std::vector<std::string>();
|
||||
|
||||
// Parse the input string into a vector of lines
|
||||
auto disabled = false;
|
||||
auto source_stream = std::stringstream(source);
|
||||
auto line = std::string{""};
|
||||
while (std::getline(source_stream, line)) {
|
||||
|
||||
// Decide whether or not to remain in 'disabled' mode
|
||||
if (line.find("#endif") != std::string::npos ||
|
||||
line.find("#elif") != std::string::npos) {
|
||||
disabled = false;
|
||||
}
|
||||
|
||||
// Not in a disabled-block
|
||||
if (!disabled) {
|
||||
|
||||
// Skip empty lines
|
||||
if (line == "") { continue; }
|
||||
|
||||
// Single line comments
|
||||
const auto comment_pos = line.find("//");
|
||||
if (comment_pos != std::string::npos) {
|
||||
if (comment_pos == 0) { continue; }
|
||||
line.erase(comment_pos);
|
||||
}
|
||||
|
||||
// Detect #define macros
|
||||
const auto define_pos = line.find("#define ");
|
||||
if (define_pos != std::string::npos) {
|
||||
const auto define = line.substr(define_pos + 8); // length of "#define "
|
||||
const auto value_pos = define.find(" ");
|
||||
const auto value = define.substr(value_pos + 1);
|
||||
const auto name = define.substr(0, value_pos);
|
||||
defines.emplace(name, std::stoi(value));
|
||||
//continue;
|
||||
}
|
||||
|
||||
// Detect #ifndef blocks
|
||||
const auto ifndef_pos = line.find("#ifndef ");
|
||||
if (ifndef_pos != std::string::npos) {
|
||||
const auto define = line.substr(ifndef_pos + 8); // length of "#ifndef "
|
||||
if (defines.find(define) != defines.end()) { disabled = true; }
|
||||
continue;
|
||||
}
|
||||
|
||||
// Detect #ifdef blocks
|
||||
const auto ifdef_pos = line.find("#ifdef ");
|
||||
if (ifdef_pos != std::string::npos) {
|
||||
const auto define = line.substr(ifdef_pos + 7); // length of "#ifdef "
|
||||
if (defines.find(define) == defines.end()) { disabled = true; }
|
||||
continue;
|
||||
}
|
||||
|
||||
// Detect #if blocks
|
||||
const auto if_pos = line.find("#if ");
|
||||
if (if_pos != std::string::npos) {
|
||||
const auto condition = line.substr(if_pos + 4); // length of "#if "
|
||||
if (!EvaluateCondition(condition, defines)) { disabled = true; }
|
||||
continue;
|
||||
}
|
||||
|
||||
// Detect #elif blocks
|
||||
const auto elif_pos = line.find("#elif ");
|
||||
if (elif_pos != std::string::npos) {
|
||||
const auto condition = line.substr(elif_pos + 6); // length of "#elif "
|
||||
if (!EvaluateCondition(condition, defines)) { disabled = true; }
|
||||
continue;
|
||||
}
|
||||
|
||||
// Discard #endif statements
|
||||
if (line.find("#endif") != std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
|
||||
lines.push_back(line);
|
||||
}
|
||||
}
|
||||
return lines;
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
// Second pass: unroll loops
|
||||
std::vector<std::string> PreprocessUnrollLoops(const std::vector<std::string>& source_lines,
|
||||
const std::unordered_map<std::string, int>& defines) {
|
||||
auto lines = std::vector<std::string>();
|
||||
|
||||
auto brackets = 0;
|
||||
auto unroll_next_loop = false;
|
||||
|
||||
for (auto line_id = size_t{0}; line_id < source_lines.size(); ++line_id) {
|
||||
const auto line = source_lines[line_id];
|
||||
|
||||
// Detect #pragma unroll directives
|
||||
if (line.find("#pragma unroll") != std::string::npos) {
|
||||
unroll_next_loop = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Brackets
|
||||
brackets += std::count(line.begin(), line.end(), '{');
|
||||
brackets -= std::count(line.begin(), line.end(), '}');
|
||||
|
||||
// Loop unrolling assuming it to be in the form "for (int w = 0; w < 4; w += 1) {"
|
||||
if (unroll_next_loop) {
|
||||
unroll_next_loop = false;
|
||||
|
||||
// Parses loop structure
|
||||
const auto for_pos = line.find("for (");
|
||||
if (for_pos == std::string::npos) { throw Error<std::runtime_error>("Mis-formatted for-loop #0"); }
|
||||
const auto remainder = line.substr(for_pos + 5); // length of "for ("
|
||||
const auto line_split = split(remainder, ' ');
|
||||
if (line_split.size() != 11) { throw Error<std::runtime_error>("Mis-formatted for-loop #1"); }
|
||||
|
||||
// Retrieves loop information (and checks for assumptions)
|
||||
const auto variable_type = line_split[0];
|
||||
const auto variable_name = line_split[1];
|
||||
if (variable_name != line_split[4]) { throw Error<std::runtime_error>("Mis-formatted for-loop #2"); }
|
||||
if (variable_name != line_split[7]) { throw Error<std::runtime_error>("Mis-formatted for-loop #3"); }
|
||||
auto loop_start_string = line_split[3];
|
||||
auto loop_end_string = line_split[6];
|
||||
remove_character(loop_start_string, ';');
|
||||
remove_character(loop_end_string, ';');
|
||||
|
||||
// Parses loop information
|
||||
const auto loop_start = std::stoi(loop_start_string);
|
||||
if (defines.count(loop_end_string) == 1) {
|
||||
loop_end_string = ToString(defines.at(loop_end_string));
|
||||
}
|
||||
const auto loop_end = std::stoi(loop_end_string);
|
||||
auto indent = std::string{""};
|
||||
for (auto i = size_t{0}; i < for_pos; ++i) { indent += " "; }
|
||||
|
||||
// Start of the loop
|
||||
line_id++;
|
||||
const auto loop_num_brackets = brackets;
|
||||
const auto line_id_start = line_id;
|
||||
for (auto loop_iter = loop_start; loop_iter < loop_end; ++loop_iter) {
|
||||
line_id = line_id_start;
|
||||
brackets = loop_num_brackets;
|
||||
lines.emplace_back(indent + "{");
|
||||
|
||||
// Body of the loop
|
||||
lines.emplace_back(indent + " " + variable_type + " " + variable_name + " = " + ToString(loop_iter) + ";");
|
||||
while (brackets >= loop_num_brackets) {
|
||||
const auto loop_line = source_lines[line_id];
|
||||
brackets += std::count(loop_line.begin(), loop_line.end(), '{');
|
||||
brackets -= std::count(loop_line.begin(), loop_line.end(), '}');
|
||||
lines.emplace_back(loop_line);
|
||||
line_id++;
|
||||
}
|
||||
line_id--;
|
||||
}
|
||||
}
|
||||
else {
|
||||
lines.emplace_back(line);
|
||||
}
|
||||
}
|
||||
return lines;
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
std::string PreprocessKernelSource(const std::string& kernel_source) {
|
||||
const auto processed_kernel = kernel_source;
|
||||
|
||||
// Retrieves the defines and removes comments from the source lines
|
||||
auto defines = std::unordered_map<std::string, int>();
|
||||
auto lines = PreprocessDefinesAndComments(kernel_source, defines);
|
||||
|
||||
// Unrolls loops (single level each call)
|
||||
lines = PreprocessUnrollLoops(lines, defines);
|
||||
lines = PreprocessUnrollLoops(lines, defines);
|
||||
|
||||
// Gather the results
|
||||
auto processed_kernel = std::string{""};
|
||||
for (const auto& line : lines) {
|
||||
processed_kernel += line + "\n";
|
||||
}
|
||||
return processed_kernel;
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
#define CLBLAST_KERNEL_PREPROCESSOR_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "utilities/utilities.hpp"
|
||||
|
||||
|
|
|
@ -29,7 +29,6 @@ void Xaxpy(const int n, const real_arg arg_alpha,
|
|||
const real alpha = GetRealArg(arg_alpha);
|
||||
|
||||
// Loops over the work that needs to be done (allows for an arbitrary number of threads)
|
||||
#pragma unroll
|
||||
for (int id = get_global_id(0); id < n; id += get_global_size(0)) {
|
||||
real xvalue = xgm[id*x_inc + x_offset];
|
||||
MultiplyAdd(ygm[id*y_inc + y_offset], alpha, xvalue);
|
||||
|
@ -46,7 +45,7 @@ void XaxpyFaster(const int n, const real_arg arg_alpha,
|
|||
|
||||
if (get_global_id(0) < n / (VW)) {
|
||||
#pragma unroll
|
||||
for (int w=0; w<WPT; ++w) {
|
||||
for (int w = 0; w < WPT; w += 1) {
|
||||
const int id = w*get_global_size(0) + get_global_id(0);
|
||||
realV xvalue = xgm[id];
|
||||
realV yvalue = ygm[id];
|
||||
|
@ -64,7 +63,7 @@ void XaxpyFastest(const int n, const real_arg arg_alpha,
|
|||
const real alpha = GetRealArg(arg_alpha);
|
||||
|
||||
#pragma unroll
|
||||
for (int w=0; w<WPT; ++w) {
|
||||
for (int w = 0; w < WPT; w += 1) {
|
||||
const int id = w*get_global_size(0) + get_global_id(0);
|
||||
realV xvalue = xgm[id];
|
||||
realV yvalue = ygm[id];
|
||||
|
@ -83,7 +82,6 @@ void XaxpyBatched(const int n, const __constant real_arg* arg_alphas,
|
|||
const real alpha = GetRealArg(arg_alphas[batch]);
|
||||
|
||||
// Loops over the work that needs to be done (allows for an arbitrary number of threads)
|
||||
#pragma unroll
|
||||
for (int id = get_global_id(0); id < n; id += get_global_size(0)) {
|
||||
real xvalue = xgm[id*x_inc + x_offsets[batch]];
|
||||
MultiplyAdd(ygm[id*y_inc + y_offsets[batch]], alpha, xvalue);
|
||||
|
|
Loading…
Reference in New Issue