ngxson HF Staff commited on
Commit
6c9cd9a
·
1 Parent(s): 9506ebb

ggml : add ggml_gelu_erf() (llama/13667)

Browse files

* ggml : add ggml_gelu_na (not approximated)

* fix naming order

* rename na --> erf

* apply review suggesions

* revert naming order

ggml/include/ggml.h CHANGED
@@ -528,14 +528,15 @@ extern "C" {
528
  GGML_UNARY_OP_STEP,
529
  GGML_UNARY_OP_TANH,
530
  GGML_UNARY_OP_ELU,
531
- GGML_UNARY_OP_RELU,
532
  GGML_UNARY_OP_SIGMOID,
533
  GGML_UNARY_OP_GELU,
 
534
  GGML_UNARY_OP_GELU_QUICK,
535
  GGML_UNARY_OP_SILU,
536
  GGML_UNARY_OP_HARDSWISH,
537
  GGML_UNARY_OP_HARDSIGMOID,
538
  GGML_UNARY_OP_EXP,
 
539
 
540
  GGML_UNARY_OP_COUNT,
541
  };
@@ -1024,6 +1025,16 @@ extern "C" {
1024
  struct ggml_context * ctx,
1025
  struct ggml_tensor * a);
1026
 
 
 
 
 
 
 
 
 
 
 
1027
  GGML_API struct ggml_tensor * ggml_gelu_quick(
1028
  struct ggml_context * ctx,
1029
  struct ggml_tensor * a);
 
528
  GGML_UNARY_OP_STEP,
529
  GGML_UNARY_OP_TANH,
530
  GGML_UNARY_OP_ELU,
 
531
  GGML_UNARY_OP_SIGMOID,
532
  GGML_UNARY_OP_GELU,
533
+ GGML_UNARY_OP_GELU_ERF,
534
  GGML_UNARY_OP_GELU_QUICK,
535
  GGML_UNARY_OP_SILU,
536
  GGML_UNARY_OP_HARDSWISH,
537
  GGML_UNARY_OP_HARDSIGMOID,
538
  GGML_UNARY_OP_EXP,
539
+ GGML_UNARY_OP_RELU,
540
 
541
  GGML_UNARY_OP_COUNT,
542
  };
 
1025
  struct ggml_context * ctx,
1026
  struct ggml_tensor * a);
1027
 
1028
+ // GELU using erf (error function) when possible
1029
+ // some backends may fallback to approximation based on Abramowitz and Stegun formula
1030
+ GGML_API struct ggml_tensor * ggml_gelu_erf(
1031
+ struct ggml_context * ctx,
1032
+ struct ggml_tensor * a);
1033
+
1034
+ GGML_API struct ggml_tensor * ggml_gelu_erf_inplace(
1035
+ struct ggml_context * ctx,
1036
+ struct ggml_tensor * a);
1037
+
1038
  GGML_API struct ggml_tensor * ggml_gelu_quick(
1039
  struct ggml_context * ctx,
1040
  struct ggml_tensor * a);
ggml/src/ggml-cpu/ggml-cpu.c CHANGED
@@ -2202,6 +2202,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
2202
  } break;
2203
 
2204
  case GGML_UNARY_OP_GELU:
 
2205
  case GGML_UNARY_OP_GELU_QUICK:
2206
  case GGML_UNARY_OP_SILU:
2207
  {
 
2202
  } break;
2203
 
2204
  case GGML_UNARY_OP_GELU:
2205
+ case GGML_UNARY_OP_GELU_ERF:
2206
  case GGML_UNARY_OP_GELU_QUICK:
2207
  case GGML_UNARY_OP_SILU:
2208
  {
ggml/src/ggml-cpu/ops.cpp CHANGED
@@ -2691,6 +2691,109 @@ static void ggml_compute_forward_gelu(
2691
  }
2692
  }
2693
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2694
  // ggml_compute_forward_gelu_quick
2695
 
