Skip to content

Commit 3b397d0

Browse files
committed
AdvancedInterfaceVariableScalarReplacement vectors handling fix
It was crashing on this small shader: #version 450 layout(location = 1) in vec3 x[2]; void main() { float crash = x[1].y; } Because of non-supported OpAccessChain-s targeting the array-ed vectors' component. Code now detects the component access and inserts a reduced OpAccessChain (targeting the replacement scalar variable component correctly). Added a test for this specific case. Fixed copyright header I got wrong last time.
1 parent c715d82 commit 3b397d0

File tree

3 files changed

+157
-24
lines changed

3 files changed

+157
-24
lines changed

source/opt/adv_interface_var_sroa.cpp

Lines changed: 79 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2022 Google LLC
1+
// Copyright (c) 2025 Epic Games, Inc.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -30,11 +30,20 @@ constexpr uint32_t kOpEntryPointInOperandInterface = 3;
3030
constexpr uint32_t kOpVariableStorageClassInOperandIndex = 0;
3131
constexpr uint32_t kOpTypeArrayElemTypeInOperandIndex = 0;
3232
constexpr uint32_t kOpTypeArrayLengthInOperandIndex = 1;
33+
constexpr uint32_t kOpTypeVectorComponentCountInOperandIndex = 1;
3334
constexpr uint32_t kOpTypeMatrixColCountInOperandIndex = 1;
3435
constexpr uint32_t kOpTypeMatrixColTypeInOperandIndex = 0;
3536
constexpr uint32_t kOpTypePtrTypeInOperandIndex = 1;
3637
constexpr uint32_t kOpConstantValueInOperandIndex = 0;
3738

