Akarshan Biswas commited on
Commit
a29a2c3
·
1 Parent(s): e4d1f59

SYCL: Add non-contiguous support in ROPE (llama/12993)

Browse files
ggml/src/ggml-sycl/ggml-sycl.cpp CHANGED
@@ -3168,11 +3168,6 @@ static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor
3168
  ggml_sycl_op_diag_mask_inf(ctx, dst);
3169
  }
3170
 
3171
- static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3172
- GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented
3173
- ggml_sycl_op_rope(ctx, dst);
3174
- }
3175
-
3176
  static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3177
  ggml_sycl_op_pool2d(ctx, dst);
3178
  }
@@ -4002,7 +3997,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4002
  if (mode == GGML_ROPE_TYPE_MROPE) {
4003
  return false;
4004
  }
4005
- return ggml_is_contiguous(op->src[0]);
4006
  }
4007
  case GGML_OP_IM2COL:
4008
  return true;
 
3168
  ggml_sycl_op_diag_mask_inf(ctx, dst);
3169
  }
3170
 
 
 
 
 
 
3171
  static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3172
  ggml_sycl_op_pool2d(ctx, dst);
3173
  }
 
3997
  if (mode == GGML_ROPE_TYPE_MROPE) {
3998
  return false;
3999
  }
4000
+ return true;
4001
  }
4002
  case GGML_OP_IM2COL:
4003
  return true;
ggml/src/ggml-sycl/rope.cpp CHANGED
@@ -34,23 +34,21 @@ static void rope_yarn(
34
  *sin_theta = sycl::sin(theta) * mscale;
35
  }
36
 
37
- template<typename T, bool has_ff>
38
- static void rope_norm(
39
- const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
40
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
41
- const sycl::nd_item<3> &item_ct1) {
42
- const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
43
- item_ct1.get_local_id(1));
44
 
45
  if (i0 >= ne0) {
46
  return;
47
  }
48
 
49
- const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
50
- item_ct1.get_local_id(2);
51
 
52
  if (i0 >= n_dims) {
53
- const int i = row*ne0 + i0;
54
 
55
  dst[i + 0] = x[i + 0];
56
  dst[i + 1] = x[i + 1];
@@ -58,42 +56,43 @@ static void rope_norm(
58
  return;
59
  }
60
 
61
- const int i = row*ne0 + i0;
62
- const int i2 = row/p_delta_rows;
 
 
 
63
 
64
- const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f);
65
 
66
- const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
67
 
68
  float cos_theta;
69
  float sin_theta;
70
 
71
- rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
72
 
73
- const float x0 = x[i + 0];
74
- const float x1 = x[i + 1];
75
 
76
- dst[i + 0] = x0*cos_theta - x1*sin_theta;
77
- dst[i + 1] = x0*sin_theta + x1*cos_theta;
78
  }
79
 
80
- template<typename T, bool has_ff>
81
- static void rope_neox(
82
- const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
83
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
84
- const sycl::nd_item<3> &item_ct1) {
85
- const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
86
- item_ct1.get_local_id(1));
87
 
88
  if (i0 >= ne0) {
89
  return;
90
  }
91
 
92
- const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
93
- item_ct1.get_local_id(2);
94
 
95
  if (i0 >= n_dims) {
96
- const int i = row*ne0 + i0;
97
 
98
  dst[i + 0] = x[i + 0];
99
  dst[i + 1] = x[i + 1];
@@ -101,23 +100,26 @@ static void rope_neox(
101
  return;
102
  }
103
 
104
- const int i = row*ne0 + i0/2;
105
- const int i2 = row/p_delta_rows;
 
 
 
106
 
107
- const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f);
108
 
109
- const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
110
 
111
  float cos_theta;
112
  float sin_theta;
113
 
114
- rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
115
 
116
- const float x0 = x[i + 0];
117
- const float x1 = x[i + n_dims/2];
118
 
119
- dst[i + 0] = x0*cos_theta - x1*sin_theta;
120
- dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
121
  }
122
 
123
  template <typename T, bool has_ff>
@@ -163,18 +165,18 @@ static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, cons
163
  }
164
 
165
  template <typename T>
166
- static void rope_norm_sycl(
167
- const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
168
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
 
169
  GGML_ASSERT(ne0 % 2 == 0);
170
  const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
171
- const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
172
  const sycl::range<3> block_nums(1, num_blocks_x, nr);
173
 
174
- const float theta_scale = powf(freq_base, -2.0f/n_dims);
175
 
176
- dpct::has_capability_or_fail(stream->get_device(),
177
- {sycl::aspect::fp16});
178
 
179
  if (freq_factors == nullptr) {
180
  /*
@@ -182,61 +184,47 @@ static void rope_norm_sycl(
182
  the limit. To get the device limit, query
183
  info::device::max_work_group_size. Adjust the work-group size if needed.
184
  */
185
- stream->parallel_for(
186
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
187
- [=](sycl::nd_item<3> item_ct1) {
188
- rope_norm<T, false>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
189
- ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
190
- item_ct1);
191
- });
192
  } else {
193
  /*
194
  DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
195
  the limit. To get the device limit, query
196
  info::device::max_work_group_size. Adjust the work-group size if needed.
197
  */
198
- stream->parallel_for(
199
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
200
- [=](sycl::nd_item<3> item_ct1) {
201
- rope_norm<T, true>(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
202
- ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
203
- item_ct1);
204
- });
205
  }
206
  }
207
 
208
  template <typename T>
209
- static void rope_neox_sycl(
210
- const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
211
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
 
212
  GGML_ASSERT(ne0 % 2 == 0);
213
  const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
214
- const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
215
  const sycl::range<3> block_nums(1, num_blocks_x, nr);
216
 
217
- const float theta_scale = powf(freq_base, -2.0f/n_dims);
218
 
219
- dpct::has_capability_or_fail(stream->get_device(),
220
- {sycl::aspect::fp16});
221
 
222
  if (freq_factors == nullptr) {
223
- stream->parallel_for(
224
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
225
- [=](sycl::nd_item<3> item_ct1) {
226
- rope_neox<T, false>(x, dst, ne0, n_dims, pos, freq_scale,
227
- p_delta_rows, ext_factor, attn_factor,
228
- corr_dims, theta_scale, freq_factors,
229
- item_ct1);
230
- });
231
  } else {
232
- stream->parallel_for(
233
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
234
- [=](sycl::nd_item<3> item_ct1) {
235
- rope_neox<T, true>(x, dst, ne0, n_dims, pos, freq_scale,
236
- p_delta_rows, ext_factor, attn_factor,
237
- corr_dims, theta_scale, freq_factors,
238
- item_ct1);
239
- });
240
  }
241
  }
242
 
@@ -272,7 +260,7 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
272
  }
273
  }
274
 
275
- void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
276
 
277
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
278
  GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
@@ -329,43 +317,46 @@ void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
329
  if (is_neox) {
330
  GGML_SYCL_DEBUG("%s: neox path\n", __func__);
331
  if (dst->src[0]->type == GGML_TYPE_F32) {
332
- rope_neox_sycl(
333
- (const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
334
- attn_factor, corr_dims, freq_factors, main_stream
335
- );
336
  } else if (dst->src[0]->type == GGML_TYPE_F16) {
337
- rope_neox_sycl(
338
- (const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
339
- attn_factor, corr_dims, freq_factors, main_stream
340
- );
341
  } else {
342
  GGML_ABORT("fatal error");
343
  }
344
  } else if (is_vision) {
345
  GGML_SYCL_DEBUG("%s: vision path\n", __func__);
346
  if (dst->src[0]->type == GGML_TYPE_F16) {
347
- rope_vision_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
348
- freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, main_stream);
 
349
  } else if (dst->src[0]->type == GGML_TYPE_F32) {
350
- rope_vision_sycl((const float *) dst->src[0]->data, (float *)dst->data, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
351
- freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, main_stream);
 
352
  } else {
353
  GGML_ABORT("Fatal error: Tensor type unsupported!");
354
  }
355
  } else {
356
  GGML_SYCL_DEBUG("%s: norm path\n", __func__);
357
  if (dst->src[0]->type == GGML_TYPE_F32) {
358
- rope_norm_sycl(
359
- (const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
360
- attn_factor, corr_dims, freq_factors, main_stream
361
- );
362
  } else if (dst->src[0]->type == GGML_TYPE_F16) {
363
- rope_norm_sycl(
364
- (const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
365
- attn_factor, corr_dims, freq_factors, main_stream
366
- );
367
  } else {
368
  GGML_ABORT("fatal error");
369
  }
370
  }
371
  }
 
 
 
 
 
 
 
 
34
  *sin_theta = sycl::sin(theta) * mscale;
35
  }
36
 
37
+ template <typename T, bool has_ff>
38
+ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
39
+ const int32_t * pos, float freq_scale, float ext_factor, float attn_factor,
40
+ const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
41
+ const sycl::nd_item<3> & item_ct1) {
42
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
 
43
 
44
  if (i0 >= ne0) {
45
  return;
46
  }
47
 
48
+ const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
 
49
 
50
  if (i0 >= n_dims) {
51
+ const int i = row * ne0 + i0;
52
 
53
  dst[i + 0] = x[i + 0];
54
  dst[i + 1] = x[i + 1];
 
56
  return;
57
  }
58
 
59
+ const int row0 = row % ne1;
60
+ const int channel0 = row / ne1;
61
+
62
+ const int i = row * ne0 + i0;
63
+ const int i2 = channel0 * s2 + row0 * s1 + i0;
64
 
65
+ const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
66
 
67
+ const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
68
 
69
  float cos_theta;
70
  float sin_theta;
71
 
72
+ rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
73
 
74
+ const float x0 = x[i2 + 0];
75
+ const float x1 = x[i2 + 1];
76
 
77
+ dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
78
+ dst[i + 1] = x0 * sin_theta + x1 * cos_theta;
79
  }
80
 
