File size: 707 Bytes
92455fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
#pragma once
#ifdef TEST_ON_CUDA
#define __FP8_TYPE __nv_fp8_e4m3
#define __FP8x4_TYPE __nv_fp8x4_e4m3
#define __BF16_TYPE __nv_bfloat16
#define __BF16x2_TYPE __nv_bfloat162
#else
#ifdef TEST_ON_RDNA4
#define __FP8_TYPE __hip_fp8_e4m3
#define __FP8x4_TYPE __hip_fp8x4_e4m3
constexpr const inline int WAVE_SIZE = 32;
constexpr const inline int XCD_SWIZZLE = 1;
#else
#define __FP8_TYPE __hip_fp8_e4m3_fnuz
#define __FP8x4_TYPE __hip_fp8x4_e4m3_fnuz
constexpr const inline int WAVE_SIZE = 64;
constexpr const inline int XCD_SWIZZLE = 8;
#endif
#define __BF16_TYPE __hip_bfloat16
#define __BF16x2_TYPE __hip_bfloat162
#define __FP16_TYPE __half
#define __INT16_TYPE int16_t
#endif
#define __FP32_TYPE float |