2696
  static void ggml_compute_forward_gelu_quick_f32(
@@ -7749,6 +7852,10 @@ void ggml_compute_forward_unary(
7749
  {
7750
  ggml_compute_forward_gelu(params, dst);
7751
  } break;
 
 
 
 
7752
  case GGML_UNARY_OP_GELU_QUICK:
7753
  {
7754
  ggml_compute_forward_gelu_quick(params, dst);
 
2691
  }
2692
  }
2693
 
2694
+ // ggml_compute_forward_gelu_erf
2695
+
2696
+ static void ggml_compute_forward_gelu_erf_f32(
2697
+ const ggml_compute_params * params,
2698
+ ggml_tensor * dst) {
2699
+
2700
+ const ggml_tensor * src0 = dst->src[0];
2701
+
2702
+ assert(ggml_is_contiguous_1(src0));
2703
+ assert(ggml_is_contiguous_1(dst));
2704
+ assert(ggml_are_same_shape(src0, dst));
2705
+
2706
+ const int ith = params->ith;
2707
+ const int nth = params->nth;
2708
+
2709
+ const int nc = src0->ne[0];
2710
+ const int nr = ggml_nrows(src0);
2711
+
2712
+ // rows per thread
2713
+ const int dr = (nr + nth - 1)/nth;
2714
+
2715
+ // row range for this thread
2716
+ const int ir0 = dr*ith;
2717
+ const int ir1 = MIN(ir0 + dr, nr);
2718
+
2719
+ for (int i1 = ir0; i1 < ir1; i1++) {
2720
+ ggml_vec_gelu_erf_f32(nc,
2721
+ (float *) ((char *) dst->data + i1*( dst->nb[1])),
2722
+ (float *) ((char *) src0->data + i1*(src0->nb[1])));
2723
+
2724
+ #ifndef NDEBUG
2725
+ for (int k = 0; k < nc; k++) {
2726
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2727
+ GGML_UNUSED(x);
2728
+ assert(!isnan(x));
2729
+ assert(!isinf(x));
2730
+ }
2731
+ #endif
2732
+ }
2733
+ }
2734
+
2735
+ static void ggml_compute_forward_gelu_erf_f16(
2736
+ const ggml_compute_params * params,
2737
+ ggml_tensor * dst) {
2738
+
2739
+ const ggml_tensor * src0 = dst->src[0];
2740
+
2741
+ assert(ggml_is_contiguous_1(src0));
2742
+ assert(ggml_is_contiguous_1(dst));
2743
+ assert(ggml_are_same_shape(src0, dst));
2744
+
2745
+ const int ith = params->ith;
2746
+ const int nth = params->nth;
2747
+
2748
+ const int nc = src0->ne[0];
2749
+ const int nr = ggml_nrows(src0);
2750
+
2751
+ // rows per thread
2752
+ const int dr = (nr + nth - 1)/nth;
2753
+
2754
+ // row range for this thread
2755
+ const int ir0 = dr*ith;
2756
+ const int ir1 = MIN(ir0 + dr, nr);
2757
+
2758
+ for (int i1 = ir0; i1 < ir1; i1++) {
2759
+ ggml_vec_gelu_erf_f16(nc,
2760
+ (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2761
+ (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2762
+
2763
+ #ifndef NDEBUG
2764
+ for (int k = 0; k < nc; k++) {
2765
+ const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2766
+ const float v = GGML_FP16_TO_FP32(x);
2767
+ GGML_UNUSED(v);
2768
+ assert(!isnan(v));
2769
+ assert(!isinf(v));
2770
+ }
2771
+ #endif
2772
+ }
2773
+ }
2774
+
2775
+ static void ggml_compute_forward_gelu_erf(
2776
+ const ggml_compute_params * params,
2777
+ ggml_tensor * dst) {
2778
+
2779
+ const ggml_tensor * src0 = dst->src[0];
2780
+
2781
+ switch (src0->type) {
2782
+ case GGML_TYPE_F32:
2783
+ {
2784
+ ggml_compute_forward_gelu_erf_f32(params, dst);
2785
+ } break;
2786
+ case GGML_TYPE_F16:
2787
+ {
2788
+ ggml_compute_forward_gelu_erf_f16(params, dst);
2789
+ } break;
2790
+ default:
2791
+ {
2792
+ GGML_ABORT("fatal error");
2793
+ }
2794
+ }
2795
+ }
2796
+
2797
  // ggml_compute_forward_gelu_quick
2798
 
2799
  static void ggml_compute_forward_gelu_quick_f32(
 
7852
  {
7853
  ggml_compute_forward_gelu(params, dst);
7854
  } break;
7855
+ case GGML_UNARY_OP_GELU_ERF:
7856
+ {
7857
+ ggml_compute_forward_gelu_erf(params, dst);
7858
+ } break;
7859
  case GGML_UNARY_OP_GELU_QUICK:
7860
  {
7861
  ggml_compute_forward_gelu_quick(params, dst);
ggml/src/ggml-cpu/vec.h CHANGED
@@ -428,6 +428,7 @@ inline static void ggml_vec_exp_f16 (const int n, ggml_fp16_t * y, const ggml_fp
428
  static const float GELU_COEF_A = 0.044715f;
429
  static const float GELU_QUICK_COEF = -1.702f;
430
  static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
 
431
 
432
  inline static float ggml_gelu_f32(float x) {
433
  return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
@@ -440,6 +441,14 @@ inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp
440
  }
441
  }
442
 
 
 
 
 
 
 
 
 
443
  #ifdef GGML_GELU_FP16
444
  inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
445
  uint16_t t;
@@ -463,6 +472,13 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
463
  }
464
  #endif
465
 
 
 
 
 
 
 
 
466
  inline static float ggml_gelu_quick_f32(float x) {
467
  return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
468
  }
 
428
  static const float GELU_COEF_A = 0.044715f;
429
  static const float GELU_QUICK_COEF = -1.702f;
430
  static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
431
+ static const float SQRT_2_INV = 0.70710678118654752440084436210484f;
432
 
433
  inline static float ggml_gelu_f32(float x) {
434
  return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
 
441
  }
442
  }
443
 
444
+ inline static void ggml_vec_gelu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
445
+ for (int i = 0; i < n; ++i) {
446
+ float xi = GGML_FP16_TO_FP32(x[i]);
447
+ float res = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));
448
+ y[i] = GGML_FP32_TO_FP16(res);
449
+ }
450
+ }
451
+
452
  #ifdef GGML_GELU_FP16
453
  inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
454
  uint16_t t;
 
472
  }
