Spaces:
Running
Running
graph : make FA compatible with MLA + add initial Metal kernels (llama/12953)
Browse files* graph : make mla compatible with FA
* metal : add exp FA kernels for DeepSeek models
ggml-ci
* llama : minor naming updates
ggml-ci
* ggml : disable FA for DS head sizes
* tests : add FA tests for MLA shapes
ggml-ci
ggml/src/ggml-cuda/ggml-cuda.cu
CHANGED
|
@@ -3237,6 +3237,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3237 |
if (op->src[0]->ne[0] == 192) {
|
| 3238 |
return false;
|
| 3239 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3240 |
if (op->src[0]->ne[3] != 1) {
|
| 3241 |
return false;
|
| 3242 |
}
|
|
|
|
| 3237 |
if (op->src[0]->ne[0] == 192) {
|
| 3238 |
return false;
|
| 3239 |
}
|
| 3240 |
+
if (op->src[0]->ne[0] == 576) {
|
| 3241 |
+
// DeepSeek MLA
|
| 3242 |
+
return false;
|
| 3243 |
+
}
|
| 3244 |
if (op->src[0]->ne[3] != 1) {
|
| 3245 |
return false;
|
| 3246 |
}
|
ggml/src/ggml-metal/ggml-metal.m
CHANGED
|
@@ -354,6 +354,7 @@ enum ggml_metal_kernel_type {
|
|
| 354 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,
|
| 355 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
|
| 356 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
|
|
|
| 357 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
|
| 358 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
|
| 359 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
|
|
@@ -362,6 +363,7 @@ enum ggml_metal_kernel_type {
|
|
| 362 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,
|
| 363 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
|
| 364 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
|
|
|
|
| 365 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
|
| 366 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
|
| 367 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
|
|
@@ -370,6 +372,7 @@ enum ggml_metal_kernel_type {
|
|
| 370 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,
|
| 371 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
|
| 372 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
|
|
|
|
| 373 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
|
| 374 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
|
| 375 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
|
|
@@ -378,6 +381,7 @@ enum ggml_metal_kernel_type {
|
|
| 378 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,
|
| 379 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
|
| 380 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
|
|
|
|
| 381 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
|
| 382 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
|
| 383 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
|
|
@@ -386,6 +390,7 @@ enum ggml_metal_kernel_type {
|
|
| 386 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,
|
| 387 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
|
| 388 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
|
|
|
|
| 389 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
|
| 390 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
|
| 391 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
|
|
@@ -394,6 +399,7 @@ enum ggml_metal_kernel_type {
|
|
| 394 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,
|
| 395 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
|
| 396 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
|
|
|
|
| 397 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
|
| 398 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
|
| 399 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
|
|
@@ -402,6 +408,7 @@ enum ggml_metal_kernel_type {
|
|
| 402 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,
|
| 403 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
|
| 404 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
|
|
|
| 405 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
|
| 406 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
|
| 407 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
|
|
@@ -437,6 +444,13 @@ enum ggml_metal_kernel_type {
|
|
| 437 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
| 438 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
|
| 439 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
GGML_METAL_KERNEL_TYPE_SET_I32,
|
| 441 |
GGML_METAL_KERNEL_TYPE_SET_F32,
|
| 442 |
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
|
@@ -1018,6 +1032,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1018 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm);
|
| 1019 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm);
|
| 1020 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
|
|
|
|
| 1021 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
|
| 1022 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
|
| 1023 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
|
|
@@ -1026,6 +1041,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1026 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat);
|
| 1027 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
|
| 1028 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
|
|
|
|
| 1029 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
|
| 1030 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
|
| 1031 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
|
|
@@ -1034,6 +1050,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1034 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm);
|
| 1035 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
|
| 1036 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
|
|
|
|
| 1037 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
|
| 1038 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
|
| 1039 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
|
|
@@ -1042,6 +1059,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1042 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm);
|
| 1043 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
|
| 1044 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
|
|
|
|
| 1045 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
|
| 1046 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
|
| 1047 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
|
|
@@ -1050,6 +1068,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1050 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm);
|
| 1051 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
|
| 1052 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
|
|
|
|
| 1053 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
|
| 1054 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
|
| 1055 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
|
|
@@ -1058,6 +1077,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1058 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm);
|
| 1059 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
|
| 1060 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
|
|
|
|
| 1061 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
|
| 1062 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
|
| 1063 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
|
|
@@ -1066,6 +1086,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1066 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm);
|
| 1067 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
|
| 1068 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
|
|
|
|
| 1069 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction);
|
| 1070 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat);
|
| 1071 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction);
|
|
@@ -1101,6 +1122,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 1101 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
|
| 1102 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
|
| 1103 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1104 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
|
| 1105 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
|
| 1106 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
|
@@ -1365,6 +1393,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
| 1365 |
// TODO: not sure if it is worth adding kernels for this size
|
| 1366 |
return false;
|
| 1367 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1368 |
if (op->src[1]->type != op->src[2]->type) {
|
| 1369 |
return false;
|
| 1370 |
}
|
|
@@ -3857,12 +3890,14 @@ static void ggml_metal_encode_node(
|
|
| 3857 |
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
|
| 3858 |
// for now avoiding mainly to keep the number of templates/kernels a bit lower
|
| 3859 |
// these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
|
| 3860 |
-
if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192)) {
|
| 3861 |
switch (src1->type) {
|
| 3862 |
case GGML_TYPE_F16:
|
| 3863 |
{
|
| 3864 |
if (ne00 == 192 && ne20 == 128) {
|
| 3865 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline;
|
|
|
|
|
|
|
| 3866 |
} else {
|
| 3867 |
switch (ne00) {
|
| 3868 |
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
|
@@ -3885,6 +3920,8 @@ static void ggml_metal_encode_node(
|
|
| 3885 |
{
|
| 3886 |
if (ne00 == 192 && ne20 == 128) {
|
| 3887 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline;
|
|
|
|
|
|
|
| 3888 |
} else {
|
| 3889 |
switch (ne00) {
|
| 3890 |
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
|
|
@@ -3907,6 +3944,8 @@ static void ggml_metal_encode_node(
|
|
| 3907 |
{
|
| 3908 |
if (ne00 == 192 && ne20 == 128) {
|
| 3909 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline;
|
|
|
|
|
|
|
| 3910 |
} else {
|
| 3911 |
switch (ne00) {
|
| 3912 |
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
|
|
@@ -3929,6 +3968,8 @@ static void ggml_metal_encode_node(
|
|
| 3929 |
{
|
| 3930 |
if (ne00 == 192 && ne20 == 128) {
|
| 3931 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline;
|
|
|
|
|
|
|
| 3932 |
} else {
|
| 3933 |
switch (ne00) {
|
| 3934 |
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
|
|
@@ -3951,6 +3992,8 @@ static void ggml_metal_encode_node(
|
|
| 3951 |
{
|
| 3952 |
if (ne00 == 192 && ne20 == 128) {
|
| 3953 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline;
|
|
|
|
|
|
|
| 3954 |
} else {
|
| 3955 |
switch (ne00) {
|
| 3956 |
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
|
|
@@ -3973,6 +4016,8 @@ static void ggml_metal_encode_node(
|
|
| 3973 |
{
|
| 3974 |
if (ne00 == 192 && ne20 == 128) {
|
| 3975 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline;
|
|
|
|
|
|
|
| 3976 |
} else {
|
| 3977 |
switch (ne00) {
|
| 3978 |
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
|
|
@@ -3995,6 +4040,8 @@ static void ggml_metal_encode_node(
|
|
| 3995 |
{
|
| 3996 |
if (ne00 == 192 && ne20 == 128) {
|
| 3997 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline;
|
|
|
|
|
|
|
| 3998 |
} else {
|
| 3999 |
switch (ne00) {
|
| 4000 |
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
|
|
@@ -4114,12 +4161,36 @@ static void ggml_metal_encode_node(
|
|
| 4114 |
}
|
| 4115 |
}
|
| 4116 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4117 |
default:
|
| 4118 |
-
|
| 4119 |
-
|
| 4120 |
-
|
| 4121 |
-
|
| 4122 |
-
|
| 4123 |
}
|
| 4124 |
}
|
| 4125 |
|
|
|
|
| 354 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,
|
| 355 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
|
| 356 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
| 357 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
|
| 358 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
|
| 359 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
|
| 360 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
|
|
|
|
| 363 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,
|
| 364 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
|
| 365 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
|
| 366 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512,
|
| 367 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
|
| 368 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
|
| 369 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
|
|
|
|
| 372 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,
|
| 373 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
|
| 374 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
|
| 375 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512,
|
| 376 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
|
| 377 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
|
| 378 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
|
|
|
|
| 381 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,
|
| 382 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
|
| 383 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
|
| 384 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512,
|
| 385 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
|
| 386 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
|
| 387 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
|
|
|
|
| 390 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,
|
| 391 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
|
| 392 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
|
| 393 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512,
|
| 394 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
|
| 395 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
|
| 396 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
|
|
|
|
| 399 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,
|
| 400 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
|
| 401 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
|
| 402 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512,
|
| 403 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
|
| 404 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
|
| 405 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
|
|
|
|
| 408 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,
|
| 409 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
|
| 410 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
| 411 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
|
| 412 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
|
| 413 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
|
| 414 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
|
|
|
|
| 444 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
| 445 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
|
| 446 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
|
| 447 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,
|
| 448 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512,
|
| 449 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512,
|
| 450 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512,
|
| 451 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
|
| 452 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
|
| 453 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
|
| 454 |
GGML_METAL_KERNEL_TYPE_SET_I32,
|
| 455 |
GGML_METAL_KERNEL_TYPE_SET_F32,
|
| 456 |
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
|
|
|
| 1032 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm);
|
| 1033 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm);
|
| 1034 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
|
| 1035 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm);
|
| 1036 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
|
| 1037 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
|
| 1038 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
|
|
|
|
| 1041 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat);
|
| 1042 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
|
| 1043 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
|
| 1044 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat);
|
| 1045 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
|
| 1046 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
|
| 1047 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
|
|
|
|
| 1050 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm);
|
| 1051 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
|
| 1052 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
|
| 1053 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm);
|
| 1054 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
|
| 1055 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
|
| 1056 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
|
|
|
|
| 1059 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm);
|
| 1060 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
|
| 1061 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
|
| 1062 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm);
|
| 1063 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
|
| 1064 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
|
| 1065 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
|
|
|
|
| 1068 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm);
|
| 1069 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
|
| 1070 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
|
| 1071 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm);
|
| 1072 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
|
| 1073 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
|
| 1074 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
|
|
|
|
| 1077 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm);
|
| 1078 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
|
| 1079 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
|
| 1080 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm);
|
| 1081 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
|
| 1082 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
|
| 1083 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
|
|
|
|
| 1086 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm);
|
| 1087 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
|
| 1088 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
|
| 1089 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
|
| 1090 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction);
|
| 1091 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat);
|
| 1092 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction);
|
|
|
|
| 1122 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
|
| 1123 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
|
| 1124 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
|
| 1125 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, flash_attn_ext_vec_f16_hk576_hv512, has_simdgroup_reduction);
|
| 1126 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, flash_attn_ext_vec_bf16_hk576_hv512, has_simdgroup_reduction && use_bfloat);
|
| 1127 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, flash_attn_ext_vec_q4_0_hk576_hv512, has_simdgroup_reduction);
|
| 1128 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, flash_attn_ext_vec_q4_1_hk576_hv512, has_simdgroup_reduction);
|
| 1129 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
|
| 1130 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
|
| 1131 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
|
| 1132 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
|
| 1133 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
|
| 1134 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
|
|
|
| 1393 |
// TODO: not sure if it is worth adding kernels for this size
|
| 1394 |
return false;
|
| 1395 |
}
|
| 1396 |
+
if (op->src[0]->ne[0] == 576) {
|
| 1397 |
+
// DeepSeek sizes
|
| 1398 |
+
// TODO: disabled for now, until optmized
|
| 1399 |
+
return false;
|
| 1400 |
+
}
|
| 1401 |
if (op->src[1]->type != op->src[2]->type) {
|
| 1402 |
return false;
|
| 1403 |
}
|
|
|
|
| 3890 |
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
|
| 3891 |
// for now avoiding mainly to keep the number of templates/kernels a bit lower
|
| 3892 |
// these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
|
| 3893 |
+
if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
|
| 3894 |
switch (src1->type) {
|
| 3895 |
case GGML_TYPE_F16:
|
| 3896 |
{
|
| 3897 |
if (ne00 == 192 && ne20 == 128) {
|
| 3898 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline;
|
| 3899 |
+
} else if (ne00 == 576 && ne20 == 512) {
|
| 3900 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
|
| 3901 |
} else {
|
| 3902 |
switch (ne00) {
|
| 3903 |
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
|
|
|
| 3920 |
{
|
| 3921 |
if (ne00 == 192 && ne20 == 128) {
|
| 3922 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline;
|
| 3923 |
+
} else if (ne00 == 576 && ne20 == 512) {
|
| 3924 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline;
|
| 3925 |
} else {
|
| 3926 |
switch (ne00) {
|
| 3927 |
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
|
|
|
|
| 3944 |
{
|
| 3945 |
if (ne00 == 192 && ne20 == 128) {
|
| 3946 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline;
|
| 3947 |
+
} else if (ne00 == 576 && ne20 == 512) {
|
| 3948 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline;
|
| 3949 |
} else {
|
| 3950 |
switch (ne00) {
|
| 3951 |
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
|
|
|
|
| 3968 |
{
|
| 3969 |
if (ne00 == 192 && ne20 == 128) {
|
| 3970 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline;
|
| 3971 |
+
} else if (ne00 == 576 && ne20 == 512) {
|
| 3972 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline;
|
| 3973 |
} else {
|
| 3974 |
switch (ne00) {
|
| 3975 |
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
|
|
|
|
| 3992 |
{
|
| 3993 |
if (ne00 == 192 && ne20 == 128) {
|
| 3994 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline;
|
| 3995 |
+
} else if (ne00 == 576 && ne20 == 512) {
|
| 3996 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline;
|
| 3997 |
} else {
|
| 3998 |
switch (ne00) {
|
| 3999 |
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
|
|
|
|
| 4016 |
{
|
| 4017 |
if (ne00 == 192 && ne20 == 128) {
|
| 4018 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline;
|
| 4019 |
+
} else if (ne00 == 576 && ne20 == 512) {
|
| 4020 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline;
|
| 4021 |
} else {
|
| 4022 |
switch (ne00) {
|
| 4023 |
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
|
|
|
|
| 4040 |
{
|
| 4041 |
if (ne00 == 192 && ne20 == 128) {
|
| 4042 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline;
|
| 4043 |
+
} else if (ne00 == 576 && ne20 == 512) {
|
| 4044 |
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
|
| 4045 |
} else {
|
| 4046 |
switch (ne00) {
|
| 4047 |
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
|
|
|
|
| 4161 |
}
|
| 4162 |
}
|
| 4163 |
} break;
|
| 4164 |
+
case 576:
|
| 4165 |
+
{
|
| 4166 |
+
if (ne20 == 512) {
|
| 4167 |
+
switch (src1->type) {
|
| 4168 |
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break;
|
| 4169 |
+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break;
|
| 4170 |
+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break;
|
| 4171 |
+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break;
|
| 4172 |
+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break;
|
| 4173 |
+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break;
|
| 4174 |
+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break;
|
| 4175 |
+
default:
|
| 4176 |
+
{
|
| 4177 |
+
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
|
| 4178 |
+
GGML_LOG_ERROR("add template specialization for this type\n");
|
| 4179 |
+
GGML_ABORT("add template specialization for this type");
|
| 4180 |
+
}
|
| 4181 |
+
}
|
| 4182 |
+
} else {
|
| 4183 |
+
GGML_LOG_ERROR("unsupported size: %lld\n", ne20);
|
| 4184 |
+
GGML_LOG_ERROR("add template specialization for this size\n");
|
| 4185 |
+
GGML_ABORT("add template specialization for this size");
|
| 4186 |
+
}
|
| 4187 |
+
} break;
|
| 4188 |
default:
|
| 4189 |
+
{
|
| 4190 |
+
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
| 4191 |
+
GGML_LOG_ERROR("add template specialization for this size\n");
|
| 4192 |
+
GGML_ABORT("add template specialization for this size");
|
| 4193 |
+
}
|
| 4194 |
}
|
| 4195 |
}
|
| 4196 |
|
ggml/src/ggml-metal/ggml-metal.metal
CHANGED
|
@@ -3546,6 +3546,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_at
|
|
| 3546 |
template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 192>;
|
| 3547 |
template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
|
| 3548 |
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
|
|
|
|
| 3549 |
|
| 3550 |
#if defined(GGML_METAL_USE_BF16)
|
| 3551 |
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
|
@@ -3556,6 +3557,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_at
|
|
| 3556 |
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
|
| 3557 |
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
|
| 3558 |
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
|
|
|
|
| 3559 |
#endif
|
| 3560 |
|
| 3561 |
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
|
@@ -3566,6 +3568,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_at
|
|
| 3566 |
template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
|
| 3567 |
template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
|
| 3568 |
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
|
|
|
|
| 3569 |
|
| 3570 |
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
|
| 3571 |
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
|
|
@@ -3575,6 +3578,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_at
|
|
| 3575 |
template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
|
| 3576 |
template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
|
| 3577 |
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
|
|
|
|
| 3578 |
|
| 3579 |
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
|
| 3580 |
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
|
|
@@ -3584,6 +3588,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_at
|
|
| 3584 |
template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
|
| 3585 |
template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
|
| 3586 |
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
|
|
|
|
| 3587 |
|
| 3588 |
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
|
| 3589 |
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
|
|
@@ -3593,6 +3598,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_at
|
|
| 3593 |
template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
|
| 3594 |
template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
|
| 3595 |
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
|
|
|
|
| 3596 |
|
| 3597 |
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
|
| 3598 |
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
|
|
@@ -3602,6 +3608,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_at
|
|
| 3602 |
template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
|
| 3603 |
template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
|
| 3604 |
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
|
|
|
|
| 3605 |
|
| 3606 |
#undef FA_TYPES
|
| 3607 |
|
|
@@ -4009,6 +4016,16 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_
|
|
| 4009 |
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 4>;
|
| 4010 |
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 4>;
|
| 4011 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4012 |
#undef FA_TYPES
|
| 4013 |
|
| 4014 |
template<typename T>
|
|
|
|
| 3546 |
template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 192>;
|
| 3547 |
template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
|
| 3548 |
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
|
| 3549 |
+
template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
| 3550 |
|
| 3551 |
#if defined(GGML_METAL_USE_BF16)
|
| 3552 |
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
|
|
|
| 3557 |
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
|
| 3558 |
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
|
| 3559 |
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
|
| 3560 |
+
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
| 3561 |
#endif
|
| 3562 |
|
| 3563 |
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
|
|
|
| 3568 |
template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
|
| 3569 |
template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
|
| 3570 |
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
|
| 3571 |
+
template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
|
| 3572 |
|
| 3573 |
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
|
| 3574 |
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
|
|
|
|
| 3578 |
template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
|
| 3579 |
template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
|
| 3580 |
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
|
| 3581 |
+
template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
|
| 3582 |
|
| 3583 |
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
|
| 3584 |
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
|
|
|
|
| 3588 |
template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
|
| 3589 |
template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
|
| 3590 |
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
|
| 3591 |
+
template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
|
| 3592 |
|
| 3593 |
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
|
| 3594 |
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
|
|
|
|
| 3598 |
template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
|
| 3599 |
template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
|
| 3600 |
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
|
| 3601 |
+
template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
|
| 3602 |
|
| 3603 |
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
|
| 3604 |
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
|
|
|
|
| 3608 |
template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
|
| 3609 |
template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
|
| 3610 |
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
|
| 3611 |
+
template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
|
| 3612 |
|
| 3613 |
#undef FA_TYPES
|
| 3614 |
|
|
|
|
| 4016 |
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 4>;
|
| 4017 |
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 4>;
|
| 4018 |
|
| 4019 |
+
template [[host_name("kernel_flash_attn_ext_vec_f16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
|
| 4020 |
+
#if defined(GGML_METAL_USE_BF16)
|
| 4021 |
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
|
| 4022 |
+
#endif
|
| 4023 |
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>;
|
| 4024 |
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>;
|
| 4025 |
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>;
|
| 4026 |
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>;
|
| 4027 |
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
|
| 4028 |
+
|
| 4029 |
#undef FA_TYPES
|
| 4030 |
|
| 4031 |
template<typename T>
|
ggml/src/ggml-vulkan/ggml-vulkan.cpp
CHANGED
|
@@ -9261,6 +9261,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
| 9261 |
case 112:
|
| 9262 |
case 128:
|
| 9263 |
case 256:
|
|
|
|
| 9264 |
break;
|
| 9265 |
default:
|
| 9266 |
return false;
|
|
|
|
| 9261 |
case 112:
|
| 9262 |
case 128:
|
| 9263 |
case 256:
|
| 9264 |
+
case 575: // DeepSeek MLA
|
| 9265 |
break;
|
| 9266 |
default:
|
| 9267 |
return false;
|