tests : sync test-grad0 from ggml

This commit is contained in:
Georgi Gerganov 2023-06-24 19:40:18 +03:00
parent fdd1860911
commit 65bdd52a86
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -1,3 +1,4 @@
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
#include "ggml.h"
#include <math.h>
@ -5,6 +6,10 @@
#include <stdlib.h>
#include <assert.h>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
#define MAX_NARGS 3
#undef MIN
@ -197,8 +202,23 @@ bool check_gradient(
float max_error_abs,
float max_error_rel) {
static int n_threads = -1;
if (n_threads < 0) {
n_threads = GGML_DEFAULT_N_THREADS;
const char *env = getenv("GGML_N_THREADS");
if (env) {
n_threads = atoi(env);
}
printf("GGML_N_THREADS = %d\n", n_threads);
}
struct ggml_cgraph gf = ggml_build_forward (f);
gf.n_threads = n_threads;
struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
gb.n_threads = n_threads;
ggml_graph_compute(ctx0, &gf);
ggml_graph_reset (&gf);