mirror of
https://github.com/CNugteren/CLBlast.git
synced 2024-07-07 12:23:46 +02:00
Fixed some things in the tuner: bugs, style, and defaults to random search
This commit is contained in:
parent
6e95752054
commit
54e160cd88
|
@ -42,8 +42,8 @@ class TuneXgemm {
|
||||||
// The list of arguments relevant for this routine
|
// The list of arguments relevant for this routine
|
||||||
static std::vector<std::string> GetOptions() {
|
static std::vector<std::string> GetOptions() {
|
||||||
return {kArgM, kArgN, kArgK, kArgAlpha, kArgBeta, kArgFraction,
|
return {kArgM, kArgN, kArgK, kArgAlpha, kArgBeta, kArgFraction,
|
||||||
kArgHeuristicSelection, kArgPsoSwarmSize,
|
kArgHeuristicSelection, kArgPsoSwarmSize,
|
||||||
kArgPsoInfGlobal, kArgPsoInfLocal, kArgPsoInfRandom};
|
kArgPsoInfGlobal, kArgPsoInfLocal, kArgPsoInfRandom};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tests for valid arguments
|
// Tests for valid arguments
|
||||||
|
@ -60,7 +60,7 @@ class TuneXgemm {
|
||||||
static double DefaultInfluenceGlobalPSO(){ return 0.1; }
|
static double DefaultInfluenceGlobalPSO(){ return 0.1; }
|
||||||
static double DefaultInfluenceLocalPSO(){ return 0.3; }
|
static double DefaultInfluenceLocalPSO(){ return 0.3; }
|
||||||
static double DefaultInfluenceRandomPSO(){ return 0.6; }
|
static double DefaultInfluenceRandomPSO(){ return 0.6; }
|
||||||
static size_t DefaultHeuristic(){ return static_cast<size_t> (cltune::SearchMethod::PSO);}
|
static size_t DefaultHeuristic(){ return static_cast<size_t>(cltune::SearchMethod::RandomSearch); }
|
||||||
static double DefaultMaxTempAnn(){ return 1.0;}
|
static double DefaultMaxTempAnn(){ return 1.0;}
|
||||||
|
|
||||||
// Describes how to obtain the sizes of the buffers
|
// Describes how to obtain the sizes of the buffers
|
||||||
|
@ -180,13 +180,15 @@ class TuneXgemm {
|
||||||
|
|
||||||
// Returns which Heuristic to run
|
// Returns which Heuristic to run
|
||||||
static size_t GetHeuristic(const Arguments<T> &args){
|
static size_t GetHeuristic(const Arguments<T> &args){
|
||||||
// Use full-search to explore all parameter combinations or another strategy to search only a
|
if (V==1) { return static_cast<size_t>(cltune::SearchMethod::FullSearch); }
|
||||||
// part of the parameter values. The fraction is set as a command-line argument.
|
|
||||||
if (args.fraction == 1.0 || args.fraction == 0.0) {
|
|
||||||
return static_cast<size_t> (cltune::SearchMethod::FullSearch);
|
|
||||||
}
|
|
||||||
else {
|
else {
|
||||||
return args.heuristic_selection;
|
// Use full-search to explore all parameter combinations or another strategy to search only a
|
||||||
|
// part of the parameter values. The fraction is set as a command-line argument.
|
||||||
|
if (args.fraction == 1.0 || args.fraction == 0.0) {
|
||||||
|
return static_cast<size_t>(cltune::SearchMethod::FullSearch);
|
||||||
|
} else {
|
||||||
|
return args.heuristic_selection;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -42,8 +42,8 @@ class TuneXgemmDirect {
|
||||||
// The list of arguments relevant for this routine
|
// The list of arguments relevant for this routine
|
||||||
static std::vector<std::string> GetOptions() {
|
static std::vector<std::string> GetOptions() {
|
||||||
return {kArgM, kArgN, kArgK, kArgAlpha, kArgBeta, kArgFraction,
|
return {kArgM, kArgN, kArgK, kArgAlpha, kArgBeta, kArgFraction,
|
||||||
kArgHeuristicSelection, kArgPsoSwarmSize,
|
kArgHeuristicSelection, kArgPsoSwarmSize,
|
||||||
kArgPsoInfGlobal, kArgPsoInfLocal, kArgPsoInfRandom};
|
kArgPsoInfGlobal, kArgPsoInfLocal, kArgPsoInfRandom};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tests for valid arguments
|
// Tests for valid arguments
|
||||||
|
@ -60,7 +60,7 @@ class TuneXgemmDirect {
|
||||||
static double DefaultInfluenceGlobalPSO(){ return 0.1; }
|
static double DefaultInfluenceGlobalPSO(){ return 0.1; }
|
||||||
static double DefaultInfluenceLocalPSO(){ return 0.3; }
|
static double DefaultInfluenceLocalPSO(){ return 0.3; }
|
||||||
static double DefaultInfluenceRandomPSO(){ return 0.6; }
|
static double DefaultInfluenceRandomPSO(){ return 0.6; }
|
||||||
static size_t DefaultHeuristic(){ return static_cast<size_t>(cltune::SearchMethod::PSO);}
|
static size_t DefaultHeuristic(){ return static_cast<size_t>(cltune::SearchMethod::RandomSearch);}
|
||||||
static double DefaultMaxTempAnn(){ return 1.0;}
|
static double DefaultMaxTempAnn(){ return 1.0;}
|
||||||
|
|
||||||
// Describes how to obtain the sizes of the buffers
|
// Describes how to obtain the sizes of the buffers
|
||||||
|
@ -177,13 +177,15 @@ class TuneXgemmDirect {
|
||||||
|
|
||||||
// Returns which Heuristic to run
|
// Returns which Heuristic to run
|
||||||
static size_t GetHeuristic(const Arguments<T> &args){
|
static size_t GetHeuristic(const Arguments<T> &args){
|
||||||
// Use full-search to explore all parameter combinations or another strategy to search only a
|
if (V==1) { return static_cast<size_t>(cltune::SearchMethod::FullSearch); }
|
||||||
// part of the parameter values. The fraction is set as a command-line argument.
|
|
||||||
if (args.fraction == 1.0 || args.fraction == 0.0) {
|
|
||||||
return static_cast<size_t> (cltune::SearchMethod::FullSearch);
|
|
||||||
}
|
|
||||||
else {
|
else {
|
||||||
return args.heuristic_selection;
|
// Use full-search to explore all parameter combinations or another strategy to search only a
|
||||||
|
// part of the parameter values. The fraction is set as a command-line argument.
|
||||||
|
if (args.fraction == 1.0 || args.fraction == 0.0) {
|
||||||
|
return static_cast<size_t>(cltune::SearchMethod::FullSearch);
|
||||||
|
} else {
|
||||||
|
return args.heuristic_selection;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -48,12 +48,12 @@ void Tuner(int argc, char* argv[]) {
|
||||||
if (o == kArgBeta) { args.beta = GetArgument(command_line_args, help, kArgBeta, GetScalar<T>()); }
|
if (o == kArgBeta) { args.beta = GetArgument(command_line_args, help, kArgBeta, GetScalar<T>()); }
|
||||||
if (o == kArgFraction) { args.fraction = GetArgument(command_line_args, help, kArgFraction, C::DefaultFraction()); }
|
if (o == kArgFraction) { args.fraction = GetArgument(command_line_args, help, kArgFraction, C::DefaultFraction()); }
|
||||||
if (o == kArgBatchCount) { args.batch_count = GetArgument(command_line_args, help, kArgBatchCount, C::DefaultBatchCount()); }
|
if (o == kArgBatchCount) { args.batch_count = GetArgument(command_line_args, help, kArgBatchCount, C::DefaultBatchCount()); }
|
||||||
if (o == kArgHeuristicSelection) {args.heuristic_selection = GetArgument(command_line_args, help, kArgHeuristicSelection, C::DefaultHeuristic()); }
|
if (o == kArgHeuristicSelection) {args.heuristic_selection = GetArgument(command_line_args, help, kArgHeuristicSelection, C::DefaultHeuristic()); }
|
||||||
if (o == kArgPsoSwarmSize) {args.pso_swarm_size = GetArgument(command_line_args, help, kArgPsoSwarmSize , C::DefaultSwarmSizePSO()); }
|
if (o == kArgPsoSwarmSize) {args.pso_swarm_size = GetArgument(command_line_args, help, kArgPsoSwarmSize , C::DefaultSwarmSizePSO()); }
|
||||||
if (o == kArgPsoInfGlobal) {args.pso_inf_global = GetArgument(command_line_args, help, kArgPsoInfGlobal, C::DefaultInfluenceGlobalPSO()); }
|
if (o == kArgPsoInfGlobal) {args.pso_inf_global = GetArgument(command_line_args, help, kArgPsoInfGlobal, C::DefaultInfluenceGlobalPSO()); }
|
||||||
if (o == kArgPsoInfLocal) {args.pso_inf_local = GetArgument(command_line_args, help, kArgPsoInfLocal, C::DefaultInfluenceLocalPSO()); }
|
if (o == kArgPsoInfLocal) {args.pso_inf_local = GetArgument(command_line_args, help, kArgPsoInfLocal, C::DefaultInfluenceLocalPSO()); }
|
||||||
if (o == kArgPsoInfRandom) {args.pso_inf_random = GetArgument(command_line_args, help, kArgPsoInfRandom, C::DefaultInfluenceRandomPSO()); }
|
if (o == kArgPsoInfRandom) {args.pso_inf_random = GetArgument(command_line_args, help, kArgPsoInfRandom, C::DefaultInfluenceRandomPSO()); }
|
||||||
if (o == kArgAnnMaxTemp) {args.ann_max_temperature = GetArgument(command_line_args, help, kArgAnnMaxTemp, C::DefaultMaxTempAnn());}
|
if (o == kArgAnnMaxTemp) {args.ann_max_temperature = GetArgument(command_line_args, help, kArgAnnMaxTemp, C::DefaultMaxTempAnn());}
|
||||||
}
|
}
|
||||||
const auto num_runs = GetArgument(command_line_args, help, kArgNumRuns, C::DefaultNumRuns());
|
const auto num_runs = GetArgument(command_line_args, help, kArgNumRuns, C::DefaultNumRuns());
|
||||||
|
|
||||||
|
@ -102,9 +102,9 @@ void Tuner(int argc, char* argv[]) {
|
||||||
auto method = C::GetHeuristic(args);
|
auto method = C::GetHeuristic(args);
|
||||||
|
|
||||||
if (method == 1) { tuner.UseRandomSearch(1.0/args.fraction); }
|
if (method == 1) { tuner.UseRandomSearch(1.0/args.fraction); }
|
||||||
else if (method == 2) { tuner.UseAnnealing(args.fraction, args.ann_max_temperature); }
|
else if (method == 2) { tuner.UseAnnealing(1.0/args.fraction, args.ann_max_temperature); }
|
||||||
else if (method == 3) {
|
else if (method == 3) {
|
||||||
tuner.UsePSO(args.fraction, args.pso_swarm_size, args.pso_inf_global, args.pso_inf_local, args.pso_inf_random);
|
tuner.UsePSO(1.0/args.fraction, args.pso_swarm_size, args.pso_inf_global, args.pso_inf_local, args.pso_inf_random);
|
||||||
}
|
}
|
||||||
else { tuner.UseFullSearch(); }
|
else { tuner.UseFullSearch(); }
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue