Improved the kernel pre-processor in various ways
parent
35956f9db1
commit
14047861ce
|
@ -32,6 +32,29 @@
|
|||
namespace clblast {
|
||||
// =================================================================================================
|
||||
|
||||
void RaiseError(const std::string& source_line, const std::string& exception_message) {
|
||||
printf("Error in source line: %s\n", source_line.c_str());
|
||||
throw Error<std::runtime_error>(exception_message);
|
||||
}
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
bool HasOnlyDigits(const std::string& str) {
|
||||
return str.find_first_not_of("0123456789") == std::string::npos;
|
||||
}
|
||||
|
||||
size_t StringToDigit(const std::string& str, const std::string& source_line) {
|
||||
const auto split_dividers = split(str, '/');
|
||||
if (split_dividers.size() == 2) {
|
||||
return StringToDigit(split_dividers[0], source_line) / StringToDigit(split_dividers[1], source_line);
|
||||
}
|
||||
if (not HasOnlyDigits(str)) { RaiseError(source_line, "Not a digit: " + str); }
|
||||
return static_cast<size_t>(std::stoi(str));
|
||||
}
|
||||
|
||||
|
||||
// =================================================================================================
|
||||
|
||||
void FindReplace(std::string &subject, const std::string &search, const std::string &replace)
|
||||
{
|
||||
auto pos = size_t{0};
|
||||
|
@ -41,13 +64,18 @@ void FindReplace(std::string &subject, const std::string &search, const std::str
|
|||
}
|
||||
}
|
||||
|
||||
void SubstituteDefines(const std::unordered_map<std::string, int>& defines,
|
||||
std::string& source_string) {
|
||||
for (const auto &define : defines) {
|
||||
FindReplace(source_string, define.first, std::to_string(define.second));
|
||||
}
|
||||
}
|
||||
|
||||
bool EvaluateCondition(std::string condition,
|
||||
const std::unordered_map<std::string, int> &defines) {
|
||||
|
||||
// Replace macros in the string
|
||||
for (const auto &define : defines) {
|
||||
FindReplace(condition, define.first, std::to_string(define.second));
|
||||
}
|
||||
SubstituteDefines(defines, condition);
|
||||
|
||||
// Process the equality sign
|
||||
const auto equal_pos = condition.find(" == ");
|
||||
|
@ -68,17 +96,32 @@ std::vector<std::string> PreprocessDefinesAndComments(const std::string& source,
|
|||
|
||||
// Parse the input string into a vector of lines
|
||||
auto disabled = false;
|
||||
auto depth = 0;
|
||||
auto source_stream = std::stringstream(source);
|
||||
auto line = std::string{""};
|
||||
while (std::getline(source_stream, line)) {
|
||||
|
||||
// Decide whether or not to remain in 'disabled' mode
|
||||
if (line.find("#endif") != std::string::npos ||
|
||||
line.find("#elif") != std::string::npos) {
|
||||
disabled = false;
|
||||
if (line.find("#endif") != std::string::npos) {
|
||||
if (depth == 1) {
|
||||
disabled = false;
|
||||
}
|
||||
depth--;
|
||||
}
|
||||
if (line.find("#else") != std::string::npos) {
|
||||
disabled = !disabled;
|
||||
if (depth == 1) {
|
||||
if (line.find("#elif") != std::string::npos) {
|
||||
disabled = false;
|
||||
}
|
||||
if (line.find("#else") != std::string::npos) {
|
||||
disabled = !disabled;
|
||||
}
|
||||
}
|
||||
|
||||
// Measures the depth of pre-processor defines
|
||||
if ((line.find("#ifndef ") != std::string::npos) ||
|
||||
(line.find("#ifdef ") != std::string::npos) ||
|
||||
(line.find("#if ") != std::string::npos)) {
|
||||
depth++;
|
||||
}
|
||||
|
||||
// Not in a disabled-block
|
||||
|
@ -150,13 +193,6 @@ std::vector<std::string> PreprocessDefinesAndComments(const std::string& source,
|
|||
|
||||
// =================================================================================================
|
||||
|
||||
inline void SubstituteDefines(const std::unordered_map<std::string, int>& defines,
|
||||
std::string& source_string) {
|
||||
if (defines.count(source_string) == 1) {
|
||||
source_string = ToString(defines.at(source_string));
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: unroll loops
|
||||
std::vector<std::string> PreprocessUnrollLoops(const std::vector<std::string>& source_lines,
|
||||
const std::unordered_map<std::string, int>& defines,
|
||||
|
@ -193,19 +229,19 @@ std::vector<std::string> PreprocessUnrollLoops(const std::vector<std::string>& s
|
|||
if (promote_next_array_to_registers) {
|
||||
promote_next_array_to_registers = false;
|
||||
const auto line_split1 = split(line, '[');
|
||||
if (line_split1.size() != 2) { throw Error<std::runtime_error>("Mis-formatted array declaration #0"); }
|
||||
if (line_split1.size() != 2) { RaiseError(line, "Mis-formatted array declaration #0"); }
|
||||
const auto line_split2 = split(line_split1[1], ']');
|
||||
if (line_split2.size() != 2) { throw Error<std::runtime_error>("Mis-formatted array declaration #1"); }
|
||||
if (line_split2.size() != 2) { RaiseError(line, "Mis-formatted array declaration #1"); }
|
||||
auto array_size_string = line_split2[0];
|
||||
SubstituteDefines(defines, array_size_string);
|
||||
const auto array_size = std::stoi(array_size_string);
|
||||
for (auto loop_iter = 0; loop_iter < array_size; ++loop_iter) {
|
||||
const auto array_size = StringToDigit(array_size_string, line);
|
||||
for (auto loop_iter = size_t{0}; loop_iter < array_size; ++loop_iter) {
|
||||
lines.emplace_back(line_split1[0] + "_" + ToString(loop_iter) + line_split2[1]);
|
||||
}
|
||||
|
||||
// Stores the array name
|
||||
const auto array_name_split = split(line_split1[0], ' ');
|
||||
if (array_name_split.size() < 2) { throw Error<std::runtime_error>("Mis-formatted array declaration #2"); }
|
||||
if (array_name_split.size() < 2) { RaiseError(line, "Mis-formatted array declaration #2"); }
|
||||
const auto array_name = array_name_split[array_name_split.size() - 1];
|
||||
arrays_to_registers[array_name] = brackets; // TODO: bracket count not used currently for scope checking
|
||||
continue;
|
||||
|
@ -218,16 +254,16 @@ std::vector<std::string> PreprocessUnrollLoops(const std::vector<std::string>& s
|
|||
|
||||
// Parses loop structure
|
||||
const auto for_pos = line.find("for (");
|
||||
if (for_pos == std::string::npos) { throw Error<std::runtime_error>("Mis-formatted for-loop #0"); }
|
||||
if (for_pos == std::string::npos) { RaiseError(line, "Mis-formatted for-loop #0"); }
|
||||
const auto remainder = line.substr(for_pos + 5); // length of "for ("
|
||||
const auto line_split = split(remainder, ' ');
|
||||
if (line_split.size() != 11) { throw Error<std::runtime_error>("Mis-formatted for-loop #1"); }
|
||||
if (line_split.size() != 11) { RaiseError(line, "Mis-formatted for-loop #1"); }
|
||||
|
||||
// Retrieves loop information (and checks for assumptions)
|
||||
const auto variable_type = line_split[0];
|
||||
const auto variable_name = line_split[1];
|
||||
if (variable_name != line_split[4]) { throw Error<std::runtime_error>("Mis-formatted for-loop #2"); }
|
||||
if (variable_name != line_split[7]) { throw Error<std::runtime_error>("Mis-formatted for-loop #3"); }
|
||||
if (variable_name != line_split[4]) { RaiseError(line, "Mis-formatted for-loop #2"); }
|
||||
if (variable_name != line_split[7]) { RaiseError(line, "Mis-formatted for-loop #3"); }
|
||||
auto loop_start_string = line_split[3];
|
||||
auto loop_end_string = line_split[6];
|
||||
auto loop_increment_string = line_split[9];
|
||||
|
@ -239,9 +275,9 @@ std::vector<std::string> PreprocessUnrollLoops(const std::vector<std::string>& s
|
|||
SubstituteDefines(defines, loop_start_string);
|
||||
SubstituteDefines(defines, loop_end_string);
|
||||
SubstituteDefines(defines, loop_increment_string);
|
||||
const auto loop_start = std::stoi(loop_start_string);
|
||||
const auto loop_end = std::stoi(loop_end_string);
|
||||
const auto loop_increment = std::stoi(loop_increment_string);
|
||||
const auto loop_start = StringToDigit(loop_start_string, line);
|
||||
const auto loop_end = StringToDigit(loop_end_string, line);
|
||||
const auto loop_increment = StringToDigit(loop_increment_string, line);
|
||||
auto indent = std::string{""};
|
||||
for (auto i = size_t{0}; i < for_pos; ++i) { indent += " "; }
|
||||
|
||||
|
@ -264,7 +300,6 @@ std::vector<std::string> PreprocessUnrollLoops(const std::vector<std::string>& s
|
|||
// Array to register promotion, e.g. arr[w] to {arr_0, arr_1}
|
||||
if (array_to_register_promotion) {
|
||||
for (const auto array_name_map : arrays_to_registers) { // only if marked to be promoted
|
||||
printf("%s: %s\n", loop_line.c_str(), array_name_map.first.c_str());
|
||||
FindReplace(loop_line, array_name_map.first + "[" + variable_name + "]",
|
||||
array_name_map.first + "_" + ToString(loop_iter));
|
||||
}
|
||||
|
@ -303,6 +338,13 @@ std::string PreprocessKernelSource(const std::string& kernel_source) {
|
|||
for (const auto& line : lines) {
|
||||
processed_kernel += line + "\n";
|
||||
}
|
||||
|
||||
// Debugging
|
||||
if (false) {
|
||||
for (auto i = size_t{0}; i < lines.size(); ++i) {
|
||||
printf("[%zu] %s\n", i, lines[i].c_str());
|
||||
}
|
||||
}
|
||||
return processed_kernel;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue