45
45
// ===----------------------------------------------------------------------===//
46
46
47
47
#include " llvm/ADT/ArrayRef.h"
48
+ #include " llvm/ADT/BitVector.h"
48
49
#include " llvm/ADT/DenseMap.h"
49
50
#include " llvm/ADT/SetVector.h"
50
51
#include " llvm/ADT/SmallVector.h"
@@ -100,11 +101,11 @@ class InterleavedAccessImpl {
100
101
unsigned MaxFactor = 0u ;
101
102
102
103
// / Transform an interleaved load into target specific intrinsics.
103
- bool lowerInterleavedLoad (LoadInst *LI ,
104
+ bool lowerInterleavedLoad (Instruction *LoadOp ,
104
105
SmallSetVector<Instruction *, 32 > &DeadInsts);
105
106
106
107
// / Transform an interleaved store into target specific intrinsics.
107
- bool lowerInterleavedStore (StoreInst *SI ,
108
+ bool lowerInterleavedStore (Instruction *StoreOp ,
108
109
SmallSetVector<Instruction *, 32 > &DeadInsts);
109
110
110
111
// / Transform a load and a deinterleave intrinsic into target specific
@@ -131,7 +132,7 @@ class InterleavedAccessImpl {
131
132
// / made.
132
133
bool replaceBinOpShuffles (ArrayRef<ShuffleVectorInst *> BinOpShuffles,
133
134
SmallVectorImpl<ShuffleVectorInst *> &Shuffles,
134
- LoadInst *LI);
135
+ Instruction *LI);
135
136
};
136
137
137
138
class InterleavedAccess : public FunctionPass {
@@ -250,10 +251,23 @@ static bool isReInterleaveMask(ShuffleVectorInst *SVI, unsigned &Factor,
250
251
}
251
252
252
253
bool InterleavedAccessImpl::lowerInterleavedLoad (
253
- LoadInst *LI , SmallSetVector<Instruction *, 32 > &DeadInsts) {
254
- if (!LI-> isSimple () || isa<ScalableVectorType>(LI ->getType ()))
254
+ Instruction *LoadOp , SmallSetVector<Instruction *, 32 > &DeadInsts) {
255
+ if (isa<ScalableVectorType>(LoadOp ->getType ()))
255
256
return false ;
256
257
258
+ if (auto *LI = dyn_cast<LoadInst>(LoadOp)) {
259
+ if (!LI->isSimple ())
260
+ return false ;
261
+ } else if (auto *VPLoad = dyn_cast<VPIntrinsic>(LoadOp)) {
262
+ assert (VPLoad->getIntrinsicID () == Intrinsic::vp_load);
263
+ // Require a constant mask and evl.
264
+ if (!isa<ConstantVector>(VPLoad->getArgOperand (1 )) ||
265
+ !isa<ConstantInt>(VPLoad->getArgOperand (2 )))
266
+ return false ;
267
+ } else {
268
+ llvm_unreachable (" unsupported load operation" );
269
+ }
270
+
257
271
// Check if all users of this load are shufflevectors. If we encounter any
258
272
// users that are extractelement instructions or binary operators, we save
259
273
// them to later check if they can be modified to extract from one of the
@@ -265,7 +279,7 @@ bool InterleavedAccessImpl::lowerInterleavedLoad(
265
279
// binop are the same load.
266
280
SmallSetVector<ShuffleVectorInst *, 4 > BinOpShuffles;
267
281
268
- for (auto *User : LI ->users ()) {
282
+ for (auto *User : LoadOp ->users ()) {
269
283
auto *Extract = dyn_cast<ExtractElementInst>(User);
270
284
if (Extract && isa<ConstantInt>(Extract->getIndexOperand ())) {
271
285
Extracts.push_back (Extract);
@@ -294,13 +308,31 @@ bool InterleavedAccessImpl::lowerInterleavedLoad(
294
308
unsigned Factor, Index;
295
309
296
310
unsigned NumLoadElements =
297
- cast<FixedVectorType>(LI ->getType ())->getNumElements ();
311
+ cast<FixedVectorType>(LoadOp ->getType ())->getNumElements ();
298
312
auto *FirstSVI = Shuffles.size () > 0 ? Shuffles[0 ] : BinOpShuffles[0 ];
299
313
// Check if the first shufflevector is DE-interleave shuffle.
300
314
if (!isDeInterleaveMask (FirstSVI->getShuffleMask (), Factor, Index, MaxFactor,
301
315
NumLoadElements))
302
316
return false ;
303
317
318
+ // If this is a vp.load, record its mask (NOT shuffle mask).
319
+ BitVector MaskedIndices (NumLoadElements);
320
+ if (auto *VPLoad = dyn_cast<VPIntrinsic>(LoadOp)) {
321
+ auto *Mask = cast<ConstantVector>(VPLoad->getArgOperand (1 ));
322
+ assert (cast<FixedVectorType>(Mask->getType ())->getNumElements () ==
323
+ NumLoadElements);
324
+ if (auto *Splat = Mask->getSplatValue ()) {
325
+ // All-zeros mask, bail out early.
326
+ if (Splat->isZeroValue ())
327
+ return false ;
328
+ } else {
329
+ for (unsigned i = 0U ; i < NumLoadElements; ++i) {
330
+ if (Mask->getAggregateElement (i)->isZeroValue ())
331
+ MaskedIndices.set (i);
332
+ }
333
+ }
334
+ }
335
+
304
336
// Holds the corresponding index for each DE-interleave shuffle.
305
337
SmallVector<unsigned , 4 > Indices;
306
338
@@ -327,9 +359,9 @@ bool InterleavedAccessImpl::lowerInterleavedLoad(
327
359
328
360
assert (Shuffle->getShuffleMask ().size () <= NumLoadElements);
329
361
330
- if (cast<Instruction>(Shuffle->getOperand (0 ))->getOperand (0 ) == LI )
362
+ if (cast<Instruction>(Shuffle->getOperand (0 ))->getOperand (0 ) == LoadOp )
331
363
Indices.push_back (Index);
332
- if (cast<Instruction>(Shuffle->getOperand (0 ))->getOperand (1 ) == LI )
364
+ if (cast<Instruction>(Shuffle->getOperand (0 ))->getOperand (1 ) == LoadOp )
333
365
Indices.push_back (Index);
334
366
}
335
367
@@ -339,25 +371,61 @@ bool InterleavedAccessImpl::lowerInterleavedLoad(
339
371
return false ;
340
372
341
373
bool BinOpShuffleChanged =
342
- replaceBinOpShuffles (BinOpShuffles.getArrayRef (), Shuffles, LI);
374
+ replaceBinOpShuffles (BinOpShuffles.getArrayRef (), Shuffles, LoadOp);
375
+
376
+ // Check if we extract only the unmasked elements.
377
+ if (MaskedIndices.any ()) {
378
+ if (any_of (Shuffles, [&](const auto *Shuffle) {
379
+ ArrayRef<int > ShuffleMask = Shuffle->getShuffleMask ();
380
+ for (int Idx : ShuffleMask) {
381
+ if (Idx < 0 )
382
+ continue ;
383
+ if (MaskedIndices.test (unsigned (Idx)))
384
+ return true ;
385
+ }
386
+ return false ;
387
+ })) {
388
+ LLVM_DEBUG (dbgs () << " IA: trying to extract a masked element through "
389
+ << " shufflevector\n " );
390
+ return false ;
391
+ }
392
+ }
393
+ // Check if we extract only the elements within evl.
394
+ if (auto *VPLoad = dyn_cast<VPIntrinsic>(LoadOp)) {
395
+ uint64_t EVL = cast<ConstantInt>(VPLoad->getArgOperand (2 ))->getZExtValue ();
396
+ if (any_of (Shuffles, [&](const auto *Shuffle) {
397
+ ArrayRef<int > ShuffleMask = Shuffle->getShuffleMask ();
398
+ for (int Idx : ShuffleMask) {
399
+ if (Idx < 0 )
400
+ continue ;
401
+ if (unsigned (Idx) >= EVL)
402
+ return true ;
403
+ }
404
+ return false ;
405
+ })) {
406
+ LLVM_DEBUG (
407
+ dbgs () << " IA: trying to extract an element out of EVL range\n " );
408
+ return false ;
409
+ }
410
+ }
343
411
344
- LLVM_DEBUG (dbgs () << " IA: Found an interleaved load: " << *LI << " \n " );
412
+ LLVM_DEBUG (dbgs () << " IA: Found an interleaved load: " << *LoadOp << " \n " );
345
413
346
414
// Try to create target specific intrinsics to replace the load and shuffles.
347
- if (!TLI->lowerInterleavedLoad (LI , Shuffles, Indices, Factor)) {
415
+ if (!TLI->lowerInterleavedLoad (LoadOp , Shuffles, Indices, Factor)) {
348
416
// If Extracts is not empty, tryReplaceExtracts made changes earlier.
349
417
return !Extracts.empty () || BinOpShuffleChanged;
350
418
}
351
419
352
420
DeadInsts.insert_range (Shuffles);
353
421
354
- DeadInsts.insert (LI );
422
+ DeadInsts.insert (LoadOp );
355
423
return true ;
356
424
}
357
425
358
426
bool InterleavedAccessImpl::replaceBinOpShuffles (
359
427
ArrayRef<ShuffleVectorInst *> BinOpShuffles,
360
- SmallVectorImpl<ShuffleVectorInst *> &Shuffles, LoadInst *LI ) {
428
+ SmallVectorImpl<ShuffleVectorInst *> &Shuffles, Instruction *LoadOp ) {
361
429
for (auto *SVI : BinOpShuffles) {
362
430
BinaryOperator *BI = cast<BinaryOperator>(SVI->getOperand (0 ));
363
431
Type *BIOp0Ty = BI->getOperand (0 )->getType ();
@@ -380,9 +448,9 @@ bool InterleavedAccessImpl::replaceBinOpShuffles(
380
448
<< " \n With : " << *NewSVI1 << " \n And : "
381
449
<< *NewSVI2 << " \n And : " << *NewBI << " \n " );
382
450
RecursivelyDeleteTriviallyDeadInstructions (SVI);
383
- if (NewSVI1->getOperand (0 ) == LI )
451
+ if (NewSVI1->getOperand (0 ) == LoadOp )
384
452
Shuffles.push_back (NewSVI1);
385
- if (NewSVI2->getOperand (0 ) == LI )
453
+ if (NewSVI2->getOperand (0 ) == LoadOp )
386
454
Shuffles.push_back (NewSVI2);
387
455
}
388
456
@@ -454,27 +522,79 @@ bool InterleavedAccessImpl::tryReplaceExtracts(
454
522
}
455
523
456
524
bool InterleavedAccessImpl::lowerInterleavedStore (
457
- StoreInst *SI, SmallSetVector<Instruction *, 32 > &DeadInsts) {
458
- if (!SI->isSimple ())
459
- return false ;
525
+ Instruction *StoreOp, SmallSetVector<Instruction *, 32 > &DeadInsts) {
526
+ Value *StoredValue;
527
+ if (auto *SI = dyn_cast<StoreInst>(StoreOp)) {
528
+ if (!SI->isSimple ())
529
+ return false ;
530
+ StoredValue = SI->getValueOperand ();
531
+ } else if (auto *VPStore = dyn_cast<VPIntrinsic>(StoreOp)) {
532
+ assert (VPStore->getIntrinsicID () == Intrinsic::vp_store);
533
+ // Require a constant mask and evl.
534
+ if (!isa<ConstantVector>(VPStore->getArgOperand (2 )) ||
535
+ !isa<ConstantInt>(VPStore->getArgOperand (3 )))
536
+ return false ;
537
+ StoredValue = VPStore->getArgOperand (0 );
538
+ } else {
539
+ llvm_unreachable (" unsupported store operation" );
540
+ }
460
541
461
- auto *SVI = dyn_cast<ShuffleVectorInst>(SI-> getValueOperand () );
542
+ auto *SVI = dyn_cast<ShuffleVectorInst>(StoredValue );
462
543
if (!SVI || !SVI->hasOneUse () || isa<ScalableVectorType>(SVI->getType ()))
463
544
return false ;
464
545
546
+ unsigned NumStoredElements =
547
+ cast<FixedVectorType>(SVI->getType ())->getNumElements ();
548
+ // If this is a vp.store, record its mask (NOT shuffle mask).
549
+ BitVector MaskedIndices (NumStoredElements);
550
+ if (auto *VPStore = dyn_cast<VPIntrinsic>(StoreOp)) {
551
+ auto *Mask = cast<ConstantVector>(VPStore->getArgOperand (2 ));
552
+ assert (cast<FixedVectorType>(Mask->getType ())->getNumElements () ==
553
+ NumStoredElements);
554
+ if (auto *Splat = Mask->getSplatValue ()) {
555
+ // All-zeros mask, bail out early.
556
+ if (Splat->isZeroValue ())
557
+ return false ;
558
+ } else {
559
+ for (unsigned i = 0U ; i < NumStoredElements; ++i) {
560
+ if (Mask->getAggregateElement (i)->isZeroValue ())
561
+ MaskedIndices.set (i);
562
+ }
563
+ }
564
+ }
565
+
465
566
// Check if the shufflevector is RE-interleave shuffle.
466
567
unsigned Factor;
467
568
if (!isReInterleaveMask (SVI, Factor, MaxFactor))
468
569
return false ;
469
570
470
- LLVM_DEBUG (dbgs () << " IA: Found an interleaved store: " << *SI << " \n " );
571
+ // Check if we store only the unmasked elements.
572
+ if (MaskedIndices.any ()) {
573
+ if (any_of (SVI->getShuffleMask (), [&](int Idx) {
574
+ return Idx >= 0 && MaskedIndices.test (unsigned (Idx));
575
+ })) {
576
+ LLVM_DEBUG (dbgs () << " IA: trying to store a masked element\n " );
577
+ return false ;
578
+ }
579
+ }
580
+ // Check if we store only the elements within evl.
581
+ if (auto *VPStore = dyn_cast<VPIntrinsic>(StoreOp)) {
582
+ uint64_t EVL = cast<ConstantInt>(VPStore->getArgOperand (3 ))->getZExtValue ();
583
+ if (any_of (SVI->getShuffleMask (),
584
+ [&](int Idx) { return Idx >= 0 && unsigned (Idx) >= EVL; })) {
585
+ LLVM_DEBUG (dbgs () << " IA: trying to store an element out of EVL range\n " );
586
+ return false ;
587
+ }
588
+ }
589
+
590
+ LLVM_DEBUG (dbgs () << " IA: Found an interleaved store: " << *StoreOp << " \n " );
471
591
472
592
// Try to create target specific intrinsics to replace the store and shuffle.
473
- if (!TLI->lowerInterleavedStore (SI , SVI, Factor))
593
+ if (!TLI->lowerInterleavedStore (StoreOp , SVI, Factor))
474
594
return false ;
475
595
476
596
// Already have a new target specific interleaved store. Erase the old store.
477
- DeadInsts.insert (SI );
597
+ DeadInsts.insert (StoreOp );
478
598
DeadInsts.insert (SVI);
479
599
return true ;
480
600
}
@@ -766,12 +886,15 @@ bool InterleavedAccessImpl::runOnFunction(Function &F) {
766
886
SmallSetVector<Instruction *, 32 > DeadInsts;
767
887
bool Changed = false ;
768
888
889
+ using namespace PatternMatch ;
769
890
for (auto &I : instructions (F)) {
770
- if (auto *LI = dyn_cast<LoadInst>(&I))
771
- Changed |= lowerInterleavedLoad (LI, DeadInsts);
891
+ if (match (&I, m_CombineOr (m_Load (m_Value ()),
892
+ m_Intrinsic<Intrinsic::vp_load>())))
893
+ Changed |= lowerInterleavedLoad (&I, DeadInsts);
772
894
773
- if (auto *SI = dyn_cast<StoreInst>(&I))
774
- Changed |= lowerInterleavedStore (SI, DeadInsts);
895
+ if (match (&I, m_CombineOr (m_Store (m_Value (), m_Value ()),
896
+ m_Intrinsic<Intrinsic::vp_store>())))
897
+ Changed |= lowerInterleavedStore (&I, DeadInsts);
775
898
776
899
if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
777
900
// At present, we only have intrinsics to represent (de)interleaving
0 commit comments