File size: 1,862 Bytes
92455fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
#ifdef TEST_ON_CUDA
#include <mma.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
namespace wmma = nvcuda::wmma;
#define LIB_CALL(call) \
do { \
cudaError_t err = call; \
if (err != cudaSuccess) { \
abort(); \
} \
} while (0)
#define HOST_TYPE(x) cuda##x
#else
#ifndef HIP_HEADERS__
#include <hip/hip_runtime.h>
#include <hip/hip_fp8.h>
#include <hip/hip_fp16.h>
#include <rocwmma/rocwmma.hpp>
#define HIP_HEADERS__
#endif
namespace wmma = rocwmma;
#define LIB_CALL(call) \
do { \
hipError_t err = call; \
if (err != hipSuccess) { \
abort(); \
} \
} while (0)
#define HOST_TYPE(x) hip##x
#endif
|