473
  #endif
474
 
475
+ inline static void ggml_vec_gelu_erf_f32(const int n, float * y, const float * x) {
476
+ for (int i = 0; i < n; ++i) {
477
+ float xi = x[i];
478
+ y[i] = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));
479
+ }
480
+ }
481
+
482
  inline static float ggml_gelu_quick_f32(float x) {
483
  return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
484
  }
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -149,6 +149,8 @@ enum ggml_metal_kernel_type {
149
  GGML_METAL_KERNEL_TYPE_SIGMOID,
150
  GGML_METAL_KERNEL_TYPE_GELU,
151
  GGML_METAL_KERNEL_TYPE_GELU_4,
 
 
152
  GGML_METAL_KERNEL_TYPE_GELU_QUICK,
153
  GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
154
  GGML_METAL_KERNEL_TYPE_SILU,
@@ -1103,6 +1105,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1103
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
1104
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
1105
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
 
 
1106
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
1107
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
1108
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
@@ -1613,6 +1617,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1613
  case GGML_UNARY_OP_RELU:
1614
  case GGML_UNARY_OP_SIGMOID:
1615
  case GGML_UNARY_OP_GELU:
 
1616
  case GGML_UNARY_OP_GELU_QUICK:
1617
  case GGML_UNARY_OP_SILU:
1618
  case GGML_UNARY_OP_ELU:
@@ -2251,6 +2256,25 @@ static bool ggml_metal_encode_node(
2251
 
2252
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2253
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2254
  case GGML_UNARY_OP_GELU_QUICK:
2255
  {
2256
  int64_t n = ggml_nelements(dst);
 
149
  GGML_METAL_KERNEL_TYPE_SIGMOID,
150
  GGML_METAL_KERNEL_TYPE_GELU,
151
  GGML_METAL_KERNEL_TYPE_GELU_4,
152
+ GGML_METAL_KERNEL_TYPE_GELU_ERF,
153
+ GGML_METAL_KERNEL_TYPE_GELU_ERF_4,
154
  GGML_METAL_KERNEL_TYPE_GELU_QUICK,
155
  GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
156
  GGML_METAL_KERNEL_TYPE_SILU,
 
1105
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
1106
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
1107
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
1108
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF, gelu_erf, true);
1109
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF_4, gelu_erf_4, true);
1110
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
1111
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
1112
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
 
1617
  case GGML_UNARY_OP_RELU:
1618
  case GGML_UNARY_OP_SIGMOID:
1619
  case GGML_UNARY_OP_GELU:
1620
+ case GGML_UNARY_OP_GELU_ERF:
1621
  case GGML_UNARY_OP_GELU_QUICK:
1622
  case GGML_UNARY_OP_SILU:
1623
  case GGML_UNARY_OP_ELU:
 
2256
 
2257
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2258
  } break;
