File pytorch-add-cmake-variable-USE_ROCM_CK.patch of Package python-torch

diff --git a/CMakeLists.txt b/CMakeLists.txt
index f3fee2f7f..73903acce 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -249,6 +249,7 @@ cmake_dependent_option(
   BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON
   "USE_CUDA AND LINUX AND BUILD_PYTHON" OFF)
 cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX" OFF)
+cmake_dependent_option(USE_ROCM_CK "Use ROCm Composable Kernel" ON "USE_ROCM" ON)
 option(CAFFE2_STATIC_LINK_CUDA "Statically link CUDA libraries" OFF)
 cmake_dependent_option(USE_CUDNN "Use cuDNN" ON "USE_CUDA" OFF)
 cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt
index 085af373e..c88b3ee97 100644
--- a/aten/src/ATen/CMakeLists.txt
+++ b/aten/src/ATen/CMakeLists.txt
@@ -344,8 +344,8 @@ endif()
 
 if(USE_ROCM)
   list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
-  list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
-  list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
+  #list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
+  #list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
   # Next two lines are needed because TunableOp uses third-party/fmt
   list(APPEND ATen_HIP_INCLUDE $<TARGET_PROPERTY:fmt::fmt-header-only,INTERFACE_INCLUDE_DIRECTORIES>)
   list(APPEND ATen_HIP_DEPENDENCY_LIBS fmt::fmt-header-only)
@@ -361,7 +361,7 @@ endif()
     ${native_quantized_hip_hip}
     ${native_transformers_hip_hip} ${native_transformers_src_hip_hip}
   )
-  if(WIN32) # Windows doesn't support Composable Kernels and Triton
+  if(WIN32 OR NOT USE_ROCM_CK) # Windows doesn't support Composable Kernels and Triton
     file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip")
     file(GLOB native_hip_ck "native/hip/ck*.hip")
     exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}"
diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
index a62b028fd..07ecaa5d6 100644
--- a/aten/src/ATen/cuda/CUDABlas.cpp
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
@@ -636,7 +636,7 @@ template <>
 void bgemm_internal<double>(CUDABLAS_BGEMM_ARGTYPES(double))
 {
   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
-#ifdef USE_ROCM
+#ifdef USE_ROCM_CK
     // hipblaslt does not support double gemm yet
     bgemm_internal_cublas<double>(CUDABLAS_BGEMM_ARGS(double));
 #else
@@ -708,7 +708,7 @@ void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16))
   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
     bgemm_internal_cublaslt<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
   }
-#ifdef USE_ROCM
+#ifdef defined(USE_ROCM) && defined(USE_ROCM_CK)
   else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
     at::native::bgemm_internal_ck<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
   }
@@ -1054,14 +1054,14 @@ template <>
 void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
 {
   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
-#ifdef USE_ROCM
+#ifdef USE_ROCM_CK
     // hipblaslt does not support double gemm yet
     gemm_internal_cublas<double>(CUDABLAS_GEMM_ARGS(double));
 #else
     gemm_internal_cublaslt<double>(CUDABLAS_GEMM_ARGS(double));
 #endif
   }
-#ifdef USE_ROCM
+#ifdef defined(USE_ROCM) && defined(USE_ROCM_CK)
   else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
     at::native::gemm_internal_ck<double>(CUDABLAS_GEMM_ARGS(double));
   }
@@ -1077,7 +1077,7 @@ void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
     gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
   }
-#ifdef USE_ROCM
+#ifdef USE_ROCM && defined(USE_ROCM_CK)
   else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
     at::native::gemm_internal_ck<float>(CUDABLAS_GEMM_ARGS(float));
   }
@@ -1091,7 +1091,7 @@ template <>
 void gemm_internal<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>))
 {
   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
-#ifdef USE_ROCM
+#ifdef USE_ROCM_CK
     // hipblaslt does not support complex gemm yet
     gemm_internal_cublas<c10::complex<double>>(CUDABLAS_GEMM_ARGS(c10::complex<double>));
 #else
@@ -1107,7 +1107,7 @@ template <>
 void gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>))
 {
   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
-#ifdef USE_ROCM
+#ifdef USE_ROCM_CK
     // hipblaslt does not support complex gemm yet
     gemm_internal_cublas<c10::complex<float>>(CUDABLAS_GEMM_ARGS(c10::complex<float>));
 #else
@@ -1125,7 +1125,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
     gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
   }
-#ifdef USE_ROCM
+#ifdef USE_ROCM && defined(USE_ROCM_CK)
   else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
     at::native::gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
   }
@@ -1141,7 +1141,7 @@ void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
     gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
   }
-#ifdef USE_ROCM
+#ifdef USE_ROCM && defined(USE_ROCM_CK)
   else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
     at::native::gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
   }
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
index 5227204b0..824b365ba 100644
--- a/cmake/Dependencies.cmake
+++ b/cmake/Dependencies.cmake
@@ -1062,6 +1062,9 @@ if(USE_ROCM)
     if(HIPBLASLT_VEC_EXT)
       list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_VEC_EXT)
     endif()
+    if(USE_ROCM_CK)
+      list(APPEND HIP_CXX_FLAGS -DUSE_ROCM_CK)
+    endif()
     list(APPEND HIP_HIPCC_FLAGS --offload-compress)
     if(WIN32)
       add_definitions(-DROCM_ON_WINDOWS)
openSUSE Build Service is sponsored by