39+
// Get the component count of the OpTypeVector |vector_type|.
40+
uint32_t GetVectorComponentCount(Instruction* vector_type) {
41+
assert(vector_type->opcode() == spv::Op::OpTypeVector);
42+
uint32_t component_count =
43+
vector_type->GetSingleWordInOperand(kOpTypeVectorComponentCountInOperandIndex);
44+
return component_count;
45+
}
46+
3847
// Get the length of the OpTypeArray |array_type|.
3948
uint32_t GetArrayLength(analysis::DefUseManager* def_use_mgr,
4049
Instruction* array_type) {
@@ -223,6 +232,8 @@ Pass::Status AdvancedInterfaceVariableScalarReplacement::ProcessEntryPoint(
223232

224233
ReplaceInEntryPoint(&entry_point, replaced_interface_vars, scalar_vars);
225234

235+
context()->InvalidateAnalysesExceptFor(IRContext::Analysis::kAnalysisNone);
236+
226237
return status;
227238
}
228239

@@ -330,19 +341,46 @@ bool AdvancedInterfaceVariableScalarReplacement::ReplaceInterfaceVariable(
330341
// We are going to replace the access chain with either direct usage of the
331342
// replacement scalar variable, or a set of composite loads/stores.
332343

333-
const Replacement* target =
344+
LookupResult result =
334345
LookupReplacement(access_chain, &replacement, var.extra_array_length);
335-
if (!target) {
346+
if (!result.replacement) {
336347
// Error has been already logged by |LookupReplacement|.
337348
return false;
338349
}
350+
const Replacement* target = result.replacement;
339351

340352
if (!target->HasChildren() && var.extra_array_length == 0) {
341-
// Replace with a direct use of the scalar variable.
342353
auto scalar = target->GetScalarVariable();
343354
assert(scalar);
344-
context()->ReplaceAllUsesWith(access_chain->result_id(),
345-
scalar->result_id());
355+
356+
uint32_t replacement = 0;
357+
if (result.index >= 0) {
358+
// Our scalar is a vector and access chain in question targets a
359+
// specific component denoted by result.index.
360+
assert(target->GetVectorComponentCount() > 0);
361+
// Replace with an access chain into a direct use of the scalar variable.
362+
uint32_t indirection_id = TakeNextId();
363+
if (indirection_id == 0) {
364+
return false;
365+
}
366+
367+
uint32_t vector_component_type_id = context()->get_def_use_mgr()->GetDef(target->GetTypeId())->GetSingleWordInOperand(0);
368+
369+
uint32_t index_id = context()->get_constant_mgr()->GetUIntConstId(result.index);
370+
Operand index_operand = {SPV_OPERAND_TYPE_ID, {index_id}};
371+
std::unique_ptr<Instruction> vector_access_chain =
372+
CreateAccessChain(context(), indirection_id, scalar,
373+
vector_component_type_id, index_operand);
374+
replacement = vector_access_chain->result_id();
375+
376+
auto inst = access_chain->InsertBefore(std::move(vector_access_chain));
377+
inst->UpdateDebugInfoFrom(access_chain);
378+
get_def_use_mgr()->AnalyzeInstDef(inst);
379+
} else {
380+
// Replace with a direct use of the scalar variable.
381+
replacement = scalar->result_id();
382+
}
383+
context()->ReplaceAllUsesWith(access_chain->result_id(), replacement);
346384
} else {
347385
// The current access chain's target is a composite, meaning that there
348386
// are other instructions using the pointer. We need to convert those to
@@ -732,7 +770,7 @@ bool AdvancedInterfaceVariableScalarReplacement::ReplaceStore(
732770
return true;
733771
}
734772

735-
const AdvancedInterfaceVariableScalarReplacement::Replacement*
773+
AdvancedInterfaceVariableScalarReplacement::LookupResult
736774
AdvancedInterfaceVariableScalarReplacement::LookupReplacement(
737775
Instruction* access_chain, const Replacement* root,
738776
uint32_t extra_array_length) {
@@ -744,37 +782,59 @@ AdvancedInterfaceVariableScalarReplacement::LookupReplacement(
744782
// array, hence we skip it when looking-up the rest.
745783
uint32_t start_index = extra_array_length == 0 ? 1 : 2;
746784

785+
uint32_t num_indices = access_chain->NumInOperands();
786+
747787
// Finds the target replacement, which might be a scalar or nested
748788
// composite.
749-
for (uint32_t i = start_index; i < access_chain->NumInOperands(); ++i) {
789+
for (uint32_t i = start_index; i < num_indices; ++i) {
750790
uint32_t index_id = access_chain->GetSingleWordInOperand(i);
751791

752792
const analysis::Constant* index_constant =
753793
const_mgr->FindDeclaredConstant(index_id);
754794
if (!index_constant) {
755795
context()->EmitErrorMessage(
756796
"Variable cannot be replaced: index is not constant", access_chain);
757-
return nullptr;
797+
return {};
798+
}
799+
800+
// OpAccessChain treats indices as signed.
801+
int64_t index_value = index_constant->GetSignExtendedValue();
802+
803+
// Very last index can target the vector type, which we
804+
// have as a scalar.
805+
if (i == num_indices - 1) {
806+
if (root->GetScalarVariable()) {
807+
if (index_value < 0 ||
808+
index_value >=
809+
static_cast<int64_t>(root->GetVectorComponentCount())) {
810+
// Out of bounds access, this is illegal IR.
811+
// Notice that OpAccessChain indexing is 0-based, so we should also
812+
// reject index == size-of-array.
813+
context()->EmitErrorMessage(
814+
"Variable cannot be replaced: invalid index", access_chain);
815+
return {};
816+
}
817+
// Current root is our replacement scalar - a vector, in fact.
818+
return {root, index_value};
819+
}
758820
}
759821

760822
assert(root->HasChildren());
761823
const auto& children = root->GetChildren();
762824

763-
// OpAccessChain treats indices as signed.
764-
int64_t index_value = index_constant->GetSignExtendedValue();
765825
if (index_value < 0 ||
766826
index_value >= static_cast<int64_t>(children.size())) {
767827
// Out of bounds access, this is illegal IR.
768828
// Notice that OpAccessChain indexing is 0-based, so we should also
769829
// reject index == size-of-array.
770830
context()->EmitErrorMessage("Variable cannot be replaced: invalid index",
771831
access_chain);
772-
return nullptr;
832+
return {};
773833
}
774834

775835
root = &children[index_value];
776836
}
777-
return root;
837+
return {root};
778838
}
779839

780840
AdvancedInterfaceVariableScalarReplacement::Replacement
@@ -863,7 +923,12 @@ AdvancedInterfaceVariableScalarReplacement::CreateReplacementVariables(
863923
std::unique_ptr<Instruction> variable = CreateVariable(
864924
type->result_id(), storage_class, var.def, var.extra_array_length);
865925

866-
node->SetSingleScalarVariable(variable.get());
926+
uint32_t vector_component_count = 0;
927+
if (opcode == spv::Op::OpTypeVector) {
928+
vector_component_count = GetVectorComponentCount(type);
929+
}
930+
931+
node->SetSingleScalarVariable(variable.get(), vector_component_count);
867932
scalar_vars->push_back(variable.get());
868933

869934
uint32_t var_id = variable->result_id();

source/opt/adv_interface_var_sroa.h

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2022 Google LLC
1+
// Copyright (c) 2025 Epic Games, Inc.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -90,14 +90,21 @@ class AdvancedInterfaceVariableScalarReplacement : public Pass {
9090

9191
Instruction* GetScalarVariable() const { return scalar_var; }
9292

93-
void SetSingleScalarVariable(Instruction* var) { scalar_var = var; }
93+
void SetSingleScalarVariable(Instruction* var, uint32_t in_vector_component_count) {
94+
scalar_var = var;
95+
vector_component_count = in_vector_component_count;
96+
}
9497

9598
uint32_t GetTypeId() const { return type_id; }
9699

100+
// Returns 0, if the Replacement is not a vector.
101+
uint32_t GetVectorComponentCount() const { return vector_component_count; }
102+
97103
private:
98104
std::vector<Replacement> children;
99105
Instruction* scalar_var;
100106
uint32_t type_id;
107+
uint32_t vector_component_count;
101108
};
102109

103110
// Collects all interface variables used by the |entry_point|.
@@ -171,14 +178,21 @@ class AdvancedInterfaceVariableScalarReplacement : public Pass {
171178
Instruction* optional_access_chain,
172179
uint32_t extra_array_length);
173180

181+
struct LookupResult {
182+
// The replacement node, nullptr if not found.
183+
const Replacement* replacement = nullptr;
184+
// If |replacement| is a vector, which was also indexed by |access_chain|,
185+
// this will have that used index value.
186+
int64_t index = -1;
187+
};
174188
// Looks up the replacement node according to the indices from the access
175189
// chain |access_chain|, using the passed |root| as a base. If any index in
176190
// the chain is non-constant or ouf-of-bound, return nullptr. If
177191
// |extra_array_length| is not zero, the first index in the chain is skipped,
178192
// as it is the one used for extra arrayness.
179-
const Replacement* LookupReplacement(Instruction* access_chain,
180-
const Replacement* root,
181-
uint32_t extra_array_length);
193+
LookupResult LookupReplacement(Instruction* access_chain,
194+
const Replacement* root,
195+
uint32_t extra_array_length);
182196

183197
// Creates a variable with type |type_id| and storage class |storage_class|.
184198
// Debug info for the newly created variable is copied from the source

test/opt/adv_interface_var_sroa_test.cpp

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2022 Google LLC
1+
// Copyright (c) 2025 Epic Games, Inc.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -70,15 +70,11 @@ TEST_F(AdvancedInterfaceVariableScalarReplacementTest,
7070
; CHECK-DAG: OpDecorate [[y]] Location 0
7171
; CHECK-DAG: OpDecorate [[gl_InvocationID]] BuiltIn InvocationId
7272
; CHECK-DAG: OpDecorate [[z0]] Location 0
73-
; CHECK-DAG: OpDecorate [[z0]] Component 0
7473
; CHECK-DAG: OpDecorate [[z1]] Location 1
75-
; CHECK-DAG: OpDecorate [[z1]] Component 0
7674
; CHECK-DAG: OpDecorate [[z0]] Patch
7775
; CHECK-DAG: OpDecorate [[z1]] Patch
7876
; CHECK-DAG: OpDecorate [[w0]] Location 2
79-
; CHECK-DAG: OpDecorate [[w0]] Component 0
8077
; CHECK-DAG: OpDecorate [[w1]] Location 3
81-
; CHECK-DAG: OpDecorate [[w1]] Component 0
8278
; CHECK-DAG: OpDecorate [[w0]] Patch
8379
; CHECK-DAG: OpDecorate [[w1]] Patch
8480
; CHECK-DAG: OpDecorate [[u0]] Location 3
@@ -335,6 +331,64 @@ TEST_F(AdvancedInterfaceVariableScalarReplacementTest,
335331
SinglePassRunAndMatch<AdvancedInterfaceVariableScalarReplacement>(spirv, true, true);
336332
}
337333

334+
TEST_F(AdvancedInterfaceVariableScalarReplacementTest,
335+
ReplaceInterfaceVarsWithScalas_Vectors) {
336+
const std::string spirv = R"(
337+
OpCapability Shader
338+
OpMemoryModel Logical GLSL450
339+
OpEntryPoint Vertex %func "shader" %x
340+
341+
; CHECK: OpName [[y:%\w+]] "y"
342+
; CHECK: OpName [[x0:%\w+]] "x[0]"
343+
; CHECK: OpName [[x1:%\w+]] "x[1]"
344+
; CHECK-NOT: OpName {{%\w+}} "y"
345+
OpName %x "x"
346+
OpName %y "y"
347+
348+
; CHECK-DAG: OpDecorate [[x0]] Location 1
349+
; CHECK-DAG: OpDecorate [[x1]] Location 2
350+
OpDecorate %x Location 1
351+
352+
%float = OpTypeFloat 32
353+
%int = OpTypeInt 32 1
354+
%uint = OpTypeInt 32 0
355+
%int_1 = OpConstant %int 1
356+
%uint_1 = OpConstant %uint 1
357+
%uint_2 = OpConstant %uint 2
358+
%v3float = OpTypeVector %float 3
359+
%_arr_v3float_uint_2 = OpTypeArray %v3float %uint_2
360+
%_ptr_Input_float = OpTypePointer Input %float
361+
%_ptr_Input_v3float = OpTypePointer Input %v3float
362+
%_ptr_Input__arr_v3float_uint_2 = OpTypePointer Input %_arr_v3float_uint_2
363+
%_ptr_Function_float = OpTypePointer Function %float
364+
%_ptr_Function__vec3 = OpTypePointer Function %v3float
365+
366+
%x = OpVariable %_ptr_Input__arr_v3float_uint_2 Input
367+
; CHECK-DAG: [[x0]] = OpVariable %_ptr_Input_v3float Input
368+
; CHECK-DAG: [[x1]] = OpVariable %_ptr_Input_v3float Input
369+
370+
%void = OpTypeVoid
371+
%void_f = OpTypeFunction %void
372+
%func = OpFunction %void None %void_f
373+
%label = OpLabel
374+
375+
%y = OpVariable %_ptr_Function_float Function
376+
; CHECK-DAG [[y]] = OpVariable %_ptr_Function_float Function
377+
378+
; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Input_float [[x1]] %uint_1
379+
; CHECK: [[val:%\w+]] = OpLoad %float [[ptr]]
380+
; CHECK: OpStore [[y]] [[val]]
381+
%x_z_ptr = OpAccessChain %_ptr_Input_float %x %int_1 %uint_1
382+
%x_z_val = OpLoad %float %x_z_ptr
383+
OpStore %y %x_z_val
384+
385+
OpReturn
386+
OpFunctionEnd
387+
)";
388+
389+
SinglePassRunAndMatch<AdvancedInterfaceVariableScalarReplacement>(spirv, true, true);
390+
}
391+
338392
TEST_F(AdvancedInterfaceVariableScalarReplacementTest,
339393
CheckPatchDecorationPreservation) {
340394
// Make sure scalars for the variables with the extra arrayness have the extra

0 commit comments

Comments
 (0)