1
- // Copyright (c) 2022 Google LLC
1
+ // Copyright (c) 2025 Epic Games, Inc.
2
2
//
3
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
4
// you may not use this file except in compliance with the License.
@@ -30,11 +30,20 @@ constexpr uint32_t kOpEntryPointInOperandInterface = 3;
30
30
constexpr uint32_t kOpVariableStorageClassInOperandIndex = 0 ;
31
31
constexpr uint32_t kOpTypeArrayElemTypeInOperandIndex = 0 ;
32
32
constexpr uint32_t kOpTypeArrayLengthInOperandIndex = 1 ;
33
+ constexpr uint32_t kOpTypeVectorComponentCountInOperandIndex = 1 ;
33
34
constexpr uint32_t kOpTypeMatrixColCountInOperandIndex = 1 ;
34
35
constexpr uint32_t kOpTypeMatrixColTypeInOperandIndex = 0 ;
35
36
constexpr uint32_t kOpTypePtrTypeInOperandIndex = 1 ;
36
37
constexpr uint32_t kOpConstantValueInOperandIndex = 0 ;
37
38
39
+ // Get the component count of the OpTypeVector |vector_type|.
40
+ uint32_t GetVectorComponentCount (Instruction* vector_type) {
41
+ assert (vector_type->opcode () == spv::Op::OpTypeVector);
42
+ uint32_t component_count =
43
+ vector_type->GetSingleWordInOperand (kOpTypeVectorComponentCountInOperandIndex );
44
+ return component_count;
45
+ }
46
+
38
47
// Get the length of the OpTypeArray |array_type|.
39
48
uint32_t GetArrayLength (analysis::DefUseManager* def_use_mgr,
40
49
Instruction* array_type) {
@@ -223,6 +232,8 @@ Pass::Status AdvancedInterfaceVariableScalarReplacement::ProcessEntryPoint(
223
232
224
233
ReplaceInEntryPoint (&entry_point, replaced_interface_vars, scalar_vars);
225
234
235
+ context ()->InvalidateAnalysesExceptFor (IRContext::Analysis::kAnalysisNone );
236
+
226
237
return status;
227
238
}
228
239
@@ -330,19 +341,46 @@ bool AdvancedInterfaceVariableScalarReplacement::ReplaceInterfaceVariable(
330
341
// We are going to replace the access chain with either direct usage of the
331
342
// replacement scalar variable, or a set of composite loads/stores.
332
343
333
- const Replacement* target =
344
+ LookupResult result =
334
345
LookupReplacement (access_chain, &replacement, var.extra_array_length );
335
- if (!target ) {
346
+ if (!result. replacement ) {
336
347
// Error has been already logged by |LookupReplacement|.
337
348
return false ;
338
349
}
350
+ const Replacement* target = result.replacement ;
339
351
340
352
if (!target->HasChildren () && var.extra_array_length == 0 ) {
341
- // Replace with a direct use of the scalar variable.
342
353
auto scalar = target->GetScalarVariable ();
343
354
assert (scalar);
344
- context ()->ReplaceAllUsesWith (access_chain->result_id (),
345
- scalar->result_id ());
355
+
356
+ uint32_t replacement = 0 ;
357
+ if (result.index >= 0 ) {
358
+ // Our scalar is a vector and access chain in question targets a
359
+ // specific component denoted by result.index.
360
+ assert (target->GetVectorComponentCount () > 0 );
361
+ // Replace with an access chain into a direct use of the scalar variable.
362
+ uint32_t indirection_id = TakeNextId ();
363
+ if (indirection_id == 0 ) {
364
+ return false ;
365
+ }
366
+
367
+ uint32_t vector_component_type_id = context ()->get_def_use_mgr ()->GetDef (target->GetTypeId ())->GetSingleWordInOperand (0 );
368
+
369
+ uint32_t index_id = context ()->get_constant_mgr ()->GetUIntConstId (result.index );
370
+ Operand index_operand = {SPV_OPERAND_TYPE_ID, {index_id}};
371
+ std::unique_ptr<Instruction> vector_access_chain =
372
+ CreateAccessChain (context (), indirection_id, scalar,
373
+ vector_component_type_id, index_operand);
374
+ replacement = vector_access_chain->result_id ();
375
+
376
+ auto inst = access_chain->InsertBefore (std::move (vector_access_chain));
377
+ inst->UpdateDebugInfoFrom (access_chain);
378
+ get_def_use_mgr ()->AnalyzeInstDef (inst);
379
+ } else {
380
+ // Replace with a direct use of the scalar variable.
381
+ replacement = scalar->result_id ();
382
+ }
383
+ context ()->ReplaceAllUsesWith (access_chain->result_id (), replacement);
346
384
} else {
347
385
// The current access chain's target is a composite, meaning that there
348
386
// are other instructions using the pointer. We need to convert those to
@@ -732,7 +770,7 @@ bool AdvancedInterfaceVariableScalarReplacement::ReplaceStore(
732
770
return true ;
733
771
}
734
772
735
- const AdvancedInterfaceVariableScalarReplacement::Replacement*
773
+ AdvancedInterfaceVariableScalarReplacement::LookupResult
736
774
AdvancedInterfaceVariableScalarReplacement::LookupReplacement (
737
775
Instruction* access_chain, const Replacement* root,
738
776
uint32_t extra_array_length) {
@@ -744,37 +782,59 @@ AdvancedInterfaceVariableScalarReplacement::LookupReplacement(
744
782
// array, hence we skip it when looking-up the rest.
745
783
uint32_t start_index = extra_array_length == 0 ? 1 : 2 ;
746
784
785
+ uint32_t num_indices = access_chain->NumInOperands ();
786
+
747
787
// Finds the target replacement, which might be a scalar or nested
748
788
// composite.
749
- for (uint32_t i = start_index; i < access_chain-> NumInOperands () ; ++i) {
789
+ for (uint32_t i = start_index; i < num_indices ; ++i) {
750
790
uint32_t index_id = access_chain->GetSingleWordInOperand (i);
751
791
752
792
const analysis::Constant* index_constant =
753
793
const_mgr->FindDeclaredConstant (index_id);
754
794
if (!index_constant) {
755
795
context ()->EmitErrorMessage (
756
796
" Variable cannot be replaced: index is not constant" , access_chain);
757
- return nullptr ;
797
+ return {};
798
+ }
799
+
800
+ // OpAccessChain treats indices as signed.
801
+ int64_t index_value = index_constant->GetSignExtendedValue ();
802
+
803
+ // Very last index can target the vector type, which we
804
+ // have as a scalar.
805
+ if (i == num_indices - 1 ) {
806
+ if (root->GetScalarVariable ()) {
807
+ if (index_value < 0 ||
808
+ index_value >=
809
+ static_cast <int64_t >(root->GetVectorComponentCount ())) {
810
+ // Out of bounds access, this is illegal IR.
811
+ // Notice that OpAccessChain indexing is 0-based, so we should also
812
+ // reject index == size-of-array.
813
+ context ()->EmitErrorMessage (
814
+ " Variable cannot be replaced: invalid index" , access_chain);
815
+ return {};
816
+ }
817
+ // Current root is our replacement scalar - a vector, in fact.
818
+ return {root, index_value};
819
+ }
758
820
}
759
821
760
822
assert (root->HasChildren ());
761
823
const auto & children = root->GetChildren ();
762
824
763
- // OpAccessChain treats indices as signed.
764
- int64_t index_value = index_constant->GetSignExtendedValue ();
765
825
if (index_value < 0 ||
766
826
index_value >= static_cast <int64_t >(children.size ())) {
767
827
// Out of bounds access, this is illegal IR.
768
828
// Notice that OpAccessChain indexing is 0-based, so we should also
769
829
// reject index == size-of-array.
770
830
context ()->EmitErrorMessage (" Variable cannot be replaced: invalid index" ,
771
831
access_chain);
772
- return nullptr ;
832
+ return {} ;
773
833
}
774
834
775
835
root = &children[index_value];
776
836
}
777
- return root;
837
+ return { root} ;
778
838
}
779
839
780
840
AdvancedInterfaceVariableScalarReplacement::Replacement
@@ -863,7 +923,12 @@ AdvancedInterfaceVariableScalarReplacement::CreateReplacementVariables(
863
923
std::unique_ptr<Instruction> variable = CreateVariable (
864
924
type->result_id (), storage_class, var.def , var.extra_array_length );
865
925
866
- node->SetSingleScalarVariable (variable.get ());
926
+ uint32_t vector_component_count = 0 ;
927
+ if (opcode == spv::Op::OpTypeVector) {
928
+ vector_component_count = GetVectorComponentCount (type);
929
+ }
930
+
931
+ node->SetSingleScalarVariable (variable.get (), vector_component_count);
867
932
scalar_vars->push_back (variable.get ());
868
933
869
934
uint32_t var_id = variable->result_id ();
0 commit comments