File Add-support-for-SPV_INTEL_function_variants.patch of Package spirv-llvm-translator
From aaa23785dcbd7cabebad860edd9a3aebfc6bbb41 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jakub=20=C5=BD=C3=A1dn=C3=ADk?= <jakub.zadnik@intel.com>
Date: Tue, 15 Jul 2025 16:57:03 +0300
Subject: [PATCH] Add support for SPV_INTEL_function_variants (#3246)
This PR implements
[SPV_INTEL_function_variants](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_function_variants.asciidoc).
It adds an optional SPIR-V to SPIR-V specialization pass that converts a
multitarget module into a targeted one.
The multitarget module does not have a LLVM IR representation, the
extension only describes the specialization algorithm that takes place
before converting the SPIR-V module into LLVM-IR. For this reason, it is
only implemented as a part of SPIRVReader and not SPIRVWriter.
The specialization is controlled by the user supplying the target device
category, family, architecture, target ISA, supported features and/or
supported capabilities via CLI flags. For example, to specialize for an
Intel x86_64 CPU with Lion Cove microarchitecture that supports SSE,
SSE2, SSE3, SSE4.1, SSE4.2, SSE4a, AVX, AVX2 and AVX512f features and
Addresses, Linkage, Kernel, Int64 and Int8 capabilities, the user needs
to provide the following flags:
```
llvm-spirv -r \
--spirv-ext=+SPV_INTEL_function_variants \
--fnvar-spec-enable \
--fnvar-spv-out targeted.spv \
--fnvar-category 1 --fnvar-family 1 --fnvar-arch 15 \
--fnvar-target 4 --fnvar-features '4,5,6,7,8,9,10,11,12' \
--fnvar-capabilities '4,5,6,11,39' \
multitarget.spv -o targeted.bc
```
Omitting a flag means that the target device supports all values for the
flag. For example, in the above example, leaving out the
`--fnvar-features` flag means that that the target device supports all
features available for the x86_64 target.
The integer values passed to the CLI flags are taken from a proposed
[targets _registry_](https://github.com/intel/llvm/pull/18822)
accompanying the extension. (Capabilities correspond directly to the
values defined in the SPIR-V specification). During the specialization
pass, the specialization pass compares these CLI-supplied integers with
the operands of `OpSpecConstantTargetINTEL`,
`OpSpecConstantArchitectureINTEL` and `OpSpecConstantCapabilitiesINTEL`
instructions in the input multitarget module, converts these
instructions to constant true/false and proceeds with the specialization
according to the rules described in the extension.
Providing the CLI values as raw integer is not the most user friendly,
and the translator does not validate the values in any way (eg.,
checking that feature X is allowed for target Y). This can be improved
after the _registry_ is merged and more mature (version >0).
Note: `--spirv-debug` can be used to print out details about what's
happening when evaluating the above spec constants. It's useful for
getting an insight into why a certain function variant got selected if
the selection does not match the expected outcome.
---
include/LLVMSPIRVExtensions.inc | 1 +
include/LLVMSPIRVOpts.h | 53 +++
lib/SPIRV/CMakeLists.txt | 1 +
lib/SPIRV/LLVMSPIRVOpts.cpp | 23 +
lib/SPIRV/SPIRVReader.cpp | 40 ++
lib/SPIRV/libSPIRV/SPIRVEntry.cpp | 1 +
lib/SPIRV/libSPIRV/SPIRVEntry.h | 3 +
lib/SPIRV/libSPIRV/SPIRVErrorEnum.h | 2 +
lib/SPIRV/libSPIRV/SPIRVFnVar.cpp | 378 +++++++++++++++
lib/SPIRV/libSPIRV/SPIRVFnVar.h | 439 ++++++++++++++++++
lib/SPIRV/libSPIRV/SPIRVModule.cpp | 180 ++++++-
lib/SPIRV/libSPIRV/SPIRVModule.h | 60 +++
lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h | 3 +
lib/SPIRV/libSPIRV/SPIRVOpCode.h | 13 +-
lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h | 7 +
spirv-headers-tag.conf | 2 +-
.../addsubmul_asm.spt | 277 +++++++++++
.../SPV_INTEL_function_variants/cl_dot4.spt | 353 ++++++++++++++
tools/llvm-spirv/llvm-spirv.cpp | 136 +++++-
19 files changed, 1949 insertions(+), 23 deletions(-)
create mode 100644 lib/SPIRV/libSPIRV/SPIRVFnVar.cpp
create mode 100644 lib/SPIRV/libSPIRV/SPIRVFnVar.h
create mode 100644 test/extensions/INTEL/SPV_INTEL_function_variants/addsubmul_asm.spt
create mode 100644 test/extensions/INTEL/SPV_INTEL_function_variants/cl_dot4.spt
diff --git a/include/LLVMSPIRVExtensions.inc b/include/LLVMSPIRVExtensions.inc
index b86791922d..5231209d4d 100644
--- a/include/LLVMSPIRVExtensions.inc
+++ b/include/LLVMSPIRVExtensions.inc
@@ -79,3 +79,4 @@ EXT(SPV_INTEL_subgroup_matrix_multiply_accumulate)
EXT(SPV_KHR_bfloat16)
EXT(SPV_INTEL_ternary_bitwise_function)
EXT(SPV_INTEL_int4)
+EXT(SPV_INTEL_function_variants)
diff --git a/include/LLVMSPIRVOpts.h b/include/LLVMSPIRVOpts.h
index 08ef7b0e9e..bdbcc814bb 100644
--- a/include/LLVMSPIRVOpts.h
+++ b/include/LLVMSPIRVOpts.h
@@ -252,6 +252,50 @@ class TranslatorOpts {
void setUseLLVMTarget(bool Flag) noexcept { UseLLVMTarget = Flag; }
bool getUseLLVMTarget() const noexcept { return UseLLVMTarget; }
+ void setFnVarCategory(uint32_t Category) noexcept {
+ FnVarCategory = Category;
+ }
+ std::optional<uint32_t> getFnVarCategory() const noexcept {
+ return FnVarCategory;
+ }
+
+ void setFnVarFamily(uint32_t Family) noexcept { FnVarFamily = Family; }
+ std::optional<uint32_t> getFnVarFamily() const noexcept {
+ return FnVarFamily;
+ }
+
+ void setFnVarArch(uint32_t Arch) noexcept { FnVarArch = Arch; }
+ std::optional<uint32_t> getFnVarArch() const noexcept { return FnVarArch; }
+
+ void setFnVarTarget(uint32_t Target) noexcept { FnVarTarget = Target; }
+ std::optional<uint32_t> getFnVarTarget() const noexcept {
+ return FnVarTarget;
+ }
+
+ void setFnVarFeatures(std::vector<uint32_t> Features) noexcept {
+ FnVarFeatures = Features;
+ }
+ std::vector<uint32_t> getFnVarFeatures() const noexcept {
+ return FnVarFeatures;
+ }
+
+ void setFnVarCapabilities(std::vector<uint32_t> Capabilities) noexcept {
+ FnVarCapabilities = Capabilities;
+ }
+ std::vector<uint32_t> getFnVarCapabilities() const noexcept {
+ return FnVarCapabilities;
+ }
+
+ void setFnVarSpecEnable(bool Val) noexcept { FnVarSpecEnable = Val; }
+ bool getFnVarSpecEnable() const noexcept { return FnVarSpecEnable; }
+
+ void setFnVarSpvOut(std::string Val) noexcept { FnVarSpvOut = Val; }
+ std::string getFnVarSpvOut() const noexcept { return FnVarSpvOut; }
+
+ // Check that options passed to --fnvar-xxx flags make sense. Return true on
+ // success, false on failure.
+ bool validateFnVarOpts() const;
+
private:
// Common translation options
VersionNumber MaxVersion = VersionNumber::MaximumVersion;
@@ -301,6 +345,15 @@ class TranslatorOpts {
bool PreserveAuxData = false;
+ std::optional<uint32_t> FnVarCategory = std::nullopt;
+ std::optional<uint32_t> FnVarFamily = std::nullopt;
+ std::optional<uint32_t> FnVarArch = std::nullopt;
+ std::optional<uint32_t> FnVarTarget = std::nullopt;
+ std::vector<uint32_t> FnVarFeatures = {};
+ std::vector<uint32_t> FnVarCapabilities = {};
+ std::string FnVarSpvOut = "";
+ bool FnVarSpecEnable = false;
+
BuiltinFormat SPIRVBuiltinFormat = BuiltinFormat::Function;
// Convert LLVM to SPIR-V using the LLVM SPIR-V Backend target
diff --git a/lib/SPIRV/CMakeLists.txt b/lib/SPIRV/CMakeLists.txt
index e155a44f4e..833eb59e33 100644
--- a/lib/SPIRV/CMakeLists.txt
+++ b/lib/SPIRV/CMakeLists.txt
@@ -39,6 +39,7 @@ set(SRC_LIST
libSPIRV/SPIRVType.cpp
libSPIRV/SPIRVValue.cpp
libSPIRV/SPIRVError.cpp
+ libSPIRV/SPIRVFnVar.cpp
)
add_llvm_library(LLVMSPIRVLib
${SRC_LIST}
diff --git a/lib/SPIRV/LLVMSPIRVOpts.cpp b/lib/SPIRV/LLVMSPIRVOpts.cpp
index 7f0c6e32a3..a96fe0c9df 100644
--- a/lib/SPIRV/LLVMSPIRVOpts.cpp
+++ b/lib/SPIRV/LLVMSPIRVOpts.cpp
@@ -44,6 +44,7 @@
#include <llvm/ADT/SmallVector.h>
#include <llvm/ADT/StringRef.h>
#include <llvm/IR/IntrinsicInst.h>
+#include <optional>
using namespace llvm;
using namespace SPIRV;
@@ -89,3 +90,25 @@ std::vector<std::string> TranslatorOpts::getAllowedSPIRVExtensionNames(
}
return AllowExtNames;
}
+
+bool TranslatorOpts::validateFnVarOpts() const {
+ if (getFnVarCategory() == std::nullopt &&
+ (getFnVarFamily() != std::nullopt || getFnVarArch() != std::nullopt)) {
+ errs() << "FnVar: Device category must be specified if the family or "
+ "architecture are specified.";
+ return false;
+ }
+
+ if (getFnVarFamily() == std::nullopt && getFnVarArch() != std::nullopt) {
+ errs() << "FnVar: Device family must be specified if the architecture is "
+ "specified.";
+ return false;
+ }
+
+ if (getFnVarTarget() == std::nullopt && !getFnVarFeatures().empty()) {
+ errs() << "Device target must be specified if the features are specified.";
+ return false;
+ }
+
+ return true;
+}
diff --git a/lib/SPIRV/SPIRVReader.cpp b/lib/SPIRV/SPIRVReader.cpp
index 66f94ce7cf..b247868386 100644
--- a/lib/SPIRV/SPIRVReader.cpp
+++ b/lib/SPIRV/SPIRVReader.cpp
@@ -41,6 +41,7 @@
#include "SPIRVAsm.h"
#include "SPIRVBasicBlock.h"
#include "SPIRVExtInst.h"
+#include "SPIRVFnVar.h"
#include "SPIRVFunction.h"
#include "SPIRVInstruction.h"
#include "SPIRVInternal.h"
@@ -1761,6 +1762,20 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
case OpLabel:
return mapValue(BV, BasicBlock::Create(*Context, BV->getName(), F));
+ case OpSpecConstantArchitectureINTEL:
+ llvm_unreachable(
+ "Encountered non-specialized OpSpecConstantArchitectureINTEL");
+ return nullptr;
+
+ case OpSpecConstantTargetINTEL:
+ llvm_unreachable("Encountered non-specialized OpSpecConstantTargetINTEL");
+ return nullptr;
+
+ case OpSpecConstantCapabilitiesINTEL:
+ llvm_unreachable(
+ "Encountered non-specialized OpSpecConstantCapabilitiesINTEL");
+ return nullptr;
+
default:
// do nothing
break;
@@ -5607,6 +5622,31 @@ bool llvm::readSpirv(LLVMContext &C, const SPIRV::TranslatorOpts &Opts,
if (!BM)
return false;
+ if (Opts.getFnVarSpecEnable()) {
+ if (!specializeFnVariants(BM.get(), ErrMsg)) {
+ return false;
+ }
+
+ // Write out the specialized/targeted module
+ if (!BM->getFnVarSpvOut().empty()) {
+ auto SaveOpt = SPIRVUseTextFormat;
+ auto OFSSpv = std::ofstream(BM->getFnVarSpvOut(), std::ios::binary);
+ SPIRVUseTextFormat = false;
+ OFSSpv << *BM;
+ if (BM->getError(ErrMsg) != SPIRVEC_Success) {
+ return false;
+ }
+ SPIRVUseTextFormat = SaveOpt;
+ }
+ }
+
+ if (BM->getExtension().find("SPV_INTEL_function_variants") !=
+ BM->getExtension().end()) {
+ ErrMsg = "Instructions from SPV_INTEL_function_variants are not "
+ "convertible to LLVM IR.";
+ return false;
+ }
+
M = convertSpirvToLLVM(C, *BM, Opts, ErrMsg).release();
if (!M)
diff --git a/lib/SPIRV/libSPIRV/SPIRVEntry.cpp b/lib/SPIRV/libSPIRV/SPIRVEntry.cpp
index 43f75cbfd7..36fcd55de5 100644
--- a/lib/SPIRV/libSPIRV/SPIRVEntry.cpp
+++ b/lib/SPIRV/libSPIRV/SPIRVEntry.cpp
@@ -42,6 +42,7 @@
#include "SPIRVBasicBlock.h"
#include "SPIRVDebug.h"
#include "SPIRVDecorate.h"
+#include "SPIRVFnVar.h"
#include "SPIRVFunction.h"
#include "SPIRVInstruction.h"
#include "SPIRVMemAliasingINTEL.h"
diff --git a/lib/SPIRV/libSPIRV/SPIRVEntry.h b/lib/SPIRV/libSPIRV/SPIRVEntry.h
index 97758e74f5..14e469ad73 100644
--- a/lib/SPIRV/libSPIRV/SPIRVEntry.h
+++ b/lib/SPIRV/libSPIRV/SPIRVEntry.h
@@ -913,6 +913,9 @@ class SPIRVCapability : public SPIRVEntryNoId<OpCapability> {
return ExtensionID::SPV_INTEL_subgroup_requirements;
case CapabilityFPFastMathModeINTEL:
return ExtensionID::SPV_INTEL_fp_fast_math_mode;
+ case CapabilityFunctionVariantsINTEL:
+ case CapabilitySpecConditionalINTEL:
+ return ExtensionID::SPV_INTEL_function_variants;
default:
return {};
}
diff --git a/lib/SPIRV/libSPIRV/SPIRVErrorEnum.h b/lib/SPIRV/libSPIRV/SPIRVErrorEnum.h
index ccdc4cc1e5..ff424aa0f4 100644
--- a/lib/SPIRV/libSPIRV/SPIRVErrorEnum.h
+++ b/lib/SPIRV/libSPIRV/SPIRVErrorEnum.h
@@ -30,6 +30,8 @@ _SPIRV_OP(UnsupportedVarArgFunction,
"Variadic functions other than 'printf' are not supported in SPIR-V.")
_SPIRV_OP(DeprecatedExtension,
"Feature requires the following deprecated SPIR-V extension:\n")
+_SPIRV_OP(InvalidNumberOfOperands,
+ "Number of operands does not match the expected count.")
/* This is the last error code to have a maximum valid value to compare to */
_SPIRV_OP(InternalMaxErrorCode, "Unknown error code")
diff --git a/lib/SPIRV/libSPIRV/SPIRVFnVar.cpp b/lib/SPIRV/libSPIRV/SPIRVFnVar.cpp
new file mode 100644
index 0000000000..1eb853093c
--- /dev/null
+++ b/lib/SPIRV/libSPIRV/SPIRVFnVar.cpp
@@ -0,0 +1,378 @@
+//===- SPIRVFnVar.cpp -===//
+//
+// The LLVM/SPIRV Translator
+//
+// Copyright (c) 2025 The Khronos Group Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+//===----------------------------------------------------------------------===//
+/// \file
+///
+/// This file implements functions declared in its header file with the help of
+/// additional helper functions.
+///
+//===----------------------------------------------------------------------===//
+
+#include "SPIRVFnVar.h"
+
+using namespace SPIRV;
+
+namespace {
+
+// Replace SPIRV value with OpConstantTrue/False based on a boolean value.
+void replaceWithBoolConst(SPIRVModule *BM, SPIRVValue *&OldVal, bool Val) {
+ auto *NewVal =
+ Val ? static_cast<SPIRVValue *>(
+ new SPIRVConstantTrue(BM, OldVal->getType(), OldVal->getId()))
+ : static_cast<SPIRVValue *>(
+ new SPIRVConstantFalse(BM, OldVal->getType(), OldVal->getId()));
+ [[maybe_unused]] bool IsSuccess = BM->eraseValue(OldVal);
+ assert(IsSuccess);
+ OldVal = BM->addConstant(NewVal);
+}
+
+// Evaluate a constant pointed to by 'Id' and store the result in 'Res'.
+// Evaluated instruction is replaced with OpConstantTrue/False depending on the
+// result.
+//
+// To keep this simple, only boolean constants and a subset of OpSpecConstantOp
+// operations are allowed. Support for more can be implemented as necessary.
+bool evaluateConstant(SPIRVModule *BM, SPIRVId Id, bool &Res,
+ std::string &ErrMsg) {
+ auto *BV = BM->getValue(Id);
+ const Op OpCode = BV->getOpCode();
+
+ assert(isConstantOpCode(OpCode));
+ assert(BV->getType()->getOpCode() == spv::OpTypeBool);
+
+ SPIRVWord SpecId = 0;
+ if (BV->hasDecorate(DecorationSpecId, 0, &SpecId)) {
+ if (OpCode != OpSpecConstantTrue && OpCode != OpSpecConstantFalse &&
+ OpCode != OpSpecConstantArchitectureINTEL &&
+ OpCode != OpSpecConstantTargetINTEL &&
+ OpCode != OpSpecConstantCapabilitiesINTEL) {
+ ErrMsg = "Setting only boolean spec constants is supported";
+ return false;
+ }
+
+ bool IsTrue = OpCode == OpSpecConstantTrue;
+ uint64_t ConstValue = 0;
+ if (BM->getSpecializationConstant(SpecId, ConstValue)) {
+ IsTrue = ConstValue;
+ }
+ Res = IsTrue;
+ replaceWithBoolConst(BM, BV, Res);
+ return true;
+ }
+
+ switch (OpCode) {
+ case OpConstantTrue: {
+ Res = true;
+ break;
+ }
+ case OpConstantFalse: {
+ Res = false;
+ break;
+ }
+ case OpSpecConstantTrue: {
+ Res = true;
+ replaceWithBoolConst(BM, BV, Res);
+ break;
+ }
+ case OpSpecConstantFalse: {
+ Res = false;
+ replaceWithBoolConst(BM, BV, Res);
+ break;
+ }
+ case OpSpecConstantArchitectureINTEL: {
+ Res =
+ static_cast<SPIRVSpecConstantArchitectureINTEL *>(BV)->matchesDevice();
+ replaceWithBoolConst(BM, BV, Res);
+ break;
+ }
+ case OpSpecConstantTargetINTEL: {
+ Res = static_cast<SPIRVSpecConstantTargetINTEL *>(BV)->matchesDevice();
+ replaceWithBoolConst(BM, BV, Res);
+ break;
+ }
+ case OpSpecConstantCapabilitiesINTEL: {
+ Res =
+ static_cast<SPIRVSpecConstantCapabilitiesINTEL *>(BV)->matchesDevice();
+ replaceWithBoolConst(BM, BV, Res);
+ break;
+ }
+ case OpSpecConstantOp: {
+ auto OpWords = static_cast<SPIRVSpecConstantOp *>(BV)->getOpWords();
+ auto OpOpCode = static_cast<Op>(OpWords[0]);
+ if (OpOpCode != OpLogicalOr && OpOpCode != OpLogicalAnd &&
+ OpOpCode != OpLogicalNot) {
+ ErrMsg = "Unsupported operation: Only OpLogicalOr/And/Not are allowed.";
+ return false;
+ }
+
+ bool Val1 = false;
+ if (!evaluateConstant(BM, OpWords[1], Val1, ErrMsg)) {
+ return false;
+ }
+
+ if (OpOpCode == OpLogicalNot) {
+ assert(OpWords.size() == 2);
+ if (OpOpCode == OpLogicalNot) {
+ Res = !Val1;
+ }
+ } else {
+ assert(OpWords.size() == 3);
+ bool Val2 = false;
+ if (!evaluateConstant(BM, OpWords[2], Val2, ErrMsg)) {
+ return false;
+ }
+
+ if (OpOpCode == OpLogicalOr) {
+ Res = Val1 || Val2;
+ } else if (OpOpCode == OpLogicalAnd) {
+ Res = Val1 && Val2;
+ }
+ }
+
+ replaceWithBoolConst(BM, BV, Res);
+ break;
+ }
+ default: {
+ std::ostringstream S;
+ S << "Evaluating unsupported instruction, opcode: " << OpCode;
+ ErrMsg = S.str();
+ return false;
+ }
+ }
+
+ return true;
+}
+
+} // anonymous namespace
+
+namespace SPIRV {
+
+bool specializeFnVariants(SPIRVModule *BM, std::string &ErrMsg) {
+ // Specialize conditional capabilities
+ std::vector<std::pair<std::pair<SPIRVId, Capability>, bool>> CondCapabilities;
+ for (const auto &CondCap : BM->getConditionalCapabilities()) {
+ const SPIRVId Condition = CondCap.first.first;
+ const Capability Cap = CondCap.first.second;
+ const SPIRVConditionalCapabilityINTEL *Entry = CondCap.second;
+ bool ShouldKeep = false;
+ if (!evaluateConstant(BM, Entry->getCondition(), ShouldKeep, ErrMsg)) {
+ return false;
+ }
+ CondCapabilities.emplace_back(
+ std::make_pair(std::make_pair(Condition, Cap), ShouldKeep));
+ }
+
+ for (const auto &CondCap : CondCapabilities) {
+ const SPIRVId Condition = CondCap.first.first;
+ const Capability Cap = CondCap.first.second;
+ const bool ShouldKeep = CondCap.second;
+ if (ShouldKeep) {
+ BM->addCapability(Cap);
+ } else {
+ // In case the capability was auto-added by other instruction
+ BM->eraseCapability(Cap);
+ }
+ BM->eraseConditionalCapability(Condition, Cap);
+ }
+
+ // Specialize conditional extensions
+ std::vector<std::pair<std::pair<uint32_t, std::string>, bool>> CondExtensions;
+ for (const auto &CondExt : BM->getConditionalExtensions()) {
+ const SPIRVId Cond = CondExt.first;
+ const std::string Ext = CondExt.second;
+ bool ShouldKeep = false;
+ if (!evaluateConstant(BM, Cond, ShouldKeep, ErrMsg)) {
+ return false;
+ }
+ CondExtensions.emplace_back(
+ std::make_pair(std::make_pair(Cond, Ext), ShouldKeep));
+ }
+
+ for (const auto &CondExt : CondExtensions) {
+ const auto Ext = CondExt.first;
+ const bool ShouldKeep = CondExt.second;
+ if (ShouldKeep) {
+ BM->getExtension().insert(Ext.second);
+ } else {
+ // In case the extension was auto-added by other instruction
+ BM->getExtension().erase(Ext.second);
+ }
+ BM->getConditionalExtensions().erase(Ext);
+ }
+
+ // Specialize conditional entry points
+ std::vector<std::pair<SPIRVId, bool>> CondEPs;
+ for (const auto &CondEP : BM->getConditionalEntryPoints()) {
+ const SPIRVId Cond = CondEP->getCondition();
+ bool ShouldKeep = false;
+ if (!evaluateConstant(BM, Cond, ShouldKeep, ErrMsg)) {
+ return false;
+ }
+ CondEPs.emplace_back(std::make_pair(Cond, ShouldKeep));
+ }
+
+ for (const auto &CondEP : CondEPs) {
+ const SPIRVId Cond = CondEP.first;
+ const bool ShouldKeep = CondEP.second;
+ BM->specializeConditionalEntryPoints(Cond, ShouldKeep);
+ }
+
+ // Specialize conditional copy object
+ std::vector<std::pair<SPIRVInstruction *, SPIRVId>> ToReplace;
+ for (unsigned IF = 0; IF < BM->getNumFunctions(); ++IF) {
+ const auto *Fun = BM->getFunction(IF);
+ for (unsigned IB = 0; IB < Fun->getNumBasicBlock(); ++IB) {
+ const auto *BB = Fun->getBasicBlock(IB);
+ for (unsigned II = 0; II < BB->getNumInst(); ++II) {
+ auto *Inst = BB->getInst(II);
+ if (Inst->getOpCode() == OpConditionalCopyObjectINTEL) {
+ const auto OperandIds =
+ static_cast<SPIRVConditionalCopyObjectINTEL *>(Inst)
+ ->getOperandIds();
+ std::optional<unsigned> ITrue = std::nullopt;
+ for (unsigned IO = 0; IO < OperandIds.size(); IO += 2) {
+ const auto CondId = OperandIds[IO];
+ bool Res;
+ if (!evaluateConstant(BM, CondId, Res, ErrMsg)) {
+ return false;
+ }
+ if (Res) {
+ // Stop at the first condition operand that evaluates to true
+ ITrue = IO;
+ break;
+ }
+ }
+ if (!ITrue.has_value()) {
+ ErrMsg = "At least one conditional of OpConditionalCopyObjectINTEL "
+ "must be true. This could mean that all function variants "
+ "have been removed.";
+ return false;
+ }
+ ToReplace.emplace_back(
+ std::make_pair(Inst, OperandIds[ITrue.value() + 1]));
+ }
+ }
+ }
+ }
+
+ for (auto &It : ToReplace) {
+ auto *OldInst = It.first;
+ auto *BB = OldInst->getBasicBlock();
+ auto *NextInst = OldInst->getNext();
+ auto *Operand = BM->getValue(It.second);
+ auto *NewInst =
+ new SPIRVCopyObject(OldInst->getType(), OldInst->getId(), Operand, BB);
+ BM->eraseInstruction(OldInst, BB);
+ BB->addInstruction(NewInst, NextInst);
+ }
+
+ // Specialize IDs annotated with ConditionalINTEL decorations
+ auto *Decors = BM->getDecorateVec();
+ std::vector<SPIRVId> IdsToRemove;
+ for (const auto &D : *Decors) {
+ if (D->getDecorateKind() == DecorationConditionalINTEL) {
+ const SPIRVId ConstId = static_cast<SPIRVWord>(D->getLiteral(0));
+ bool ShouldKeep = false;
+ if (!evaluateConstant(BM, ConstId, ShouldKeep, ErrMsg)) {
+ return false;
+ }
+ if (!ShouldKeep) {
+ IdsToRemove.push_back(D->getTargetId());
+ }
+ }
+ }
+
+ for (const auto &Id : IdsToRemove) {
+ if (!BM->eraseReferencesOfInst(Id)) {
+ ErrMsg = "Error removing references of instruction decorated with "
+ "ConditionalINTEL";
+ return false;
+ }
+ auto *Val = BM->getValue(Id);
+ if (Val->getOpCode() == OpFunctionCall) {
+ auto *Call = static_cast<SPIRVFunctionCall *>(Val);
+ auto *BB = Call->getBasicBlock();
+ BM->eraseInstruction(Call, BB);
+ } else if (Val->getOpCode() == OpFunction) {
+ const auto *Fun = static_cast<const SPIRVFunction *>(Val);
+ for (unsigned I = 0; I < Fun->getNumArguments(); ++I) {
+ const auto ArgId = Fun->getArgumentId(I);
+ if (!BM->eraseReferencesOfInst(ArgId)) {
+ ErrMsg = "Error erasing references of argument of a function "
+ "annotated with ConditionalINTEL";
+ return false;
+ }
+ }
+ for (const auto &VarId : Fun->getVariables()) {
+ if (!BM->eraseReferencesOfInst(VarId)) {
+ ErrMsg = "Error erasing references of variable within function "
+ "annotated with ConditionalINTEL";
+ return false;
+ }
+ }
+ for (unsigned IB = 0; IB < Fun->getNumBasicBlock(); ++IB) {
+ const auto *const BB = Fun->getBasicBlock(IB);
+ for (unsigned II = 0; II < BB->getNumInst(); ++II) {
+ const auto *const Inst = BB->getInst(II);
+ if (Inst->hasId()) {
+ const auto InstId = Inst->getId();
+ if (!BM->eraseReferencesOfInst(InstId)) {
+ ErrMsg = "Error erasing references of instruction within "
+ "function annotated with ConditionalINTEL";
+ return false;
+ }
+ }
+ }
+ if (!BM->eraseReferencesOfInst(BB->getId())) {
+ ErrMsg = "Error erasing references of basic block label within "
+ "function annotated with ConditionalINTEL";
+ return false;
+ }
+ }
+ erase_if(*BM->getFuncVec(), [Id](auto F) { return F->getId() == Id; });
+ } else if (Val->getOpCode() == OpVariable ||
+ isTypeOpCode(Val->getOpCode()) ||
+ Val->getOpCode() == OpExtInstImport ||
+ isConstantOpCode(Val->getOpCode()) ||
+ Val->getOpCode() == OpAsmINTEL ||
+ Val->getOpCode() == OpAsmTargetINTEL) {
+ if (!BM->eraseValue(Val)) {
+ ErrMsg = "Error erasing value annotated with ConditionalINTEL";
+ return false;
+ }
+ } else {
+ ErrMsg = "Unsupported instruction annotated with ConditionalINTEL";
+ return false;
+ }
+ }
+
+ // Remove any leftover ConditionalINTEL decorations
+ erase_if(*Decors, [](auto D) {
+ return D->getDecorateKind() == DecorationConditionalINTEL;
+ });
+
+ // Remove capabilities/extensions of SPV_INTEL_function_variants
+ BM->eraseCapability(CapabilityFunctionVariantsINTEL);
+ BM->eraseCapability(CapabilitySpecConditionalINTEL);
+ BM->getExtension().erase("SPV_INTEL_function_variants");
+
+ return true;
+}
+
+} // namespace SPIRV
diff --git a/lib/SPIRV/libSPIRV/SPIRVFnVar.h b/lib/SPIRV/libSPIRV/SPIRVFnVar.h
new file mode 100644
index 0000000000..acf904e6da
--- /dev/null
+++ b/lib/SPIRV/libSPIRV/SPIRVFnVar.h
@@ -0,0 +1,439 @@
+//===- SPIRVFnVar.cpp -===//
+//
+// The LLVM/SPIRV Translator
+//
+// Copyright (c) 2025 The Khronos Group Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+//===----------------------------------------------------------------------===//
+/// \file
+///
+/// This file defines entries from the SPV_INTEL_function_variants extension. It
+/// also provides a function for specializing a multi-target module into a
+/// targeted one.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef SPIRV_LIBSPIRV_SPIRVFNVAR_H
+#define SPIRV_LIBSPIRV_SPIRVFNVAR_H
+
+#include "SPIRVEntry.h"
+#include "SPIRVInstruction.h"
+#include "SPIRVValue.h"
+
+#if _SPIRVDBG
+#include <iomanip>
+#endif
+
+namespace SPIRV {
+
+// Specialize multitarget module into a targeted one according to
+// SPV_INTEL_function_variants.
+//
+// This is a SPIR-V to SPIR-V transform. After the transform, the module should
+// not contain any notion of SPV_INTEL_function_variants and can be processed
+// even by consumers that do not support this extension.
+bool specializeFnVariants(SPIRVModule *BM, std::string &ErrMsg);
+
+// Below are entries defined by the extension:
+
+class SPIRVConditionalEntryPointINTEL : public SPIRVAnnotation {
+public:
+ static const SPIRVWord FixedWC = 5;
+ SPIRVConditionalEntryPointINTEL(SPIRVModule *TheModule, SPIRVId Condition,
+ SPIRVExecutionModelKind TheExecModel,
+ SPIRVId TheId, const std::string &TheName,
+ std::vector<SPIRVId> Variables)
+ : SPIRVAnnotation(OpConditionalEntryPointINTEL,
+ TheModule->get<SPIRVFunction>(TheId),
+ getSizeInWords(TheName) + Variables.size() + 4),
+ Condition(Condition), ExecModel(TheExecModel), Name(TheName),
+ Variables(Variables) {}
+ SPIRVConditionalEntryPointINTEL()
+ : SPIRVAnnotation(OpConditionalEntryPointINTEL) {}
+
+ SPIRVId getCondition() const { return Condition; }
+ SPIRVExecutionModelKind getExecModel() const { return ExecModel; }
+ std::string getName() const { return Name; }
+ std::vector<SPIRVId> getVariables() const { return Variables; }
+
+protected:
+ void encode(spv_ostream &O) const override {
+ getEncoder(O) << Condition << ExecModel << Target << Name << Variables;
+ }
+
+ void decode(std::istream &I) override {
+ getDecoder(I) >> Condition >> ExecModel >> Target >> Name;
+ Variables.resize(WordCount - FixedWC - getSizeInWords(Name) + 1);
+ getDecoder(I) >> Variables;
+ Module->setName(getOrCreateTarget(), Name);
+ Module->addConditionalEntryPoint(Condition, ExecModel, Target, Name,
+ Variables);
+ }
+
+private:
+ SPIRVId Condition;
+ SPIRVExecutionModelKind ExecModel = ExecutionModelMax;
+ std::string Name;
+ std::vector<SPIRVId> Variables;
+};
+
+class SPIRVConditionalExtensionINTEL
+ : public SPIRVEntryNoId<OpConditionalExtensionINTEL> {
+public:
+ SPIRVConditionalExtensionINTEL(SPIRVModule *M, SPIRVId C,
+ const std::string &SS)
+ : SPIRVEntryNoId(M, 2 + getSizeInWords(SS)), Condition(C), S(SS) {}
+ SPIRVConditionalExtensionINTEL() {}
+
+ std::string getExtensionName() const { return S; }
+ SPIRVId getCondition() const { return Condition; }
+
+protected:
+ void encode(spv_ostream &O) const override {
+ getEncoder(O) << Condition << S;
+ }
+
+ void decode(std::istream &I) override {
+ getDecoder(I) >> Condition >> S;
+ Module->getConditionalExtensions().insert(std::make_pair(Condition, S));
+ }
+
+private:
+ SPIRVId Condition;
+ std::string S;
+};
+
+class SPIRVConditionalCapabilityINTEL
+ : public SPIRVEntryNoId<OpConditionalCapabilityINTEL> {
+public:
+ static const SPIRVWord FixedWC = 3;
+ SPIRVConditionalCapabilityINTEL(SPIRVModule *M, SPIRVId C,
+ SPIRVCapabilityKind K)
+ : SPIRVEntryNoId(M, 3), Condition(C), Kind(K) {
+ updateModuleVersion();
+ }
+ SPIRVConditionalCapabilityINTEL() {}
+
+ SPIRVId getCondition() const { return Condition; }
+
+protected:
+ void encode(spv_ostream &O) const override {
+ getEncoder(O) << Condition << Kind;
+ }
+
+ void decode(std::istream &I) override {
+ getDecoder(I) >> Condition >> Kind;
+ Module->addConditionalCapability(Condition, Kind);
+ }
+
+private:
+ SPIRVId Condition;
+ SPIRVCapabilityKind Kind;
+};
+
+class SPIRVConditionalCopyObjectINTEL : public SPIRVInstruction {
+public:
+ const static Op OC = OpConditionalCopyObjectINTEL;
+ const static SPIRVWord FixedWordCount = 3;
+
+ // Complete constructor
+ SPIRVConditionalCopyObjectINTEL(SPIRVType *TheType, SPIRVId TheId,
+ const std::vector<SPIRVId> &TheConstituents,
+ SPIRVBasicBlock *TheBB)
+ : SPIRVInstruction(FixedWordCount + TheConstituents.size(), OC, TheType,
+ TheId, TheBB),
+ Constituents(TheConstituents) {
+ validate();
+ assert(TheBB && "Invalid BB");
+ }
+ // Incomplete constructor
+ SPIRVConditionalCopyObjectINTEL() : SPIRVInstruction(OC) {}
+
+ std::vector<SPIRVId> getOperandIds() { return Constituents; }
+
+ std::vector<SPIRVValue *> getOperands() override {
+ return getValues(Constituents);
+ }
+
+protected:
+ void setWordCount(SPIRVWord TheWordCount) override {
+ SPIRVEntry::setWordCount(TheWordCount);
+ Constituents.resize(TheWordCount - FixedWordCount);
+ }
+ _SPIRV_DEF_ENCDEC3(Type, Id, Constituents)
+ void validate() const override {
+ SPIRVInstruction::validate();
+ size_t TypeOpCode = this->getType()->getOpCode();
+ assert(TypeOpCode != OpTypeVoid && "Conditional copy type cannot be void");
+ (void)(TypeOpCode);
+ assert(Constituents.size() % 2 == 0 &&
+ "Conditional copy requires condition-operand pairs");
+ assert(Constituents.size() >= 2 &&
+ "Conditional copy requires at least one condition-operand pair");
+ }
+ std::vector<SPIRVId> Constituents;
+};
+
+class SPIRVSpecConstantTargetINTEL : public SPIRVValue {
+public:
+ constexpr static SPIRVWord FixedWC = 4;
+ constexpr static spv::Op OC = OpSpecConstantTargetINTEL;
+
+ // Complete constructor
+ SPIRVSpecConstantTargetINTEL(SPIRVModule *M, SPIRVType *TheType,
+ SPIRVId TheId, SPIRVWord TheTarget,
+ const std::vector<SPIRVWord> TheFeatures)
+ : SPIRVValue(M, TheFeatures.size() + FixedWC, OC, TheType, TheId) {
+ Features = TheFeatures;
+ Target = TheTarget;
+ NumWords = TheFeatures.size() + FixedWC;
+ validate();
+ }
+ // Incomplete constructor
+ SPIRVSpecConstantTargetINTEL() : SPIRVValue(OC) {}
+
+ SPIRVWord getTarget() const { return Target; }
+ bool matchesDevice() {
+ std::optional<SPIRVWord> DeviceTarget = getModule()->getFnVarTarget();
+ std::vector<SPIRVWord> DeviceFeatures = getModule()->getFnVarFeatures();
+ bool Res = true;
+ if (DeviceTarget != std::nullopt && DeviceTarget.value() != Target) {
+ Res = false;
+ }
+ if (!DeviceFeatures.empty()) {
+ for (const auto &Feature : Features) {
+ if (std::find(DeviceFeatures.cbegin(), DeviceFeatures.cend(),
+ Feature) == DeviceFeatures.cend()) {
+ Res = false;
+ }
+ }
+ }
+
+ SPIRVDBG(
+ spvdbgs() << "[FnVar] match instr Target: ";
+ spvdbgs() << std::setw(4) << Target;
+ spvdbgs() << ", Features:"; if (Features.empty()) {
+ spvdbgs() << " none";
+ } else {
+ for (const auto &Feat : Features) {
+ spvdbgs() << " " << Feat;
+ }
+ } spvdbgs() << " | ID: %"
+ << getId() << std::endl;
+ spvdbgs() << "[FnVar] device Target: "; if (DeviceTarget ==
+ std::nullopt) {
+ spvdbgs() << "none";
+ } else {
+ spvdbgs() << std::setw(4) << DeviceTarget.value();
+ } spvdbgs() << ", Features:";
+ for (const auto &Feat : DeviceFeatures) {
+ spvdbgs() << " " << Feat;
+ } spvdbgs()
+ << std::endl;
+ spvdbgs() << "[FnVar] result: " << Res << std::endl;);
+
+ return Res;
+ }
+
+protected:
+ _SPIRV_DEF_ENCDEC4(Type, Id, Target, Features);
+ void setWordCount(SPIRVWord WordCount) override {
+ SPIRVEntry::setWordCount(WordCount);
+ Features.resize(WordCount - FixedWC);
+ NumWords = WordCount - FixedWC;
+ }
+
+private:
+ unsigned NumWords;
+ std::vector<SPIRVWord> Features;
+ SPIRVWord Target;
+};
+
+class SPIRVSpecConstantArchitectureINTEL : public SPIRVValue {
+public:
+ constexpr static SPIRVWord FixedWC = 7;
+ constexpr static spv::Op OC = OpSpecConstantArchitectureINTEL;
+
+ // Complete constructor
+ SPIRVSpecConstantArchitectureINTEL(SPIRVModule *M, SPIRVType *TheType,
+ SPIRVId TheId, SPIRVWord TheCategory,
+ SPIRVWord TheFamily, spv::Op TheCmpOp,
+ SPIRVWord TheArchitecture)
+ : SPIRVValue(M, FixedWC, OC, TheType, TheId) {
+ Category = TheCategory;
+ Family = TheFamily;
+ CmpOp = TheCmpOp;
+ Architecture = TheArchitecture;
+ validate();
+ }
+ // Incomplete constructor
+ SPIRVSpecConstantArchitectureINTEL() : SPIRVValue(OC) {}
+
+ SPIRVWord getCategory() { return Category; }
+ SPIRVWord getFamily() { return Family; }
+ spv::Op getCmpOp() { return CmpOp; }
+ SPIRVWord getArchitecture() { return Architecture; }
+ bool matchesDevice() {
+ std::optional<SPIRVWord> DeviceCategory = getModule()->getFnVarCategory();
+ std::optional<SPIRVWord> DeviceFamily = getModule()->getFnVarFamily();
+ std::optional<SPIRVWord> DeviceArchitecture = getModule()->getFnVarArch();
+ bool Res = true;
+
+ if (DeviceCategory != std::nullopt && DeviceCategory.value() != Category) {
+ Res = false;
+ }
+ if (DeviceFamily != std::nullopt && DeviceFamily.value() != Family) {
+ Res = false;
+ }
+ if (DeviceArchitecture != std::nullopt) {
+ switch (CmpOp) {
+ case OpIEqual:
+ Res = DeviceArchitecture == Architecture;
+ break;
+ case OpINotEqual:
+ Res = DeviceArchitecture != Architecture;
+ break;
+ case OpULessThan:
+ Res = DeviceArchitecture < Architecture;
+ break;
+ case OpULessThanEqual:
+ Res = DeviceArchitecture <= Architecture;
+ break;
+ case OpUGreaterThan:
+ Res = DeviceArchitecture > Architecture;
+ break;
+ case OpUGreaterThanEqual:
+ Res = DeviceArchitecture >= Architecture;
+ break;
+ default:
+ assert(false && "Invalid checked CmpOp");
+ Res = false;
+ break;
+ }
+ }
+
+ SPIRVDBG(
+ spvdbgs() << "[FnVar] match instr Category: " << std::setw(4)
+ << Category
+
+ << ", Family: " << std::setw(4) << Family << ", CmpOp: "
+ << std::setw(4) << CmpOp << ", Architecture: " << std::setw(4)
+ << Architecture << " | ID: %" << getId() << std::endl;
+ spvdbgs() << "[FnVar] device Category: "; if (DeviceCategory ==
+ std::nullopt) {
+ spvdbgs() << "none";
+ } else {
+ spvdbgs() << std::setw(4) << DeviceCategory.value();
+ } spvdbgs() << ", Family: ";
+ if (DeviceFamily == std::nullopt) { spvdbgs() << "none"; } else {
+ spvdbgs() << std::setw(4) << DeviceFamily.value();
+ } spvdbgs()
+ << ", Architecture: ";
+ if (DeviceArchitecture == std::nullopt) { spvdbgs() << "none"; } else {
+ spvdbgs() << std::setw(4) << DeviceArchitecture.value();
+ } spvdbgs()
+ << std::endl;
+ spvdbgs() << "[FnVar] result: " << Res << std::endl;);
+
+ return Res;
+ }
+
+protected:
+ _SPIRV_DEF_ENCDEC6(Type, Id, Category, Family, CmpOp, Architecture);
+
+private:
+ SPIRVWord Category;
+ SPIRVWord Family;
+ spv::Op CmpOp;
+ SPIRVWord Architecture;
+};
+
+class SPIRVSpecConstantCapabilitiesINTEL : public SPIRVValue {
+public:
+ constexpr static SPIRVWord FixedWC = 3;
+ constexpr static spv::Op OC = OpSpecConstantCapabilitiesINTEL;
+
+ // Complete constructor
+ SPIRVSpecConstantCapabilitiesINTEL(
+ SPIRVModule *M, SPIRVType *TheType, SPIRVId TheId,
+ const std::vector<SPIRVWord> TheCapabilities)
+ : SPIRVValue(M, TheCapabilities.size() + FixedWC, OC, TheType, TheId) {
+ Capabilities = TheCapabilities;
+ NumWords = TheCapabilities.size() + FixedWC;
+ validate();
+ }
+ // Incomplete constructor
+ SPIRVSpecConstantCapabilitiesINTEL() : SPIRVValue(OC) {}
+
+ std::vector<SPIRVWord> getCapabilities() const { return Capabilities; }
+ bool matchesDevice() {
+ std::vector<SPIRVWord> DeviceCapabilities =
+ getModule()->getFnVarCapabilities();
+ bool Res = true;
+
+ if (!DeviceCapabilities.empty()) {
+ for (const auto &Capability : Capabilities) {
+ if (std::find(DeviceCapabilities.cbegin(), DeviceCapabilities.cend(),
+ Capability) == DeviceCapabilities.cend()) {
+ Res = false;
+ }
+ }
+ }
+
+ SPIRVDBG(
+ spvdbgs() << "[FnVar] match instr Capabilities: ";
+ if (Capabilities.empty()) { spvdbgs() << "none"; } else {
+ for (const auto &Cap : Capabilities) {
+ spvdbgs() << " " << Cap;
+ }
+ } spvdbgs()
+ << " | ID: %" << getId() << std::endl;
+ spvdbgs() << "[FnVar] device Capabilities: ";
+ for (const auto &Cap : DeviceCapabilities) {
+ spvdbgs() << " " << Cap;
+ } spvdbgs()
+ << std::endl;
+ spvdbgs() << "[FnVar] result: " << Res << std::endl;);
+
+ return Res;
+ }
+
+protected:
+ _SPIRV_DEF_ENCDEC3(Type, Id, Capabilities);
+
+ void validate() const override {
+ SPIRVValue::validate();
+ if (Capabilities.size() < 1) {
+ std::stringstream SS;
+ SS << "Id: " << Id << ", OpCode: " << OpCodeNameMap::map(OpCode)
+ << ", Name: \"" << Name << "\". Expected at least one capability.\n";
+ getErrorLog().checkError(false, SPIRVEC_InvalidNumberOfOperands,
+ SS.str());
+ }
+ }
+
+ void setWordCount(SPIRVWord WordCount) override {
+ SPIRVEntry::setWordCount(WordCount);
+ Capabilities.resize(WordCount - FixedWC);
+ NumWords = WordCount - FixedWC;
+ }
+
+private:
+ unsigned NumWords;
+ std::vector<SPIRVWord> Capabilities;
+};
+
+} // namespace SPIRV
+#endif // SPIRV_LIBSPIRV_SPIRVFNVAR_H
diff --git a/lib/SPIRV/libSPIRV/SPIRVModule.cpp b/lib/SPIRV/libSPIRV/SPIRVModule.cpp
index eda598ee18..ea82be35e9 100644
--- a/lib/SPIRV/libSPIRV/SPIRVModule.cpp
+++ b/lib/SPIRV/libSPIRV/SPIRVModule.cpp
@@ -41,11 +41,13 @@
#include "SPIRVAsm.h"
#include "SPIRVDebug.h"
#include "SPIRVEntry.h"
-#include "SPIRVExtInst.h"
+#include "SPIRVEnum.h"
+#include "SPIRVFnVar.h"
#include "SPIRVFunction.h"
#include "SPIRVInstruction.h"
#include "SPIRVMemAliasingINTEL.h"
#include "SPIRVNameMapEnum.h"
+#include "SPIRVOpCode.h"
#include "SPIRVStream.h"
#include "SPIRVType.h"
#include "SPIRVValue.h"
@@ -122,14 +124,29 @@ class SPIRVModuleImpl : public SPIRVModule {
SPIRVAddressingModelKind getAddressingModel() override { return AddrModel; }
SPIRVExtInstSetKind getBuiltinSet(SPIRVId SetId) const override;
const SPIRVCapMap &getCapability() const override { return CapMap; }
+ const SPIRVConditionalCapMap &getConditionalCapabilities() const override {
+ return ConditionalCapMap;
+ }
+ const SPIRVConditionalEntryPointVec &
+ getConditionalEntryPoints() const override {
+ return ConditionalEntryPointVec;
+ }
bool hasCapability(SPIRVCapabilityKind Cap) const override {
return CapMap.find(Cap) != CapMap.end();
}
std::set<std::string> &getExtension() override { return SPIRVExt; }
+ SPIRVConditionalExtensionSet &getConditionalExtensions() override {
+ return SPIRVCondExt;
+ }
SPIRVFunction *getFunction(unsigned I) const override { return FuncVec[I]; }
SPIRVVariableBase *getVariable(unsigned I) const override {
return VariableVec[I];
}
+ SPIRVValue *getConst(unsigned I) const override { return ConstVec[I]; }
+ std::vector<SPIRVDecorateGeneric *> *getDecorateVec() override {
+ return &DecorateVec;
+ }
+ std::vector<SPIRVFunction *> *getFuncVec() override { return &FuncVec; }
SPIRVValue *getValue(SPIRVId TheId) const override;
std::vector<SPIRVValue *>
getValues(const std::vector<SPIRVId> &) const override;
@@ -142,6 +159,7 @@ class SPIRVModuleImpl : public SPIRVModule {
SPIRVConstant *getLiteralAsConstant(unsigned Literal) override;
unsigned getNumFunctions() const override { return FuncVec.size(); }
unsigned getNumVariables() const override { return VariableVec.size(); }
+ unsigned getNumConsts() const override { return ConstVec.size(); }
std::vector<SPIRVValue *> getFunctionPointers() const override {
std::vector<SPIRVValue *> Res;
for (auto *C : ConstVec)
@@ -204,6 +222,38 @@ class SPIRVModuleImpl : public SPIRVModule {
SPIRVVersion = Ver;
}
+ bool eraseReferencesOfInst(SPIRVId Id) override {
+ const auto *const Entry = getEntry(Id);
+
+ if (!Entry)
+ return false;
+
+ if (!Entry->hasId())
+ return false;
+
+ // Remove all OpNames referencing the Id
+ if (NamedId.find(Id) != NamedId.end())
+ NamedId.erase(Id);
+
+ // Remove all OpMemberNames referencing the Id
+ erase_if(MemberNameVec, [Id](auto M) { return M->getTargetId() == Id; });
+
+ // Remove all decorations of the Id
+ erase_if(DecorateVec, [Id](auto D) { return D->getTargetId() == Id; });
+
+ // If Id points to a function, remove its OpEntryPoint, OpExecutionMode
+ // and OpExecutionModeId
+ if (Entry->getOpCode() == OpFunction)
+ erase_if(EntryPointVec,
+ [Id](auto EP) { return EP->getTargetId() == Id; });
+
+ return true;
+ }
+
+ void eraseCapability(SPIRVCapabilityKind CapKind) override {
+ CapMap.erase(CapKind);
+ }
+
// Object creation functions
template <class T> void addTo(std::vector<T *> &V, SPIRVEntry *E);
SPIRVEntry *addEntry(SPIRVEntry *E) override;
@@ -226,6 +276,8 @@ class SPIRVModuleImpl : public SPIRVModule {
const std::shared_ptr<const SPIRVExtInst> &DebugLine) override;
void addCapability(SPIRVCapabilityKind) override;
void addCapabilityInternal(SPIRVCapabilityKind) override;
+ void addConditionalCapability(SPIRVId, SPIRVCapabilityKind) override;
+ void eraseConditionalCapability(SPIRVId, SPIRVCapabilityKind) override;
void addExtension(ExtensionID) override;
const SPIRVDecorateGeneric *addDecorate(SPIRVDecorateGeneric *) override;
SPIRVDecorationGroup *addDecorationGroup() override;
@@ -242,12 +294,17 @@ class SPIRVModuleImpl : public SPIRVModule {
void addEntryPoint(SPIRVExecutionModelKind ExecModel, SPIRVId EntryPoint,
const std::string &Name,
const std::vector<SPIRVId> &Variables) override;
+ void addConditionalEntryPoint(SPIRVId, SPIRVExecutionModelKind ExecModel,
+ SPIRVId EntryPoint, const std::string &Name,
+ const std::vector<SPIRVId> &Variables) override;
+ void specializeConditionalEntryPoints(SPIRVId, bool) override;
SPIRVForward *addForward(SPIRVType *Ty) override;
SPIRVForward *addForward(SPIRVId, SPIRVType *Ty) override;
SPIRVFunction *addFunction(SPIRVFunction *) override;
SPIRVFunction *addFunction(SPIRVTypeFunction *, SPIRVId) override;
SPIRVEntry *replaceForward(SPIRVForward *, SPIRVEntry *) override;
void eraseInstruction(SPIRVInstruction *, SPIRVBasicBlock *) override;
+ bool eraseValue(SPIRVValue *) override;
// Type creation functions
template <class T> T *addType(T *Ty);
@@ -516,6 +573,7 @@ class SPIRVModuleImpl : public SPIRVModule {
SPIRVWord SrcLangVer;
std::set<std::string> SrcExtension;
std::set<std::string> SPIRVExt;
+ SPIRVConditionalExtensionSet SPIRVCondExt;
SPIRVAddressingModelKind AddrModel;
SPIRVMemoryModelKind MemoryModel;
@@ -571,8 +629,11 @@ class SPIRVModuleImpl : public SPIRVModule {
SPIRVAsmVector AsmVec;
SPIRVExecModelIdSetMap EntryPointSet;
SPIRVEntryPointVec EntryPointVec;
+ SPIRVExecModelIdSetMap ConditionalEntryPointSet;
+ SPIRVConditionalEntryPointVec ConditionalEntryPointVec;
SPIRVStringMap StrMap;
SPIRVCapMap CapMap;
+ SPIRVConditionalCapMap ConditionalCapMap;
SPIRVUnknownStructFieldMap UnknownStructFieldMap;
SPIRVTypeBool *BoolTy;
SPIRVTypeVoid *VoidTy;
@@ -605,6 +666,9 @@ SPIRVModuleImpl::~SPIRVModuleImpl() {
for (auto C : CapMap)
delete C.second;
+ for (auto C : ConditionalCapMap)
+ delete C.second;
+
for (auto *M : ModuleProcessedVec)
delete M;
}
@@ -747,6 +811,30 @@ void SPIRVModuleImpl::addCapabilityInternal(SPIRVCapabilityKind Cap) {
}
}
+void SPIRVModuleImpl::addConditionalCapability(SPIRVId Condition,
+ SPIRVCapabilityKind Cap) {
+ SPIRVDBG(spvdbgs() << "addConditionalCapability: "
+ << SPIRVCapabilityNameMap::map(Cap)
+ << ", condition: " << Condition << '\n');
+ if (ConditionalCapMap.find(std::make_pair(Condition, Cap)) !=
+ ConditionalCapMap.end()) {
+ return;
+ }
+
+ auto *CapObj = new SPIRVConditionalCapabilityINTEL(this, Condition, Cap);
+ if (AutoAddExtensions) {
+ assert(false && "Auto adding conditional extensions is not supported.");
+ }
+
+ ConditionalCapMap.insert(
+ std::make_pair(std::make_pair(Condition, Cap), CapObj));
+}
+
+void SPIRVModuleImpl::eraseConditionalCapability(SPIRVId Condition,
+ SPIRVCapabilityKind Cap) {
+ ConditionalCapMap.erase(std::make_pair(Condition, Cap));
+}
+
SPIRVConstant *SPIRVModuleImpl::getLiteralAsConstant(unsigned Literal) {
auto Loc = LiteralMap.find(Literal);
if (Loc != LiteralMap.end())
@@ -917,8 +1005,11 @@ bool SPIRVModuleImpl::isEntryPoint(SPIRVExecutionModelKind ExecModel,
assert(isValid(ExecModel) && "Invalid execution model");
assert(EP != SPIRVID_INVALID && "Invalid function id");
auto Loc = EntryPointSet.find(ExecModel);
- if (Loc == EntryPointSet.end())
- return false;
+ if (Loc == EntryPointSet.end()) {
+ Loc = ConditionalEntryPointSet.find(ExecModel);
+ if (Loc == ConditionalEntryPointSet.end())
+ return false;
+ }
return Loc->second.count(EP);
}
@@ -1210,6 +1301,45 @@ void SPIRVModuleImpl::addEntryPoint(SPIRVExecutionModelKind ExecModel,
addCapabilities(SPIRV::getCapability(ExecModel));
}
+void SPIRVModuleImpl::addConditionalEntryPoint(
+ SPIRVId Condition, SPIRVExecutionModelKind ExecModel, SPIRVId EntryPoint,
+ const std::string &Name, const std::vector<SPIRVId> &Variables) {
+ assert(isValid(ExecModel) && "Invalid execution model");
+ assert(EntryPoint != SPIRVID_INVALID && "Invalid entry point");
+ auto *EP = add(new SPIRVConditionalEntryPointINTEL(
+ this, Condition, ExecModel, EntryPoint, Name, Variables));
+ ConditionalEntryPointVec.push_back(EP);
+ ConditionalEntryPointSet[ExecModel].insert(EntryPoint);
+}
+
+void SPIRVModuleImpl::specializeConditionalEntryPoints(SPIRVId Condition,
+ bool ShouldKeep) {
+ std::vector<const SPIRVConditionalEntryPointINTEL *> EPsToRemove;
+ std::vector<SPIRVId> EPIdsToRemove;
+ for (const auto *EP : ConditionalEntryPointVec) {
+ if (EP->getCondition() == Condition) {
+ EPsToRemove.push_back(EP);
+ EPIdsToRemove.push_back(EP->getTargetId());
+ if (ShouldKeep) {
+ // add the removed conditional entry point as a normal entry point
+ addEntryPoint(EP->getExecModel(), EP->getTargetId(), EP->getName(),
+ EP->getVariables());
+ }
+ }
+ }
+
+ erase_if(ConditionalEntryPointVec, [&EPsToRemove](const auto *EP) {
+ return std::find(EPsToRemove.begin(), EPsToRemove.end(), EP) !=
+ EPsToRemove.end();
+ });
+
+ for (const auto &Id : EPIdsToRemove) {
+ for (auto &[ExecMode, EPSet] : ConditionalEntryPointSet) {
+ EPSet.erase(Id);
+ }
+ }
+}
+
SPIRVForward *SPIRVModuleImpl::addForward(SPIRVType *Ty) {
return add(new SPIRVForward(this, Ty, getId()));
}
@@ -1249,6 +1379,31 @@ void SPIRVModuleImpl::eraseInstruction(SPIRVInstruction *I,
delete I;
}
+bool SPIRVModuleImpl::eraseValue(SPIRVValue *V) {
+ Op OpCode = V->getOpCode();
+ SPIRVId Id = V->getId();
+
+ if (isTypeOpCode(OpCode)) {
+ erase_if(TypeVec, [Id](auto T) { return T->getId() == Id; });
+ } else if (OpCode == OpVariable) {
+ erase_if(VariableVec, [Id](auto V) { return V->getId() == Id; });
+ } else if (isConstantOpCode(OpCode)) {
+ erase_if(ConstVec, [Id](auto C) { return C->getId() == Id; });
+ } else if (OpCode == Op::OpAsmINTEL) {
+ erase_if(AsmVec, [Id](auto A) { return A->getId() == Id; });
+ } else if (OpCode == Op::OpAsmTargetINTEL) {
+ erase_if(AsmTargetVec, [Id](auto AT) { return AT->getId() == Id; });
+ } else {
+ return false;
+ }
+
+ auto Loc = IdEntryMap.find(Id);
+ assert(Loc != IdEntryMap.end());
+ IdEntryMap.erase(Loc);
+ delete V;
+ return true;
+}
+
SPIRVValue *SPIRVModuleImpl::addConstant(SPIRVValue *C) { return add(C); }
SPIRVValue *SPIRVModuleImpl::addConstant(SPIRVType *Ty, uint64_t V) {
@@ -2072,21 +2227,35 @@ spv_ostream &operator<<(spv_ostream &O, SPIRVModule &M) {
for (auto &I : MI.CapMap)
O << *I.second;
+ for (auto &I : MI.ConditionalCapMap)
+ O << *I.second;
+
for (auto &I : M.getExtension()) {
assert(!I.empty() && "Invalid extension");
O << SPIRVExtension(&M, I);
}
+ for (auto &I : M.getConditionalExtensions()) {
+ auto Cond = I.first;
+ auto Ext = I.second;
+ assert(!Ext.empty() && "Invalid conditional extension");
+ O << SPIRVConditionalExtensionINTEL(&M, Cond, Ext);
+ }
+
for (auto &I : MI.IdToInstSetMap)
O << SPIRVExtInstImport(&M, I.first, SPIRVBuiltinSetNameMap::map(I.second));
O << SPIRVMemoryModel(&M);
O << MI.EntryPointVec;
+ O << MI.ConditionalEntryPointVec;
for (auto &I : MI.EntryPointVec)
MI.get<SPIRVFunction>(I->getTargetId())->encodeExecutionModes(O);
+ for (auto &I : MI.ConditionalEntryPointVec)
+ MI.get<SPIRVFunction>(I->getTargetId())->encodeExecutionModes(O);
+
O << MI.StringVec;
for (auto &I : M.getSourceExtension()) {
@@ -2104,6 +2273,11 @@ spv_ostream &operator<<(spv_ostream &O, SPIRVModule &M) {
IsEntryPoint = true;
break;
}
+ for (auto &EPS : MI.ConditionalEntryPointSet)
+ if (EPS.second.count(I)) {
+ IsEntryPoint = true;
+ break;
+ }
if (!IsEntryPoint)
M.getEntry(I)->encodeName(O);
}
diff --git a/lib/SPIRV/libSPIRV/SPIRVModule.h b/lib/SPIRV/libSPIRV/SPIRVModule.h
index a0565c9019..12fe72a3f6 100644
--- a/lib/SPIRV/libSPIRV/SPIRVModule.h
+++ b/lib/SPIRV/libSPIRV/SPIRVModule.h
@@ -99,6 +99,10 @@ class SPIRVTypeTokenINTEL;
class SPIRVTypeJointMatrixINTEL;
class SPIRVTypeCooperativeMatrixKHR;
class SPIRVTypeTaskSequenceINTEL;
+class SPIRVConditionalCapabilityINTEL;
+class SPIRVConditionalExtensionINTEL;
+class SPIRVConditionalEntryPointINTEL;
+class SPIRVConditionalCopyObjectINTEL;
typedef SPIRVBasicBlock SPIRVLabel;
struct SPIRVTypeImageDescriptor;
@@ -106,6 +110,13 @@ struct SPIRVTypeImageDescriptor;
class SPIRVModule {
public:
typedef std::map<SPIRVCapabilityKind, SPIRVCapability *> SPIRVCapMap;
+ typedef std::map<std::pair<SPIRVId, SPIRVCapabilityKind>,
+ SPIRVConditionalCapabilityINTEL *>
+ SPIRVConditionalCapMap;
+ typedef std::vector<SPIRVConditionalEntryPointINTEL *>
+ SPIRVConditionalEntryPointVec;
+ typedef std::set<std::pair<SPIRVId, std::string>>
+ SPIRVConditionalExtensionSet;
static SPIRVModule *createSPIRVModule();
static SPIRVModule *createSPIRVModule(const SPIRV::TranslatorOpts &);
@@ -134,14 +145,22 @@ class SPIRVModule {
// Module query functions
virtual SPIRVAddressingModelKind getAddressingModel() = 0;
virtual const SPIRVCapMap &getCapability() const = 0;
+ virtual const SPIRVConditionalCapMap &getConditionalCapabilities() const = 0;
+ virtual const SPIRVConditionalEntryPointVec &
+ getConditionalEntryPoints() const = 0;
virtual bool hasCapability(SPIRVCapabilityKind) const = 0;
virtual SPIRVExtInstSetKind getBuiltinSet(SPIRVId) const = 0;
virtual std::set<std::string> &getExtension() = 0;
+ virtual SPIRVConditionalExtensionSet &getConditionalExtensions() = 0;
virtual SPIRVFunction *getFunction(unsigned) const = 0;
virtual SPIRVVariableBase *getVariable(unsigned) const = 0;
+ virtual SPIRVValue *getConst(unsigned) const = 0;
+ virtual std::vector<SPIRVDecorateGeneric *> *getDecorateVec() = 0;
+ virtual std::vector<SPIRVFunction *> *getFuncVec() = 0;
virtual SPIRVMemoryModelKind getMemoryModel() const = 0;
virtual unsigned getNumFunctions() const = 0;
virtual unsigned getNumVariables() const = 0;
+ virtual unsigned getNumConsts() const = 0;
virtual std::vector<SPIRVValue *> getFunctionPointers() const = 0;
virtual SourceLanguage getSourceLanguage(SPIRVWord *) const = 0;
virtual std::set<std::string> &getSourceExtension() = 0;
@@ -181,6 +200,8 @@ class SPIRVModule {
virtual void resolveUnknownStructFields() = 0;
virtual void setSPIRVVersion(VersionNumber) = 0;
virtual void insertEntryNoId(SPIRVEntry *Entry) = 0;
+ virtual bool eraseReferencesOfInst(SPIRVId Id) = 0;
+ virtual void eraseCapability(SPIRVCapabilityKind CapKind) = 0;
void setMinSPIRVVersion(VersionNumber Ver) {
setSPIRVVersion(std::max(Ver, getSPIRVVersion()));
@@ -233,6 +254,10 @@ class SPIRVModule {
virtual void addEntryPoint(SPIRVExecutionModelKind, SPIRVId,
const std::string &,
const std::vector<SPIRVId> &) = 0;
+ virtual void addConditionalEntryPoint(SPIRVId, SPIRVExecutionModelKind,
+ SPIRVId, const std::string &,
+ const std::vector<SPIRVId> &) = 0;
+ virtual void specializeConditionalEntryPoints(SPIRVId, bool) = 0;
virtual SPIRVForward *addForward(SPIRVType *Ty) = 0;
virtual SPIRVForward *addForward(SPIRVId, SPIRVType *Ty) = 0;
virtual SPIRVFunction *addFunction(SPIRVFunction *) = 0;
@@ -240,6 +265,7 @@ class SPIRVModule {
SPIRVId Id = SPIRVID_INVALID) = 0;
virtual SPIRVEntry *replaceForward(SPIRVForward *, SPIRVEntry *) = 0;
virtual void eraseInstruction(SPIRVInstruction *, SPIRVBasicBlock *) = 0;
+ virtual bool eraseValue(SPIRVValue *) = 0;
// Type creation functions
virtual SPIRVTypeArray *addArrayType(SPIRVType *, SPIRVValue *) = 0;
@@ -346,6 +372,13 @@ class SPIRVModule {
for (auto I : Caps)
addCapability(I);
}
+ virtual void addConditionalCapability(SPIRVId, SPIRVCapabilityKind) = 0;
+ template <typename T>
+ void addConditionalCapabilities(SPIRVId Condition, const T &Caps) {
+ for (auto I : Caps)
+ addConditionalCapability(Condition, I);
+ }
+ virtual void eraseConditionalCapability(SPIRVId, SPIRVCapabilityKind) = 0;
virtual void addExtension(ExtensionID) = 0;
/// Used by SPIRV entries to add required capability internally.
/// Should not be used by users directly.
@@ -535,6 +568,10 @@ class SPIRVModule {
return TranslationOpts.getSpecializationConstant(SpecId, ConstValue);
}
+ void setSpecializationConstant(SPIRVWord SpecId, uint64_t ConstValue) {
+ TranslationOpts.setSpecConst(SpecId, ConstValue);
+ }
+
FPContractMode getFPContractMode() const {
return TranslationOpts.getFPContractMode();
}
@@ -593,6 +630,29 @@ class SPIRVModule {
return TranslationOpts.getDesiredBIsRepresentation();
}
+ std::optional<uint32_t> getFnVarCategory() const {
+ return TranslationOpts.getFnVarCategory();
+ }
+ std::optional<uint32_t> getFnVarFamily() const {
+ return TranslationOpts.getFnVarFamily();
+ }
+ std::optional<uint32_t> getFnVarArch() const {
+ return TranslationOpts.getFnVarArch();
+ }
+ std::optional<uint32_t> getFnVarTarget() const {
+ return TranslationOpts.getFnVarTarget();
+ }
+ std::vector<uint32_t> getFnVarFeatures() const {
+ return TranslationOpts.getFnVarFeatures();
+ }
+ std::vector<uint32_t> getFnVarCapabilities() const {
+ return TranslationOpts.getFnVarCapabilities();
+ }
+
+ std::string getFnVarSpvOut() const {
+ return TranslationOpts.getFnVarSpvOut();
+ }
+
// I/O functions
friend spv_ostream &operator<<(spv_ostream &O, SPIRVModule &M);
friend std::istream &operator>>(std::istream &I, SPIRVModule &M);
diff --git a/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h b/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h
index dfe2c0e555..31a12a97c3 100644
--- a/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h
+++ b/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h
@@ -210,6 +210,7 @@ template <> inline void SPIRVMap<Decoration, std::string>::init() {
add(DecorationCacheControlLoadINTEL, "CacheControlLoadINTEL");
add(DecorationCacheControlStoreINTEL, "CacheControlStoreINTEL");
+ add(DecorationConditionalINTEL, "DecorationConditionalINTEL");
// From spirv_internal.hpp
add(internal::DecorationRuntimeAlignedINTEL, "RuntimeAlignedINTEL");
@@ -684,6 +685,8 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
add(internal::CapabilityBindlessImagesINTEL, "BindlessImagesINTEL");
add(CapabilityInt4TypeINTEL, "Int4TypeINTEL");
add(CapabilityInt4CooperativeMatrixINTEL, "Int4CooperativeMatrixINTEL");
+ add(CapabilityFunctionVariantsINTEL, "FunctionVariantsINTEL");
+ add(CapabilitySpecConditionalINTEL, "SpecConditionalINTEL");
}
SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap)
diff --git a/lib/SPIRV/libSPIRV/SPIRVOpCode.h b/lib/SPIRV/libSPIRV/SPIRVOpCode.h
index 8e0ff61903..3699edc2af 100644
--- a/lib/SPIRV/libSPIRV/SPIRVOpCode.h
+++ b/lib/SPIRV/libSPIRV/SPIRVOpCode.h
@@ -250,15 +250,24 @@ inline bool isTypeOpCode(Op OpCode) {
OC == OpTypeUntypedPointerKHR;
}
+inline bool isFnVarSpecConstINTEL(Op OpCode) {
+ unsigned OC = OpCode;
+ return OC == OpSpecConstantArchitectureINTEL ||
+ OC == OpSpecConstantTargetINTEL ||
+ OC == OpSpecConstantCapabilitiesINTEL;
+}
+
inline bool isSpecConstantOpCode(Op OpCode) {
unsigned OC = OpCode;
- return OpSpecConstantTrue <= OC && OC <= OpSpecConstantOp;
+ return (OpSpecConstantTrue <= OC && OC <= OpSpecConstantOp) ||
+ isFnVarSpecConstINTEL(OpCode);
}
inline bool isConstantOpCode(Op OpCode) {
unsigned OC = OpCode;
return (OpConstantTrue <= OC && OC <= OpSpecConstantOp) || OC == OpUndef ||
- OC == OpConstantPipeStorage || OC == OpConstantFunctionPointerINTEL;
+ OC == OpConstantPipeStorage || OC == OpConstantFunctionPointerINTEL ||
+ isSpecConstantOpCode(OpCode);
}
inline bool isModuleScopeAllowedOpCode(Op OpCode) {
diff --git a/lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h b/lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h
index 91d642e558..971abcabe2 100644
--- a/lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h
+++ b/lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h
@@ -580,6 +580,13 @@ _SPIRV_OP(Subgroup2DBlockPrefetchINTEL, 6234)
_SPIRV_OP(Subgroup2DBlockStoreINTEL, 6235)
_SPIRV_OP(SubgroupMatrixMultiplyAccumulateINTEL, 6237)
_SPIRV_OP(BitwiseFunctionINTEL, 6242)
+_SPIRV_OP(ConditionalExtensionINTEL, 6248)
+_SPIRV_OP(ConditionalEntryPointINTEL, 6249)
+_SPIRV_OP(ConditionalCapabilityINTEL, 6250)
+_SPIRV_OP(SpecConstantTargetINTEL, 6251)
+_SPIRV_OP(SpecConstantArchitectureINTEL, 6252)
+_SPIRV_OP(SpecConstantCapabilitiesINTEL, 6253)
+_SPIRV_OP(ConditionalCopyObjectINTEL, 6254)
_SPIRV_OP(GroupIMulKHR, 6401)
_SPIRV_OP(GroupFMulKHR, 6402)
_SPIRV_OP(GroupBitwiseAndKHR, 6403)
diff --git a/spirv-headers-tag.conf b/spirv-headers-tag.conf
index e0aafdf191..9e127c1b87 100644
--- a/spirv-headers-tag.conf
+++ b/spirv-headers-tag.conf
@@ -1 +1 @@
-c9aad99f9276817f18f72a4696239237c83cb775
+9e3836d7d6023843a72ecd3fbf3f09b1b6747a9e
diff --git a/test/extensions/INTEL/SPV_INTEL_function_variants/addsubmul_asm.spt b/test/extensions/INTEL/SPV_INTEL_function_variants/addsubmul_asm.spt
new file mode 100644
index 0000000000..dae46bef56
--- /dev/null
+++ b/test/extensions/INTEL/SPV_INTEL_function_variants/addsubmul_asm.spt
@@ -0,0 +1,277 @@
+;;; A function "foo" compiled for three devices and called from one main "work"
+;;; function. To differentiate them, the "foo" functions perform floating-point
+;;; addition, subtraction and multiplication, respectively. In addition, two
+;;; function variants contain inline assembly (to test conditional extensions).
+
+; RUN: llvm-spirv --to-binary %s -o %t_multitarget.spv
+
+;;; The following should select the base variant (FAdd)
+; RUN: llvm-spirv -r \
+; RUN: --spirv-ext=+SPV_INTEL_function_variants,+SPV_INTEL_inline_assembly \
+; RUN: --fnvar-spec-enable \
+; RUN: --fnvar-spv-out %t_targeted.spv \
+; RUN: --fnvar-category 1 --fnvar-family 1 --fnvar-arch 1 \
+; RUN: --fnvar-target 4 --fnvar-features '7,8' \
+; RUN: %t_multitarget.spv -o %t_targeted.bc
+; RUN: llvm-spirv %t_targeted.spv -to-text -o %t_targeted.spt
+; RUN: FileCheck < %t_targeted.spt %s --check-prefixes=CHECK-SPIRV-COMMON,CHECK-SPIRV-BASE
+
+; RUN: llvm-spirv -r \
+; RUN: --spirv-ext=+SPV_INTEL_function_variants,+SPV_INTEL_inline_assembly \
+; RUN: --fnvar-spec-enable \
+; RUN: --fnvar-spv-out %t_targeted.spv \
+; RUN: --fnvar-target 4 --fnvar-features '7,8' \
+; RUN: %t_multitarget.spv -o %t_targeted.bc
+; RUN: llvm-spirv %t_targeted.spv -to-text -o %t_targeted.spt
+; RUN: FileCheck < %t_targeted.spt %s --check-prefixes=CHECK-SPIRV-COMMON,CHECK-SPIRV-BASE
+
+; RUN: llvm-spirv -r \
+; RUN: --spirv-ext=+SPV_INTEL_function_variants,+SPV_INTEL_inline_assembly \
+; RUN: --fnvar-spec-enable \
+; RUN: --fnvar-spv-out %t_targeted.spv \
+; RUN: --fnvar-target 4 --fnvar-features '7,8,9' \
+; RUN: %t_multitarget.spv -o %t_targeted.bc
+; RUN: llvm-spirv %t_targeted.spv -to-text -o %t_targeted.spt
+; RUN: FileCheck < %t_targeted.spt %s --check-prefixes=CHECK-SPIRV-COMMON,CHECK-SPIRV-BASE
+
+; RUN: llvm-spirv -r \
+; RUN: --spirv-ext=+SPV_INTEL_function_variants,+SPV_INTEL_inline_assembly \
+; RUN: --fnvar-spec-enable \
+; RUN: --fnvar-spv-out %t_targeted.spv \
+; RUN: --fnvar-category 1 --fnvar-family 1 --fnvar-arch 2 \
+; RUN: --fnvar-target 4 --fnvar-features '7,8,9,10' \
+; RUN: %t_multitarget.spv -o %t_targeted.bc
+; RUN: llvm-spirv %t_targeted.spv -to-text -o %t_targeted.spt
+; RUN: FileCheck < %t_targeted.spt %s --check-prefixes=CHECK-SPIRV-COMMON,CHECK-SPIRV-BASE
+
+;;; The following should select the ASM1 variant (FSub)
+; RUN: llvm-spirv -r \
+; RUN: --spirv-ext=+SPV_INTEL_function_variants,+SPV_INTEL_inline_assembly \
+; RUN: --fnvar-spec-enable \
+; RUN: --fnvar-spv-out %t_targeted.spv \
+; RUN: --fnvar-category 1 --fnvar-family 1 --fnvar-arch 1 \
+; RUN: --fnvar-target 4 --fnvar-features '7,8,9,10' \
+; RUN: %t_multitarget.spv -o %t_targeted.bc
+; RUN: llvm-spirv %t_targeted.spv -to-text -o %t_targeted.spt
+; RUN: FileCheck < %t_targeted.spt %s --check-prefixes=CHECK-SPIRV-COMMON,CHECK-SPIRV-ASM1
+
+; RUN: llvm-spirv -r \
+; RUN: --spirv-ext=+SPV_INTEL_function_variants,+SPV_INTEL_inline_assembly \
+; RUN: --fnvar-spec-enable \
+; RUN: --fnvar-spv-out %t_targeted.spv \
+; RUN: --fnvar-category 1 --fnvar-family 1 \
+; RUN: --fnvar-target 4 --fnvar-features '7,8,9,10' \
+; RUN: %t_multitarget.spv -o %t_targeted.bc
+; RUN: llvm-spirv %t_targeted.spv -to-text -o %t_targeted.spt
+; RUN: FileCheck < %t_targeted.spt %s --check-prefixes=CHECK-SPIRV-COMMON,CHECK-SPIRV-ASM1
+
+; RUN: llvm-spirv -r \
+; RUN: --spirv-ext=+SPV_INTEL_function_variants,+SPV_INTEL_inline_assembly \
+; RUN: --fnvar-spec-enable \
+; RUN: --fnvar-spv-out %t_targeted.spv \
+; RUN: --fnvar-category 1 \
+; RUN: --fnvar-target 4 --fnvar-features '7,8,9,10' \
+; RUN: %t_multitarget.spv -o %t_targeted.bc
+; RUN: llvm-spirv %t_targeted.spv -to-text -o %t_targeted.spt
+; RUN: FileCheck < %t_targeted.spt %s --check-prefixes=CHECK-SPIRV-COMMON,CHECK-SPIRV-ASM1
+
+; RUN: llvm-spirv -r \
+; RUN: --spirv-ext=+SPV_INTEL_function_variants,+SPV_INTEL_inline_assembly \
+; RUN: --fnvar-spec-enable \
+; RUN: --fnvar-spv-out %t_targeted.spv \
+; RUN: --fnvar-target 4 --fnvar-features '7,8,9,10' \
+; RUN: %t_multitarget.spv -o %t_targeted.bc
+; RUN: llvm-spirv %t_targeted.spv -to-text -o %t_targeted.spt
+; RUN: FileCheck < %t_targeted.spt %s --check-prefixes=CHECK-SPIRV-COMMON,CHECK-SPIRV-ASM1
+
+; RUN: llvm-spirv -r \
+; RUN: --spirv-ext=+SPV_INTEL_function_variants,+SPV_INTEL_inline_assembly \
+; RUN: --fnvar-spec-enable \
+; RUN: --fnvar-spv-out %t_targeted.spv \
+; RUN: --fnvar-target 4 \
+; RUN: %t_multitarget.spv -o %t_targeted.bc
+; RUN: llvm-spirv %t_targeted.spv -to-text -o %t_targeted.spt
+; RUN: FileCheck < %t_targeted.spt %s --check-prefixes=CHECK-SPIRV-COMMON,CHECK-SPIRV-ASM1
+
+; RUN: llvm-spirv -r \
+; RUN: --spirv-ext=+SPV_INTEL_function_variants,+SPV_INTEL_inline_assembly \
+; RUN: --fnvar-spec-enable \
+; RUN: --fnvar-spv-out %t_targeted.spv \
+; RUN: %t_multitarget.spv -o %t_targeted.bc
+; RUN: llvm-spirv %t_targeted.spv -to-text -o %t_targeted.spt
+; RUN: FileCheck < %t_targeted.spt %s --check-prefixes=CHECK-SPIRV-COMMON,CHECK-SPIRV-ASM1
+
+;;; The following should select the ASM2 variant (FMul)
+; RUN: llvm-spirv -r \
+; RUN: --spirv-ext=+SPV_INTEL_function_variants,+SPV_INTEL_inline_assembly \
+; RUN: --fnvar-spec-enable \
+; RUN: --fnvar-spv-out %t_targeted.spv \
+; RUN: --fnvar-category 1 --fnvar-family 21 --fnvar-arch 0 \
+; RUN: --fnvar-target 5 --fnvar-features '2,3,4,5,6' \
+; RUN: %t_multitarget.spv -o %t_targeted.bc
+; RUN: llvm-spirv %t_targeted.spv -to-text -o %t_targeted.spt
+; RUN: FileCheck < %t_targeted.spt %s --check-prefixes=CHECK-SPIRV-COMMON,CHECK-SPIRV-ASM2
+
+; RUN: llvm-spirv -r \
+; RUN: --spirv-ext=+SPV_INTEL_function_variants,+SPV_INTEL_inline_assembly \
+; RUN: --fnvar-spec-enable \
+; RUN: --fnvar-spv-out %t_targeted.spv \
+; RUN: --fnvar-category 1 --fnvar-family 21 --fnvar-arch 0 \
+; RUN: --fnvar-target 6 --fnvar-features '2,3,4,5,6' \
+; RUN: %t_multitarget.spv -o %t_targeted.bc
+; RUN: llvm-spirv %t_targeted.spv -to-text -o %t_targeted.spt
+; RUN: FileCheck < %t_targeted.spt %s --check-prefixes=CHECK-SPIRV-COMMON,CHECK-SPIRV-ASM2
+
+; CHECK-SPIRV-COMMON-NOT: Capability FunctionVariantsINTEL
+; CHECK-SPIRV-COMMON-NOT: Capability SpecConditionalINTEL
+; CHECK-SPIRV-COMMON-NOT: Extension "SPV_INTEL_function_variants"
+; CHECK-SPIRV-COMMON-NOT: ConditionalINTEL
+; CHECK-SPIRV-COMMON-NOT: ConditionalExtension
+; CHECK-SPIRV-COMMON-NOT: ConditionalCapability
+; CHECK-SPIRV-COMMON-NOT: ConditionalEntryPoint
+; CHECK-SPIRV-COMMON-NOT: ConditionalCopyObjectINTEL
+; CHECK-SPIRV-COMMON-NOT: SpecConstant
+
+; CHECK-SPIRV-BASE: FAdd
+; CHECK-SPIRV-BASE-NOT: AsmTargetINTEL
+; CHECK-SPIRV-BASE-NOT: AsmINTEL
+; CHECK-SPIRV-BASE-NOT: AsmCallINTEL
+
+; CHECK-SPIRV-ASM1: AsmTargetINTEL
+; CHECK-SPIRV-ASM1: AsmINTEL
+; CHECK-SPIRV-ASM1: "nop1"
+; CHECK-SPIRV-ASM1: FSub
+; CHECK-SPIRV-ASM1: AsmCallINTEL
+
+; CHECK-SPIRV-ASM2: AsmTargetINTEL
+; CHECK-SPIRV-ASM2: AsmINTEL
+; CHECK-SPIRV-ASM2: "nop2"
+; CHECK-SPIRV-ASM2: FMul
+; CHECK-SPIRV-ASM2: AsmCallINTEL
+
+;;; Input multi-target SPIR-V module
+119734787 67072 458752 61 0
+2 Capability Addresses
+2 Capability Linkage
+2 Capability Kernel
+2 Capability AsmINTEL
+2 Capability SpecConditionalINTEL
+2 Capability FunctionVariantsINTEL
+3 ConditionalCapabilityINTEL 1 Int64
+3 ConditionalCapabilityINTEL 1 Int8
+3 ConditionalCapabilityINTEL 2 AsmINTEL
+8 Extension "SPV_INTEL_function_variants"
+9 ConditionalExtensionINTEL 2 "SPV_INTEL_inline_assembly"
+5 ExtInstImport 3 "OpenCL.std"
+3 MemoryModel 2 2
+3 Source 4 100000
+3 Name 4 "foo"
+4 Name 5 "work"
+3 Name 6 "foo"
+3 Name 10 "foo"
+
+5 Decorate 4 LinkageAttributes "foo" Export
+6 Decorate 5 LinkageAttributes "work" Export
+4 Decorate 4 DecorationConditionalINTEL 1
+5 Decorate 6 LinkageAttributes "foo" Export
+3 Decorate 7 SideEffectsINTEL
+4 Decorate 8 DecorationConditionalINTEL 9
+4 Decorate 7 DecorationConditionalINTEL 9
+4 Decorate 6 DecorationConditionalINTEL 9
+5 Decorate 10 LinkageAttributes "foo" Export
+3 Decorate 11 SideEffectsINTEL
+4 Decorate 12 DecorationConditionalINTEL 13
+4 Decorate 11 DecorationConditionalINTEL 13
+4 Decorate 10 DecorationConditionalINTEL 13
+4 Decorate 14 DecorationConditionalINTEL 1
+4 Decorate 15 DecorationConditionalINTEL 9
+4 Decorate 16 DecorationConditionalINTEL 13
+4 Decorate 17 DecorationConditionalINTEL 1
+4 Decorate 18 DecorationConditionalINTEL 1
+4 Decorate 19 DecorationConditionalINTEL 1
+4 Decorate 20 DecorationConditionalINTEL 2
+4 Decorate 21 DecorationConditionalINTEL 2
+4 TypeInt 17 8 0
+4 TypeInt 18 64 0
+4 TypePointer 19 7 17
+2 TypeVoid 20
+3 TypeFunction 21 20
+3 TypeFloat 22 32
+4 TypePointer 23 7 22
+6 TypeFunction 24 22 23 23 23
+2 TypeBool 25
+4 SpecConstantTargetINTEL 25 26 4
+7 SpecConstantArchitectureINTEL 25 27 1 1 IEqual 1
+6 SpecConstantTargetINTEL 25 28 4 9 10
+6 SpecConstantOp 25 9 167 27 28
+7 SpecConstantTargetINTEL 25 29 5 2 4 5
+7 SpecConstantTargetINTEL 25 30 6 2 4 5
+6 SpecConstantOp 25 13 166 29 30
+6 SpecConstantOp 25 31 166 9 13
+5 SpecConstantOp 25 32 168 31
+6 SpecConstantOp 25 1 167 26 32
+6 SpecConstantOp 25 2 166 9 13
+
+8 AsmTargetINTEL 8 "spirv64-unknown-unknown"
+8 AsmTargetINTEL 12 "spirv64-unknown-unknown"
+8 AsmINTEL 20 7 21 8 "nop1" ""
+8 AsmINTEL 20 11 21 12 "nop2" ""
+
+
+5 Function 22 4 2 24
+3 FunctionParameter 23 33
+3 FunctionParameter 23 34
+3 FunctionParameter 23 35
+
+2 Label 36
+6 Load 22 37 33 2 4
+6 Load 22 38 34 2 4
+5 FAdd 22 39 37 38
+5 Store 35 39 2 4
+2 ReturnValue 39
+
+1 FunctionEnd
+
+5 Function 22 5 0 24
+3 FunctionParameter 23 40
+3 FunctionParameter 23 41
+3 FunctionParameter 23 42
+
+2 Label 43
+7 FunctionCall 22 14 4 40 41 42
+7 FunctionCall 22 15 6 40 41 42
+7 FunctionCall 22 16 10 40 41 42
+9 ConditionalCopyObjectINTEL 22 44 1 14 9 15 13 16
+2 ReturnValue 44
+
+1 FunctionEnd
+
+5 Function 22 6 0 24
+3 FunctionParameter 23 45
+3 FunctionParameter 23 46
+3 FunctionParameter 23 47
+
+2 Label 48
+6 Load 22 49 45 2 4
+6 Load 22 50 46 2 4
+5 FSub 22 51 49 50
+5 Store 47 51 2 4
+4 AsmCallINTEL 20 52 7
+2 ReturnValue 51
+
+1 FunctionEnd
+
+5 Function 22 10 0 24
+3 FunctionParameter 23 53
+3 FunctionParameter 23 54
+3 FunctionParameter 23 55
+
+2 Label 56
+6 Load 22 57 53 2 4
+6 Load 22 58 54 2 4
+5 FMul 22 59 57 58
+5 Store 55 59 2 4
+4 AsmCallINTEL 20 60 11
+2 ReturnValue 59
+
+1 FunctionEnd
diff --git a/test/extensions/INTEL/SPV_INTEL_function_variants/cl_dot4.spt b/test/extensions/INTEL/SPV_INTEL_function_variants/cl_dot4.spt
new file mode 100644
index 0000000000..960d0e76e1
--- /dev/null
+++ b/test/extensions/INTEL/SPV_INTEL_function_variants/cl_dot4.spt
@@ -0,0 +1,353 @@
+;;; 32-bit and 16-bit variants of a dot product kernel generated from OpenCL C.
+
+; RUN: llvm-spirv --to-binary %s -o %t_multitarget.spv
+
+;;; The following should produce the base (32-bit float) variant:
+; RUN: llvm-spirv -r \
+; RUN: --spirv-ext=+SPV_INTEL_function_variants \
+; RUN: --fnvar-spec-enable \
+; RUN: --fnvar-spv-out %t_targeted.spv \
+; RUN: --fnvar-category 2 --fnvar-family 3 --fnvar-arch 1 \
+; RUN: --fnvar-target 7 \
+; RUN: %t_multitarget.spv -o %t_targeted.bc
+; RUN: llvm-spirv %t_targeted.spv -to-text -o %t_targeted.spt
+; RUN: FileCheck < %t_targeted.spt %s --check-prefixes=CHECK-SPIRV-COMMON,CHECK-SPIRV-BASE
+
+; RUN: llvm-spirv -r \
+; RUN: --spirv-ext=+SPV_INTEL_function_variants \
+; RUN: --fnvar-spec-enable \
+; RUN: --fnvar-spv-out %t_targeted.spv \
+; RUN: --fnvar-category 2 --fnvar-family 3 --fnvar-arch 3 \
+; RUN: --fnvar-target 8 --fnvar-features '1,2,3,4'\
+; RUN: %t_multitarget.spv -o %t_targeted.bc
+; RUN: llvm-spirv %t_targeted.spv -to-text -o %t_targeted.spt
+; RUN: FileCheck < %t_targeted.spt %s --check-prefixes=CHECK-SPIRV-COMMON,CHECK-SPIRV-BASE
+
+; RUN: llvm-spirv -r \
+; RUN: --spirv-ext=+SPV_INTEL_function_variants \
+; RUN: --fnvar-spec-enable \
+; RUN: --fnvar-spv-out %t_targeted.spv \
+; RUN: --fnvar-category 2 --fnvar-family 3 --fnvar-arch 4 \
+; RUN: --fnvar-target 7 \
+; RUN: --fnvar-capabilities '4,5,6,11,39' \
+; RUN: %t_multitarget.spv -o %t_targeted.bc
+; RUN: llvm-spirv %t_targeted.spv -to-text -o %t_targeted.spt
+; RUN: FileCheck < %t_targeted.spt %s --check-prefixes=CHECK-SPIRV-COMMON,CHECK-SPIRV-BASE
+
+;;; The following should produce 16-bit floating point variant:
+; RUN: llvm-spirv -r \
+; RUN: --spirv-ext=+SPV_INTEL_function_variants \
+; RUN: --fnvar-spec-enable \
+; RUN: --fnvar-spv-out %t_targeted.spv \
+; RUN: --fnvar-category 2 --fnvar-family 3 --fnvar-arch 4 \
+; RUN: --fnvar-target 7 \
+; RUN: %t_multitarget.spv -o %t_targeted.bc
+; RUN: llvm-spirv %t_targeted.spv -to-text -o %t_targeted.spt
+; RUN: FileCheck < %t_targeted.spt %s --check-prefixes=CHECK-SPIRV-COMMON,CHECK-SPIRV-FP16
+
+; RUN: llvm-spirv -r \
+; RUN: --spirv-ext=+SPV_INTEL_function_variants \
+; RUN: --fnvar-spec-enable \
+; RUN: --fnvar-spv-out %t_targeted.spv \
+; RUN: --fnvar-category 2 --fnvar-family 3 \
+; RUN: --fnvar-target 7 \
+; RUN: %t_multitarget.spv -o %t_targeted.bc
+; RUN: llvm-spirv %t_targeted.spv -to-text -o %t_targeted.spt
+; RUN: FileCheck < %t_targeted.spt %s --check-prefixes=CHECK-SPIRV-COMMON,CHECK-SPIRV-FP16
+
+; RUN: llvm-spirv -r \
+; RUN: --spirv-ext=+SPV_INTEL_function_variants \
+; RUN: --fnvar-spec-enable \
+; RUN: --fnvar-spv-out %t_targeted.spv \
+; RUN: --fnvar-category 2 --fnvar-family 3 --fnvar-arch 5 \
+; RUN: --fnvar-target 8 --fnvar-features '1,2,3,4'\
+; RUN: %t_multitarget.spv -o %t_targeted.bc
+; RUN: llvm-spirv %t_targeted.spv -to-text -o %t_targeted.spt
+; RUN: FileCheck < %t_targeted.spt %s --check-prefixes=CHECK-SPIRV-COMMON,CHECK-SPIRV-FP16
+
+; RUN: llvm-spirv -r \
+; RUN: --spirv-ext=+SPV_INTEL_function_variants \
+; RUN: --fnvar-spec-enable \
+; RUN: --fnvar-spv-out %t_targeted.spv \
+; RUN: --fnvar-category 2 --fnvar-family 3 --fnvar-arch 4 \
+; RUN: --fnvar-target 7 \
+; RUN: --fnvar-capabilities '4,5,6,8,9,11,39' \
+; RUN: %t_multitarget.spv -o %t_targeted.bc
+; RUN: llvm-spirv %t_targeted.spv -to-text -o %t_targeted.spt
+; RUN: FileCheck < %t_targeted.spt %s --check-prefixes=CHECK-SPIRV-COMMON,CHECK-SPIRV-FP16
+
+; CHECK-SPIRV-COMMON: EntryPoint
+; CHECK-SPIRV-FP16: TypeFloat {{[0-9]+}} 16
+; CHECK-SPIRV-BASE: TypeFloat {{[0-9]+}} 32
+; CHECK-SPIRV-FP16-NOT: TypeFloat {{[0-9]+}} 32
+; CHECK-SPIRV-BASE-NOT: TypeFloat {{[0-9]+}} 16
+; CHECK-SPIRV-COMMON: Dot
+; CHECK-SPIRV-COMMON-NOT: Capability FunctionVariantsINTEL
+; CHECK-SPIRV-COMMON-NOT: Capability SpecConditionalINTEL
+; CHECK-SPIRV-COMMON-NOT: Extension "SPV_INTEL_function_variants"
+; CHECK-SPIRV-COMMON-NOT: ConditionalINTEL
+; CHECK-SPIRV-COMMON-NOT: ConditionalExtension
+; CHECK-SPIRV-COMMON-NOT: ConditionalCapability
+; CHECK-SPIRV-COMMON-NOT: ConditionalEntryPoint
+; CHECK-SPIRV-COMMON-NOT: SpecConstant
+; CHECK-SPIRV-COMMON-NOT: ConditionalCopyObjectINTEL
+
+;;; Input multi-target SPIR-V module
+119734787 67072 1114112 110 0
+2 Capability Addresses
+2 Capability Linkage
+2 Capability Kernel
+2 Capability Int64
+2 Capability Int8
+2 Capability FunctionVariantsINTEL
+3 ConditionalCapabilityINTEL 1 Float16Buffer
+3 ConditionalCapabilityINTEL 1 Float16
+8 Extension "SPV_INTEL_function_variants"
+5 ExtInstImport 2 "OpenCL.std"
+3 MemoryModel 2 2
+7 ConditionalEntryPointINTEL 3 6 4 "dot4" 5
+7 ConditionalEntryPointINTEL 1 6 6 "dot4" 7
+3 ExecutionMode 4 31
+3 ExecutionMode 6 31
+3 Source 3 102000
+6 Name 1 "dot4_fp16.spv"
+6 Name 3 "dot4_fp32.spv"
+11 Name 5 "__spirv_BuiltInGlobalInvocationId"
+11 Name 7 "__spirv_BuiltInGlobalInvocationId"
+3 Name 8 "A"
+3 Name 9 "B"
+3 Name 10 "C"
+4 Name 11 "call.i"
+4 Name 12 "conv.i"
+4 Name 13 "shl.i"
+4 Name 14 "conv1.i"
+4 Name 15 "call2.i"
+4 Name 16 "call4.i"
+4 Name 17 "call5.i"
+4 Name 18 "sext.i"
+5 Name 19 "arrayidx.i"
+4 Name 20 "entry"
+3 Name 21 "A"
+3 Name 22 "B"
+3 Name 23 "C"
+9 Name 24 "__clang_ocl_kern_imp_dot4"
+4 Name 25 "call"
+4 Name 26 "conv"
+3 Name 27 "shl"
+4 Name 28 "conv1"
+4 Name 29 "call2"
+4 Name 30 "call4"
+4 Name 31 "call5"
+4 Name 32 "sext"
+5 Name 33 "arrayidx"
+4 Name 34 "entry"
+3 Name 35 "A"
+3 Name 36 "B"
+3 Name 37 "C"
+4 Name 38 "call.i"
+4 Name 39 "conv.i"
+4 Name 40 "shl.i"
+4 Name 41 "conv1.i"
+4 Name 42 "call2.i"
+4 Name 43 "call4.i"
+4 Name 44 "call5.i"
+4 Name 45 "sext.i"
+5 Name 46 "arrayidx.i"
+4 Name 47 "entry"
+3 Name 48 "A"
+3 Name 49 "B"
+3 Name 50 "C"
+9 Name 51 "__clang_ocl_kern_imp_dot4"
+4 Name 52 "call"
+4 Name 53 "conv"
+3 Name 54 "shl"
+4 Name 55 "conv1"
+4 Name 56 "call2"
+4 Name 57 "call4"
+4 Name 58 "call5"
+4 Name 59 "sext"
+5 Name 60 "arrayidx"
+4 Name 61 "entry"
+
+4 Decorate 8 Alignment 4
+4 Decorate 8 FuncParamAttr 4
+4 Decorate 9 Alignment 4
+4 Decorate 9 FuncParamAttr 4
+4 Decorate 10 Alignment 4
+4 Decorate 10 FuncParamAttr 4
+3 Decorate 5 Constant
+4 Decorate 5 BuiltIn 28
+4 Decorate 21 Alignment 4
+4 Decorate 21 FuncParamAttr 4
+4 Decorate 22 Alignment 4
+4 Decorate 22 FuncParamAttr 4
+4 Decorate 23 Alignment 4
+4 Decorate 23 FuncParamAttr 4
+11 Decorate 24 LinkageAttributes "__clang_ocl_kern_imp_dot4" Export
+4 Decorate 62 DecorationConditionalINTEL 3
+4 Decorate 63 DecorationConditionalINTEL 3
+4 Decorate 64 DecorationConditionalINTEL 3
+4 Decorate 5 DecorationConditionalINTEL 3
+4 Decorate 4 DecorationConditionalINTEL 3
+4 Decorate 24 DecorationConditionalINTEL 3
+4 Decorate 35 Alignment 2
+4 Decorate 35 FuncParamAttr 4
+4 Decorate 36 Alignment 2
+4 Decorate 36 FuncParamAttr 4
+4 Decorate 37 Alignment 2
+4 Decorate 37 FuncParamAttr 4
+3 Decorate 7 Constant
+4 Decorate 7 BuiltIn 28
+4 Decorate 48 Alignment 2
+4 Decorate 48 FuncParamAttr 4
+4 Decorate 49 Alignment 2
+4 Decorate 49 FuncParamAttr 4
+4 Decorate 50 Alignment 2
+4 Decorate 50 FuncParamAttr 4
+11 Decorate 51 LinkageAttributes "__clang_ocl_kern_imp_dot4" Export
+4 Decorate 65 DecorationConditionalINTEL 1
+4 Decorate 66 DecorationConditionalINTEL 1
+4 Decorate 67 DecorationConditionalINTEL 1
+4 Decorate 7 DecorationConditionalINTEL 1
+4 Decorate 6 DecorationConditionalINTEL 1
+4 Decorate 51 DecorationConditionalINTEL 1
+4 Decorate 68 DecorationConditionalINTEL 3
+4 Decorate 69 DecorationConditionalINTEL 3
+4 Decorate 70 DecorationConditionalINTEL 3
+4 Decorate 71 DecorationConditionalINTEL 3
+4 Decorate 72 DecorationConditionalINTEL 1
+4 Decorate 73 DecorationConditionalINTEL 1
+4 Decorate 74 DecorationConditionalINTEL 1
+4 Decorate 75 DecorationConditionalINTEL 1
+4 TypeInt 76 8 0
+4 TypeInt 79 64 0
+4 TypeInt 80 32 0
+5 Constant 79 62 30 0
+5 Constant 79 63 32 0
+4 Constant 80 64 2
+5 Constant 79 65 31 0
+5 Constant 79 66 32 0
+4 Constant 80 67 2
+3 TypeFloat 68 32
+4 TypePointer 69 5 68
+2 TypeVoid 78
+4 TypePointer 77 5 76
+6 TypeFunction 70 78 69 69 77
+4 TypeVector 71 68 4
+3 TypeFloat 72 16
+4 TypePointer 73 5 72
+6 TypeFunction 74 78 73 73 77
+4 TypeVector 75 72 4
+4 TypeVector 81 79 3
+4 TypePointer 82 1 81
+2 TypeBool 83
+10 SpecConstantCapabilitiesINTEL 83 94 4 5 6 8 9 11 39
+7 SpecConstantArchitectureINTEL 83 91 2 3 UGreaterThanEqual 4
+6 SpecConstantOp 83 96 167 94 91
+4 SpecConstantTargetINTEL 83 92 7
+4 SpecConstantTargetINTEL 83 93 8
+6 SpecConstantOp 83 95 166 92 93
+6 SpecConstantOp 83 1 167 96 95
+8 SpecConstantCapabilitiesINTEL 83 87 4 5 6 11 39
+7 SpecConstantArchitectureINTEL 83 84 2 3 UGreaterThanEqual 0
+6 SpecConstantOp 83 89 167 87 84
+4 SpecConstantTargetINTEL 83 85 7
+4 SpecConstantTargetINTEL 83 86 8
+6 SpecConstantOp 83 88 166 85 86
+6 SpecConstantOp 83 90 167 89 88
+5 SpecConstantOp 83 97 168 1
+6 SpecConstantOp 83 3 167 90 97
+4 Variable 82 5 1
+4 Variable 82 7 1
+
+5 Function 78 4 0 70
+3 FunctionParameter 69 8
+3 FunctionParameter 69 9
+3 FunctionParameter 77 10
+
+2 Label 20
+6 Load 81 98 5 2 1
+5 CompositeExtract 79 11 98 0
+4 UConvert 80 12 11
+5 ShiftLeftLogical 80 13 12 64
+4 SConvert 79 14 13
+8 ExtInst 71 15 2 vloadn 14 8 4
+8 ExtInst 71 16 2 vloadn 14 9 4
+5 Dot 68 17 15 16
+5 ShiftLeftLogical 79 18 11 63
+5 ShiftRightArithmetic 79 99 18 62
+5 InBoundsPtrAccessChain 77 19 10 99
+4 Bitcast 69 100 19
+5 Store 100 17 2 4
+1 Return
+
+1 FunctionEnd
+
+5 Function 78 24 1 70
+3 FunctionParameter 69 21
+3 FunctionParameter 69 22
+3 FunctionParameter 77 23
+
+2 Label 34
+6 Load 81 101 5 2 1
+5 CompositeExtract 79 25 101 0
+4 UConvert 80 26 25
+5 ShiftLeftLogical 80 27 26 64
+4 SConvert 79 28 27
+8 ExtInst 71 29 2 vloadn 28 21 4
+8 ExtInst 71 30 2 vloadn 28 22 4
+5 Dot 68 31 29 30
+5 ShiftLeftLogical 79 32 25 63
+5 ShiftRightArithmetic 79 102 32 62
+5 InBoundsPtrAccessChain 77 33 23 102
+4 Bitcast 69 103 33
+5 Store 103 31 2 4
+1 Return
+
+1 FunctionEnd
+
+5 Function 78 6 0 74
+3 FunctionParameter 73 35
+3 FunctionParameter 73 36
+3 FunctionParameter 77 37
+
+2 Label 47
+6 Load 81 104 7 2 1
+5 CompositeExtract 79 38 104 0
+4 UConvert 80 39 38
+5 ShiftLeftLogical 80 40 39 67
+4 SConvert 79 41 40
+8 ExtInst 75 42 2 vloadn 41 35 4
+8 ExtInst 75 43 2 vloadn 41 36 4
+5 Dot 72 44 42 43
+5 ShiftLeftLogical 79 45 38 66
+5 ShiftRightArithmetic 79 105 45 65
+5 InBoundsPtrAccessChain 77 46 37 105
+4 Bitcast 73 106 46
+5 Store 106 44 2 2
+1 Return
+
+1 FunctionEnd
+
+5 Function 78 51 1 74
+3 FunctionParameter 73 48
+3 FunctionParameter 73 49
+3 FunctionParameter 77 50
+
+2 Label 61
+6 Load 81 107 7 2 1
+5 CompositeExtract 79 52 107 0
+4 UConvert 80 53 52
+5 ShiftLeftLogical 80 54 53 67
+4 SConvert 79 55 54
+8 ExtInst 75 56 2 vloadn 55 48 4
+8 ExtInst 75 57 2 vloadn 55 49 4
+5 Dot 72 58 56 57
+5 ShiftLeftLogical 79 59 52 66
+5 ShiftRightArithmetic 79 108 59 65
+5 InBoundsPtrAccessChain 77 60 50 108
+4 Bitcast 73 109 60
+5 Store 109 58 2 2
+1 Return
+
+1 FunctionEnd
diff --git a/tools/llvm-spirv/llvm-spirv.cpp b/tools/llvm-spirv/llvm-spirv.cpp
index b3cecc7d56..4fc9f6f0e0 100644
--- a/tools/llvm-spirv/llvm-spirv.cpp
+++ b/tools/llvm-spirv/llvm-spirv.cpp
@@ -51,7 +51,6 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/Bitcode/BitcodeReader.h"
#include "llvm/Bitcode/BitcodeWriter.h"
-#include "llvm/IR/Constants.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
@@ -156,11 +155,10 @@ static cl::opt<SPIRV::BIsRepresentation> BIsRepresentation(
"SPIR-V Friendly IR")),
cl::init(SPIRV::BIsRepresentation::OpenCL12));
-static cl::opt<bool>
- PreserveOCLKernelArgTypeMetadataThroughString(
- "preserve-ocl-kernel-arg-type-metadata-through-string", cl::init(false),
- cl::desc("Preserve OpenCL kernel_arg_type and kernel_arg_type_qual "
- "metadata through OpString"));
+static cl::opt<bool> PreserveOCLKernelArgTypeMetadataThroughString(
+ "preserve-ocl-kernel-arg-type-metadata-through-string", cl::init(false),
+ cl::desc("Preserve OpenCL kernel_arg_type and kernel_arg_type_qual "
+ "metadata through OpString"));
static cl::opt<bool>
SPIRVToolsDis("spirv-tools-dis", cl::init(false),
@@ -203,9 +201,10 @@ static cl::opt<bool>
SPIRVMemToReg("spirv-mem2reg", cl::init(false),
cl::desc("LLVM/SPIR-V translation enable mem2reg"));
-static cl::opt<bool> SPIRVPreserveAuxData(
- "spirv-preserve-auxdata", cl::init(false),
- cl::desc("Preserve all auxiliary data, such as function attributes and metadata"));
+static cl::opt<bool>
+ SPIRVPreserveAuxData("spirv-preserve-auxdata", cl::init(false),
+ cl::desc("Preserve all auxiliary data, such as "
+ "function attributes and metadata"));
static cl::opt<bool> SpecConstInfo(
"spec-const-info",
@@ -287,6 +286,59 @@ static cl::opt<bool> SPIRVUseLLVMSPIRVBackendTarget(
"don't use the LLVM SPIR-V Backend target."),
cl::init(false));
+static cl::opt<uint32_t> FnVarCategory(
+ "fnvar-category",
+ cl::desc("Specify architecture category of the target device (omitting "
+ "this flag denotes that the target device can be of any "
+ "category). Used only with -r and --fnvar-spec-enable."),
+ cl::value_desc("category"), cl::ValueRequired);
+
+static cl::opt<uint32_t> FnVarFamily(
+ "fnvar-family",
+ cl::desc("Specify architecture family of the target device (omitting this "
+ "flag denotes that the target device can be of any family). Used "
+ "only with -r and --fnvar-spec-enable."),
+ cl::value_desc("family"), cl::ValueRequired);
+
+static cl::opt<uint32_t> FnVarArch(
+ "fnvar-arch",
+ cl::desc("Specify architecture of the target device (omitting this flag "
+ "denotes that the target device can be of any architecture). Used "
+ "only with -r and --fnvar-spec-enable."),
+ cl::value_desc("architecture"), cl::ValueRequired);
+
+static cl::opt<uint32_t>
+ FnVarTarget("fnvar-target",
+ cl::desc("Specify target of the target device (omitting this "
+ "flag denotes that the target device can be any "
+ "target). Used only with -r and --fnvar-spec-enable."),
+ cl::value_desc("target"), cl::ValueRequired);
+
+static cl::list<uint32_t> FnVarFeatures(
+ "fnvar-features", cl::CommaSeparated,
+ cl::desc("Specify features of the target device (omitting this flag "
+ "denotes that the target device supports all features). Used only "
+ "with -r and --fnvar-spec-enable."),
+ cl::value_desc("feature0,feature1,..."), cl::ValueRequired);
+
+static cl::list<uint32_t> FnVarCapabilities(
+ "fnvar-capabilities", cl::CommaSeparated,
+ cl::desc("Specify capabilities of the target device (omitting this flag "
+ "denotes that the target device supports all features). Used only "
+ "with -r and --fnvar-spec-enable."),
+ cl::value_desc("capability0,capability1,..."), cl::ValueRequired);
+
+static cl::opt<std::string> FnVarSpvOut(
+ "fnvar-spv-out",
+ cl::desc("Save the specialized target-specific SPIR-V module to this file. "
+ "Used only with -r and --fnvar-spec-enable."),
+ cl::value_desc("file"), cl::ValueRequired);
+
+static cl::opt<bool> FnVarSpecEnable(
+ "fnvar-spec-enable", cl::init(false),
+ cl::desc("Enable specialization of function variants according to "
+ "SPV_INTEL_function_variants. Requires -r flag."));
+
static std::string removeExt(const std::string &FileName) {
size_t Pos = FileName.find_last_of(".");
if (Pos != std::string::npos)
@@ -802,8 +854,7 @@ int main(int Ac, char **Av) {
}
if (SPIRVPreserveAuxData) {
- Opts.setPreserveAuxData(
- SPIRVPreserveAuxData);
+ Opts.setPreserveAuxData(SPIRVPreserveAuxData);
if (!IsReverse)
Opts.setAllowedToUseExtension(
SPIRV::ExtensionID::SPV_KHR_non_semantic_info);
@@ -843,9 +894,9 @@ int main(int Ac, char **Av) {
SPIRV::DebugInfoEIS::NonSemantic_Shader_DebugInfo_200)
Opts.setAllowExtraDIExpressionsEnabled(true);
if (DebugEIS.getValue() ==
- SPIRV::DebugInfoEIS::NonSemantic_Shader_DebugInfo_100 ||
+ SPIRV::DebugInfoEIS::NonSemantic_Shader_DebugInfo_100 ||
DebugEIS.getValue() ==
- SPIRV::DebugInfoEIS::NonSemantic_Shader_DebugInfo_200)
+ SPIRV::DebugInfoEIS::NonSemantic_Shader_DebugInfo_200)
Opts.setAllowedToUseExtension(
SPIRV::ExtensionID::SPV_KHR_non_semantic_info);
}
@@ -857,6 +908,56 @@ int main(int Ac, char **Av) {
if (SPIRVEmitFunctionPtrAddrSpace.getNumOccurrences() != 0)
Opts.setEmitFunctionPtrAddrSpace(true);
+ Opts.setFnVarSpecEnable(FnVarSpecEnable);
+
+ if (!IsReverse &&
+ (FnVarSpecEnable || FnVarCategory != 0 || FnVarFamily != 0 ||
+ FnVarArch != 0 || FnVarTarget != 0 || !FnVarFeatures.empty() ||
+ !FnVarCapabilities.empty() || !FnVarSpvOut.empty())) {
+ errs() << "--fnvar-xxx flags can be used only with -r\n";
+ return -1;
+ }
+
+ if (!FnVarSpecEnable &&
+ (FnVarCategory != 0 || FnVarFamily != 0 || FnVarArch != 0 ||
+ FnVarTarget != 0 || !FnVarFeatures.empty() ||
+ !FnVarCapabilities.empty() || !FnVarSpvOut.empty())) {
+ errs() << "--fnvar-xxx flags need to be enabled with --fnvar-spec-enable\n";
+ return -1;
+ }
+
+ if (FnVarCategory.getNumOccurrences() > 0) {
+ Opts.setFnVarCategory(FnVarCategory);
+ }
+
+ if (FnVarFamily.getNumOccurrences() > 0) {
+ Opts.setFnVarFamily(FnVarFamily);
+ }
+
+ if (FnVarArch.getNumOccurrences() > 0) {
+ Opts.setFnVarArch(FnVarArch);
+ }
+
+ if (FnVarTarget.getNumOccurrences() > 0) {
+ Opts.setFnVarTarget(FnVarTarget);
+ }
+
+ if (!FnVarFeatures.empty()) {
+ Opts.setFnVarFeatures(FnVarFeatures);
+ }
+
+ if (!FnVarCapabilities.empty()) {
+ Opts.setFnVarCapabilities(FnVarCapabilities);
+ }
+
+ if (!FnVarSpvOut.empty()) {
+ Opts.setFnVarSpvOut(FnVarSpvOut);
+ }
+
+ if (!Opts.validateFnVarOpts()) {
+ return -1;
+ }
+
#ifdef _SPIRV_SUPPORT_TEXT_FMT
if (ToText && (ToBinary || IsReverse || IsRegularization)) {
errs() << "Cannot use -to-text with -to-binary, -r, -s\n";
@@ -906,15 +1007,15 @@ int main(int Ac, char **Av) {
std::optional<SPIRV::SPIRVModuleReport> BinReport =
SPIRV::getSpirvReport(IFS, ErrCode);
if (!BinReport) {
- std::cerr << "Invalid SPIR-V binary: \"" << SPIRV::getErrorMessage(ErrCode) << "\"\n";
+ std::cerr << "Invalid SPIR-V binary: \""
+ << SPIRV::getErrorMessage(ErrCode) << "\"\n";
return -1;
}
SPIRV::SPIRVModuleTextReport TextReport =
SPIRV::formatSpirvReport(BinReport.value());
- std::cout << "SPIR-V module report:"
- << "\n Version: " << TextReport.Version
+ std::cout << "SPIR-V module report:" << "\n Version: " << TextReport.Version
<< "\n Memory model: " << TextReport.MemoryModel
<< "\n Addressing model: " << TextReport.AddrModel << "\n";
@@ -931,7 +1032,8 @@ int main(int Ac, char **Av) {
std::cout << " Number of extended instruction sets: "
<< TextReport.ExtendedInstructionSets.size() << "\n";
for (auto &ExtendedInstructionSet : TextReport.ExtendedInstructionSets)
- std::cout << " Extended Instruction Set: " << ExtendedInstructionSet << "\n";
+ std::cout << " Extended Instruction Set: " << ExtendedInstructionSet
+ << "\n";
}
return 0;
}