File pytorch-add-cmake-variable-USE_ROCM_CK.patch of Package python-torch
diff --git a/CMakeLists.txt b/CMakeLists.txt
index f3fee2f7f..73903acce 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -249,6 +249,7 @@ cmake_dependent_option(
BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON
"USE_CUDA AND LINUX AND BUILD_PYTHON" OFF)
cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX" OFF)
+cmake_dependent_option(USE_ROCM_CK "Use ROCm Composable Kernel" ON "USE_ROCM" ON)
option(CAFFE2_STATIC_LINK_CUDA "Statically link CUDA libraries" OFF)
cmake_dependent_option(USE_CUDNN "Use cuDNN" ON "USE_CUDA" OFF)
cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt
index 085af373e..c88b3ee97 100644
--- a/aten/src/ATen/CMakeLists.txt
+++ b/aten/src/ATen/CMakeLists.txt
@@ -344,8 +344,8 @@ endif()
if(USE_ROCM)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
- list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
- list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
+ #list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
+ #list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
# Next two lines are needed because TunableOp uses third-party/fmt
list(APPEND ATen_HIP_INCLUDE $<TARGET_PROPERTY:fmt::fmt-header-only,INTERFACE_INCLUDE_DIRECTORIES>)
list(APPEND ATen_HIP_DEPENDENCY_LIBS fmt::fmt-header-only)
@@ -361,7 +361,7 @@ endif()
${native_quantized_hip_hip}
${native_transformers_hip_hip} ${native_transformers_src_hip_hip}
)
- if(WIN32) # Windows doesn't support Composable Kernels and Triton
+ if(WIN32 OR NOT USE_ROCM_CK) # Windows doesn't support Composable Kernels and Triton
file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip")
file(GLOB native_hip_ck "native/hip/ck*.hip")
exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}"
diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
index a62b028fd..07ecaa5d6 100644
--- a/aten/src/ATen/cuda/CUDABlas.cpp
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
@@ -636,7 +636,7 @@ template <>
void bgemm_internal<double>(CUDABLAS_BGEMM_ARGTYPES(double))
{
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
-#ifdef USE_ROCM
+#ifdef USE_ROCM_CK
// hipblaslt does not support double gemm yet
bgemm_internal_cublas<double>(CUDABLAS_BGEMM_ARGS(double));
#else
@@ -708,7 +708,7 @@ void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
bgemm_internal_cublaslt<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
}
-#ifdef USE_ROCM
+#ifdef defined(USE_ROCM) && defined(USE_ROCM_CK)
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::bgemm_internal_ck<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
}
@@ -1054,14 +1054,14 @@ template <>
void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
{
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
-#ifdef USE_ROCM
+#ifdef USE_ROCM_CK
// hipblaslt does not support double gemm yet
gemm_internal_cublas<double>(CUDABLAS_GEMM_ARGS(double));
#else
gemm_internal_cublaslt<double>(CUDABLAS_GEMM_ARGS(double));
#endif
}
-#ifdef USE_ROCM
+#ifdef defined(USE_ROCM) && defined(USE_ROCM_CK)
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<double>(CUDABLAS_GEMM_ARGS(double));
}
@@ -1077,7 +1077,7 @@ void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
}
-#ifdef USE_ROCM
+#ifdef USE_ROCM && defined(USE_ROCM_CK)
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<float>(CUDABLAS_GEMM_ARGS(float));
}
@@ -1091,7 +1091,7 @@ template <>
void gemm_internal<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>))
{
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
-#ifdef USE_ROCM
+#ifdef USE_ROCM_CK
// hipblaslt does not support complex gemm yet
gemm_internal_cublas<c10::complex<double>>(CUDABLAS_GEMM_ARGS(c10::complex<double>));
#else
@@ -1107,7 +1107,7 @@ template <>
void gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>))
{
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
-#ifdef USE_ROCM
+#ifdef USE_ROCM_CK
// hipblaslt does not support complex gemm yet
gemm_internal_cublas<c10::complex<float>>(CUDABLAS_GEMM_ARGS(c10::complex<float>));
#else
@@ -1125,7 +1125,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
}
-#ifdef USE_ROCM
+#ifdef USE_ROCM && defined(USE_ROCM_CK)
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
}
@@ -1141,7 +1141,7 @@ void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
-#ifdef USE_ROCM
+#ifdef USE_ROCM && defined(USE_ROCM_CK)
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
index 5227204b0..824b365ba 100644
--- a/cmake/Dependencies.cmake
+++ b/cmake/Dependencies.cmake
@@ -1062,6 +1062,9 @@ if(USE_ROCM)
if(HIPBLASLT_VEC_EXT)
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_VEC_EXT)
endif()
+ if(USE_ROCM_CK)
+ list(APPEND HIP_CXX_FLAGS -DUSE_ROCM_CK)
+ endif()
list(APPEND HIP_HIPCC_FLAGS --offload-compress)
if(WIN32)
add_definitions(-DROCM_ON_WINDOWS)