81
+ template <typename T, bool has_ff>
82
+ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
83
+ const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
84
+ const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
85
+ const sycl::nd_item<3> & item_ct1) {
86
+ const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
 
87
 
88
  if (i0 >= ne0) {
89
  return;
90
  }
91
 
92
+ const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
 
93
 
94
  if (i0 >= n_dims) {
95
+ const int i = row * ne0 + i0;
96
 
97
  dst[i + 0] = x[i + 0];
98
  dst[i + 1] = x[i + 1];
 
100
  return;
101
  }
102
 
103
+ const int row0 = row % ne1;
104
+ const int channel0 = row / ne1;
105
+
106
+ const int i = row * ne0 + i0 / 2;
107
+ const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
108
 
109
+ const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
110
 
111
+ const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
112
 
113
  float cos_theta;
114
  float sin_theta;
115
 
116
+ rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
117
 
118
+ const float x0 = x[i2 + 0];
119
+ const float x1 = x[i2 + n_dims / 2];
120
 
121
+ dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
122
+ dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
123
  }
124
 
125
  template <typename T, bool has_ff>
 
165
  }
166
 
167
  template <typename T>
168
+ static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
169
+ const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base,
170
+ const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
171
+ const float * freq_factors, queue_ptr stream) {
172
  GGML_ASSERT(ne0 % 2 == 0);
173
  const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
174
+ const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
175
  const sycl::range<3> block_nums(1, num_blocks_x, nr);
176
 
177
+ const float theta_scale = powf(freq_base, -2.0f / n_dims);
178
 
179
+ dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
 
180
 
181
  if (freq_factors == nullptr) {
182
  /*
 
184
  the limit. To get the device limit, query
185
  info::device::max_work_group_size. Adjust the work-group size if needed.
186
  */
187
+ stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
188
+ rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
189
+ theta_scale, freq_factors, item_ct1);
190
+ });
 
 
 
191
  } else {
192
  /*
193
  DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
194
  the limit. To get the device limit, query
195
  info::device::max_work_group_size. Adjust the work-group size if needed.
196
  */
197
+ stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
198
+ rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
199
+ theta_scale, freq_factors, item_ct1);
200
+ });
 
 
 
201
  }
202
  }
203
 
204
  template <typename T>
205
+ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
206
+ const int n_dims, const int nr, const int32_t * pos, const float freq_scale,
207
+ const float freq_base, const float ext_factor, const float attn_factor,
208
+ const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
209
  GGML_ASSERT(ne0 % 2 == 0);
210
  const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
211
+ const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
212
  const sycl::range<3> block_nums(1, num_blocks_x, nr);
213
 
214
+ const float theta_scale = powf(freq_base, -2.0f / n_dims);
215
 
216
+ dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
 
217
 
218
  if (freq_factors == nullptr) {
219
+ stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
220
+ rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
221
+ theta_scale, freq_factors, item_ct1);
222
+ });
 
 
 
 
223
  } else {
224
+ stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
225
+ rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
226
+ theta_scale, freq_factors, item_ct1);
227
+ });
 
 
 
 
228
  }
229
  }
230
 
 
260
  }
261
  }
262
 
263
+ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
264
 
265
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
266
  GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
 
317
  if (is_neox) {
318
  GGML_SYCL_DEBUG("%s: neox path\n", __func__);
319
  if (dst->src[0]->type == GGML_TYPE_F32) {
320
+ rope_neox_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
321
+ pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
 
 
322
  } else if (dst->src[0]->type == GGML_TYPE_F16) {
323
+ rope_neox_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
324
+ n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
325
+ main_stream);
 
326
  } else {
327
  GGML_ABORT("fatal error");
328
  }
329
  } else if (is_vision) {
330
  GGML_SYCL_DEBUG("%s: vision path\n", __func__);
331
  if (dst->src[0]->type == GGML_TYPE_F16) {
332
+ rope_vision_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, ne02, s01,
333
+ s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
334
+ freq_factors, sections, main_stream);
335
  } else if (dst->src[0]->type == GGML_TYPE_F32) {
336
+ rope_vision_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
337
+ nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
338
+ main_stream);
339
  } else {
340
  GGML_ABORT("Fatal error: Tensor type unsupported!");
341
  }
342
  } else {
343
  GGML_SYCL_DEBUG("%s: norm path\n", __func__);
344
  if (dst->src[0]->type == GGML_TYPE_F32) {
345
+ rope_norm_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
346
+ pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
 
 
347
  } else if (dst->src[0]->type == GGML_TYPE_F16) {
348
+ rope_norm_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
349
+ n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
350
+ main_stream);
 
351
  } else {
352
  GGML_ABORT("fatal error");
353
  }
354
  }
355
  }
356
+
357
+ void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
358
+ GGML_SYCL_DEBUG("call %s\n", __func__);
359
+ ggml_sycl_op_rope(ctx, dst);
360
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
361
+ }
362
+
ggml/src/ggml-sycl/rope.hpp CHANGED
@@ -15,6 +15,6 @@
15
 
16
  #include "common.hpp"
17
 
18
- void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
19
 
20
  #endif // GGML_SYCL_ROPE_HPP
 
15
 
16
  #include "common.hpp"
17
 
18
+ void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
19
 
20
  #endif // GGML_SYCL_ROPE_HPP