2259
+ case GGML_UNARY_OP_GELU_ERF:
2260
+ {
2261
+ int64_t n = ggml_nelements(dst);
2262
+
2263
+ id<MTLComputePipelineState> pipeline = nil;
2264
+
2265
+ if (n % 4 == 0) {
2266
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF_4].pipeline;
2267
+ n /= 4;
2268
+ } else {
2269
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF].pipeline;
2270
+ }
2271
+
2272
+ [encoder setComputePipelineState:pipeline];
2273
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2274
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2275
+
2276
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2277
+ } break;
2278
  case GGML_UNARY_OP_GELU_QUICK:
2279
  {
2280
  int64_t n = ggml_nelements(dst);
ggml/src/ggml-metal/ggml-metal.metal CHANGED
@@ -856,6 +856,7 @@ kernel void kernel_tanh(
856
  constant float GELU_COEF_A = 0.044715f;
857
  constant float GELU_QUICK_COEF = -1.702f;
858
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
 
859
 
860
  kernel void kernel_gelu(
861
  device const float * src0,
@@ -897,6 +898,42 @@ kernel void kernel_gelu_quick_4(
897
  dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
898
  }
899
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
900
  kernel void kernel_silu(
901
  device const float * src0,
902
  device float * dst,
 
856
  constant float GELU_COEF_A = 0.044715f;
857
  constant float GELU_QUICK_COEF = -1.702f;
858
  constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
859
+ constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
860
 
861
  kernel void kernel_gelu(
862
  device const float * src0,
 
898
  dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
899
  }
900
 
901
+ // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
902
+ // ref: https://www.johndcook.com/blog/python_erf/
903
+ constant float p_erf = 0.3275911f;
904
+ constant float a1_erf = 0.254829592f;
905
+ constant float a2_erf = -0.284496736f;
906
+ constant float a3_erf = 1.421413741f;
907
+ constant float a4_erf = -1.453152027f;
908
+ constant float a5_erf = 1.061405429f;
909
+
910
+ template<typename T>
911
+ T erf_approx(T x) {
912
+ T sign_x = sign(x);
913
+ x = fabs(x);
914
+ T t = 1.0f / (1.0f + p_erf * x);
915
+ T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
916
+ return sign_x * y;
917
+ }
918
+
919
+ kernel void kernel_gelu_erf(
920
+ device const float * src0,
921
+ device float * dst,
922
+ uint tpig[[thread_position_in_grid]]) {
923
+ device const float & x = src0[tpig];
924
+
925
+ dst[tpig] = 0.5f*x*(1.0f+erf_approx<float>(x*SQRT_2_INV));
926
+ }
927
+
928
+ kernel void kernel_gelu_erf_4(
929
+ device const float4 * src0,
930
+ device float4 * dst,
931
+ uint tpig[[thread_position_in_grid]]) {
932
+ device const float4 & x = src0[tpig];
933
+
934
+ dst[tpig] = 0.5f*x*(1.0f+erf_approx<float4>(x*SQRT_2_INV));
935
+ }
936
+
937
  kernel void kernel_silu(
938
  device const float * src0,
939
  device float * dst,
ggml/src/ggml.c CHANGED
@@ -1099,9 +1099,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
1099
  "HARDSWISH",
1100
  "HARDSIGMOID",
1101
  "EXP",
 
1102
  };
1103
 
1104
- static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14");
1105
 
1106
 
1107
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -2501,6 +2502,20 @@ struct ggml_tensor * ggml_gelu_inplace(
2501
  return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU);
2502
  }
2503
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2504
  // ggml_gelu_quick
2505
 
2506
  struct ggml_tensor * ggml_gelu_quick(
 
1099
  "HARDSWISH",
1100
  "HARDSIGMOID",
1101
  "EXP",
1102
+ "GELU_ERF",
1103
  };
1104
 
1105
+ static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
1106
 
1107
 
1108
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 
2502
  return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU);
2503
  }
2504
 
2505
+ // ggml_gelu_erf
2506
+
2507
+ struct ggml_tensor * ggml_gelu_erf(
2508
+ struct ggml_context * ctx,
2509
+ struct ggml_tensor * a) {
2510
+ return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_ERF);
2511
+ }
2512
+
2513
+ struct ggml_tensor * ggml_gelu_erf_inplace(
2514
+ struct ggml_context * ctx,
2515
+ struct ggml_tensor * a) {
2516
+ return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_ERF);
2517
+ }
2518
+
2519
  // ggml_gelu_quick
2520
 
2521
  struct ggml_tensor * ggml_gelu_quick(