2017-11-25 10:24:05 +01:00
|
|
|
|
|
|
|
// =================================================================================================
|
|
|
|
// This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This
|
|
|
|
// project loosely follows the Google C++ styleguide and uses a tab-size of two spaces and a max-
|
|
|
|
// width of 100 characters per line.
|
|
|
|
//
|
|
|
|
// Author(s):
|
|
|
|
// Cedric Nugteren <www.cedricnugteren.nl>
|
|
|
|
//
|
|
|
|
// This file implements the OpenCL kernel preprocessor (see the header for more information).
|
|
|
|
//
|
2017-11-25 17:46:01 +01:00
|
|
|
// Restrictions:
|
|
|
|
// - Use comments only single-line "//" style, not "/*" and "*/"
|
|
|
|
// - Don't use strings with characters parsed (e.g. '//', '}', '#ifdef')
|
|
|
|
// - Supports conditionals: #if #ifdef #ifndef #else #elif #endif
|
|
|
|
// - ...with the operators: ==
|
|
|
|
// - "#pragma unroll" requires next loop in the form "for (int w = 0; w < 4; w += 1) {"
|
|
|
|
// The above also requires the spaces in that exact form
|
2017-11-26 13:32:48 +01:00
|
|
|
// - The loop variable should be a unique string within the code in the for-loop body (e.g. don't
|
|
|
|
// use 'i' or 'w' but rather '_w' or a longer name.
|
2017-11-25 17:46:01 +01:00
|
|
|
//
|
2017-11-25 10:24:05 +01:00
|
|
|
// =================================================================================================
|
|
|
|
|
2017-11-25 17:46:01 +01:00
|
|
|
#include <string>
|
|
|
|
#include <sstream>
|
|
|
|
#include <algorithm>
|
|
|
|
#include <unordered_map>
|
2017-11-25 10:24:05 +01:00
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
#include "kernel_preprocessor.hpp"
|
|
|
|
|
|
|
|
namespace clblast {
|
|
|
|
// =================================================================================================
|
|
|
|
|
2017-11-28 20:52:08 +01:00
|
|
|
void RaiseError(const std::string& source_line, const std::string& exception_message) {
|
|
|
|
printf("Error in source line: %s\n", source_line.c_str());
|
|
|
|
throw Error<std::runtime_error>(exception_message);
|
|
|
|
}
|
|
|
|
|
|
|
|
// =================================================================================================
|
|
|
|
|
|
|
|
bool HasOnlyDigits(const std::string& str) {
|
|
|
|
return str.find_first_not_of("0123456789") == std::string::npos;
|
|
|
|
}
|
|
|
|
|
|
|
|
size_t StringToDigit(const std::string& str, const std::string& source_line) {
|
|
|
|
const auto split_dividers = split(str, '/');
|
|
|
|
if (split_dividers.size() == 2) {
|
|
|
|
return StringToDigit(split_dividers[0], source_line) / StringToDigit(split_dividers[1], source_line);
|
|
|
|
}
|
|
|
|
if (not HasOnlyDigits(str)) { RaiseError(source_line, "Not a digit: " + str); }
|
|
|
|
return static_cast<size_t>(std::stoi(str));
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// =================================================================================================
|
|
|
|
|
2017-11-25 17:46:01 +01:00
|
|
|
void FindReplace(std::string &subject, const std::string &search, const std::string &replace)
|
|
|
|
{
|
|
|
|
auto pos = size_t{0};
|
|
|
|
while ((pos = subject.find(search, pos)) != std::string::npos) {
|
|
|
|
subject.replace(pos, search.length(), replace);
|
|
|
|
pos += replace.length();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-11-28 20:52:08 +01:00
|
|
|
void SubstituteDefines(const std::unordered_map<std::string, int>& defines,
|
|
|
|
std::string& source_string) {
|
|
|
|
for (const auto &define : defines) {
|
|
|
|
FindReplace(source_string, define.first, std::to_string(define.second));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-11-25 17:46:01 +01:00
|
|
|
bool EvaluateCondition(std::string condition,
|
|
|
|
const std::unordered_map<std::string, int> &defines) {
|
|
|
|
|
|
|
|
// Replace macros in the string
|
2017-11-28 20:52:08 +01:00
|
|
|
SubstituteDefines(defines, condition);
|
2017-11-25 17:46:01 +01:00
|
|
|
|
|
|
|
// Process the equality sign
|
|
|
|
const auto equal_pos = condition.find(" == ");
|
|
|
|
if (equal_pos != std::string::npos) {
|
|
|
|
const auto left = condition.substr(0, equal_pos);
|
|
|
|
const auto right = condition.substr(equal_pos + 4);
|
|
|
|
return (left == right);
|
|
|
|
}
|
|
|
|
return false; // unknown error
|
|
|
|
}
|
|
|
|
|
|
|
|
// =================================================================================================
|
|
|
|
|
|
|
|
// First pass: detect defines and comments
|
|
|
|
std::vector<std::string> PreprocessDefinesAndComments(const std::string& source,
|
|
|
|
std::unordered_map<std::string, int>& defines) {
|
|
|
|
auto lines = std::vector<std::string>();
|
|
|
|
|
|
|
|
// Parse the input string into a vector of lines
|
|
|
|
auto disabled = false;
|
2017-11-28 20:52:08 +01:00
|
|
|
auto depth = 0;
|
2017-11-25 17:46:01 +01:00
|
|
|
auto source_stream = std::stringstream(source);
|
|
|
|
auto line = std::string{""};
|
|
|
|
while (std::getline(source_stream, line)) {
|
|
|
|
|
|
|
|
// Decide whether or not to remain in 'disabled' mode
|
2017-11-28 20:52:08 +01:00
|
|
|
if (line.find("#endif") != std::string::npos) {
|
|
|
|
if (depth == 1) {
|
|
|
|
disabled = false;
|
|
|
|
}
|
|
|
|
depth--;
|
|
|
|
}
|
|
|
|
if (depth == 1) {
|
|
|
|
if (line.find("#elif") != std::string::npos) {
|
|
|
|
disabled = false;
|
|
|
|
}
|
|
|
|
if (line.find("#else") != std::string::npos) {
|
|
|
|
disabled = !disabled;
|
|
|
|
}
|
2017-11-25 17:46:01 +01:00
|
|
|
}
|
2017-11-28 20:52:08 +01:00
|
|
|
|
|
|
|
// Measures the depth of pre-processor defines
|
|
|
|
if ((line.find("#ifndef ") != std::string::npos) ||
|
|
|
|
(line.find("#ifdef ") != std::string::npos) ||
|
|
|
|
(line.find("#if ") != std::string::npos)) {
|
|
|
|
depth++;
|
2017-11-26 13:32:48 +01:00
|
|
|
}
|
2017-11-25 17:46:01 +01:00
|
|
|
|
|
|
|
// Not in a disabled-block
|
|
|
|
if (!disabled) {
|
|
|
|
|
|
|
|
// Skip empty lines
|
|
|
|
if (line == "") { continue; }
|
|
|
|
|
|
|
|
// Single line comments
|
|
|
|
const auto comment_pos = line.find("//");
|
|
|
|
if (comment_pos != std::string::npos) {
|
|
|
|
if (comment_pos == 0) { continue; }
|
|
|
|
line.erase(comment_pos);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Detect #define macros
|
|
|
|
const auto define_pos = line.find("#define ");
|
|
|
|
if (define_pos != std::string::npos) {
|
|
|
|
const auto define = line.substr(define_pos + 8); // length of "#define "
|
|
|
|
const auto value_pos = define.find(" ");
|
|
|
|
const auto value = define.substr(value_pos + 1);
|
|
|
|
const auto name = define.substr(0, value_pos);
|
|
|
|
defines.emplace(name, std::stoi(value));
|
|
|
|
//continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Detect #ifndef blocks
|
|
|
|
const auto ifndef_pos = line.find("#ifndef ");
|
|
|
|
if (ifndef_pos != std::string::npos) {
|
|
|
|
const auto define = line.substr(ifndef_pos + 8); // length of "#ifndef "
|
|
|
|
if (defines.find(define) != defines.end()) { disabled = true; }
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Detect #ifdef blocks
|
|
|
|
const auto ifdef_pos = line.find("#ifdef ");
|
|
|
|
if (ifdef_pos != std::string::npos) {
|
|
|
|
const auto define = line.substr(ifdef_pos + 7); // length of "#ifdef "
|
|
|
|
if (defines.find(define) == defines.end()) { disabled = true; }
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Detect #if blocks
|
|
|
|
const auto if_pos = line.find("#if ");
|
|
|
|
if (if_pos != std::string::npos) {
|
|
|
|
const auto condition = line.substr(if_pos + 4); // length of "#if "
|
|
|
|
if (!EvaluateCondition(condition, defines)) { disabled = true; }
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Detect #elif blocks
|
|
|
|
const auto elif_pos = line.find("#elif ");
|
|
|
|
if (elif_pos != std::string::npos) {
|
|
|
|
const auto condition = line.substr(elif_pos + 6); // length of "#elif "
|
|
|
|
if (!EvaluateCondition(condition, defines)) { disabled = true; }
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
2017-11-26 13:32:48 +01:00
|
|
|
// Discard #else and #endif statements
|
|
|
|
if (line.find("#endif") != std::string::npos || line.find("#else") != std::string::npos) {
|
2017-11-25 17:46:01 +01:00
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
lines.push_back(line);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return lines;
|
|
|
|
}
|
|
|
|
|
|
|
|
// =================================================================================================
|
|
|
|
|
|
|
|
// Second pass: unroll loops
|
|
|
|
std::vector<std::string> PreprocessUnrollLoops(const std::vector<std::string>& source_lines,
|
2017-11-27 20:26:30 +01:00
|
|
|
const std::unordered_map<std::string, int>& defines,
|
|
|
|
std::unordered_map<std::string, size_t>& arrays_to_registers,
|
|
|
|
const bool array_to_register_promotion) {
|
2017-11-25 17:46:01 +01:00
|
|
|
auto lines = std::vector<std::string>();
|
|
|
|
|
2017-11-27 20:26:30 +01:00
|
|
|
auto brackets = size_t{0};
|
2017-11-25 17:46:01 +01:00
|
|
|
auto unroll_next_loop = false;
|
2017-11-27 20:26:30 +01:00
|
|
|
auto promote_next_array_to_registers = false;
|
2017-11-25 17:46:01 +01:00
|
|
|
|
|
|
|
for (auto line_id = size_t{0}; line_id < source_lines.size(); ++line_id) {
|
|
|
|
const auto line = source_lines[line_id];
|
|
|
|
|
2017-11-27 20:26:30 +01:00
|
|
|
// Detect #pragma promote_to_registers directives (unofficial pragma)
|
|
|
|
if (array_to_register_promotion) {
|
|
|
|
if (line.find("#pragma promote_to_registers") != std::string::npos) {
|
|
|
|
promote_next_array_to_registers = true;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-11-25 17:46:01 +01:00
|
|
|
// Detect #pragma unroll directives
|
|
|
|
if (line.find("#pragma unroll") != std::string::npos) {
|
|
|
|
unroll_next_loop = true;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Brackets
|
|
|
|
brackets += std::count(line.begin(), line.end(), '{');
|
|
|
|
brackets -= std::count(line.begin(), line.end(), '}');
|
|
|
|
|
2017-11-27 20:26:30 +01:00
|
|
|
// Promote array declarations to registers
|
|
|
|
if (promote_next_array_to_registers) {
|
|
|
|
promote_next_array_to_registers = false;
|
|
|
|
const auto line_split1 = split(line, '[');
|
2017-11-28 20:52:08 +01:00
|
|
|
if (line_split1.size() != 2) { RaiseError(line, "Mis-formatted array declaration #0"); }
|
2017-11-27 20:26:30 +01:00
|
|
|
const auto line_split2 = split(line_split1[1], ']');
|
2017-11-28 20:52:08 +01:00
|
|
|
if (line_split2.size() != 2) { RaiseError(line, "Mis-formatted array declaration #1"); }
|
2017-11-27 20:26:30 +01:00
|
|
|
auto array_size_string = line_split2[0];
|
|
|
|
SubstituteDefines(defines, array_size_string);
|
2017-11-28 20:52:08 +01:00
|
|
|
const auto array_size = StringToDigit(array_size_string, line);
|
|
|
|
for (auto loop_iter = size_t{0}; loop_iter < array_size; ++loop_iter) {
|
2017-11-27 20:26:30 +01:00
|
|
|
lines.emplace_back(line_split1[0] + "_" + ToString(loop_iter) + line_split2[1]);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Stores the array name
|
|
|
|
const auto array_name_split = split(line_split1[0], ' ');
|
2017-11-28 20:52:08 +01:00
|
|
|
if (array_name_split.size() < 2) { RaiseError(line, "Mis-formatted array declaration #2"); }
|
2017-11-27 20:26:30 +01:00
|
|
|
const auto array_name = array_name_split[array_name_split.size() - 1];
|
|
|
|
arrays_to_registers[array_name] = brackets; // TODO: bracket count not used currently for scope checking
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2017-11-25 17:46:01 +01:00
|
|
|
// Loop unrolling assuming it to be in the form "for (int w = 0; w < 4; w += 1) {"
|
|
|
|
if (unroll_next_loop) {
|
|
|
|
unroll_next_loop = false;
|
|
|
|
|
|
|
|
// Parses loop structure
|
|
|
|
const auto for_pos = line.find("for (");
|
2017-11-28 20:52:08 +01:00
|
|
|
if (for_pos == std::string::npos) { RaiseError(line, "Mis-formatted for-loop #0"); }
|
2017-11-25 17:46:01 +01:00
|
|
|
const auto remainder = line.substr(for_pos + 5); // length of "for ("
|
|
|
|
const auto line_split = split(remainder, ' ');
|
2017-11-28 20:52:08 +01:00
|
|
|
if (line_split.size() != 11) { RaiseError(line, "Mis-formatted for-loop #1"); }
|
2017-11-25 17:46:01 +01:00
|
|
|
|
|
|
|
// Retrieves loop information (and checks for assumptions)
|
|
|
|
const auto variable_type = line_split[0];
|
|
|
|
const auto variable_name = line_split[1];
|
2017-11-28 20:52:08 +01:00
|
|
|
if (variable_name != line_split[4]) { RaiseError(line, "Mis-formatted for-loop #2"); }
|
|
|
|
if (variable_name != line_split[7]) { RaiseError(line, "Mis-formatted for-loop #3"); }
|
2017-11-25 17:46:01 +01:00
|
|
|
auto loop_start_string = line_split[3];
|
|
|
|
auto loop_end_string = line_split[6];
|
2017-11-26 13:32:48 +01:00
|
|
|
auto loop_increment_string = line_split[9];
|
2017-11-25 17:46:01 +01:00
|
|
|
remove_character(loop_start_string, ';');
|
|
|
|
remove_character(loop_end_string, ';');
|
2017-11-26 13:32:48 +01:00
|
|
|
remove_character(loop_increment_string, ')');
|
2017-11-25 17:46:01 +01:00
|
|
|
|
|
|
|
// Parses loop information
|
2017-11-26 13:32:48 +01:00
|
|
|
SubstituteDefines(defines, loop_start_string);
|
|
|
|
SubstituteDefines(defines, loop_end_string);
|
|
|
|
SubstituteDefines(defines, loop_increment_string);
|
2017-11-28 20:52:08 +01:00
|
|
|
const auto loop_start = StringToDigit(loop_start_string, line);
|
|
|
|
const auto loop_end = StringToDigit(loop_end_string, line);
|
|
|
|
const auto loop_increment = StringToDigit(loop_increment_string, line);
|
2017-11-25 17:46:01 +01:00
|
|
|
auto indent = std::string{""};
|
|
|
|
for (auto i = size_t{0}; i < for_pos; ++i) { indent += " "; }
|
|
|
|
|
|
|
|
// Start of the loop
|
|
|
|
line_id++;
|
|
|
|
const auto loop_num_brackets = brackets;
|
|
|
|
const auto line_id_start = line_id;
|
2017-11-26 13:32:48 +01:00
|
|
|
for (auto loop_iter = loop_start; loop_iter < loop_end; loop_iter += loop_increment) {
|
2017-11-25 17:46:01 +01:00
|
|
|
line_id = line_id_start;
|
|
|
|
brackets = loop_num_brackets;
|
|
|
|
lines.emplace_back(indent + "{");
|
|
|
|
|
|
|
|
// Body of the loop
|
2017-11-26 13:32:48 +01:00
|
|
|
//lines.emplace_back(indent + " " + variable_type + " " + variable_name + " = " + ToString(loop_iter) + ";");
|
2017-11-25 17:46:01 +01:00
|
|
|
while (brackets >= loop_num_brackets) {
|
2017-11-26 13:32:48 +01:00
|
|
|
auto loop_line = source_lines[line_id];
|
2017-11-25 17:46:01 +01:00
|
|
|
brackets += std::count(loop_line.begin(), loop_line.end(), '{');
|
|
|
|
brackets -= std::count(loop_line.begin(), loop_line.end(), '}');
|
2017-11-27 20:26:30 +01:00
|
|
|
|
|
|
|
// Array to register promotion, e.g. arr[w] to {arr_0, arr_1}
|
|
|
|
if (array_to_register_promotion) {
|
|
|
|
for (const auto array_name_map : arrays_to_registers) { // only if marked to be promoted
|
|
|
|
FindReplace(loop_line, array_name_map.first + "[" + variable_name + "]",
|
|
|
|
array_name_map.first + "_" + ToString(loop_iter));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Regular variable substitution
|
2017-11-26 13:32:48 +01:00
|
|
|
FindReplace(loop_line, variable_name, ToString(loop_iter));
|
2017-11-25 17:46:01 +01:00
|
|
|
lines.emplace_back(loop_line);
|
|
|
|
line_id++;
|
|
|
|
}
|
|
|
|
line_id--;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
else {
|
|
|
|
lines.emplace_back(line);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return lines;
|
|
|
|
}
|
|
|
|
|
|
|
|
// =================================================================================================
|
|
|
|
|
2017-11-25 10:24:05 +01:00
|
|
|
std::string PreprocessKernelSource(const std::string& kernel_source) {
|
2017-11-25 17:46:01 +01:00
|
|
|
|
|
|
|
// Retrieves the defines and removes comments from the source lines
|
|
|
|
auto defines = std::unordered_map<std::string, int>();
|
|
|
|
auto lines = PreprocessDefinesAndComments(kernel_source, defines);
|
|
|
|
|
|
|
|
// Unrolls loops (single level each call)
|
2017-11-27 20:26:30 +01:00
|
|
|
auto arrays_to_registers = std::unordered_map<std::string, size_t>();
|
|
|
|
lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers, true);
|
|
|
|
lines = PreprocessUnrollLoops(lines, defines, arrays_to_registers, true);
|
2017-11-25 17:46:01 +01:00
|
|
|
|
|
|
|
// Gather the results
|
|
|
|
auto processed_kernel = std::string{""};
|
|
|
|
for (const auto& line : lines) {
|
|
|
|
processed_kernel += line + "\n";
|
|
|
|
}
|
2017-11-28 20:52:08 +01:00
|
|
|
|
|
|
|
// Debugging
|
|
|
|
if (false) {
|
|
|
|
for (auto i = size_t{0}; i < lines.size(); ++i) {
|
|
|
|
printf("[%zu] %s\n", i, lines[i].c_str());
|
|
|
|
}
|
|
|
|
}
|
2017-11-25 10:24:05 +01:00
|
|
|
return processed_kernel;
|
|
|
|
}
|
|
|
|
|
|
|
|
// =================================================================================================
|
|
|
|
} // namespace clblast
|