/*========================== begin_copyright_notice ============================

Copyright (C) 2017-2025 Intel Corporation

SPDX-License-Identifier: MIT

============================= end_copyright_notice ===========================*/

//
/// GenXSimdCFConformance
/// ---------------------
///
/// This pass checks that the use of SIMD control flow (llvm.genx.simdcf.goto
/// and llvm.genx.simdcf.join) conforms to the rules required to allow us to
/// generate actual goto and join instructions. If not, the intrinsics are
/// lowered to code that implements the defined semantics for the intrinsics,
/// but does not use SIMD CF instructions, so is usually less efficient.
///
/// It also makes certain transformations to make goto/join legal in terms of
/// its position in the basic block. These can fail silently, in which case the
/// conformance check will fail on the goto/join in question:
///
/// * A goto and its extractvalues must be at the end of the block. (Actually,
///   if the !any result of the goto is used in a conditional branch at the end
///   of the block, then the goto being baled into the branch means that it is
///   treated as being at the end of the block anyway. The only reason we need
///   to sink it here is to ensure that isGotoBlock works.)
///
/// * For a join label block (a block that is the JIP of other gotos/joins), a
///   join must come at the start of the block.
///
/// * For a branching join block (one whose conditional branch condition is the
///   !any result from a join), the join must be at the end of the block.
///
/// * For a block that has one join with both of the above true, we need to move
///   all other code out of the block.
///
/// The pass is run twice: an "early SIMD CF conformance pass" (a module pass)
/// just before GenXLowering, and a "late SIMD CF conformance pass" (a function
/// group pass) just before second baling.
///
/// The early pass is the one that checks for conformance, and lowers the goto
/// and join intrinsics if the code is not conformant. The conformance checks
/// implement the rules listed in the documentation for the goto and join
/// intrinsics.
///
/// Lowering a goto issues a "failed to optimize SIMD control flow" warning. No
/// clue is given in the warning as to what caused the conformance failure,
/// however you (a compiler developer) can find out (for a test case submitted
/// by a compiler user) by turning on -debug and looking at the output from this
/// pass.
///
/// The late pass checks again for conformance, but if the code is not
/// conformant, it just errors. We could lower the gotos and joins there too,
/// but it would be more fiddly as we would have to ensure that the code
/// conforms with what is expected at that stage of compilation, and there is
/// no further chance to optimize it there.
///
/// We are not expecting this error to happen.
///
/// Otherwise, the late pass sets the register category of the EM and RM values
/// to "EM" and "RM", so they do not get any register allocated.
///
/// Conformance rules
/// ^^^^^^^^^^^^^^^^^
///
/// If the goto and join intrinsics are not used in a way that conforms to the
/// rules, then they will still have the semantics in their spec, but this pass
/// will lower at least some of them to equivalent but less efficient code.
///
/// The rules are:
///
/// 1. Because the hardware has a single EM (execution mask) register, all EM
///    values input to and generated by these intrinsics must not interfere with
///    each other; that is, they must have disjoint live ranges. For the
///    purposes of determining interference, if any EM value is a phi node
///    with incoming constant all ones, then the constant all ones value is
///    counted as being live from the start of the function and is not allowed
///    to interfere with other EM values (although it can interfere with other
///    such constant all ones values).
///
/// 2. An EM value is allowed to be defined:
///
///    a. as part of the struct returned by one of these intrinsics;
///
///    b. by a phi node, as long as each incoming is either an EM value or
///       a constant all ones;
///
///    c. by an extractvalue extracting it from a struct containing an EM value;
///
///    d. as a function argument, as long as an EM value is also returned by the
///       function (perhaps as part of a struct);
///
///    e. by an insertvalue as part of a return value struct;
///
///    f. as the return value of a non-intrinsic call (perhaps as part of a
///    struct),
///       as long as there is also a call arg that is an EM value, and the
///       called function has the corresponding function arg and return value as
///       EM values;
///
///    g. since shufflevector from EM does not change EM and only makes it
///    shorter
///       to create implicit predication of desired width, it's also considered
///       as an EM definition, but it can only be used by wrregion and select;
///
/// 3. An EM value is allowed to be used:
///
///    a. as the OldEM input to one of these intrinsics;
///
///    b. in a phi node, as long as the result of the phi node is an EM value;
///
///    c. as the condition in a wrregion or select;
///
///    d. as the input to a shufflevector whose effect is to slice part of the
///    EM
///       value starting at index 0, as long as the result of that slice is only
///       used as the condition in a wrregion or select;
///
///    e. as a call argument, as long as the corresponding function argument is
///    an
///       EM value, and the call has an EM return value;
///
///    f. in a return (perhaps as part of a struct), as long as the function
///    also
///       has an argument that is an EM value.
///
///    For an EM value defined in a goto, or a join whose scalar BranchCond
///    result is used in a conditional branch, or in an extractvalue out of the
///    result of such a goto or join, the only use allowed in the same basic
///    block as the goto/join is such an extractvalue.
///
/// 4. The OldEM input to the two intrinsics must be either an EM value or
///    constant all ones. In the latter case, and in the case of a constant
///    incoming to an EM phi node, its live range is considered to reach back
///    through all paths to the function entry for the purposes of rule (1).
///
/// 5. Each join point has a web of RM (resume mask) values, linked as by rules
///    (6) and (7). All RM values within one join point's web must not interfere
///    with each other; that is, they must have disjoint live ranges. For the
///    purposes of determining interference, if an RM value is a phi node with
///    incoming constant all zeros, then the constant all zeros value is
///    counted as being live from the start of the function and is not allowed
///    to interfere with other RM values for this join (although it can
///    interfere with other such constant all zeros values).
///
/// 6. An RM value is allowed to be defined:
///
///    a. as part of the struct returned by ``llvm.genx.simdcf.goto``;
///
///    b. by a phi node, as long as each incoming is either an RM value or
///       a constant all zeros.
///
/// 7. An RM value is allowed to be used:
///
///    a. as the OldRM input to ``llvm.genx.simdcf.goto``;
///
///    b. as the RM input to ``llvm.genx.simdcf.join``, but only to one join in
///       the whole web;
///
///    c. in a phi node, as long as the result of the phi node is an RM value.
///
/// 8. The OldRM input to ``llvm.genx.simdcf.goto``, or the RM input to
///    ``llvm.genx.simdcf.join``, must be either an RM value, or constant all
///    zeros. In the latter case, and in the case of a constant incoming to an
///    RM phi node, its live range is considered to reach back through all paths
///    to the function entry or to the web's ``llvm.genx.simdcf.join`` for the
///    purposes of rule (5).
///
/// 9. The BranchCond struct element of the result of ``llvm.genx.simdcf.goto``
///    must either be unused (unextracted), or, after being extractvalued,
///    must have exactly one use, which is in a
///    conditional branch terminating the same basic block. In the unused case,
///    the basic block must end with an unconditional branch. (This is a goto
///    that is immediately followed by a join.)
///
/// 10. The BranchCond struct element of the result of ``llvm.genx.simdcf.join``
///     must either be unused (unextracted), or, after being extractvalued,
///     have exactly one use, which is in a conditional branch terminating the
///     same basic block.
///
/// 11. It must be possible to derive an ordering for the basic blocks in a
///     function such that, in the conditional branch using the result of any
///     goto or join, the "false" successor is fall-through and the "true"
///     successor is to a join later on in the sequence. For a goto followed by
///     an unconditional branch, the successor is fall-through _and_ the next
///     join in sequence.
///
/// **IR restriction**: goto and join intrinsics must conform to these rules
/// (since this pass lowers any that do not).
///
//===----------------------------------------------------------------------===//
#include "FunctionGroup.h"
#include "GenX.h"
#include "GenXConstants.h"
#include "GenXGotoJoin.h"
#include "GenXLiveness.h"
#include "GenXModule.h"
#include "GenXTargetMachine.h"
#include "GenXUtil.h"

#include "vc/Utils/GenX/KernelInfo.h"
#include "vc/Utils/GenX/RegCategory.h"

#include "Probe/Assertion.h"
#include "llvmWrapper/IR/DerivedTypes.h"
#include "llvmWrapper/IR/InstrTypes.h"
#include "llvmWrapper/IR/Instructions.h"
#include "llvmWrapper/IR/IntrinsicInst.h"
#include "llvmWrapper/Support/TypeSize.h"

#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Analysis/CFG.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/GenXIntrinsics/GenXIntrinsics.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/DiagnosticPrinter.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Transforms/Utils/Local.h"

#define DEBUG_TYPE "GENX_SIMDCFCONFORMANCE"

using namespace llvm;
using namespace genx;

static cl::opt<bool> EnableGenXGotoJoin(
    "enable-genx-goto-join", cl::init(true), cl::Hidden,
    cl::desc(
        "Enable use of Gen goto/join instructions for SIMD control flow."));

namespace {

// Diagnostic information for error/warning relating to SIMD control flow.
class DiagnosticInfoSimdCF : public DiagnosticInfoOptimizationBase {
private:
  static const int KindID;

  static int getKindID() { return KindID; }

public:
  static void emit(Instruction *Inst, StringRef Msg,
                   DiagnosticSeverity Severity = DS_Error);
  DiagnosticInfoSimdCF(DiagnosticSeverity Severity, const Function &Fn,
                       const DebugLoc &DLoc, StringRef Msg)
      : DiagnosticInfoOptimizationBase((DiagnosticKind)getKindID(), Severity,
                                       /*PassName=*/nullptr, Msg, Fn, DLoc) {}
  // This kind of message is always enabled, and not affected by -rpass.
  bool isEnabled() const override { return true; }
  static bool classof(const DiagnosticInfo *DI) {
    return DI->getKind() == getKindID();
  }

  // TODO: consider changing format
  void print(DiagnosticPrinter &DP) const override {
    DP << "GenXSimdCFConformance: " << RemarkName;
  }
};

const int DiagnosticInfoSimdCF::KindID = getNextAvailablePluginDiagnosticKind();

// GenX SIMD control flow conformance pass -- common data between early and
// late passes.
class GenXSimdCFConformance {
protected:
  Module *M = nullptr;
  FunctionGroup *FG = nullptr;
  FunctionGroupAnalysis *FGA = nullptr;
  DominatorTreeGroupWrapperPass *DTWrapper = nullptr;
  std::map<Function *, DominatorTree *> DTs;
  GenXLiveness *Liveness = nullptr;
  bool Modified = false;
  SetVector<SimpleValue> EMVals;
  std::map<CallInst *, SetVector<SimpleValue>> RMVals;
  bool lowerSimdCF = false;
  SmallSet<Instruction *, 32> CrossInsts;

private:
  // GotoJoinEVs: container for goto/join Extract Value (EV) info. Also
  // allowes to remove duplication of EVs. Performs it in construction
  // and moves EVs right after goto/join. Hoisting can be performed
  // again with hoistEVs method. For instance, it is used on join
  // hoisting to save correct EM liveranges.
  class GotoJoinEVs {
  private:
    enum ValPos {
      EMPos = 0,
      RMPos = 1,
      JoinCondPos = 1,
      GotoCondPos = 2,
      PosNum
    };

    bool testPosCorrectness(const unsigned Index) const;

    ExtractValueInst *EVs[PosNum] = {nullptr, nullptr, nullptr};
    bool IsGoto;
    Value *GotoJoin;

    std::string getNameForMissingEV(unsigned EVIndex) const;
    void CanonicalizePHIs();
    void CollectEVs();

  public:
    GotoJoinEVs(Value *GJ = nullptr);
    ExtractValueInst *getEMEV() const;
    ExtractValueInst *getRMEV() const;
    ExtractValueInst *getCondEV() const;
    Value *getGotoJoin() const;
    Instruction *getSplitPoint() const;
    void setCondEV(ExtractValueInst *CondEV);
    bool isGoto() const;
    bool isJoin() const;
    void hoistEVs() const;
  };

  class JoinPointOptData {
  private:
    BasicBlock *FalsePred;
    Instruction *EM;

  public:
    JoinPointOptData(BasicBlock *FalsePred = nullptr, Instruction *EM = nullptr)
        : FalsePred(FalsePred), EM(EM) {}
    BasicBlock *getTruePred() const { return EM->getParent(); }
    BasicBlock *getFalsePred() const { return FalsePred; }
    Instruction *getRealEM() const { return EM; }
  };

  SetVector<SimpleValue> EMValsStack;
  MapVector<Value *, GotoJoinEVs> GotoJoinEVsMap;
  MapVector<BasicBlock *, JoinPointOptData> BlocksToOptimize;
  std::map<CallInst *, CallInst *> GotoJoinMap;
  std::map<Value *, Value *> EMProducers;
  std::map<Value *, Value *> LoweredEMValsMap;

protected:
  GenXSimdCFConformance()
      : M(0), FG(0), FGA(0), DTWrapper(0), Liveness(0), lowerSimdCF(false) {}
  void gatherEMVals();
  void gatherRMVals();
  void removeFromEMRMVals(Value *V);
  void moveCodeInGotoBlocks(bool hoistGotoUsers = false);
  void moveCodeInJoinBlocks();
  void ensureConformance();
  void lowerAllSimdCF();
  void canonicalizeEM();
  void splitGotoJoinBlocks();
  void lowerUnsuitableGetEMs();
  void optimizeRestoredSIMDCF();
  void clear() {
    DTs.clear();
    EMVals.clear();
    RMVals.clear();
    GotoJoinMap.clear();
    GotoJoinEVsMap.clear();
    EMProducers.clear();
    LoweredEMValsMap.clear();
    BlocksToOptimize.clear();
    CrossInsts.clear();
  }
  DominatorTree *getDomTree(Function *F);

private:
  bool isLatePass() const { return FG != nullptr; }
  void emptyBranchingJoinBlocksInFunc(Function *F);
  void emptyBranchingJoinBlock(CallInst *Join);
  bool hoistJoin(CallInst *Join);
  bool checkEMVal(SimpleValue EMVal);
  bool checkGoto(SimpleValue EMVal);
  bool checkJoin(SimpleValue EMVal);
  bool checkGotoJoin(SimpleValue EMVal);
  void removeBadEMVal(SimpleValue EMVal);
  void pushValues(Value *V);
  bool getConnectedVals(SimpleValue Val, vc::RegCategory Cat,
                        bool IncludeOptional, CallInst *OkJoin,
                        SmallVectorImpl<SimpleValue> *ConnectedVals,
                        bool LowerBadUsers = false);
  void checkEMInterference();
  void checkInterference(SetVector<SimpleValue> *Vals,
                         SetVector<Value *> *BadDefs, Instruction *ConstStop);
  bool hoistGotoUser(Instruction *Inst, CallInst *Goto, unsigned operandNo);
  void gatherGotoJoinEMVals(bool IncludeIncoming = true);
  void handleEVs();
  void resolveBitCastChains();
  Value *eliminateBitCastPreds(Value *Val, std::set<Value *> &DeadInst,
                               std::set<Value *> &Visited);
  Value *getEMProducer(Value *Inst, std::set<Value *> &Visited,
                       bool BitCastAllowed = false);
  void handleCondValue(Value *GotoJoin);
  void handleNoCondEVCase(GotoJoinEVs &GotoJoinData);
  void handleOptimizedBranchCase(GotoJoinEVs &GotoJoinData,
                                 BasicBlock *&TrueSucc, BasicBlock *&FalseSucc);
  void handleExistingBranchCase(GotoJoinEVs &GotoJoinData,
                                BasicBlock *&TrueSucc, BasicBlock *&FalseSucc,
                                BranchInst *ExistingBranch);
  void addNewPhisIncomings(BasicBlock *BranchingBlock, BasicBlock *TrueSucc,
                           BasicBlock *FalseSucc);
  void collectCondEVUsers(ExtractValueInst *CondEV,
                          std::vector<Value *> &BadUsers,
                          BranchInst *&CorrectUser);
  void updateBadCondEVUsers(GotoJoinEVs &GotoJoinData,
                            std::vector<Value *> &BadUsers,
                            BasicBlock *TrueSucc, BasicBlock *FalseSucc);
  Value *findGotoJoinVal(vc::RegCategory Cat, BasicBlock *Loc,
                         Instruction *CondEV, BasicBlockEdge &TrueEdge,
                         BasicBlockEdge &FalseEdge, Value *TrueVal,
                         Value *FalseVal,
                         std::map<BasicBlock *, Value *> &foundVals);
  bool canUseLoweredEM(Instruction *Val);
  bool canUseRealEM(Instruction *Inst, unsigned opNo);
  void replaceUseWithLoweredEM(Instruction *Val, unsigned opNo,
                               SetVector<Value *> &ToRemove);
  Value *findLoweredEMValue(Value *Val);
  Value *buildLoweringViaGetEM(Value *Val, Instruction *InsertBefore);
  Value *getGetEMLoweredValue(Value *Val, Instruction *InsertBefore);
  Value *lowerEVIUse(ExtractValueInst *EVI, Instruction *User,
                     BasicBlock *PhiPredBlock = nullptr);
  Value *lowerPHIUse(PHINode *PN, SetVector<Value *> &ToRemove);
  Value *lowerArgumentUse(Argument *Arg);
  Value *insertCond(Value *OldVal, Value *NewVal, const Twine &Name,
                    Instruction *InsertBefore, const DebugLoc &DL);
  Value *truncateCond(Value *In, Type *Ty, const Twine &Name,
                      Instruction *InsertBefore, const DebugLoc &DL);
  void lowerGoto(CallInst *Goto);
  void lowerJoin(CallInst *Join);
  void replaceGotoJoinUses(CallInst *GotoJoin, ArrayRef<Value *> Vals);
  void optimizeLinearization(BasicBlock *BB, JoinPointOptData &JPData);
  bool isActualStoredEM(Instruction *Inst, JoinPointOptData &JPData);
  bool canBeMovedUnderSIMDCF(Value *Val, BasicBlock *CurrBB,
                             JoinPointOptData &JPData,
                             std::set<Instruction *> &Visited);
  bool isSelectConditionCondEV(SelectInst *Sel, JoinPointOptData &JPData);
  void replaceGetEMUse(Instruction *Inst, JoinPointOptData &JPData);
};

// GenX early SIMD control flow conformance pass
class GenXEarlySimdCFConformance : public GenXSimdCFConformance,
                                   public ModulePass {
public:
  static char ID;
  explicit GenXEarlySimdCFConformance() : ModulePass(ID) {}
  StringRef getPassName() const override {
    return "GenX early SIMD control flow conformance";
  }
  void getAnalysisUsage(AnalysisUsage &AU) const override {
    ModulePass::getAnalysisUsage(AU);
  }
  bool runOnModule(Module &M) override;
};

// GenX late SIMD control flow conformance pass
class GenXLateSimdCFConformance : public FGPassImplInterface,
                                  public IDMixin<GenXLateSimdCFConformance>,
                                  public GenXSimdCFConformance {
public:
  explicit GenXLateSimdCFConformance() {}
  static StringRef getPassName() {
    return "GenX late SIMD control flow conformance";
  }
  static void getAnalysisUsage(AnalysisUsage &AU) {
    AU.addRequired<DominatorTreeGroupWrapperPass>();
    AU.addRequired<GenXLiveness>();
    AU.addRequired<TargetPassConfig>();
    AU.addPreserved<GenXModule>();
    AU.addPreserved<GenXLiveness>();
    AU.addPreserved<FunctionGroupAnalysis>();
  }
  void releaseMemory() override { clear(); }
  bool runOnFunctionGroup(FunctionGroup &FG) override;

private:
  void setCategories();
  void modifyEMUses(Value *EM);
  void hoistExtractEMInstructions();
};

/***********************************************************************
 * Local function for testing one assertion statement.
 * It returns true if intrinsic is GOTO or JOIN as expected.
 */
bool testIsGotoJoin(const Value *const GotoJoin) {
  bool Result = false;
  IGC_ASSERT(GotoJoin);
  const auto ID = vc::getAnyIntrinsicID(GotoJoin);
  switch (ID) {
  case GenXIntrinsic::genx_simdcf_goto:
  case GenXIntrinsic::genx_simdcf_join:
    Result = true;
    break;
  default:
    IGC_ASSERT(0);
    Result = false;
    break;
  }
  return Result;
}

/***********************************************************************
 * Local function for testing one assertion statement.
 * It returns true if intrinsic is JOIN as expected.
 */
bool testIsJoin(const Value *const GotoJoin) {
  bool Result = false;
  IGC_ASSERT(GotoJoin);
  const auto ID = vc::getAnyIntrinsicID(GotoJoin);
  switch (ID) {
  case GenXIntrinsic::genx_simdcf_join:
    Result = true;
    break;
  default:
    IGC_ASSERT(0);
    Result = false;
    break;
  }
  return Result;
}

/***********************************************************************
 * Local function for testing one assertion statement.
 * It returns true if all is ok.
 */
bool testIsValidEMUse(const Value *const User, const Value::use_iterator &ui) {
  bool Result = false;
  IGC_ASSERT(User);
  const unsigned int ID = vc::getAnyIntrinsicID(User);
  switch (ID) {
  case GenXIntrinsic::genx_rdpredregion:
  case GenXIntrinsic::genx_simdcf_goto:
  case GenXIntrinsic::genx_simdcf_join:
  case GenXIntrinsic::genx_simdcf_get_em:
  case GenXIntrinsic::genx_wrpredpredregion:
    Result = true;
    break;
  case GenXIntrinsic::genx_wrregioni:
  case GenXIntrinsic::genx_wrregionf:
    Result =
        (ui->getOperandNo() == GenXIntrinsic::GenXRegion::PredicateOperandNum);
    IGC_ASSERT(Result);
    break;
  case GenXIntrinsic::not_any_intrinsic:
    Result = (isa<PHINode>(User) || isa<InsertValueInst>(User) ||
              isa<CallInst>(User) || isa<ReturnInst>(User) ||
              isa<ShuffleVectorInst>(User));
    IGC_ASSERT_MESSAGE(Result, "unexpected use of EM");
    break;
  default:
    Result =
        (isa<ReturnInst>(User) || isa<InsertValueInst>(User) ||
         isa<ExtractValueInst>(User) ||
         !cast<CallInst>(User)->getCalledFunction()->doesNotAccessMemory());
    IGC_ASSERT_MESSAGE(Result, "unexpected ALU intrinsic use of EM");
    break;
  }
  return Result;
}

/***********************************************************************
 * Local function for testing one assertion statement.
 * It returns true if Pos is correct.
 */
bool GenXSimdCFConformance::GotoJoinEVs::testPosCorrectness(
    const unsigned Index) const {
  bool Result = false;
  switch (Index) {
  case EMPos:
  case RMPos: // same as JoinCondPos
    Result = true;
    break;
  case GotoCondPos:
    Result = IsGoto;
    IGC_ASSERT_MESSAGE(Result, "Bad index in ExtractValue for goto/join!");
    break;
  default:
    Result = false;
    IGC_ASSERT_MESSAGE(0, "Bad index in ExtractValue for goto/join!");
    break;
  }
  return Result;
}

} // end anonymous namespace

char GenXEarlySimdCFConformance::ID = 0;
namespace llvm {
void initializeGenXEarlySimdCFConformancePass(PassRegistry &);
}
INITIALIZE_PASS_BEGIN(GenXEarlySimdCFConformance, "GenXEarlySimdCFConformance",
                      "GenXEarlySimdCFConformance", false, false)
INITIALIZE_PASS_END(GenXEarlySimdCFConformance, "GenXEarlySimdCFConformance",
                    "GenXEarlySimdCFConformance", false, false)

ModulePass *llvm::createGenXEarlySimdCFConformancePass() {
  initializeGenXEarlySimdCFConformancePass(*PassRegistry::getPassRegistry());
  return new GenXEarlySimdCFConformance();
}

namespace llvm {
void initializeGenXLateSimdCFConformanceWrapperPass(PassRegistry &);
using GenXLateSimdCFConformanceWrapper =
    FunctionGroupWrapperPass<GenXLateSimdCFConformance>;
} // namespace llvm
INITIALIZE_PASS_BEGIN(GenXLateSimdCFConformanceWrapper,
                      "GenXLateSimdCFConformanceWrapper",
                      "GenXLateSimdCFConformanceWrapper", false, false)
INITIALIZE_PASS_DEPENDENCY(FunctionGroupAnalysis)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeGroupWrapperPassWrapper)
INITIALIZE_PASS_DEPENDENCY(GenXLivenessWrapper)
INITIALIZE_PASS_DEPENDENCY(GenXModule)
INITIALIZE_PASS_END(GenXLateSimdCFConformanceWrapper,
                    "GenXLateSimdCFConformanceWrapper",
                    "GenXLateSimdCFConformanceWrapper", false, false)

ModulePass *llvm::createGenXLateSimdCFConformanceWrapperPass() {
  initializeGenXLateSimdCFConformanceWrapperPass(
      *PassRegistry::getPassRegistry());
  return new GenXLateSimdCFConformanceWrapper();
}

static bool hasStackCall(const Module &M) {
  return std::any_of(M.begin(), M.end(),
                     [](const auto &F) { return vc::requiresStackCall(&F); });
}

/***********************************************************************
 * runOnModule : run the early SIMD control flow conformance pass for this
 *  module
 */
bool GenXEarlySimdCFConformance::runOnModule(Module &ArgM) {
  LLVM_DEBUG(dbgs() << "Early SIMD CF Conformance starts\n");

  Modified = false;
  M = &ArgM;
  FG = nullptr;
  FGA = nullptr;
  DTWrapper = nullptr;
  lowerSimdCF = hasStackCall(ArgM);
  // Perform actions to create correct DF for EM
  canonicalizeEM();
  // Gather the EM values, both from goto/join and phi nodes.
  gatherEMVals();
  // Gather the RM values from gotos and phi nodes.
  gatherRMVals();
  // Hoist instructions that does not depend on Goto's result.
  // It is needed to perform correct split.
  moveCodeInGotoBlocks();
  // Split Goto/Join blocks to recreate actual SIMD CF
  splitGotoJoinBlocks();
  // Handle instructions that depend on Goto's result
  moveCodeInGotoBlocks(true);
  // Handle Joins to create correct SIMD CF structure
  moveCodeInJoinBlocks();
  // TODO: currently all SIMD CF is lowered if there is
  // an unmask construction in module. It is very suboptimal.
  if (lowerSimdCF)
    lowerAllSimdCF();
  else {
    // Repeatedly check the code for conformance and lower non-conformant gotos
    // and joins until the code stabilizes.
    ensureConformance();
    optimizeRestoredSIMDCF();
  }
  // Perform check for genx_simdcf_get_em intrinsics and remove redundant ones.
  lowerUnsuitableGetEMs();
  clear();

  LLVM_DEBUG(dbgs() << "Early SIMD CF Conformance ends\n");

  return Modified;
}

/***********************************************************************
 * runOnFunctionGroup : run the late SIMD control flow conformance pass for this
 * FunctionGroup
 */
bool GenXLateSimdCFConformance::runOnFunctionGroup(FunctionGroup &ArgFG) {
  LLVM_DEBUG(dbgs() << "Late SIMD CF Conformance starts\n");

  Modified = false;
  FG = &ArgFG;
  M = FG->getModule();
  // Get analyses that we use and/or modify.
  FGA = &getAnalysis<FunctionGroupAnalysis>();
  DTWrapper = &getAnalysis<DominatorTreeGroupWrapperPass>();
  Liveness = &getAnalysis<GenXLiveness>();
  // Gather the EM values, both from goto/join and phi nodes.
  gatherEMVals();
  // Gather the RM values from gotos and phi nodes.
  gatherRMVals();
  // hoist EM extraction instructions and delete duplicates in simple cases
  hoistExtractEMInstructions();
  // Move code in goto and join blocks as necessary.
  moveCodeInGotoBlocks();
  moveCodeInJoinBlocks();
  // Check the code for conformance. In this late pass, we do not expect to
  // find non-conformance.
  ensureConformance();
  // For remaining unlowered gotos and joins (the ones that will become SIMD
  // control flow instructions), mark the webs of EM and RM values as
  // category EM or RM respectively. For EM, this also modifies uses as needed.
  setCategories();
  clear();

  LLVM_DEBUG(dbgs() << "Late SIMD CF Conformance ends\n");

  return Modified;
}

/***********************************************************************
 * gatherGotoJoinEMVals : gather the EM values for gotos/joins only
 *
 * IncludeIncoming is used for adding goto/join def to EMVals
 */
void GenXSimdCFConformance::gatherGotoJoinEMVals(bool IncludeIncoming) {
  // We find gotos and joins by scanning all uses of the intrinsics and (in the
  // case of the late pass) ignoring ones not in this function group, rather
  // than scanning the whole IR.
  Type *I1Ty = Type::getInt1Ty(M->getContext());
  for (auto IID :
       {GenXIntrinsic::genx_simdcf_goto, GenXIntrinsic::genx_simdcf_join}) {
    Type *EMTy = IGCLLVM::FixedVectorType::get(I1Ty, 32);
    for (unsigned Width = 1; Width <= 32; Width <<= 1) {
      Type *Tys[] = {EMTy, IGCLLVM::FixedVectorType::get(I1Ty, Width)};
      auto *GotoJoinFunc = vc::getAnyDeclaration(M, IID, Tys);
      for (auto ui = GotoJoinFunc->use_begin(), ue = GotoJoinFunc->use_end();
           ui != ue; ++ui) {
        auto *GotoJoin = dyn_cast<CallInst>(ui->getUser());
        if (!GotoJoin)
          continue;
        if (FG && (FGA->getGroup(GotoJoin->getFunction()) != FG ||
                   ui->getOperandNo() != IGCLLVM::getNumArgOperands(GotoJoin)))
          continue;
        // We have a goto/join (in our function group in the case of the late
        // pass).  Add the EM value (struct index 0) to EMVals.
        EMVals.insert(SimpleValue(GotoJoin, 0));
        // Also add its EM input to EMVals, if not a constant.
        if (IncludeIncoming && !isa<Constant>(GotoJoin->getOperand(0)))
          EMVals.insert(SimpleValue(GotoJoin->getOperand(0), 0));
      }
      if (GotoJoinFunc->use_empty())
        GotoJoinFunc->eraseFromParent();
    }
  }
}

/***********************************************************************
 * gatherEMVals : gather the EM values, including phi nodes
 */
void GenXSimdCFConformance::gatherEMVals() {
  LLVM_DEBUG(dbgs() << "gatherEMVals: start\n");
  // Collect gotos/joins and their defs
  gatherGotoJoinEMVals(true);

  Type *I1Ty = Type::getInt1Ty(M->getContext());
  Type *EMTy = IGCLLVM::FixedVectorType::get(I1Ty, 32);
  Type *Tys[] = {EMTy};
  auto *SavemaskFunc = vc::getAnyDeclaration(
      M, GenXIntrinsic::genx_simdcf_savemask, Tys);
  for (auto ui = SavemaskFunc->use_begin(), ue = SavemaskFunc->use_end();
       ui != ue; ++ui) {
    auto *Savemask = dyn_cast<CallInst>(ui->getUser());
    if (!Savemask)
      continue;
    if (FG && (FGA->getGroup(Savemask->getFunction()) != FG ||
               ui->getOperandNo() != IGCLLVM::getNumArgOperands(Savemask)))
      continue;
    lowerSimdCF = true;
    // Add its EM input to EMVals, if not a constant.
    if (!isa<Constant>(Savemask->getOperand(0)))
      EMVals.insert(SimpleValue(Savemask->getOperand(0), 0));
  }
  if (SavemaskFunc->use_empty())
    SavemaskFunc->eraseFromParent();

  auto *UnmaskFunc = vc::getAnyDeclaration(
      M, GenXIntrinsic::genx_simdcf_unmask, Tys);
  for (auto ui = UnmaskFunc->use_begin(), ue = UnmaskFunc->use_end(); ui != ue;
       ++ui) {
    auto *Unmask = dyn_cast<CallInst>(ui->getUser());
    if (!Unmask)
      continue;
    if (FG && (FGA->getGroup(Unmask->getFunction()) != FG ||
               ui->getOperandNo() != IGCLLVM::getNumArgOperands(Unmask)))
      continue;
    lowerSimdCF = true;
    // We have a unmask (in our function group in the case of the late
    EMVals.insert(SimpleValue(Unmask));
  }
  if (UnmaskFunc->use_empty())
    UnmaskFunc->eraseFromParent();

  auto *RemaskFunc = vc::getAnyDeclaration(
      M, GenXIntrinsic::genx_simdcf_remask, Tys);
  for (auto ui = RemaskFunc->use_begin(), ue = RemaskFunc->use_end(); ui != ue;
       ++ui) {
    auto *Remask = dyn_cast<CallInst>(ui->getUser());
    if (!Remask)
      continue;
    if (FG && (FGA->getGroup(Remask->getFunction()) != FG ||
               ui->getOperandNo() != IGCLLVM::getNumArgOperands(Remask)))
      continue;
    lowerSimdCF = true;
    // We have a remask (in our function group in the case of the late
    // pass).  Add the EM value (struct index 0) to EMVals.
    EMVals.insert(SimpleValue(Remask));
    // Also add its EM input to EMVals, if not a constant.
    if (!isa<Constant>(Remask->getOperand(0)))
      EMVals.insert(SimpleValue(Remask->getOperand(0)));
  }
  if (RemaskFunc->use_empty())
    RemaskFunc->eraseFromParent();

  // delete useless cm_unmask_begin and cm_unmask_end
  auto *UnmaskEF =
      vc::getAnyDeclaration(M, GenXIntrinsic::genx_unmask_end);
  for (auto ui = UnmaskEF->use_begin(), ue = UnmaskEF->use_end(); ui != ue;) {
    auto u = ui->getUser();
    ++ui;
    if (auto *UnmaskEnd = dyn_cast<CallInst>(u))
      UnmaskEnd->eraseFromParent();
  }
  if (UnmaskEF->use_empty())
    UnmaskEF->eraseFromParent();

  auto *UnmaskBF =
      vc::getAnyDeclaration(M, GenXIntrinsic::genx_unmask_begin);
  for (auto ui = UnmaskBF->use_begin(), ue = UnmaskBF->use_end(); ui != ue;) {
    auto u = ui->getUser();
    ++ui;
    if (auto *UnmaskBeg = dyn_cast<CallInst>(u))
      UnmaskBeg->eraseFromParent();
  }
  if (UnmaskBF->use_empty())
    UnmaskBF->eraseFromParent();

  // Find related phi nodes and values related by insertvalue/extractvalue/call
  // using EMVal as a worklist.
  for (unsigned i = 0; i != EMVals.size(); ++i) {
    SimpleValue EMVal = EMVals[i];
    // For this EM value, get the connected values.
    SmallVector<SimpleValue, 8> ConnectedVals;
    getConnectedVals(EMVal, vc::RegCategory::EM, /*IncludeOptional=*/true,
                     /*OkJoin=*/nullptr, &ConnectedVals);
    // Add the connected values to EMVals.
    for (auto j = ConnectedVals.begin(), je = ConnectedVals.end(); j != je; ++j)
      if (!isa<Constant>(j->getValue()))
        EMVals.insert(*j);
  }
  CrossInsts.clear();
}

/***********************************************************************
 * gatherRMVals : gather RM values for each join
 */
void GenXSimdCFConformance::gatherRMVals() {
  LLVM_DEBUG(dbgs() << "gatherRMVals: start\n");
  for (auto ji = EMVals.begin(), je = EMVals.end(); ji != je; ++ji) {
    auto EMVal = *ji;
    auto IID = vc::getAnyIntrinsicID(EMVal.getValue());
    if (IID != GenXIntrinsic::genx_simdcf_join)
      continue;
    auto Join = cast<CallInst>(EMVal.getValue());
    // We have a join. Gather its web of RM values.
    auto RMValsEntry = &RMVals[Join];
    if (!isa<Constant>(Join->getOperand(1)))
      RMValsEntry->insert(Join->getOperand(1));
    for (unsigned rvi = 0; rvi != RMValsEntry->size(); ++rvi) {
      SimpleValue RM = (*RMValsEntry)[rvi];
      // RM is a value in this join's RM web. Get other values related by phi
      // nodes and extractvalues and gotos.
      SmallVector<SimpleValue, 8> ConnectedVals;
      getConnectedVals(RM, vc::RegCategory::RM, /*IncludeOptional=*/true, Join,
                       &ConnectedVals);
      for (auto j = ConnectedVals.begin(), je = ConnectedVals.end(); j != je;
           ++j)
        if (!isa<Constant>(j->getValue()))
          RMValsEntry->insert(*j);
    }
  }
  CrossInsts.clear();
}

/***********************************************************************
 * findGotoJoinVal : find goto/join that should be applied at the
 * specified location
 *
 * It uses dominator tree to find the value needed. Category is used to
 * set proper name for instruction and doesn't affect reg category
 * that is used in reg alloc. It only shows what we are dealing with.
 */
Value *GenXSimdCFConformance::findGotoJoinVal(
    vc::RegCategory Cat, BasicBlock *Loc, Instruction *GotoJoinEV,
    BasicBlockEdge &TrueEdge, BasicBlockEdge &FalseEdge, Value *TrueVal,
    Value *FalseVal, std::map<BasicBlock *, Value *> &foundVals) {
  IGC_ASSERT(TrueEdge.getStart() == FalseEdge.getStart());
  IGC_ASSERT(TrueEdge.getEnd() != FalseEdge.getEnd());
  IGC_ASSERT_MESSAGE(
      (Cat == vc::RegCategory::EM || Cat == vc::RegCategory::Predicate),
      "Handling only EM and Cond!");

  LLVM_DEBUG(dbgs() << "findGotoJoinVal: Entering " << Loc->getName() << "\n");

  // Check if value was found before
  auto ResIt = foundVals.find(Loc);
  if (ResIt != foundVals.end())
    return ResIt->second;

  DominatorTree *DomTree = getDomTree(Loc->getParent());
  if (DomTree->dominates(TrueEdge, Loc)) {
    LLVM_DEBUG(dbgs() << "findGotoJoinVal: Dominated by True Edge\n");
    foundVals[Loc] = TrueVal;
    return TrueVal;
  }
  if (DomTree->dominates(FalseEdge, Loc)) {
    LLVM_DEBUG(dbgs() << "findGotoJoinVal: Dominated by False Edge\n");
    foundVals[Loc] = FalseVal;
    return FalseVal;
  }

  // Need to create phi somewhere.
  // Try to get IDom. If we found CondEV's BB then we are
  // already in the final block
  auto Node = DomTree->getNode(Loc);
  auto IDom = Node->getIDom();
  IGC_ASSERT_MESSAGE(IDom, "No IDom found!");
  BasicBlock *PhiLoc = nullptr;
  PhiLoc = IDom->getBlock();
  if (IDom->getBlock() == GotoJoinEV->getParent())
    PhiLoc = Loc;

  std::string Name = (Cat == vc::RegCategory::EM) ? "ExecMaskEV" : "CondEV";
  auto PHI = PHINode::Create(GotoJoinEV->getType(), pred_size(PhiLoc), Name,
                             &PhiLoc->front());
  foundVals[PhiLoc] = PHI;
  if (PhiLoc != Loc)
    foundVals[Loc] = PHI;

  for (auto pi = pred_begin(PhiLoc), pe = pred_end(PhiLoc); pi != pe; ++pi) {
    BasicBlock *Pred = *pi;
    Value *Val = nullptr;

    // Don't check dominators for def since we are looking for
    // edges that are located after it
    if (Pred == TrueEdge.getStart()) {
      // This happens when we enter def block from join block
      // w/o any intermediate blocks (actually we expect this
      // situation to happen always). Check that we came through
      // true branch.
      if (Pred->getTerminator()->getSuccessor(0) == PhiLoc) {
        Val = TrueVal;
        LLVM_DEBUG(dbgs() << "findGotoJoinVal: Usual case\n");
      } else {
        // This situation shouldn't happen, but if so, we can handle it
        Val = FalseVal;
        LLVM_DEBUG(dbgs() << "findGotoJoinVal: Strange case\n");
      }
    } else {
      Val = findGotoJoinVal(Cat, Pred, GotoJoinEV, TrueEdge, FalseEdge, TrueVal,
                            FalseVal, foundVals);
    }

    PHI->addIncoming(Val, Pred);
  }

  LLVM_DEBUG(dbgs() << "findGotoJoinVal: Built PHI for EV:" << *PHI << "\n");
  return PHI;
}

/**
 * collectCondEVUsers : gather Cond EV users
 *
 * Bad users: they should not use cond EV.
 * Correct user: conditional branch CondEV's BB. This
 * is the only possible conformant user.
 */
void GenXSimdCFConformance::collectCondEVUsers(ExtractValueInst *CondEV,
                                               std::vector<Value *> &BadUsers,
                                               BranchInst *&CorrectUser) {
  // Bad users: they should not use cond EV. Make a real value for them
  // Correct user: conditional branch in this BB
  for (auto ui = CondEV->use_begin(), ue = CondEV->use_end(); ui != ue; ++ui) {
    BranchInst *Br = dyn_cast<BranchInst>(ui->getUser());

    // If cond EV is used by wrong branch, we can simply consider
    // it as non-baled conditional branch
    if (!Br || Br->getParent() != CondEV->getParent()) {
      LLVM_DEBUG(dbgs() << "collectCondEVUsers: Found bad CondEV user:\n"
                        << *ui->getUser() << "\n");
      BadUsers.push_back(ui->getUser());
    } else if (Br) {
      IGC_ASSERT_MESSAGE(!CorrectUser, "Found another correct user!");
      LLVM_DEBUG(dbgs() << "collectCondEVUsers: Found correct user:\n"
                        << *Br << "\n");
      CorrectUser = Br;
    }
  }
}

/**
 * updateBadCondEVUsers : update bad cond EV users
 *
 * It replaces cond EV uses by values that can be
 * obtained on true and false pathes
 */
void GenXSimdCFConformance::updateBadCondEVUsers(
    GenXSimdCFConformance::GotoJoinEVs &GotoJoinData,
    std::vector<Value *> &BadUsers, BasicBlock *TrueSucc,
    BasicBlock *FalseSucc) {
  ExtractValueInst *CondEV = GotoJoinData.getCondEV();
  IGC_ASSERT_MESSAGE(CondEV, "Expected valid CondEV!");

  BasicBlockEdge TrueEdge(CondEV->getParent(), TrueSucc);
  BasicBlockEdge FalseEdge(CondEV->getParent(), FalseSucc);
  Constant *TrueVal = Constant::getAllOnesValue(CondEV->getType());
  Constant *FalseVal = Constant::getNullValue(CondEV->getType());

  // Update users
  std::map<BasicBlock *, Value *> FoundCondEV;
  for (auto bi = BadUsers.begin(), be = BadUsers.end(); bi != be; ++bi) {
    Instruction *User = cast<Instruction>(*bi);
    for (unsigned idx = 0, opNum = User->getNumOperands(); idx < opNum; ++idx) {
      if (CondEV != User->getOperand(idx))
        continue;

      User->setOperand(idx, findGotoJoinVal(vc::RegCategory::Predicate,
                                            User->getParent(), CondEV, TrueEdge,
                                            FalseEdge, TrueVal, FalseVal,
                                            FoundCondEV));
    }
  }
}

/**
 * addNewPhisIncomings : add new incomings after split
 *
 * It is needed to update phis after turning unconditional
 * branch into conditional one. True successor is assumed to
 * be correct join point, but the only thing we know here
 * is that FalseSucc branches to TrueSucc. Branching Block's
 * successors are TrueSucc and FalseSucc.
 */
void GenXSimdCFConformance::addNewPhisIncomings(BasicBlock *BranchingBlock,
                                                BasicBlock *TrueSucc,
                                                BasicBlock *FalseSucc) {
  for (auto Inst = &TrueSucc->front(); auto *PN = dyn_cast<PHINode>(Inst);
       Inst = Inst->getNextNode()) {
    Value *CurrVal = PN->getIncomingValueForBlock(BranchingBlock);
    PN->addIncoming(CurrVal, FalseSucc);
  }
}

/**
 * handleNoCondEVCase : handle case when there is no
 * CondEV for goto/join.
 *
 * It performs split for goto in order to prepare
 * goto for possible EM lower. Goto is branch itself
 * so such transformation doesn't introduce any
 * overhead in case of conformant SIMD CF.
 *
 * TODO: this transformation can be reverted in case of
 * non-conformant SIMD CF if necessary data was saved.
 * It is not done now because no non-conformant cases
 * were found so far.
 */
void GenXSimdCFConformance::handleNoCondEVCase(
    GenXSimdCFConformance::GotoJoinEVs &GotoJoinData) {
  IGC_ASSERT_MESSAGE(!GotoJoinData.getCondEV(), "Unexpected CondEV!");

  // Handle only goto
  if (GotoJoinData.isJoin())
    return;
  auto SplitPoint = GotoJoinData.getSplitPoint();

  // Skip possible goto users
  for (;; SplitPoint = SplitPoint->getNextNode()) {
    if (SplitPoint->isTerminator())
      break;
    if (auto *CI = dyn_cast<CallInst>(SplitPoint)) {
      // We need to perform split before next goto/join to save their
      // conformance
      auto IID = vc::getAnyIntrinsicID(CI);
      if (IID == GenXIntrinsic::genx_simdcf_goto ||
          IID == GenXIntrinsic::genx_simdcf_join)
        break;
    }
  }

  Value *GotoJoin = GotoJoinData.getGotoJoin();
  ExtractValueInst *CondEV = ExtractValueInst::Create(
      GotoJoin, {2}, "missing_extractcond", SplitPoint);
  GotoJoinData.setCondEV(CondEV);

  if (auto *Br = dyn_cast<BranchInst>(SplitPoint)) {
    if (Br->isConditional()) {
      // This CF is non-conformant: there should be a join point
      // before this branch, but it wasn't found. Skip it.
      return;
    }
    // We are turning unconditional branch into conditional one
    BasicBlock *Split =
        BasicBlock::Create(CondEV->getContext(), "goto_split",
                           CondEV->getFunction(), Br->getSuccessor(0));
    BranchInst::Create(Br->getSuccessor(0), Split);
    BranchInst::Create(Br->getSuccessor(0), Split, CondEV, Br);

    // Update phis in TrueSucc
    addNewPhisIncomings(CondEV->getParent(), Br->getSuccessor(0), Split);

    Br->eraseFromParent();
  } else {
    // Split point is in the middle of BB. We assume that there is a join point
    // after it.
    // TODO: consider adding this check. No such cases were found now.
    BasicBlock *TrueSucc =
        CondEV->getParent()->splitBasicBlock(SplitPoint, "cond_ev_true_split");
    CondEV->getParent()->getTerminator()->eraseFromParent();
    LLVM_DEBUG(dbgs() << "Created " << TrueSucc->getName()
                      << " to handle missing conditional branch\n");

    // False block: need to create new one
    BasicBlock *FalseSucc =
        BasicBlock::Create(CondEV->getContext(), "cond_ev_false_split",
                           CondEV->getFunction(), TrueSucc);
    LLVM_DEBUG(dbgs() << "Created " << FalseSucc->getName()
                      << " to handle missing conditional branch\n");

    // Link blocks
    BranchInst::Create(TrueSucc, FalseSucc, CondEV, CondEV->getParent());
    BranchInst::Create(TrueSucc, FalseSucc);
  }

  // CFG changed: update DomTree.
  // TODO: there must be workaround to do it in a more optimal way
  DominatorTree *domTree = getDomTree(CondEV->getFunction());
  domTree->recalculate(*CondEV->getFunction());
}

/**
 * handleOptimizedBranchCase : perform split for optimized branch case
 *
 * TODO: this make sence only in case when the true successor is a
 * join block, otherwise it will introduce more overhead due to
 * goto/join lowering. Also there should be check that this
 * join really uses current EM and RM. This issue is resolved
 * at the end of this pass in EM/RM liveness analysis and cannot
 * be done easy at this point. For now assume that everything OK
 * with it here.
 *
 * TODO: It is possible to undo this transformation if we store
 * all necessery data here. Currently it is not done:
 * no non-conformant cases found for now.
 *
 * Due to earlier transformations we can split BB after the last
 * goto/join EV. It will solve issue with join located in this
 * basic block. Code movements to sink goto/join will be performed
 * further, we don't need to focus on it here.
 */
void GenXSimdCFConformance::handleOptimizedBranchCase(
    GenXSimdCFConformance::GotoJoinEVs &GotoJoinData, BasicBlock *&TrueSucc,
    BasicBlock *&FalseSucc) {
  // Look for the first non-goto/join user inst
  auto SplitPoint = GotoJoinData.getSplitPoint();

  ExtractValueInst *CondEV = GotoJoinData.getCondEV();
  IGC_ASSERT_MESSAGE(CondEV, "Expected valid CondEV!");

  // Split: this is true succ which is join point (at least we assume that)
  TrueSucc =
      CondEV->getParent()->splitBasicBlock(SplitPoint, "cond_ev_true_split");
  LLVM_DEBUG(dbgs() << "Created " << TrueSucc->getName()
                    << " to handle missing conditional branch\n");
  CondEV->getParent()->getTerminator()->eraseFromParent();
  // False block: need to create new one
  FalseSucc = BasicBlock::Create(CondEV->getContext(), "cond_ev_false_split",
                                 CondEV->getFunction(), TrueSucc);
  LLVM_DEBUG(dbgs() << "Created " << FalseSucc->getName()
                    << " to handle missing conditional branch\n");
  // Link blocks
  BranchInst::Create(TrueSucc, FalseSucc, CondEV, CondEV->getParent());
  BranchInst::Create(TrueSucc, FalseSucc);

  // Store info for possible optimization
  BlocksToOptimize[TrueSucc] =
      JoinPointOptData(FalseSucc, GotoJoinData.getEMEV());

  // CFG changed: update DomTree.
  // TODO: there must be workaround to do it in a more optimal way
  DominatorTree *domTree = getDomTree(CondEV->getFunction());
  domTree->recalculate(*CondEV->getFunction());
}

/**
 * handleExistingBranchCase : perform actions needed to
 * handle case when branch wasn't optimized
 *
 * It stores True/False successors and adds new BB
 * in case when both successors are the same BB.
 */
void GenXSimdCFConformance::handleExistingBranchCase(
    GenXSimdCFConformance::GotoJoinEVs &GotoJoinData, BasicBlock *&TrueSucc,
    BasicBlock *&FalseSucc, BranchInst *ExistingBranch) {
  ExtractValueInst *CondEV = GotoJoinData.getCondEV();
  IGC_ASSERT_MESSAGE(CondEV, "Expected valid CondEV!");
  IGC_ASSERT_MESSAGE(ExistingBranch->isConditional(),
                     "Expected conditional branch!");

  TrueSucc = ExistingBranch->getSuccessor(0);
  FalseSucc = ExistingBranch->getSuccessor(1);

  if (TrueSucc == FalseSucc) {
    // We need to simply introduce new BB to get CondEV
    FalseSucc = BasicBlock::Create(CondEV->getContext(), "cond_ev_split",
                                   CondEV->getFunction(), TrueSucc);
    BranchInst::Create(TrueSucc, FalseSucc);
    ExistingBranch->setSuccessor(1, FalseSucc);

    LLVM_DEBUG(dbgs() << "Created " << FalseSucc->getName()
                      << " to handle always taken CONDITIONAL branch\n");

    // Update phis in TrueSucc
    addNewPhisIncomings(CondEV->getParent(), TrueSucc, FalseSucc);

    // CFG changed: update DomTree.
    // TODO: there must be workaround to do it in a more optimal way
    DominatorTree *domTree = getDomTree(CondEV->getFunction());
    domTree->recalculate(*CondEV->getFunction());
  }
}

/**
 * handleCondValue : perform analysis on Cond EV usage and fix
 * it if needed
 *
 * The basic use case is optimized False Successor. That
 * often happens in standard SimplifyCFG pass.
 */
void GenXSimdCFConformance::handleCondValue(Value *GotoJoin) {
  GotoJoinEVs &GotoJoinData = GotoJoinEVsMap[GotoJoin];
  ExtractValueInst *CondEV = GotoJoinData.getCondEV();

  // No cond EV - nothing to handle. Here we create branch for goto
  // to make it easier to handle possible bad EM users. Goto is a
  // branch itself and it won't introduce any overhead in case
  // of conformant SIMD CF
  if (!CondEV) {
    handleNoCondEVCase(GotoJoinData);
    return;
  }

  // Collect Cond EV users
  std::vector<Value *> BadUsers;
  BranchInst *CorrectUser = nullptr;
  collectCondEVUsers(CondEV, BadUsers, CorrectUser);

  // Nothing needs to be fixed. However, allow this algorithm to fix
  // case with TrueSucc == FalseSucc for goto in order to simplify further
  // analysis.
  if (BadUsers.empty() && GotoJoinData.isJoin())
    return;

  BasicBlock *TrueSucc = nullptr;
  BasicBlock *FalseSucc = nullptr;

  if (!CorrectUser) {
    // Branch was optimized by some pass. We need to create it again.
    handleOptimizedBranchCase(GotoJoinData, TrueSucc, FalseSucc);
  } else {
    // Branch is still here. Perform actions needed.
    handleExistingBranchCase(GotoJoinData, TrueSucc, FalseSucc, CorrectUser);
  }

  // Update users
  updateBadCondEVUsers(GotoJoinData, BadUsers, TrueSucc, FalseSucc);
}

/***********************************************************************
 * splitGotoJoinBlocks : split Basic Blocks that contains goto/join
 *
 * This is used to solve problems that can be introduced by some
 * standard LLVM passes: one of them is simplified CFG that lead to
 * goto/join's condition usage by non-branch instruction. After this
 * transformation each BB will contain only one goto or join instruction
 * (or none of them), that fact allows us to make further changes simplier.
 */
void GenXSimdCFConformance::splitGotoJoinBlocks() {

  LLVM_DEBUG(dbgs() << "splitGotoJoinBlocks: start\n");

  for (auto &Elem : GotoJoinEVsMap) {

    Value *GotoJoin = Elem.first;
    auto &GotoJoinData = Elem.second;

    LLVM_DEBUG(dbgs() << "Trying to split BB for:\n" << *GotoJoin << "\n");

    handleCondValue(GotoJoin);

    if (GotoJoinData.isJoin()) {
      auto SplitPoint = GotoJoinData.getSplitPoint();
      if (SplitPoint->isTerminator())
        continue;
      SplitPoint->getParent()->splitBasicBlock(SplitPoint, "split_for_join");
      // CFG changed: update DomTree.
      // TODO: there must be workaround to do it in a more optimal way
      DominatorTree *domTree = getDomTree(SplitPoint->getFunction());
      domTree->recalculate(*SplitPoint->getFunction());
    }
  }

  LLVM_DEBUG(dbgs() << "Done splitting\n\n" << *M << "\n\n");
}

/***********************************************************************
 * removeFromEMRMVals : remove a value from EMVals or RMVals
 *
 * This is used just before erasing a phi node in moveCodeInJoinBlocks.
 */
void GenXSimdCFConformance::removeFromEMRMVals(Value *V) {
  auto *VT = dyn_cast<VectorType>(V->getType());
  if (!VT || !VT->getElementType()->isIntegerTy(1))
    return;
  if (EMVals.remove(SimpleValue(V, 0)))
    return;
  for (auto i = RMVals.begin(), e = RMVals.end(); i != e; ++i) {
    auto RMValsEntry = &i->second;
    if (RMValsEntry->remove(SimpleValue(V, 0)))
      return;
  }
}

/***********************************************************************
 * hoistGotoUser : hoist instruction that uses goto's EV and is located
 * after it in the same basic block.
 *
 * Since goto must be at the end of basic block, we have to solve
 * this problem somehow. Current approach is to duplicate instruction
 * on both paths (true and false) and update uses.
 *
 * It is always possible to perform such transformation even if there
 * is a chain of users: we just can duplicate them all. Since we know
 * all values on the true pass, it should be possible to perform full
 * calculation in this case. However, it is not done now because it can
 * lead to much worse code when SIMD CF is not conformant (we are not
 * sure that it is conformant at this point).
 */
bool GenXSimdCFConformance::hoistGotoUser(Instruction *Inst, CallInst *Goto,
                                          unsigned operandNo) {
  // Find branch for goto
  ExtractValueInst *CondEV = GotoJoinEVsMap[Goto].getCondEV();
  auto BrIt = std::find_if(
      CondEV->use_begin(), CondEV->use_end(), [&Goto](const Use &u) {
        auto *Br = dyn_cast<BranchInst>(u.getUser());
        return (Br && Br->getParent() == Goto->getParent() &&
                Br->isConditional());
      });
  IGC_ASSERT_MESSAGE(BrIt != CondEV->use_end(),
                     "All gotos should become branching earlier!");

  BranchInst *Br = cast<BranchInst>(BrIt->getUser());
  BasicBlock *TrueSucc = Br->getSuccessor(0);
  BasicBlock *FalseSucc = Br->getSuccessor(1);

  // Handle FallThrough block with phis.
  //
  // TODO: it is redundant in some cases. For example, there can be Phi that
  // uses bitcasts from EM from two paths. In this case we can use one
  // GetEM from Phi with EM. Currently there is no trivial mechanism
  // to check for that because in this case Phi arguments are supposed to use
  // different Exectution Masks according to DF.
  //
  // Temporary solution for that is to place a splitter block that branches to
  // such bb directly. Examples of that case can be found in local-atomics
  // tests in ISPC.
  if (isa<PHINode>(&FalseSucc->front())) {
    BasicBlock *Splitter =
        BasicBlock::Create(FalseSucc->getContext(), "phi_fallthrough_splitter",
                           FalseSucc->getParent());
    Splitter->moveAfter(Goto->getParent());
    BranchInst::Create(FalseSucc, Splitter);
    Br->setSuccessor(1, Splitter);
    // Update phis
    for (auto CurrInst = &FalseSucc->front();
         auto *PN = dyn_cast<PHINode>(CurrInst);
         CurrInst = CurrInst->getNextNode()) {
      for (unsigned idx = 0, num = PN->getNumIncomingValues(); idx < num;
           ++idx) {
        if (PN->getIncomingBlock(idx) == Goto->getParent())
          PN->setIncomingBlock(idx, Splitter);
      }
    }
    FalseSucc = Splitter;
    // CFG changed: update DomTree.
    // TODO: there must be workaround to do it in a more optimal way
    DominatorTree *domTree = getDomTree(CondEV->getFunction());
    domTree->recalculate(*CondEV->getFunction());
  }

  // Copy instruction and set the value for true block. Place it before goto.
  Instruction *TrueVal = Inst->clone();
  TrueVal->insertBefore(Goto);
  TrueVal->setOperand(operandNo, Constant::getNullValue(
                                     Inst->getOperand(operandNo)->getType()));

  // Copy instruction and place it in the false successor. Get EM will be
  // created later to handle its goto use.
  Instruction *FalseVal = Inst->clone();
  FalseVal->insertBefore(FalseSucc->getFirstNonPHI());

  // Handle all users
  BasicBlockEdge TrueEdge(Goto->getParent(), TrueSucc);
  BasicBlockEdge FalseEdge(Goto->getParent(), FalseSucc);
  std::map<BasicBlock *, Value *> foundVals;
  std::vector<Value *> newOperands;
  for (auto ui = Inst->use_begin(), ue = Inst->use_end(); ui != ue; ++ui) {
    auto User = cast<Instruction>(ui->getUser());
    // TODO: it can be solved with duplicated instructions.
    // Currently we are not going to duplicate them.
    if (User->getParent() == Inst->getParent()) {
      TrueVal->eraseFromParent();
      FalseVal->eraseFromParent();
      return false;
    }

    BasicBlock *Loc = User->getParent();
    if (auto *PN = dyn_cast<PHINode>(User))
      Loc = PN->getIncomingBlock(ui->getOperandNo());

    // Store new value
    Value *NewOperand = nullptr;
    if (Loc == Goto->getParent())
      NewOperand = TrueVal;
    else
      NewOperand = findGotoJoinVal(vc::RegCategory::EM, Loc, Inst, TrueEdge,
                                   FalseEdge, TrueVal, FalseVal, foundVals);

    newOperands.push_back(NewOperand);
  }

  // Update uses
  unsigned i = 0;
  for (auto ui = Inst->use_begin(), ue = Inst->use_end(); ui != ue;) {
    auto User = cast<Instruction>(ui->getUser());
    unsigned opNo = ui->getOperandNo();
    ++ui;
    User->setOperand(opNo, newOperands[i++]);
  }

  return true;
}

/***********************************************************************
 * moveCodeInGotoBlocks : move code in goto blocks
 *
 * A goto and its extractvalues must be at the end of the block. (Actually, if
 * the !any result of the goto is used in a conditional branch at the end of
 * the block, then the goto being baled into the branch means that it is
 * treated as being at the end of the block anyway. The only reason we need to
 * sink it here is to ensure that isGotoBlock works.)
 *
 * This can silently fail to sink a goto, in which case checkGoto will spot that
 * the goto is not conformant.
 */
void GenXSimdCFConformance::moveCodeInGotoBlocks(bool hoistGotoUsers) {
  LLVM_DEBUG(dbgs() << "moveCodeInGotoBlocks: start\n");
  for (auto gi = EMVals.begin(), ge = EMVals.end(); gi != ge; ++gi) {
    auto EMVal = *gi;
    auto IID = vc::getAnyIntrinsicID(EMVal.getValue());
    if (IID != GenXIntrinsic::genx_simdcf_goto)
      continue;
    auto Goto = cast<CallInst>(EMVal.getValue());
    // We want to sink the goto and its extracts. In fact we hoist any other
    // instruction, checking that it does not use the extracts.
    // With hoistGotoUsers, we are trying to hoist them, too.
    // We are skipping all instructions that use skipped instructions
    // in order to save dominance.
    std::set<Instruction *> Skipping;
    for (Instruction *NextInst = Goto->getNextNode();;) {
      auto Inst = NextInst;
      if (Inst->isTerminator())
        break;
      IGC_ASSERT(Inst);
      NextInst = Inst->getNextNode();
      if (auto *Extract = dyn_cast<ExtractValueInst>(Inst))
        if (Extract->getOperand(0) == Goto)
          continue;
      bool Failed = false;
      for (unsigned oi = 0, oe = Inst->getNumOperands(); oi != oe; ++oi) {
        if (auto *I = dyn_cast<Instruction>(Inst->getOperand(oi)))
          if (Skipping.count(I)) {
            LLVM_DEBUG(dbgs() << "Skipping " << Inst->getName()
                              << " due to use of skipped inst\n");
            Skipping.insert(Inst);
            Failed = true;
            break;
          }
        if (auto *Extract = dyn_cast<ExtractValueInst>(Inst->getOperand(oi)))
          if (Extract->getOperand(0) == Goto) {
            // This is used after splitting basic blocks.
            // To perform this all gotos must be branching since EM
            // is changed by goto.
            if (hoistGotoUsers && hoistGotoUser(Inst, Goto, oi)) {
              continue;
            }
            LLVM_DEBUG(dbgs() << "moveCodeInGotoBlocks: " << Goto->getName()
                              << " failed\n");
            LLVM_DEBUG(dbgs() << "Could not hoist " << Inst->getName() << "\n");
            Failed = true;
            Skipping.insert(Inst);
            break; // Intervening instruction uses extract of goto; abandon
          }
      }
      if (Failed)
        continue;
      // Hoist the instruction.
      Inst->removeFromParent();
      Inst->insertBefore(Goto);
    }
  }
}

/***********************************************************************
 * moveCodeInJoinBlocks : move code in join blocks as necessary
 *
 * 1. For a join label block (a block that is the JIP of other gotos/joins), a
 *    join must come at the start of the block.
 *
 * 2. For a branching join block (one whose conditional branch condition is the
 *    !any result from a join), the join must be at the end of the block.
 *
 * 3. For a block that has one join with both of the above true, we need to move
 *    all other code out of the block.
 *
 * We achieve this as follows:
 *
 * a. First handle case 3. For any such block, hoist any other code to the end
 *    of its immediate dominator. To allow for the immediate dominator also
 *    being a case 3 join, we process blocks in post-order depth first search
 *    order, so we visit a block before its dominator. Thus code from a case 3
 *    join block eventually gets moved up to its closest dominating block that
 *    is not a case 3 join block.
 *
 *    Because it is more convenient and does not hurt, we also hoist the code
 *    before the first join in a block that initially looks like it is case 3,
 *    even if it then turns out not to be a case 3 join because it has multiple
 *    joins.
 *
 * b. Then scan all joins handling case 1.
 *
 * c. No need to handle case 2 here, as it (together with a similar requirement
 *    to sink a goto in a branching goto block) is checked in checkConformance
 *    and treated as sunk subsequently by virtue of getting baled in to the
 *    branch.
 *
 * This happens in both SIMD CF conformance passes, in case constant loading
 * etc sneaks code back into the wrong place in a join block. Any pass after
 * the late SIMD CF conformance pass needs to be careful not to sneak code back
 * into a join block.
 *
 * Any failure to do the above is not flagged here, but it will be spotted when
 * checking the join for conformance.
 *
 * moveCodeInGotoBlocks needs to run first, as we rely on its sinking of an
 * unconditional branch goto for isBranchingGotoJoinBlock to work.
 */
void GenXSimdCFConformance::moveCodeInJoinBlocks() {
  LLVM_DEBUG(dbgs() << "moveCodeInJoinBlocks: start\n");
  // a. Handle case 3 join blocks.
  if (!FG) {
    // Early pass: iterate all funcs in the module.
    for (auto mi = M->begin(), me = M->end(); mi != me; ++mi) {
      Function *F = &*mi;
      if (!F->empty())
        emptyBranchingJoinBlocksInFunc(F);
    }
  } else {
    // Late pass: iterate all funcs in the function group.
    for (auto fgi = FG->begin(), fge = FG->end(); fgi != fge; ++fgi) {
      Function *F = *fgi;
      emptyBranchingJoinBlocksInFunc(F);
    }
  }
  // b. Process all other joins (in fact all joins, but ones successfully
  // processed above will not need anything doing).
  // Get the joins into a vector first, because the code below modifies EMVals.
  SmallVector<CallInst *, 4> Joins;
  for (auto ji = EMVals.begin(), je = EMVals.end(); ji != je; ++ji) {
    auto EMVal = *ji;
    auto IID = vc::getAnyIntrinsicID(EMVal.getValue());
    if (IID != GenXIntrinsic::genx_simdcf_join)
      continue;
    Joins.push_back(cast<CallInst>(EMVal.getValue()));
  }
  for (auto ji = Joins.begin(), je = Joins.end(); ji != je; ++ji) {
    auto Join = *ji;
    auto JoinBlock = Join->getParent();
    if (GotoJoin::isJoinLabel(JoinBlock, /*SkipCriticalEdgeSplitter=*/true))
      hoistJoin(Join);
    else {
      // The join is in a block that is not a join label. Also check the case
      // that there is a predecessor that:
      // 1. has one successor; and
      // 2. is empty other than phi nodes; and
      // 3. is a join label.
      // In that case we merge the two blocks, merging phi nodes.
      // I have seen this situation arise where LLVM decides to add a loop
      // pre-header block.
      BasicBlock *PredBlock = nullptr;
      for (auto ui = JoinBlock->use_begin(), ue = JoinBlock->use_end();
           ui != ue; ++ui) {
        auto *Br = dyn_cast<BranchInst>(ui->getUser());
        if (!Br || Br->isConditional())
          continue;
        auto BB = Br->getParent();
        if (BB->getFirstNonPHIOrDbg() != Br)
          continue;
        if (GotoJoin::isJoinLabel(BB, /*SkipCriticalEdgeSplitter=*/true)) {
          PredBlock = BB;
          break;
        }
      }
      if (PredBlock) {
        // We have such a predecessor block. First hoist the join in our block.
        if (hoistJoin(Join)) {
          // Join hoisting succeeded. Now merge the blocks.
          LLVM_DEBUG(dbgs()
                     << "moveCodeInJoinBlocks: merging " << PredBlock->getName()
                     << " into " << JoinBlock->getName() << "\n");
          // First adjust the phi nodes to include both blocks' incomings.
          for (auto *Phi = dyn_cast<PHINode>(&JoinBlock->front()); Phi;
               Phi = dyn_cast<PHINode>(Phi->getNextNode())) {
            int Idx = Phi->getBasicBlockIndex(PredBlock);
            if (Idx >= 0) {
              Value *Incoming = Phi->getIncomingValue(Idx);
              auto *PredPhi = dyn_cast<PHINode>(Incoming);
              if (PredPhi && PredPhi->getParent() != PredBlock)
                PredPhi = nullptr;
              if (PredPhi) {
                // The incoming in JoinBlock is a phi node in PredBlock. Add its
                // incomings.
                Phi->removeIncomingValue(Idx, /*DeletePHIIfEmpty=*/false);
                for (unsigned oi = 0, oe = PredPhi->getNumIncomingValues();
                     oi != oe; ++oi)
                  Phi->addIncoming(PredPhi->getIncomingValue(oi),
                                   PredPhi->getIncomingBlock(oi));
              } else {
                // Otherwise, add the predecessors of PredBlock to the phi node
                // in JoinBlock.
                for (auto ui2 = PredBlock->use_begin(),
                          ue2 = PredBlock->use_end();
                     ui2 != ue2; ++ui2) {
                  Instruction *Term = dyn_cast<Instruction>(ui2->getUser());
                  IGC_ASSERT(Term);
                  if (Term->isTerminator()) {
                    auto PredPred = Term->getParent();
                    if (Idx >= 0) {
                      Phi->setIncomingBlock(Idx, PredPred);
                      Idx = -1;
                    } else
                      Phi->addIncoming(Incoming, PredPred);
                  }
                }
              }
            }
          }
          // Any phi in PredBlock that is not used in a phi in JoinBlock (and
          // so still has at least one use after the code above) needs to be
          // moved to JoinBlock, with itself added as the extra incomings. The
          // incoming blocks to JoinBlock other than PredBlock must be loop
          // back edges.
          for (;;) {
            auto *Phi = dyn_cast<PHINode>(&PredBlock->front());
            if (!Phi)
              break;
            if (Phi->use_empty()) {
              removeFromEMRMVals(Phi);
              Phi->eraseFromParent();
              continue;
            }
            for (auto ui = JoinBlock->use_begin(), ue = JoinBlock->use_end();
                 ui != ue; ++ui) {
              auto *Term = dyn_cast<Instruction>(ui->getUser());
              IGC_ASSERT(Term);
              if (!Term->isTerminator())
                continue;
              auto TermBB = Term->getParent();
              if (TermBB == PredBlock)
                continue;
              Phi->addIncoming(Phi, TermBB);
            }
            Phi->removeFromParent();
            Phi->insertBefore(&JoinBlock->front());
          }
          // Adjust branches targeting PredBlock to target JoinBlock instead.
          PredBlock->replaceAllUsesWith(JoinBlock);
          // Remove PredBlock.
          PredBlock->eraseFromParent();
        }
      }
    }
  }
}

/***********************************************************************
 * emptyBranchingJoinBlocksInFunc : empty other instructions out of each
 *    block in a function that is both a join label and a branching join block
 *
 * See comment for moveCodeInJoinBlocks above.
 */
void GenXSimdCFConformance::emptyBranchingJoinBlocksInFunc(Function *F) {
  for (auto i = po_begin(&F->getEntryBlock()), e = po_end(&F->getEntryBlock());
       i != e; ++i) {
    BasicBlock *BB = *i;
    CallInst *Join = GotoJoin::isBranchingJoinBlock(BB);
    if (!Join)
      continue;
    emptyBranchingJoinBlock(Join);
  }
}

/***********************************************************************
 * emptyBranchingJoinBlock : empty instructions other than the join (and its
 *      extracts) from this branching join block
 */
void GenXSimdCFConformance::emptyBranchingJoinBlock(CallInst *Join) {
  BasicBlock *BB = Join->getParent();
  Instruction *InsertBefore = nullptr;
  for (Instruction *NextInst = BB->getFirstNonPHIOrDbg();;) {
    auto Inst = NextInst;
    if (Inst->isTerminator())
      break;
    NextInst = Inst->getNextNode();
    if (Inst == Join)
      continue; // do not hoist the join itself
    auto IID = vc::getAnyIntrinsicID(Inst);
    if (IID == GenXIntrinsic::genx_simdcf_join)
      break; // we have encountered another join; there must be more than one
    if (auto *EV = dyn_cast<ExtractValueInst>(Inst))
      if (EV->getOperand(0) == Join)
        continue; // do not hoist an extract of the join
    // Check that the instruction's operands do not use anything in this block
    // (the phi nodes, or the join and extracts being left behind).
    for (unsigned oi = 0, oe = Inst->getNumOperands(); oi != oe; ++oi) {
      auto *Opnd = dyn_cast<Instruction>(Inst->getOperand(oi));
      if (Opnd && Opnd->getParent() == BB) {
        LLVM_DEBUG(dbgs() << "Failed to empty branching join label for join "
                          << Join->getName() << "\n");
        return; // Instruction uses something in this block: abandon.
      }
    }
    if (!InsertBefore) {
      // Lazy determination of the insert point. If it is a branching goto/join
      // block, insert before the goto/join.
      auto DomTree = getDomTree(BB->getParent());
      IGC_ASSERT(DomTree);
      auto BBNode = DomTree->getNode(BB);
      IGC_ASSERT(BBNode);
      auto InsertBB = BBNode->getIDom()->getBlock();
      InsertBefore = GotoJoin::isBranchingGotoJoinBlock(InsertBB);
      if (!InsertBefore)
        InsertBefore = InsertBB->getTerminator();
    }
    // Hoist the instruction.
    Inst->removeFromParent();
    Inst->insertBefore(InsertBefore);
    Modified = true;
  }
}

/***********************************************************************
 * getDomTree : get dominator tree for a function
 */
DominatorTree *GenXSimdCFConformance::getDomTree(Function *F) {
  if (!DTWrapper) {
    // In early pass, which is a module pass.
    if (!DTs[F]) {
      auto DT = new DominatorTree;
      DT->recalculate(*F);
      DTs[F] = DT;
    }
    return DTs[F];
  }
  // In late pass, use the DominatorTreeGroupWrapper.
  return DTWrapper->getDomTree(F);
}

/***********************************************************************
 * hoistJoin : hoist a join to the top of its basic block if possible
 *
 * Return:  whether succeeded
 *
 * This is used for a join in a block that is a join label, but not a branching
 * join block. See comment for emptyJoinBlocks above.
 *
 * There might be multiple joins in the function, and the one supplied is not
 * necessarily the first one. If it is a later one, this function will silently
 * fail, which is harmless. If it silently fails for the first join, then we
 * end up with a join label block that does not start with a join, which
 * checkConformance will spot later on.
 *
 * This function does return whether it has succeeded, which is used in
 * moveCodeInJoinBlocks in the case that it wants to merge a loop pre-header
 * back into the join block.
 */
bool GenXSimdCFConformance::hoistJoin(CallInst *Join) {
  // This only works if no operand of the join uses one of the instructions
  // before it in the block, other than phi nodes.
  // However, if we find such an instruction and it is an extractvalue from the
  // result of an earlier goto/join in a different block, we can just move it
  // to after that goto/join.
  for (unsigned oi = 0, oe = IGCLLVM::getNumArgOperands(Join); oi != oe; ++oi) {
    auto *Opnd = dyn_cast<Instruction>(Join->getOperand(oi));
    if (!Opnd || isa<PHINode>(Opnd))
      continue;
    if (Opnd->getParent() == Join->getParent()) {
      if (auto *EV = dyn_cast<ExtractValueInst>(Opnd)) {
        unsigned IID = vc::getAnyIntrinsicID(EV->getOperand(0));
        if (IID == GenXIntrinsic::genx_simdcf_goto ||
            IID == GenXIntrinsic::genx_simdcf_join) {
          auto GotoJoin = cast<CallInst>(EV->getOperand(0));
          if (GotoJoin->getParent() != Join->getParent()) {
            LLVM_DEBUG(dbgs() << "moving out of join block: " << *EV << "\n");
            EV->removeFromParent();
            EV->insertBefore(GotoJoin->getNextNode());
            continue;
          }
        }
      }
      LLVM_DEBUG(dbgs() << "hoistJoin: " << Join->getName() << " failed\n");
      return false; // failed -- join uses non-phi instruction before it
    }
  }
  // Hoist the join.
  auto BB = Join->getParent();
  auto InsertBefore = BB->getFirstNonPHIOrDbg();
  if (InsertBefore == Join)
    return true; // already at start
  Join->removeFromParent();
  Join->insertBefore(InsertBefore);
  // Such transformation should be performed only for Early Conformance pass
  if (!FG)
    GotoJoinEVsMap[Join].hoistEVs();
  Modified = true;
  return true;
}

/***********************************************************************
 * ensureConformance : check for conformance, and lower any non-conformant
 *    gotos and joins
 */
void GenXSimdCFConformance::ensureConformance() {
  LLVM_DEBUG(dbgs() << "ensureConformance: start\n");
  // Push all EM values onto the stack for checking. Push the joins last, since
  // we want to process those before their corresponding gotos, so that
  // GotoJoinMap is set for a goto by the time we process a valid goto.
  for (auto i = EMVals.begin(), e = EMVals.end(); i != e; ++i) {
    auto IID = vc::getAnyIntrinsicID(i->getValue());
    if (IID != GenXIntrinsic::genx_simdcf_join &&
        IID != GenXIntrinsic::genx_simdcf_unmask &&
        IID != GenXIntrinsic::genx_simdcf_remask) {
      EMValsStack.insert(*i);
      LLVM_DEBUG(if (auto *Inst = dyn_cast<Instruction>(i->getValue())) {
        auto FuncName = Inst->getFunction()->getName();
        dbgs() << "Entry EMVals " << FuncName << " - " << *Inst << "\n";
      });
    }
  }
  for (auto i = EMVals.begin(), e = EMVals.end(); i != e; ++i) {
    auto IID = vc::getAnyIntrinsicID(i->getValue());
    if (IID == GenXIntrinsic::genx_simdcf_join)
      EMValsStack.insert(*i);
  } // Process the stack.
  SmallVector<CallInst *, 4> GotosToLower;
  SmallVector<CallInst *, 4> JoinsToLower;
  for (;;) {
    if (!EMValsStack.empty()) {
      // Remove and process the top entry on the stack.
      auto EMVal = EMValsStack.pop_back_val();
      if (checkEMVal(EMVal))
        continue;
      removeBadEMVal(EMVal);
      if (!EMVal.getIndex()) {
        if (auto *CI = dyn_cast<CallInst>(EMVal.getValue())) {
          auto IID = vc::getAnyIntrinsicID(EMVal.getValue());
          switch (IID) {
          case GenXIntrinsic::genx_simdcf_goto:
            GotosToLower.push_back(CI);
            break;
          case GenXIntrinsic::genx_simdcf_join:
            JoinsToLower.push_back(CI);
            break;
          default:
            break;
          }
        }
      }
      continue;
    }
    // The stack is empty. Check for EM values interfering with each other.
    checkEMInterference();
    if (EMValsStack.empty()) {
      // Stack still empty; we have finished.
      break;
    }
  }

  if (isLatePass()) {
    // In the late pass, we are not expecting to have found any non-conformant
    // gotos and joins that need lowering. All such gotos and joins should have
    // been identified in the early pass, unless passes in between have
    // transformed the code in an unexpected way that has made the simd CF
    // non-conformant. Give an error here if this has happened.
    LLVM_DEBUG(
        if (!GotosToLower.empty()) {
          dbgs() << "Not empty GotosToLower:";
          for (auto *Dump : GotosToLower)
            dbgs() << *Dump;
        } if (!JoinsToLower.empty()) {
          dbgs() << "Not empty JoinsToLower:";
          for (auto *Dump : JoinsToLower)
            dbgs() << *Dump;
        });
    IGC_ASSERT_EXIT_MESSAGE(
        GotosToLower.empty(),
        "unexpected non-conformant SIMD CF in late SIMD CF conformance pass");
    IGC_ASSERT_EXIT_MESSAGE(
        JoinsToLower.empty(),
        "unexpected non-conformant SIMD CF in late SIMD CF conformance pass");
  }

  // Lower gotos and joins that turned out to be non-conformant.
  for (auto i = GotosToLower.begin(), e = GotosToLower.end(); i != e; ++i)
    lowerGoto(*i);
  for (auto i = JoinsToLower.begin(), e = JoinsToLower.end(); i != e; ++i)
    lowerJoin(*i);
}

/***********************************************************************
 * getEMProducer : perform recurrent check for EM terms.
 *
 * It goes through all phis and bitcasts (when BitCastAllowed is true)
 * and determines whether the EM is correct in DF terms. It doesn't
 * check live range interference, but can spot non-conformant usage
 * in case when EM from bad instruction is being used.
 *
 * This approach is used when we need to perform some actions on full
 * EM data flow, for example, to insert phis when eliminating redundant
 * bitcasts.
 *
 * All found EM producers are stored in EMProducers and can be used
 * later without performing full search.
 *
 * TODO: currently returns User if it deals with EM. It is done in
 * this way as workaround for possible future changes (for example,
 * getConnectedVals refactor). The idea of such approach is to be
 * able to update info if something changes.
 */
Value *GenXSimdCFConformance::getEMProducer(Value *User,
                                            std::set<Value *> &Visited,
                                            bool BitCastAllowed) {
  LLVM_DEBUG(dbgs() << "Looking for EM producer for value:\n" << *User << "\n");

  if (Visited.count(User)) {
    if (dyn_cast<PHINode>(User))
      return User;
    return nullptr;
  }

  // Check for previously found value
  auto It = EMProducers.find(User);
  if (It != EMProducers.end()) {
    LLVM_DEBUG(if (It->second) dbgs() << "Using previously found value:\n"
                                      << *It->second << "\n";
               else dbgs() << "Using previously found empty-value!\n");
    return It->second;
  }

  if (auto *C = dyn_cast<Constant>(User)) {
    // All one is considered as EM at entry point
    if (C->isAllOnesValue()) {
      LLVM_DEBUG(dbgs() << "EMProducer is an AllOne constant\n");
      EMProducers[C] = C;
      return C;
    }
  } else if (auto *PN = dyn_cast<PHINode>(User)) {
    // For phi node, check all its preds. They all must be EMs
    Visited.insert(PN);
    for (unsigned idx = 0, opNo = PN->getNumOperands(); idx < opNo; ++idx) {
      Value *Pred = PN->getOperand(idx);

      if (!getEMProducer(Pred, Visited, BitCastAllowed)) {
        LLVM_DEBUG(dbgs() << "!!! Bad phi pred detected for:\n" << *PN << "\n");
        EMProducers[PN] = nullptr;
        return nullptr;
      }
    }

    LLVM_DEBUG(dbgs() << "EMProducer is phi itself:\n" << *PN << "\n");
    EMProducers[PN] = PN;
    return PN;
  } else if (auto *EVI = dyn_cast<ExtractValueInst>(User)) {
    // Extract value can be an EV from goto/join or from callee that
    // returned it. For the second case we check that the pred is
    // still in EM values since it could be lowered.
    CallInst *CI = dyn_cast<CallInst>(EVI->getOperand(0));
    if (CI) {
      // Goto/join check
      auto IID = vc::getAnyIntrinsicID(CI);
      if (IID == GenXIntrinsic::genx_simdcf_goto ||
          IID == GenXIntrinsic::genx_simdcf_join) {
        LLVM_DEBUG(dbgs() << "Reached goto/join\n");
        EMProducers[EVI] = EVI;
        return EVI;
      }

      // EV from other calls.
      if (EMVals.count(SimpleValue(CI, EVI->getIndices()[0]))) {
        LLVM_DEBUG(dbgs() << "Value from return\n");
        EMProducers[EVI] = EVI;
        return EVI;
      }
    }
  } else if (auto *Arg = dyn_cast<Argument>(User)) {
    // For argument we need to ensure that it is still in EM values
    // since it could be lowered.
    if (EMVals.count(SimpleValue(Arg, Arg->getArgNo()))) {
      LLVM_DEBUG(dbgs() << "Input argument\n");
      EMProducers[Arg] = Arg;
      return Arg;
    }
  } else if (auto *IVI = dyn_cast<InsertValueInst>(User)) {
    // Insert value prepares structure for return. Check the
    // value that is being inserted
    Visited.insert(IVI);
    if (auto EMProd = getEMProducer(IVI->getInsertedValueOperand(), Visited,
                                    BitCastAllowed)) {
      LLVM_DEBUG(dbgs() << "Insert for return\n");
      EMProducers[IVI] = EMProd;
      return IVI;
    }
  } else if (BitCastAllowed) {
    if (auto *BCI = dyn_cast<BitCastInst>(User)) {
      // BitCast doesn't produce new EM. Just go through it.
      Visited.insert(BCI);
      if (auto EMProd =
              getEMProducer(BCI->getOperand(0), Visited, BitCastAllowed)) {
        LLVM_DEBUG(dbgs() << "Bitcast from EM producer\n");
        EMProducers[BCI] = EMProd;
        return BCI;
      }
    }
  }

  // All other instructions cannot be treated as EM producers
  LLVM_DEBUG(dbgs() << "!!! IT IS NOT A EM PRODUCER !!!\n");
  return nullptr;
}

/***********************************************************************
 * lowerUnsuitableGetEMs : remove all unsuitable get_em intrinsics.
 *
 * This intrinsic is unsuitable if:
 *   - It uses constant value: it is simply redundant
 *   - The EM argument is not actually a EM: this may happen if
 *     SIMD CF was non-conformant and this EM was lowered.
 */
void GenXSimdCFConformance::lowerUnsuitableGetEMs() {
  auto *I1Ty = Type::getInt1Ty(M->getContext());
  auto *GetEMDecl = vc::getAnyDeclaration(
      M, GenXIntrinsic::genx_simdcf_get_em,
      {IGCLLVM::FixedVectorType::get(I1Ty, 32)});
  std::vector<Instruction *> ToDelete;
  for (auto ui = GetEMDecl->use_begin(); ui != GetEMDecl->use_end();) {
    std::set<Value *> Visited;
    auto GetEM = cast<Instruction>(ui->getUser());
    ++ui;
    auto GetEMPred = GetEM->getOperand(0);

    if (GetEM->use_empty()) {
      ToDelete.push_back(GetEM);
      continue;
    }

    // Constants and non-EM values should be used directly
    if (dyn_cast<Constant>(GetEMPred) ||
        !getEMProducer(dyn_cast<Instruction>(GetEMPred), Visited)) {
      GetEM->replaceAllUsesWith(GetEM->getOperand(0));
      ToDelete.push_back(GetEM);
    }
  }
  if (GetEMDecl->use_empty())
    GetEMDecl->eraseFromParent();

  for (auto *Inst : ToDelete) {
    Inst->eraseFromParent();
  }
}

/***********************************************************************
 * lowerAllSimdCF : do NOT check for conformance, and simply lower
 * all any gotos, joins, and unmasks
 */
void GenXSimdCFConformance::lowerAllSimdCF() {
  for (auto i = EMVals.begin(), e = EMVals.end(); i != e; ++i) {
    if (auto *CI = dyn_cast<CallInst>(i->getValue())) {
      auto IID = vc::getAnyIntrinsicID(i->getValue());
      if (IID == GenXIntrinsic::genx_simdcf_join)
        lowerJoin(CI);
      else if (IID == GenXIntrinsic::genx_simdcf_goto)
        lowerGoto(CI);
      else if (IID == GenXIntrinsic::genx_simdcf_unmask) {
        auto SaveMask = CI->getArgOperand(0);
        if (auto *CI0 = dyn_cast<CallInst>(SaveMask)) {
          IRBuilder<> Builder(CI0);
          auto Replace =
              Builder.CreateBitCast(CI0->getArgOperand(0), CI0->getType());
          CI0->replaceAllUsesWith(Replace);
          CI0->eraseFromParent();
        }
        IRBuilder<> Builder(CI);
        auto Replace =
            Builder.CreateBitCast(CI->getArgOperand(1), CI->getType());
        CI->replaceAllUsesWith(Replace);
        CI->eraseFromParent();
      } else if (IID == GenXIntrinsic::genx_simdcf_remask) {
        IRBuilder<> Builder(CI);
        auto Replace =
            Builder.CreateBitCast(CI->getArgOperand(1), CI->getType());
        CI->replaceAllUsesWith(Replace);
        CI->eraseFromParent();
      }
    }
  }
}

/***********************************************************************
 * checkEMVal : check an EM value for conformance
 *
 * Return:    true if ok, false if the EM value needs to be removed
 */
bool GenXSimdCFConformance::checkEMVal(SimpleValue EMVal) {
  LLVM_DEBUG(dbgs() << "checkEMVal: " << *EMVal.getValue() << "#"
                    << EMVal.getIndex() << "\n");
  if (!EnableGenXGotoJoin)
    return false; // use of goto/join disabled
  SmallVector<SimpleValue, 8> ConnectedVals;
  // Check connected values. Do not lower bad users in Late Pass because
  // current SIMD CF Conformance check approach expects that SIMD CF must
  // be OK at this point if it wasn't lowered during Early Pass.
  if (!getConnectedVals(EMVal, vc::RegCategory::EM, /*IncludeOptional=*/true,
                        /*OkJoin=*/nullptr, &ConnectedVals,
                        /*LowerBadUsers=*/!FG)) {
    LLVM_DEBUG(dbgs() << "checkEMVal: invalid def or uses\n");
    return false; // something invalid about the EM value itself
  }
  // Check that all connected values are EM values.
  for (auto i = ConnectedVals.begin(), e = ConnectedVals.end(); i != e; ++i) {
    SimpleValue ConnectedVal = *i;
    if (auto *C = dyn_cast<Constant>(ConnectedVal.getValue())) {
      if (!C->isAllOnesValue()) {
        LLVM_DEBUG(
            dbgs()
            << "checkEMVal: ConnectedVal is constant that is not all ones\n");
        return false; // uses constant that is not all ones, invalid
      }
    } else if (!EMVals.count(ConnectedVal)) {
      LLVM_DEBUG(dbgs() << "checkEMVal: ConnectedVal is not in EMVals : ");
      LLVM_DEBUG(dbgs() << "val == " << *ConnectedVal.getValue() << "\n");
      return false; // connected value is not in EMVals
    }
    LLVM_DEBUG(dbgs() << "checkEMVal: ConnectedVal checked "
                      << *ConnectedVal.getValue() << "\n");
  }
  auto IID = vc::getAnyIntrinsicID(EMVal.getValue());
  switch (IID) {
  case GenXIntrinsic::genx_simdcf_goto:
    return checkGoto(EMVal);
  case GenXIntrinsic::genx_simdcf_join:
    return checkJoin(EMVal);
  default:
    break;
  }
  return true;
}

/***********************************************************************
 * checkGotoJoinSunk : check whether a goto/join is sunk to the bottom of
 *    its basic block, other than extractvalues from its result
 */
static bool checkGotoJoinSunk(CallInst *GotoJoin) {
  for (Instruction *Inst = GotoJoin;;) {
    Inst = Inst->getNextNode();
    if (Inst->isTerminator()) {
      if (!isa<BranchInst>(Inst))
        return false;
      break;
    }
    auto *EV = dyn_cast<ExtractValueInst>(Inst);
    if (!EV || EV->getOperand(0) != GotoJoin)
      return false;
  }
  return true;
}

/***********************************************************************
 * checkGoto : check conformance of an actual goto instruction
 */
bool GenXSimdCFConformance::checkGoto(SimpleValue EMVal) {
  if (!checkGotoJoin(EMVal))
    return false;
  // Check that there is a linked join. (We do not need to check here that the
  // linked join is an EM value; that happened in checkEMVal due to the join
  // being treated as a linked value in getConnectedVals.)
  auto Goto = cast<CallInst>(EMVal.getValue());
  if (!GotoJoinMap[Goto]) {
    LLVM_DEBUG(dbgs() << "checkGoto: no linked join\n");
    return false;
  }
  // Check that the goto is sunk to the end of the block, other than extracts
  // from its result, and a branch. moveCodeInGotoBlocks ensures that if
  // possible; if that failed, this conformance check fails.
  if (!checkGotoJoinSunk(Goto)) {
    LLVM_DEBUG(dbgs() << "checkGoto: not sunk\n");
    return false;
  }
  return true;
}

/***********************************************************************
 * checkJoin : check conformance of an actual join instruction
 */
bool GenXSimdCFConformance::checkJoin(SimpleValue EMVal) {
  if (!checkGotoJoin(EMVal))
    return false;
  // Check that the join is at the start of the block. emptyJoinBlock should
  // have ensured this, unless the code was such that it could not.
  auto Join = cast<CallInst>(EMVal.getValue());
  if (!GotoJoin::isValidJoin(Join)) {
    LLVM_DEBUG(dbgs() << "checkJoin: not valid join\n");
    return false;
  }
  // If the !any result of this join is used in a conditional branch at the
  // end, check that the join is sunk to the end of the block, other than
  // extracts from its result, and a branch. moveCodeInJoinBlocks ensures that
  // if possible; if that failed, this conformance check fails.
  if (auto *Br = dyn_cast<BranchInst>(Join->getParent()->getTerminator()))
    if (Br->isConditional())
      if (auto *EV = dyn_cast<ExtractValueInst>(Br->getCondition()))
        if (EV->getOperand(0) == Join)
          if (!checkGotoJoinSunk(Join)) {
            LLVM_DEBUG(dbgs() << "checkJoin: not sunk\n");
            return false;
          }
  // Gather the web of RM values.
  auto RMValsEntry = &RMVals[Join];
  RMValsEntry->clear();
  LLVM_DEBUG(dbgs() << "checkJoin: gather web of RM vals for " << *Join
                    << "\n");
  if (!isa<Constant>(Join->getOperand(1)))
    RMValsEntry->insert(Join->getOperand(1));
  for (unsigned rvi = 0; rvi != RMValsEntry->size(); ++rvi) {
    SimpleValue RM = (*RMValsEntry)[rvi];
    // RM is a value in this join's RM web. Get other values related by phi
    // nodes and extractvalues and gotos.
    SmallVector<SimpleValue, 8> ConnectedVals;
    bool Ok = getConnectedVals(RM, vc::RegCategory::RM,
                               /*IncludeOptional=*/false, Join, &ConnectedVals);
    LLVM_DEBUG(
        dbgs() << "checkJoin: getConnectedVals-> " << RM.getValue()->getName()
               << "#" << RM.getIndex() << "\n";
        for (auto i = ConnectedVals.begin(), e = ConnectedVals.end(); i != e;
             ++i) dbgs()
        << "   " << i->getValue()->getName() << "#" << i->getIndex() << "\n");
    if (!Ok) {
      LLVM_DEBUG(dbgs() << "checkJoin: illegal RM value in web\n");
      return false;
    }
    for (auto j = ConnectedVals.begin(), je = ConnectedVals.end(); j != je;
         ++j) {
      SimpleValue ConnectedVal = *j;
      if (auto *C = dyn_cast<Constant>(ConnectedVal.getValue())) {
        // A constant in the RM web must be all zeros.
        if (!C->isNullValue()) {
          LLVM_DEBUG(dbgs() << "checkJoin: non-0 constant in RM web\n");
          return false;
        }
      } else {
        // Insert the non-constant value.  If it is a goto with struct index
        // other than 1, it is illegal.
        if (RMValsEntry->insert(ConnectedVal)) {
          LLVM_DEBUG(dbgs() << "checkJoin: New one: "
                            << ConnectedVal.getValue()->getName() << "#"
                            << ConnectedVal.getIndex() << "\n");
          auto IID = vc::getAnyIntrinsicID(ConnectedVal.getValue());
          switch (IID) {
          case GenXIntrinsic::genx_simdcf_join:
            LLVM_DEBUG(dbgs() << "checkJoin: multiple joins in RM web\n");
            return false;
          case GenXIntrinsic::genx_simdcf_goto:
            if (ConnectedVal.getIndex() != 1 /* struct index of RM result */) {
              LLVM_DEBUG(dbgs() << "checkJoin: wrong struct index in goto\n");
              return false;
            }
            break;
          default:
            break;
          }
        }
      }
    }
  }
  // Check whether the RM values interfere with each other.
  SetVector<Value *> BadDefs;
  checkInterference(RMValsEntry, &BadDefs, Join);
  if (!BadDefs.empty()) {
    LLVM_DEBUG(dbgs() << "checkJoin: RMs interfere\n");
    return false;
  }
  // Set GotoJoinMap for each goto in the RM web.
  for (unsigned rvi = 0; rvi != RMValsEntry->size(); ++rvi) {
    SimpleValue RM = (*RMValsEntry)[rvi];
    auto IID = vc::getAnyIntrinsicID(RM.getValue());
    if (IID == GenXIntrinsic::genx_simdcf_goto)
      GotoJoinMap[cast<CallInst>(RM.getValue())] = Join;
  }
  return true;
}

/***********************************************************************
 * getEmptyCriticalEdgeSplitterSuccessor : if BB is an empty critical edge
 *    splitter block (one predecessor and one successor), then return the
 *    single successor
 */
static BasicBlock *getEmptyCriticalEdgeSplitterSuccessor(BasicBlock *BB) {
  if (!BB->hasOneUse())
    return nullptr; // not exactly one predecessor
  auto *Term = dyn_cast<Instruction>(BB->getFirstNonPHIOrDbg());
  if (!Term->isTerminator())
    return nullptr; // not empty
  auto TI = cast<IGCLLVM::TerminatorInst>(Term);
  if (TI->getNumSuccessors() != 1)
    return nullptr; // not exactly one successor
  return TI->getSuccessor(0);
}

/***********************************************************************
 * checkGotoJoin : common code to check conformance of an actual goto or join
 *    instruction
 */
bool GenXSimdCFConformance::checkGotoJoin(SimpleValue EMVal) {
  auto CI = cast<CallInst>(EMVal.getValue());
  // If there is an extract of the scalar result of the goto/join, check that
  // it is used in the conditional branch at the end of the block.
  ExtractValueInst *ExtractScalar = nullptr;
  for (auto ui = CI->use_begin(), ue = CI->use_end(); ui != ue; ++ui)
    if (auto *EV = dyn_cast<ExtractValueInst>(ui->getUser()))
      if (!isa<VectorType>(EV->getType()) && EV->hasNUsesOrMore(1)) {
        if (ExtractScalar) {
          LLVM_DEBUG(dbgs() << "checkGotoJoin: goto/join has more than one "
                               "extract of its !any result\n");
          return false;
        }
        ExtractScalar = EV;
      }
  if (ExtractScalar) {
    if (!ExtractScalar->hasOneUse()) {
      LLVM_DEBUG(dbgs() << "checkGotoJoin: goto/join's !any result does not "
                           "have exactly one use\n");
      return false;
    }
    auto *Br = dyn_cast<BranchInst>(ExtractScalar->use_begin()->getUser());
    if (!Br || Br->getParent() != CI->getParent()) {
      LLVM_DEBUG(dbgs() << "checkGotoJoin: goto/join's !any result not used in "
                           "conditional branch in same block\n");
      return false;
    }
    // For a goto/join with a conditional branch, check that the "true"
    // successor is a join label. We also tolerate there being an empty
    // critical edge splitter block in between; this will get removed in
    // setCategories in this pass.
    BasicBlock *TrueSucc = Br->getSuccessor(0);
    Instruction *First = TrueSucc->getFirstNonPHIOrDbg();
    auto IID = vc::getAnyIntrinsicID(First);
    if (IID != GenXIntrinsic::genx_simdcf_join) {
      // "True" successor is not a join label. Check for an empty critical edge
      // splitter block in between.
      TrueSucc = getEmptyCriticalEdgeSplitterSuccessor(TrueSucc);
      if (!TrueSucc) {
        LLVM_DEBUG(
            dbgs()
            << "checkGotoJoin: goto/join true successor not join label\n");
        return false; // Not empty critical edge splitter
      }
      if (vc::getAnyIntrinsicID(TrueSucc->getFirstNonPHIOrDbg()) !=
          GenXIntrinsic::genx_simdcf_join) {
        LLVM_DEBUG(
            dbgs()
            << "checkGotoJoin: goto/join true successor not join label\n");
        return false; // Successor is not join label
      }
    }
  }
  return true;
}

/***********************************************************************
 * removeBadEMVal : remove a bad EM value
 *
 * This removes a non-conformant EM value, and pushes any connected EM value
 * onto the stack so it gets re-checked for conformance.
 */
void GenXSimdCFConformance::removeBadEMVal(SimpleValue EMVal) {
  LLVM_DEBUG(dbgs() << "removeBadEMVal: "; EMVal.print(dbgs()); dbgs() << "\n");
  // Remove the EM value.
  if (!EMVals.remove(EMVal))
    return; // was not in EMVals
  // Push anything related to it onto the stack for re-checking.
  SmallVector<SimpleValue, 8> ConnectedVals;
  getConnectedVals(EMVal, vc::RegCategory::EM, /*IncludeOptional=*/true,
                   /*OkJoin=*/nullptr, &ConnectedVals);
  for (auto i = ConnectedVals.begin(), e = ConnectedVals.end(); i != e; ++i) {
    SimpleValue ConnectedVal = *i;
    if (EMVals.count(ConnectedVal))
      EMValsStack.insert(ConnectedVal);
  }
}

/***********************************************************************
 * pushValues : push EM struct elements in a value onto EMValsStack
 */
void GenXSimdCFConformance::pushValues(Value *V) {
  for (unsigned si = 0, se = IndexFlattener::getNumElements(V->getType());
       si != se; ++si) {
    SimpleValue SV(V, si);
    if (EMVals.count(SV))
      EMValsStack.insert(SV);
  }
}

/***********************************************************************
 * checkAllUsesAreSelectOrWrRegion : check that all uses of a value are the
 *    condition in select or wrregion or wrpredpredregion (or a predicate
 *    in a non-ALU intrinsic)
 *
 * This is used in getConnectedVals below for the result of a use of an EM
 * value in an rdpredregion, or a shufflevector that is a slice so will be
 * lowered to rdpredregion.
 */
static bool checkAllUsesAreSelectOrWrRegion(Value *V) {
  LLVM_DEBUG(dbgs() << "checkAllUsesAreSelectOrWrRegion: start\n");
  for (auto ui2 = V->use_begin(); ui2 != V->use_end(); /*empty*/) {
    auto User2 = cast<Instruction>(ui2->getUser());
    unsigned OpNum = ui2->getOperandNo();
    ++ui2;
    LLVM_DEBUG(dbgs() << "checkAllUsesAreSelectOrWrRegion: for user " << *User2 << "\n");
    if (isa<SelectInst>(User2))
      continue;

    // Matches uses that can be turned into select.
    if (auto *BI = dyn_cast<BinaryOperator>(User2)) {
      LLVM_DEBUG(dbgs() << "checkAllUsesAreSelectOrWrRegion: binary op\n");

      auto Opc = BI->getOpcode();
      Constant *AllOne = Constant::getAllOnesValue(V->getType());
      Constant *AllNul = Constant::getNullValue(V->getType());

      // EM && X -> sel EM X 0
      // EM || X -> sel EM 1 X
      if (Opc == BinaryOperator::And || Opc == BinaryOperator::Or) {
        Value *Ops[3] = {V, nullptr, nullptr};
        if (Opc == BinaryOperator::And) {
          Ops[1] = BI->getOperand(1 - OpNum);
          Ops[2] = AllNul;
        } else if (Opc == BinaryOperator::Or) {
          Ops[1] = AllOne;
          Ops[2] = BI->getOperand(1 - OpNum);
        }
        auto SI = SelectInst::Create(Ops[0], Ops[1], Ops[2], ".revsel", BI, BI);
        BI->replaceAllUsesWith(SI);
        BI->eraseFromParent();
        continue;
      }

      // ~EM || X ==> sel EM, X, 1
      using namespace PatternMatch;
      if (BI->hasOneUse() &&
          BI->user_back()->getOpcode() == BinaryOperator::Or &&
          match(BI, m_Xor(m_Specific(V), m_Specific(AllOne)))) {
        Instruction *OrInst = BI->user_back();
        Value *Op = OrInst->getOperand(0) != BI ? OrInst->getOperand(0)
                                                : OrInst->getOperand(1);
        auto SI = SelectInst::Create(V, Op, AllOne, ".revsel", OrInst, OrInst);
        OrInst->replaceAllUsesWith(SI);
        OrInst->eraseFromParent();
        BI->eraseFromParent();
        continue;
      }

      // ~EM && X ==> sel EM, 0, X
      using namespace PatternMatch;
      if (BI->hasOneUse() &&
          BI->user_back()->getOpcode() == BinaryOperator::And &&
          match(BI, m_Xor(m_Specific(V), m_Specific(AllOne)))) {
        Instruction *AndInst = BI->user_back();
        Value *Op = AndInst->getOperand(0) != BI ? AndInst->getOperand(0)
                                                 : AndInst->getOperand(1);
        auto SI =
            SelectInst::Create(V, AllNul, Op, ".revsel", AndInst, AndInst);
        AndInst->replaceAllUsesWith(SI);
        AndInst->eraseFromParent();
        BI->eraseFromParent();
        continue;
      }
      LLVM_DEBUG(dbgs() << "checkAllUsesAreSelectOrWrRegion: failed\n");
    } else if (auto *CI = dyn_cast<CastInst>(User2)) {
      LLVM_DEBUG(dbgs() << "checkAllUsesAreSelectOrWrRegion: cast inst\n");
      // Turn zext/sext to select.
      if (CI->getOpcode() == Instruction::CastOps::ZExt ||
          CI->getOpcode() == Instruction::CastOps::SExt) {
        unsigned NElts =
            cast<IGCLLVM::FixedVectorType>(V->getType())->getNumElements();
        unsigned NBits = CI->getType()->getScalarSizeInBits();
        int Val = (CI->getOpcode() == Instruction::CastOps::ZExt) ? 1 : -1;
        APInt One(NBits, Val);
        Constant *LHS = ConstantVector::getSplat(
            IGCLLVM::getElementCount(NElts),
            ConstantInt::get(CI->getType()->getScalarType(), One));
        Constant *AllNul = Constant::getNullValue(CI->getType());
        auto SI = SelectInst::Create(V, LHS, AllNul, ".revsel", CI, CI);
        CI->replaceAllUsesWith(SI);
        CI->eraseFromParent();
        continue;
      }
      LLVM_DEBUG(dbgs() << "checkAllUsesAreSelectOrWrRegion: failed\n");
    }

    unsigned IID = vc::getAnyIntrinsicID(User2);
    if (GenXIntrinsic::isWrRegion(IID))
      continue;
    if (IID == GenXIntrinsic::genx_wrpredpredregion &&
        OpNum == IGCLLVM::getNumArgOperands(cast<CallInst>(User2)) - 1)
      continue;
    if (vc::isAnyNonTrivialIntrinsic(IID) &&
        !cast<CallInst>(User2)->doesNotAccessMemory())
      continue;
    LLVM_DEBUG(
        dbgs() << "checkAllUsesAreSelectOrWrRegion: not found pattern!\n");

    return false;
  }
  return true;
}

static inline void PrepareFunctionAttributes(Function *CalledFunc,
                                             Type *ValTy) {
  if (CalledFunc->hasFnAttribute(vc::FunctionMD::VCSimdCFArg)) {
    for (auto *Arg = CalledFunc->arg_begin(); Arg != CalledFunc->arg_end();
         ++Arg)
      if (Arg->getType() == ValTy) {
        CalledFunc->removeFnAttr(vc::FunctionMD::VCSimdCFArg);
        LLVM_DEBUG(dbgs() << "Remove VCSimdCFArg attr from "
                          << CalledFunc->getName() << "\n");
        break;
      }
  }
  if (CalledFunc->hasFnAttribute(vc::FunctionMD::VCSimdCFRet)) {
    auto *RetTy = CalledFunc->getReturnType();
    if (auto *ST = dyn_cast<StructType>(RetTy)) {
      unsigned RetIdx = 0;
      for (unsigned End = IndexFlattener::getNumElements(ST); RetIdx < End;
           ++RetIdx) {
        auto *Ty = IndexFlattener::getElementType(ST, RetIdx);
        if (Ty->isVectorTy() && Ty->getScalarType()->isIntegerTy(1))
          CalledFunc->removeFnAttr(vc::FunctionMD::VCSimdCFRet);
      }
    }
    if (RetTy->isVectorTy() && RetTy->getScalarType()->isIntegerTy(1))
      CalledFunc->removeFnAttr(vc::FunctionMD::VCSimdCFRet);
  }
}

/***********************************************************************
 * getConnectedVals : for a SimpleValue, get other SimpleValues connected to
 *    it through phi nodes, insertvalue, extractvalue, goto/join, and maybe
 *    args and return values
 *
 * Enter:   Val = SimpleValue to start at
 *          Cat = vc::RegCategory::EM to do EM connections
 *                vc::RegCategory::RM to do RM connections
 *          IncludeOptional = for EM connections, include optional connections
 *                where Val is a function arg and it is connected to call args,
 *                and where Val is the operand to return and it is connected to
 *                the returned value at call sites
 *          OkJoin = for RM connections, error if a use in a join other than
 *                this one is found
 *          ConnectedVals = vector to store connected values in
 *
 * Return:  true if ok, false if def or some use is not suitable for EM/RM
 *
 * The provided value must be non-constant, but the returned connected values
 * may include constants. Duplicates may be stored in ConnectedVals.
 *
 * This function is used in three different ways by its callers:
 *
 * 1. to gather a web of putative EM values or RM values starting at goto/join
 *    instructions;
 *
 * 2. to test whether a putative EM/RM value is valid by whether its connected
 *    neighbors are EM/RM values;
 *
 * 3. when removing a value from the EM/RM values list, to find its connected
 *    neighbors to re-run step 2 on each of them.
 *
 * TODO: some refactoring should be performed here due to quite big
 *       CF with many different actions. Also some of these actions
 *       are repeated in different situations.
 */
bool GenXSimdCFConformance::getConnectedVals(
    SimpleValue Val, vc::RegCategory Cat, bool IncludeOptional,
    CallInst *OkJoin, SmallVectorImpl<SimpleValue> *ConnectedVals,
    bool LowerBadUsers) {
  // Check the def first.
  if (auto *Arg = dyn_cast<Argument>(Val.getValue())) {
    if (Cat != vc::RegCategory::EM)
      return false; // can't have RM argument
    // Connected to some return value. There is a problem here in that it might
    // find another predicate return value that is nothing to do with SIMD CF,
    // and thus stop SIMD CF being optimized. But passing a predicate in and
    // out of a function is rare outside of SIMD CF, so we do not worry about
    // that.
    // It is possible that EM was optimized from ret. In this case the ret type
    // is void. Allow such situation.
    Function *F = Arg->getParent();
    unsigned RetIdx = 0;
    auto RetTy = F->getReturnType();
    auto ValTy = IndexFlattener::getElementType(Val.getValue()->getType(),
                                                Val.getIndex());
    if (auto *ST = dyn_cast<StructType>(RetTy)) {
      for (unsigned End = IndexFlattener::getNumElements(ST);; ++RetIdx) {
        if (RetIdx == End)
          return false; // no predicate ret value found
        if (IndexFlattener::getElementType(ST, RetIdx) == ValTy)
          break;
      }
    } else if (RetTy != ValTy && !RetTy->isVoidTy())
      return F->hasFnAttribute(
          vc::FunctionMD::VCSimdCFRet); // no predicate ret value found
    if (!RetTy->isVoidTy())
      for (auto fi = F->begin(), fe = F->end(); fi != fe; ++fi)
        if (auto *Ret = dyn_cast<ReturnInst>(fi->getTerminator()))
          ConnectedVals->push_back(SimpleValue(Ret->getOperand(0), RetIdx));
    if (IncludeOptional) {
      // With IncludeOptional, also add the corresponding arg at each call
      // site.
      for (auto *U : F->users())
        if (auto *CI = checkFunctionCall(U, F))
          ConnectedVals->push_back(
              SimpleValue(CI->getArgOperand(Arg->getArgNo()), Val.getIndex()));
    }
  } else if (auto *Phi = dyn_cast<PHINode>(Val.getValue())) {
    // phi: add (the corresponding struct element of) each incoming
    for (unsigned oi = 0, oe = Phi->getNumIncomingValues(); oi != oe; ++oi)
      ConnectedVals->push_back(
          SimpleValue(Phi->getIncomingValue(oi), Val.getIndex()));
  } else if (auto *EVI = dyn_cast<ExtractValueInst>(Val.getValue())) {
    // extractvalue: add the appropriate struct element of the input
    ConnectedVals->push_back(SimpleValue(
        EVI->getOperand(0),
        Val.getIndex() + IndexFlattener::flatten(
                             cast<StructType>(EVI->getOperand(0)->getType()),
                             EVI->getIndices())));
  } else if (auto *IVI = dyn_cast<InsertValueInst>(Val.getValue())) {
    // insertvalue: add the appropriate struct element in either the
    // aggregate input or the value to insert input
    unsigned InsertedIndex =
        Val.getIndex() -
        IndexFlattener::flatten(cast<StructType>(IVI->getType()),
                                IVI->getIndices());
    unsigned NumElements =
        IndexFlattener::getNumElements(IVI->getOperand(1)->getType());
    SimpleValue SV;
    if (InsertedIndex < NumElements)
      SV = SimpleValue(IVI->getOperand(1), InsertedIndex);
    else
      SV = SimpleValue(IVI->getOperand(0), Val.getIndex());
    ConnectedVals->push_back(SV);
  } else if (auto *SVI = dyn_cast<ShuffleVectorInst>(Val.getValue())) {
    // shufflevector: add the EM use
    ConnectedVals->push_back(SimpleValue(SVI->getOperand(0), 0));
  } else if (auto *CI = dyn_cast<CallInst>(Val.getValue())) {
    switch (vc::getAnyIntrinsicID(CI)) {
    case GenXIntrinsic::genx_simdcf_goto:
      // goto: invalid unless it is the EM/RM result of goto as applicable
      if (Val.getIndex() != (Cat == vc::RegCategory::EM ? 0U : 1U))
        return false;
      // Add the corresponding input.
      ConnectedVals->push_back(CI->getOperand(Val.getIndex()));
      // If doing EM connections, add the corresponding join. This does
      // nothing if checkJoin has not yet run for the corresponding join,
      // since GotoJoinMap has not yet been set up for our goto. We tolerate
      // that situation; if the goto really has no linked join, that is
      // picked up later in checkGoto.
      if (Cat == vc::RegCategory::EM)
        if (auto Join = GotoJoinMap[cast<CallInst>(Val.getValue())])
          ConnectedVals->push_back(
              SimpleValue(Join, 0 /* struct idx of EM result */));
      break;
    case GenXIntrinsic::genx_simdcf_join: {
      // join: invalid unless it is the EM result
      if (Val.getIndex() || Cat != vc::RegCategory::EM)
        return false;
      // Add the corresponding input.
      ConnectedVals->push_back(CI->getOperand(Val.getIndex()));
      // Add the corresponding gotos. This does nothing if checkJoin has not
      // yet run for this join, since RMVals has not yet been set up for it.
      // That is OK, because adding the corresponding gotos here is required
      // only when we are called by removeBadEMVal to remove the join, so the
      // gotos get re-checked and found to be invalid.
      auto RMValsEntry = &RMVals[cast<CallInst>(Val.getValue())];
      for (auto i = RMValsEntry->begin(), e = RMValsEntry->end(); i != e; ++i)
        if (vc::getAnyIntrinsicID(i->getValue()) ==
            GenXIntrinsic::genx_simdcf_goto)
          ConnectedVals->push_back(
              SimpleValue(i->getValue(), 0 /* struct idx of EM result */));
      break;
    }
    case GenXIntrinsic::genx_simdcf_savemask:
    case GenXIntrinsic::genx_simdcf_remask:
    case GenXIntrinsic::genx_simdcf_get_em:
      // Add the corresponding input.
      ConnectedVals->push_back(CI->getOperand(0));
      return true;
    case GenXIntrinsic::genx_constantpred:
      // constantpred: add the constant. Don't add any other uses of it,
      // because it might be commoned up with other RM webs.
      ConnectedVals->push_back(CI->getOperand(0));
      return true;
    case GenXIntrinsic::not_any_intrinsic: {
      // Value returned from a call.
      if (Cat != vc::RegCategory::EM)
        return false; // invalid for RM
      // Add the corresponding value at each return in the called function.
      auto CalledFunc = CI->getCalledFunction();
      for (auto fi = CalledFunc->begin(), fe = CalledFunc->end(); fi != fe;
           ++fi)
        if (auto *Ret = dyn_cast<ReturnInst>(fi->getTerminator()))
          if (!Ret->getType()->isVoidTy())
            ConnectedVals->push_back(
                SimpleValue(Ret->getOperand(0), Val.getIndex()));
      // Connected to some call arg. There is a problem here in that it might
      // find another predicate arg that is nothing to do with SIMD CF, and
      // thus stop SIMD CF being optimized. But passing a predicate in and
      // out of a function is rare outside of SIMD CF, so we do not worry
      // about that.
      auto ValTy =
          IndexFlattener::getElementType(Val.getType(), Val.getIndex());

      PrepareFunctionAttributes(CalledFunc, ValTy);

      if (!CalledFunc->hasFnAttribute(vc::FunctionMD::VCSimdCFArg))
        for (unsigned Idx = 0, End = IGCLLVM::getNumArgOperands(CI);; ++Idx) {
          if (Idx == End)
            return false; // no corresponding call arg found
          if (CI->getArgOperand(Idx)->getType() == ValTy) {
            ConnectedVals->push_back(SimpleValue(CI->getArgOperand(Idx), 0));
            break;
          }
        }
      break;
    }
    default:
      return false; // unexpected call as def
    }
  } else
    return false; // unexpected instruction as def
  // Check the uses.
  std::vector<SimpleValue> UsersToLower;
  LLVM_DEBUG(dbgs() << "getConnectedVals: get all uses for " << *Val.getValue()
                    << "\n");
  for (auto ui = Val.getValue()->use_begin(), ue = Val.getValue()->use_end();
       ui != ue; ++ui) {
    auto User = cast<Instruction>(ui->getUser());
    // Can't have 2 difference EMs in one instruction
    if (!isa<PHINode>(User) && !isa<CallInst>(User) &&
        !isa<ExtractValueInst>(User) && !CrossInsts.insert(User).second) {
      LLVM_DEBUG(dbgs() << "Found multy-EM use instruction:\n" << *User);
      lowerSimdCF = true;
    }
    LLVM_DEBUG(dbgs() << "getConnectedVals: -> geted " << *User << "\n");
    if (auto *Phi = dyn_cast<PHINode>(User)) {
      // Use in phi node. Add the phi result.
      ConnectedVals->push_back(SimpleValue(Phi, Val.getIndex()));
      continue;
    }
    if (auto *EVI = dyn_cast<ExtractValueInst>(User)) {
      // Use in extractvalue.
      // If extracting the right index, add the result.
      unsigned StartIndex = IndexFlattener::flatten(
          cast<StructType>(EVI->getOperand(0)->getType()), EVI->getIndices());
      unsigned NumIndices = IndexFlattener::getNumElements(EVI->getType());
      unsigned ExtractedIndex = Val.getIndex() - StartIndex;
      if (ExtractedIndex < NumIndices)
        ConnectedVals->push_back(SimpleValue(EVI, ExtractedIndex));
      continue;
    }
    if (auto *IVI = dyn_cast<InsertValueInst>(User)) {
      // Use in insertvalue. Could be either the aggregate input or the value
      // to insert.
      unsigned StartIndex = IndexFlattener::flatten(
          cast<StructType>(IVI->getType()), IVI->getIndices());
      unsigned NumIndices =
          IndexFlattener::getNumElements(IVI->getOperand(1)->getType());
      if (!ui->getOperandNo()) {
        // Use in insertvalue as the aggregate input. Add the corresponding
        // element in the result, as long as it is not overwritten by the
        // insertvalue.
        if (Val.getIndex() - StartIndex >= NumIndices)
          ConnectedVals->push_back(SimpleValue(IVI, Val.getIndex()));
      } else {
        // Use in insertvalue as the value to insert. Add the corresponding
        // element in the result.
        ConnectedVals->push_back(SimpleValue(IVI, StartIndex + Val.getIndex()));
      }
      continue;
    }
    if (isa<ReturnInst>(User)) {
      // Use in a return.
      if (Cat != vc::RegCategory::EM)
        return false; // invalid for RM
      // Connected to some function arg. There is a problem here in that it
      // might find another predicate arg that is nothing to do with SIMD CF,
      // and thus stop SIMD CF being optimized. But passing a predicate in and
      // out of a function is rare outside of SIMD CF, so we do not worry
      // about that.
      auto ValTy =
          IndexFlattener::getElementType(Val.getType(), Val.getIndex());
      auto F = User->getFunction();

      bool Lower = false;
      PrepareFunctionAttributes(F, ValTy);
      if (!F->hasFnAttribute(vc::FunctionMD::VCSimdCFArg))
        for (auto ai = F->arg_begin(), ae = F->arg_end();; ++ai) {
          if (ai == ae) {
            // no arg of the right type found
            Lower = true;
            UsersToLower.push_back(SimpleValue(User, ui->getOperandNo()));
            LLVM_DEBUG(dbgs() << "getConnectedVals: ai == ae push_back "
                              << *User << " No=" << ui->getOperandNo() << "\n");
            break;
          }
          auto Arg = &*ai;
          if (Arg->getType() == ValTy) {
            ConnectedVals->push_back(SimpleValue(Arg, 0));
            break;
          }
        }
      if (IncludeOptional && !Lower) {
        // With IncludeOptional, also add the values connected by being the
        // return value at each call site.
        for (auto *U : F->users())
          if (auto *CI = checkFunctionCall(U, F))
            ConnectedVals->push_back(SimpleValue(CI, Val.getIndex()));
      }
      continue;
    }
    if (isa<SelectInst>(User)) {
      // A use in a select is allowed only for EM used as the condition.
      if (Cat != vc::RegCategory::EM || ui->getOperandNo() != 0) {
        UsersToLower.push_back(SimpleValue(User, ui->getOperandNo()));
        LLVM_DEBUG(dbgs() << "getConnectedVals: cat push_back " << *User
                          << " No=" << ui->getOperandNo() << "\n");
      }
      continue;
    }
    if (auto *SVI = dyn_cast<ShuffleVectorInst>(User)) {
      if (!ShuffleVectorAnalyzer(SVI).isReplicatedSlice()) {
        UsersToLower.push_back(SimpleValue(User, ui->getOperandNo()));
        LLVM_DEBUG(dbgs() << "getConnectedVals: shuffle push_back " << *User
                          << " No=" << ui->getOperandNo() << "\n");
        continue;
      }
      // This is a shufflevector that is a replicated slice, so it can be
      // lowered to rdpredregion or baled with instruction with channels.
      // (We only see this in the early pass; GenXLowering has
      // turned it into rdpredregion by the late pass.) Check that all its uses
      // are select or wrregion.
      if (!checkAllUsesAreSelectOrWrRegion(SVI)) {
        UsersToLower.push_back(SimpleValue(User, ui->getOperandNo()));
        LLVM_DEBUG(dbgs() << "getConnectedVals: UsersToLower push_back "
                          << *User << " No=" << ui->getOperandNo() << "\n");
        continue;
      }
      // Shufflevector produces EM for value baled inst, so this is a (almost)
      // real EM def: add it here to perform correct EM interference check
      ConnectedVals->push_back(SimpleValue(SVI, ui->getOperandNo()));
      continue;
    }
    if (auto *CI = dyn_cast<CallInst>(User)) {
      LLVM_DEBUG(dbgs() << "getConnectedVals: CallInst income\n"
                        << *CI << "\n");
      switch (vc::getAnyIntrinsicID(CI)) {
      case GenXIntrinsic::genx_simdcf_get_em:
        LLVM_DEBUG(dbgs() << "getConnectedVals: case genx_simdcf_get_em\n");
        IGC_ASSERT(Cat == vc::RegCategory::EM);
        // Skip it if the category is right. This
        // intrinsic doesn't produce EM
        break;
      case GenXIntrinsic::genx_simdcf_unmask:
      case GenXIntrinsic::genx_simdcf_remask:
        LLVM_DEBUG(dbgs() << "getConnectedVals: case genx_simdcf_**mask\n");
        IGC_ASSERT(Cat == vc::RegCategory::EM);
        ConnectedVals->push_back(SimpleValue(CI, 0));
        break;
      case GenXIntrinsic::genx_simdcf_goto:
        LLVM_DEBUG(dbgs() << "getConnectedVals: case genx_simdcf_goto\n");
        // use in goto: valid only if arg 0 (EM) or 1 (RM)
        LLVM_DEBUG(dbgs() << "with operand no = " << ui->getOperandNo()
                          << "\n");
        if (ui->getOperandNo() != (Cat == vc::RegCategory::EM ? 0U : 1U))
          return false;
        // Add corresponding result.
        ConnectedVals->push_back(SimpleValue(CI, ui->getOperandNo()));
        break;
      case GenXIntrinsic::genx_simdcf_join:
        LLVM_DEBUG(dbgs() << "getConnectedVals: case genx_simdcf_join\n");
        // use in join: valid only if arg 0 (EM) or 1 (RM)
        if (ui->getOperandNo() != (Cat == vc::RegCategory::EM ? 0U : 1U))
          return false;
        // If EM, add corresponding result.
        if (Cat == vc::RegCategory::EM)
          ConnectedVals->push_back(SimpleValue(CI, 0));
        else if (OkJoin && OkJoin != CI) {
          // RM value used in a join other than OkJoin. That is illegal, as we
          // can only have one join per RM web.
          LLVM_DEBUG(dbgs() << "getConnectedVals: found illegal join "
                            << CI->getName() << "\n");
          return false;
        }
        break;
      case GenXIntrinsic::genx_wrregionf:
      case GenXIntrinsic::genx_wrregioni:
        break; // Use as wrregion predicate is allowed.
      case GenXIntrinsic::genx_rdpredregion:
        LLVM_DEBUG(dbgs() << "getConnectedVals: case genx_rdpredregion\n");
        // We only see rdpredregion in the late pass; in the early pass it is
        // still a shufflevector.  Check that all its uses are select or
        // wrregion.
        if (!checkAllUsesAreSelectOrWrRegion(CI)) {
          UsersToLower.push_back(SimpleValue(User, ui->getOperandNo()));
          LLVM_DEBUG(dbgs() << "getConnectedVals: UsersToLower push_back "
                            << *CI << " No=" << ui->getOperandNo() << "\n");
        }
        break;
      case GenXIntrinsic::genx_wrpredpredregion:
        LLVM_DEBUG(dbgs() << "getConnectedVals: case genx_wrpredpredregion\n");
        // Use in wrpredpredregion allowed as the last arg.
        if (ui->getOperandNo() + 1 != IGCLLVM::getNumArgOperands(CI)) {
          UsersToLower.push_back(SimpleValue(User, ui->getOperandNo()));
          LLVM_DEBUG(dbgs() << "getConnectedVals: operands ne push_back " << *CI
                            << " No=" << ui->getOperandNo() << "\n");
        }
        break;
      default:
        LLVM_DEBUG(dbgs() << "case default\n");
        // Allowed as an predicate in a non-ALU intrinsic.
        if (CI->getCalledFunction()->doesNotAccessMemory()) {
          UsersToLower.push_back(SimpleValue(User, ui->getOperandNo()));
          LLVM_DEBUG(dbgs()
                     << "getConnectedVals: does not access memory push_back "
                     << *CI << " No=" << ui->getOperandNo() << "\n");
        }
        break;
      case GenXIntrinsic::not_any_intrinsic: {
        LLVM_DEBUG(dbgs() << "getConnectedVals: case not_any_intrinsic\n");
        // Use in subroutine call. Add the corresponding function arg.
        Function *CalledFunc = CI->getCalledFunction();
        IGC_ASSERT(CalledFunc);

        auto ValTy = IndexFlattener::getElementType(Val.getValue()->getType(),
                                                    Val.getIndex());

        PrepareFunctionAttributes(CalledFunc, ValTy);
        // If Attribute setted - do not check
        if (!CalledFunc->hasFnAttribute(vc::FunctionMD::VCSimdCFArg)) {

          auto ai = CalledFunc->arg_begin();
          for (unsigned Count = ui->getOperandNo(); Count; --Count, ++ai)
            ;
          Argument *Arg = &*ai;
          ConnectedVals->push_back(SimpleValue(Arg, Val.getIndex()));
        }
        // Connected to some return value from the call. There is a problem
        // here in that it might find another predicate return value that is
        // nothing to do with SIMD CF, and thus stop SIMD CF being optimized.
        // But passing a predicate in and out of a function is rare outside
        // of SIMD CF, so we do not worry about that.
        if (!CalledFunc->hasFnAttribute(vc::FunctionMD::VCSimdCFRet)) {
          unsigned RetIdx = 0;
          if (auto *ST = dyn_cast<StructType>(CI->getType())) {
            LLVM_DEBUG(dbgs()
                       << "getConnectedVals: StructType get" << *ST << "\n");
            for (unsigned End = IndexFlattener::getNumElements(ST);; ++RetIdx) {
              if (RetIdx == End) {
                UsersToLower.push_back(SimpleValue(
                    User, ui->getOperandNo())); // no predicate ret value found
                LLVM_DEBUG(dbgs() << "getConnectedVals: push_back " << *CI
                                  << " No=" << ui->getOperandNo() << "\n");
              }
              if (IndexFlattener::getElementType(ST, RetIdx) == ValTy) {
                ConnectedVals->push_back(SimpleValue(CI, RetIdx));
                break;
              }
            }
          } else if (CI->getType() == ValTy)
            ConnectedVals->push_back(SimpleValue(CI, 0));
          else if (!CI->getType()->isVoidTy()) {
            UsersToLower.push_back(SimpleValue(
                User, ui->getOperandNo())); // no predicate ret value found
            LLVM_DEBUG(dbgs() << "getConnectedVals: push_back " << *CI
                              << " No=" << ui->getOperandNo() << "\n");
          }
        }
        break;
      }
      }
      continue;
    }
    UsersToLower.push_back(SimpleValue(User, ui->getOperandNo()));
    LLVM_DEBUG(dbgs() << "getConnectedVals: default push_back " << *User
                      << " No=" << ui->getOperandNo() << "\n");
  }

  if (LowerBadUsers) {
    SetVector<Value *> ToRemove;
    for (auto &BadUser : UsersToLower) {
      replaceUseWithLoweredEM(dyn_cast<Instruction>(BadUser.getValue()),
                              BadUser.getIndex(), ToRemove);
    }
    for (auto Inst : ToRemove) {
      removeFromEMRMVals(Inst);
    }
  } else {
    if (!UsersToLower.empty()) {
      LLVM_DEBUG(dbgs() << "getConnectedVals: find bad users:\n";
                 for (auto &BadUser
                      : UsersToLower) {
                   dbgs() << "    ";
                   BadUser.print(dbgs());
                 });
      return false;
    }
  }

  return true;
}

// check if this is an EM value or part of an EM value.
static bool isEM(Value *V) {
  if (auto *SI = dyn_cast<ShuffleVectorInst>(V))
    return isEM(SI->getOperand(0)) || isEM(SI->getOperand(1));
  return GotoJoin::isEMValue(V);
}

// canonicalizeEM : canonicalize EM uses so that EM uses will not
// stop SIMD-CF conformance.
void GenXSimdCFConformance::canonicalizeEM() {
  using namespace PatternMatch;
  std::vector<Instruction *> DeadInstructions;

  for (auto &F : M->getFunctionList())
    for (auto &BB : F) {
      for (Instruction *Inst = BB.getTerminator(); Inst;) {
        // select(C0&C1, a, b) -> select(C0, select(C1, a, b), b)
        // select(C0|C1, a, b) -> select(C0, a, select(C1, a, b))
        Value *C0, *C1, *A, *B;
        if (match(Inst, m_Select(m_BinOp(m_Value(C0), m_Value(C1)), m_Value(A),
                                 m_Value(B)))) {
          bool C1IsEM = isEM(C1);
          if (C1IsEM || isEM(C0)) {
            Value *Cond = Inst->getOperand(0);
            if (Cond->getType()->isVectorTy()) {
              BinaryOperator *BO = cast<BinaryOperator>(Cond);
              // Set Inst as insert point in order to save dominance
              IRBuilder<> Builder(Inst);
              if (C1IsEM)
                std::swap(C0, C1);
              if (BO->getOpcode() == BinaryOperator::And) {
                Value *V = Builder.CreateSelect(C1, A, B);
                V = Builder.CreateSelect(C0, V, B);
                Inst->replaceAllUsesWith(V);
                DeadInstructions.push_back(Inst);
              } else if (BO->getOpcode() == BinaryOperator::Or) {
                Value *V = Builder.CreateSelect(C1, A, B);
                V = Builder.CreateSelect(C0, A, V);
                Inst->replaceAllUsesWith(V);
                DeadInstructions.push_back(Inst);
              }
            }
          }
        }

        Inst = (Inst == &BB.front()) ? nullptr : Inst->getPrevNode();
      }
    }

  for (Instruction *I : DeadInstructions)
    RecursivelyDeleteTriviallyDeadInstructions(I);

  // Collect data for gotos/joins EVs
  handleEVs();
  // Resolve bitcast chains so they don't break conformance
  resolveBitCastChains();
}

/***********************************************************************
 * handleEVs : collect goto/join EVs and perform some transformations
 * on them.
 *
 * All transformations are done in GotoJoinEVs constructor.
 */
void GenXSimdCFConformance::handleEVs() {
  // Collect gotos/joins
  gatherGotoJoinEMVals(false);
  for (auto &val : EMVals) {
    Value *GotoJoin = val.getValue();
    IGC_ASSERT(testIsGotoJoin(GotoJoin));
    GotoJoinEVsMap[GotoJoin] = GotoJoinEVs(GotoJoin);
  }
  EMVals.clear();
}

/***********************************************************************
 * eliminateBitCastPreds : perform bitcast elimination on EM DF
 *
 * GetEMPred should be called earlier to check if Val is actually
 * a EM producer.
 */
Value *GenXSimdCFConformance::eliminateBitCastPreds(
    Value *Val, std::set<Value *> &DeadInst, std::set<Value *> &Visited) {
  Type *EMType =
      IGCLLVM::FixedVectorType::get(Type::getInt1Ty(M->getContext()), 32);

  if (Visited.count(Val)) {
    return EMProducers[Val];
  }

  Visited.insert(Val);

  if (auto *BCI = dyn_cast<BitCastInst>(Val)) {
    IGC_ASSERT_MESSAGE(EMProducers[BCI] == BCI->getOperand(0),
                       "Bad EM producer was saved!");

    DeadInst.insert(BCI);
    return eliminateBitCastPreds(BCI->getOperand(0), DeadInst, Visited);
  } else if (auto *PN = dyn_cast<PHINode>(Val)) {
    IGC_ASSERT_MESSAGE(EMProducers[PN] == PN, "Bad EM producer was saved!");

    PHINode *NewPN = nullptr;
    if (PN->getType() != EMType) {
      // Different type at phi. This may happen if its incoming value
      // became bitcast.
      LLVM_DEBUG(dbgs() << "eliminateBitCastPreds: Creating new PHI for:\n"
                        << *PN << "\n");
      NewPN = PHINode::Create(EMType, PN->getNumIncomingValues(), "EMTerm", PN);
      EMProducers[NewPN] = NewPN;
      // In case of cycle, we will return newly created phi
      EMProducers[PN] = NewPN;
      // Phi can become redundant after it
      DeadInst.insert(PN);
    }

    for (unsigned oi = 0, on = PN->getNumIncomingValues(); oi < on; ++oi) {
      auto EMProd =
          eliminateBitCastPreds(PN->getIncomingValue(oi), DeadInst, Visited);
      if (!NewPN) {
        PN->setIncomingValue(oi, EMProd);
        PN->setIncomingBlock(oi, PN->getIncomingBlock(oi));
      } else {
        NewPN->addIncoming(EMProd, PN->getIncomingBlock(oi));
      }
    }

    return NewPN ? NewPN : PN;
  } else if (auto *C = dyn_cast<Constant>(Val)) {
    IGC_ASSERT_MESSAGE(C->isAllOnesValue(), "Should be checked before!");
    IGC_ASSERT_MESSAGE(EMProducers[C] == C, "Bad EM producer was saved!");

    return Constant::getAllOnesValue(EMType);
  } else {
    IGC_ASSERT(Val);
    IGC_ASSERT_MESSAGE(EMProducers[Val] == Val, "Bad EM producer was saved!");
    IGC_ASSERT_MESSAGE(Val->getType() == EMType,
                       "Unexpected final EM producer!");

    return Val;
  }
}

/***********************************************************************
 * resolveBitCastChains : resolve EM -> (bitcast) -> EM chains
 *
 * Standard LLVM passes create such chains sometimes and it makes
 * SIMD CF non-conformant. Here we check this and make changes to
 * resolve it if possible. If it is not, SIMD CF remains non-conformant
 * and is lowered later.
 */
void GenXSimdCFConformance::resolveBitCastChains() {
  LLVM_DEBUG(dbgs() << "resolveBitCastChains: Resolving Bitcast chains:\n");

  // We don't have EM values here so we have to gather them
  // here, too. This is because we can change EM values set
  // during these transformations.
  gatherEMVals();

  std::set<Value *> DeadInst;
  for (auto &Val : EMVals) {
    if (auto *PN = dyn_cast<PHINode>(Val.getValue())) {
      LLVM_DEBUG(dbgs() << "resolveBitCastChains: Found phi:\n" << *PN << "\n");
    } else if (auto *BCI = dyn_cast<BitCastInst>(Val.getValue())) {
      LLVM_DEBUG(dbgs() << "resolveBitCastChains: Found bitcast:\n"
                        << *BCI << "\n");
    } else
      continue;

    std::set<Value *> Visited;
    Instruction *I = dyn_cast<Instruction>(Val.getValue());
    Value *EMProd = getEMProducer(I, Visited, true);

    if (!EMProd) {
      LLVM_DEBUG(dbgs() << "resolveBitCastChains: !!! Not EM producer was "
                           "detected when resolving bitcast chains !!!\n");
      continue;
    }

    Visited.clear();
    Value *NewEMProd = eliminateBitCastPreds(EMProd, DeadInst, Visited);
    if (NewEMProd != EMProd) {
      EMProd->replaceAllUsesWith(NewEMProd);
    }
  }

  EMVals.clear();

  for (auto DI : DeadInst) {
    if (auto *I = dyn_cast<Instruction>(DI))
      RecursivelyDeleteTriviallyDeadInstructions(I);
  }

  // TODO: since we are using EMProducers only here and during get_em check,
  // clean it after these transformation sinse it may contain dead data.
  EMProducers.clear();

  LLVM_DEBUG(
      dbgs() << "resolveBitCastChains: Done resolving bitcast chains:\n");
}

/***********************************************************************
 * checkEMInterference : check for EM values interfering with each other,
 *      lowering gotos/joins as necessary
 *
 * There is only one EM in the hardware, and we need to model that by ensuring
 * that our multiple EM values, including phi nodes, do not interfere with each
 * other. This is effectively a register allocator with only one register.
 */
void GenXSimdCFConformance::checkEMInterference() {
  // Do an interference check, returning a list of defs that appear in the live
  // range of other values.
  SetVector<Value *> BadDefs;
  checkInterference(&EMVals, &BadDefs, nullptr);
  for (auto i = BadDefs.begin(), e = BadDefs.end(); i != e; ++i)
    removeBadEMVal(*i);
}

/***********************************************************************
 * findLoweredEMValue : find lowered EM Value
 */
Value *GenXSimdCFConformance::findLoweredEMValue(Value *Val) {
  LLVM_DEBUG(dbgs() << "findLoweredEMValue: Looking for lowered value for:\n"
                    << *Val << "\n");

  auto It = LoweredEMValsMap.find(Val);
  if (It != LoweredEMValsMap.end()) {
    auto *loweredVal = It->second;
    LLVM_DEBUG(dbgs() << "findLoweredEMValue: Found lowered value:\n"
                      << *loweredVal << "\n");
    return loweredVal;
  }

  LLVM_DEBUG(dbgs() << "findLoweredEMValue: No lowered value was found\n");

  return nullptr;
}

/***********************************************************************
 * buildLoweringViaGetEM : build GetEM instruction to get explicit EM
 *   from Val.
 */
Value *GenXSimdCFConformance::buildLoweringViaGetEM(Value *Val,
                                                    Instruction *InsertBefore) {
  Function *GetEMDecl = vc::getAnyDeclaration(
      M, GenXIntrinsic::genx_simdcf_get_em, {Val->getType()});
  Value *GetEM = CallInst::Create(GetEMDecl, {Val}, "getEM", InsertBefore);
  LoweredEMValsMap[Val] = GetEM;

  LLVM_DEBUG(dbgs() << "buildLoweringViaGetEM: Built getEM:\n"
                    << *GetEM << "\n");

  return GetEM;
}

/***********************************************************************
 * getGetEMLoweredValue : find lowered EM Value (via GetEM) or build
 *   GetEM instruction if lowered value was not found.
 */
Value *GenXSimdCFConformance::getGetEMLoweredValue(Value *Val,
                                                   Instruction *InsertBefore) {
  auto *GetEM = findLoweredEMValue(Val);

  if (!GetEM) {
    GetEM = buildLoweringViaGetEM(Val, InsertBefore);
  }

  return GetEM;
}

/***********************************************************************
 * lowerEVIUse : lower ExtractValue use.
 *
 * EM is being lowered via genx_simdcf_get_em intrinsic.
 */
Value *GenXSimdCFConformance::lowerEVIUse(ExtractValueInst *EVI,
                                          Instruction *User,
                                          BasicBlock *PhiPredBlock) {
  LLVM_DEBUG(dbgs() << "lowerEVIUse: Lowering EVI use:\n" << *EVI << "\n");

  CallInst *GotoJoin = dyn_cast<CallInst>(EVI->getOperand(0));
  IGC_ASSERT_MESSAGE(testIsGotoJoin(GotoJoin), "Bad ExtractValue with EM!");

  // The CFG was corrected for SIMD CF by earlier transformations
  // so isBranchingGotoJoinBlock works correctly here.
  if (GotoJoin::isBranchingGotoJoinBlock(GotoJoin->getParent()) == GotoJoin) {
    // For branching case, we need to create false and true value
    LLVM_DEBUG(dbgs() << "lowerEVIUse: Handling branching block case\n");

    BasicBlock *DefBB = GotoJoin->getParent();
    BasicBlock *TrueBlock = DefBB->getTerminator()->getSuccessor(0);
    BasicBlock *FalseBlock = DefBB->getTerminator()->getSuccessor(1);
    BasicBlock *Loc = PhiPredBlock ? PhiPredBlock : User->getParent();

    // GetEM is removed later if redundant.
    Value *TrueVal = Constant::getNullValue(EVI->getType());
    Value *FalseVal = getGetEMLoweredValue(EVI, FalseBlock->getFirstNonPHI());

    // Early return for direct phi true edge: lowered value is zeroed
    if (PhiPredBlock == DefBB && TrueBlock == User->getParent()) {
      IGC_ASSERT(PhiPredBlock);
      IGC_ASSERT_MESSAGE(FalseBlock != TrueBlock,
                         "Crit edge should be inserted earlier!");
      return TrueVal;
    }

    std::map<BasicBlock *, Value *> foundVals;
    BasicBlockEdge TrueEdge(DefBB, TrueBlock);
    BasicBlockEdge FalseEdge(DefBB, FalseBlock);

    return findGotoJoinVal(vc::RegCategory::EM, Loc, EVI, TrueEdge, FalseEdge,
                           TrueVal, FalseVal, foundVals);
  }

  // Non-branching case: must be join. Insert get_em right after join's EM
  IGC_ASSERT_MESSAGE(testIsJoin(GotoJoin),
                     "Gotos should be turned into branching earlier!");

  LLVM_DEBUG(dbgs() << "lowerEVIUse: Handling simple join case\n");

  return getGetEMLoweredValue(EVI, EVI->getNextNode());
}

/***********************************************************************
 * lowerPHIUse : lower PHI use.
 *
 * EM is being lowered via genx_simdcf_get_em intrinsic.
 * This intrinsic is inserted right after the phis in current BB
 * in case of non-join block. For join blocks, the full PHI lowering
 * is performed: we have to lower all incoming values.
 *
 * Lowered phis are also stored in LoweredPhisMap to
 * prevent redundant lowerings.
 */
Value *GenXSimdCFConformance::lowerPHIUse(PHINode *PN,
                                          SetVector<Value *> &ToRemove) {
  LLVM_DEBUG(dbgs() << "lowerPHIUse: Lowering PHI use:\n" << *PN << "\n");

  // Check if the phi was already lowered
  if (auto *FoundVal = findLoweredEMValue(PN)) {
    return FoundVal;
  }

  if (!GotoJoin::isJoinLabel(PN->getParent())) {
    auto res = getGetEMLoweredValue(PN, PN->getParent()->getFirstNonPHI());
    LLVM_DEBUG(dbgs() << "lowerPHIUse: Created " << *res << "\n");
    return res;
  }

  LLVM_DEBUG(dbgs() << "Performing full lowering\n");

  // Clone phi and store it as lowered value.
  auto *newPN = cast<PHINode>(PN->clone());
  newPN->insertAfter(PN);
  LoweredEMValsMap[PN] = newPN;

  LLVM_DEBUG(dbgs() << "lowerPHIUse: Cloned phi before lowering values:\n"
                    << *newPN << "\n");

  // Lower clone's preds
  for (unsigned idx = 0, op_no = newPN->getNumIncomingValues(); idx < op_no;
       ++idx) {
    replaceUseWithLoweredEM(newPN, idx, ToRemove);
  }

  LLVM_DEBUG(dbgs() << "lowerPHIUse: Cloned phi with lowered values:\n"
                    << *newPN << "\n");

  return newPN;
}

/***********************************************************************
 * lowerArgumentUse : lower argument use.
 *
 * EM is being lowered via genx_simdcf_get_em intrinsic.
 * Get_em is created at function enter. Lowering can be needed
 * if argument's user was moved under SIMD CF due to some reason.
 */
Value *GenXSimdCFConformance::lowerArgumentUse(Argument *Arg) {
  LLVM_DEBUG(dbgs() << "lowerArgumentUse: Lowering argument use:\n"
                    << *Arg << "\n");

  return getGetEMLoweredValue(Arg, Arg->getParent()->front().getFirstNonPHI());
}

/***********************************************************************
 * replaceUseWithLoweredEM : lower incoming EM for user.
 *
 * EM is being lowered via genx_simdcf_get_em intrinsic.
 */
void GenXSimdCFConformance::replaceUseWithLoweredEM(
    Instruction *Val, unsigned operandNo, SetVector<Value *> &ToRemove) {
  Value *EM = Val->getOperand(operandNo);

  LLVM_DEBUG(dbgs() << "replaceUseWithLoweredEM: Replacing EM use:\n"
                    << *EM << "\nwith lowered EM for:\n"
                    << *Val << "\n");

  Value *LoweredEM = nullptr;

  if (auto *EVI = dyn_cast<ExtractValueInst>(EM)) {
    BasicBlock *PhiPredBlock = nullptr;
    if (auto *PN = dyn_cast<PHINode>(Val))
      PhiPredBlock = PN->getIncomingBlock(operandNo);
    LoweredEM = lowerEVIUse(EVI, Val, PhiPredBlock);
  } else if (auto *SVI = dyn_cast<ShuffleVectorInst>(EM)) {
    // Shuffle vector: go through it and lower its pred.
    // All changes will be applied here.
    replaceUseWithLoweredEM(SVI, 0, ToRemove);
  } else if (auto *PN = dyn_cast<PHINode>(EM)) {
    LoweredEM = lowerPHIUse(PN, ToRemove);
  } else if (auto *Arg = dyn_cast<Argument>(EM)) {
    LoweredEM = lowerArgumentUse(Arg);
  } else if (isa<Constant>(EM) &&
             EM->getType()->getScalarType()->isIntegerTy(1)) {
    LoweredEM = EM;
  } else
    // All other instructions should not be EM producers with correct DF
    IGC_ASSERT_EXIT_MESSAGE(0, "Failed to lower EM!");

  if (LoweredEM)
    Val->setOperand(operandNo, LoweredEM);

  ToRemove.insert(Val);
}

/***********************************************************************
 * canUseLoweredEM : check whether instruction can use lowered EM
 *
 * Lowered EM is an explicit value that can be consumed by any
 * instruction except of goto and join because they take implicit EM.
 */
bool GenXSimdCFConformance::canUseLoweredEM(Instruction *Val) {
  auto IID = vc::getAnyIntrinsicID(Val);
  if (IID == GenXIntrinsic::genx_simdcf_goto ||
      IID == GenXIntrinsic::genx_simdcf_join)
    return false;

  // For phi, check that it does not deal with goto or join.
  if (auto *PN = dyn_cast<PHINode>(Val)) {
    for (unsigned idx = 0, opNo = PN->getNumIncomingValues(); idx < opNo;
         ++idx) {
      auto *Inst = dyn_cast<ExtractValueInst>(PN->getOperand(idx));
      if (Inst) {
        auto Pred = Inst->getOperand(0);
        auto PredIID = vc::getAnyIntrinsicID(Pred);
        if (PredIID == GenXIntrinsic::genx_simdcf_goto ||
            PredIID == GenXIntrinsic::genx_simdcf_join)
          return false;
      }
    }
  }

  return true;
}

/***********************************************************************
 * canUseRealEM : check whether instruction can use real EM that is
 * passed via #opNo operand.
 *
 * This is used to check if instruction can use real EM.
 *
 * TODO: It is used only by linearized fragment optimization now.
 * This function should be extended and put into getConnectedVals
 * algorithm in order to make the last one simplier. For now,
 * this check will be passed only by selects, shufflevectors and wrregions
 * because these instructions movement makes sence during the
 * optimization.
 */
bool GenXSimdCFConformance::canUseRealEM(Instruction *Inst, unsigned opNo) {
  if (auto *Select = dyn_cast<SelectInst>(Inst)) {
    // Real EM can be condition only
    return opNo == 0;
  }

  if (auto *SVI = dyn_cast<ShuffleVectorInst>(Inst)) {
    // TODO: getConnectedVals checks only this, but
    // there is no check for idxs correctness.
    // They should be 0, 1, 2, ..., EXEC_SIZE - 1 for
    // EM truncation.
    if (!ShuffleVectorAnalyzer::isReplicatedSlice(SVI))
      return false;

    return checkAllUsesAreSelectOrWrRegion(SVI);
  }

  // Left switch for further extensions
  switch (vc::getAnyIntrinsicID(Inst)) {
  case GenXIntrinsic::genx_wrregionf:
  case GenXIntrinsic::genx_wrregioni:
    // Real EM can be wrrregion predicate only
    return opNo == GenXIntrinsic::GenXRegion::PredicateOperandNum;
  default:
    break;
  }

  return false;
}

/***********************************************************************
 * checkInterference : check for a list of values interfering with each other
 *
 * Enter:   Vals = values to check (not constants)
 *          BadDefs = SetVector in which to store any def that is found in the
 *                    live range of another def
 *          ConstStop = instruction to treat as the def point of a constantpred,
 *                      nullptr to treat the start of the function as the def
 *                      point
 *
 * This code finds interference by scanning back from uses, finding other defs,
 * relying on the dominance property of SSA. Having found that two EM values A
 * and B interfere due to the def of A appearing in the live range of B, we
 * could choose either one to lower its goto and join. In fact we choose A (the
 * found def), as that tends to lower inner SIMD CF, giving a chance for the
 * outer SIMD CF to become legal.
 *
 * Because GenXSimdCFConformance runs before live ranges are determined, so
 * that it can modify code as it wants, we cannot use the normal interference
 * testing code in GenXLiveness.
 *
 * The idea of ConstStop is different depending on whether we are testing
 * interference of all EM values, or all RM values for a particular join:
 *
 * * For interference between all EM values, any constant (input to
 *   constantpred intrinsic) must be all ones, which is checked elsewhere. It
 *   represents the state of the execution mask at the start of the function,
 *   therefore we need to pretend that the constantpred's live range extends
 *   back to the start of the function.  This is done by the caller setting
 *   ConstStop to 0.
 *
 * * For interference between all RM values for one particular join, any
 *   constant must be all zeros, which is checked elsewhere. It represents the
 *   state of that join's resume mask on entry to the function, and just after
 *   executing the join. Therefore we need to pretend that the constantpred's
 *   live range extends back to those two places. This is done by the caller
 *   setting ConstStop to the join instruction.
 */
void GenXSimdCFConformance::checkInterference(SetVector<SimpleValue> *Vals,
                                              SetVector<Value *> *BadDefs,
                                              Instruction *ConstStop) {
  // Scan the live range of each value, looking for a def of another value.
  // Finding such a def indicates interference.
  SetVector<Value *> ToRemove;
  for (auto evi = Vals->begin(), eve = Vals->end(); evi != eve; ++evi) {
    Value *EMVal = evi->getValue();
    bool IsConstantPred =
        vc::getAnyIntrinsicID(EMVal) == GenXIntrinsic::genx_constantpred;
    // Set of blocks where we know the value is live out.
    SmallSet<BasicBlock *, 8> LiveOut;
    // Start from each use and scan backwards. If the EMVal was affected by
    // transformations, there is no need to check other uses.
    for (auto ui = EMVal->use_begin(), ue = EMVal->use_end();
         ui != ue && ToRemove.count(EMVal) == 0;) {
      auto User = cast<Instruction>(ui->getUser());
      auto OpNo = ui->getOperandNo();
      ++ui;
      if (auto *EVI = dyn_cast<ExtractValueInst>(User)) {
        // Ignore a use that is an extractvalue not involving the right struct
        // index.
        unsigned StartIndex = IndexFlattener::flatten(
            cast<StructType>(EVI->getOperand(0)->getType()), EVI->getIndices());
        unsigned NumIndices = IndexFlattener::getNumElements(EVI->getType());
        if (evi->getIndex() - StartIndex >= NumIndices)
          continue;
      }
      BasicBlock *PhiPred = nullptr;
      if (auto *Phi = dyn_cast<PHINode>(User))
        PhiPred = Phi->getIncomingBlock(OpNo);
      auto Inst = User;
      SmallVector<BasicBlock *, 4> PendingBBStack;
      for (;;) {
        if (!Inst) {
          // Go on to the next pending predecessor.
          if (PendingBBStack.empty())
            break;
          Inst = PendingBBStack.pop_back_val()->getTerminator();
        }
        if (&Inst->getParent()->front() == Inst) {
          // Reached the start of the block. Make all unprocessed predecessors
          // pending. Except if the use is in a phi node and this is the first
          // time we reach the start of a block: in that case, mark only the
          // corresponding phi block is pending.
          if (PhiPred) {
            if (LiveOut.insert(PhiPred).second)
              PendingBBStack.push_back(PhiPred);
            PhiPred = nullptr;
          } else {
            BasicBlock *InstBB = Inst->getParent();
            std::copy_if(pred_begin(InstBB), pred_end(InstBB),
                         std::back_inserter(PendingBBStack),
                         [&LiveOut](BasicBlock *BB) {
                           return LiveOut.insert(BB).second;
                         });
          }
          Inst = nullptr;
          continue;
        }
        // Go back to the previous instruction. (This happens even when
        // starting at the end of a new block, thus skipping scanning the uses
        // of the terminator, but that's OK because the terminator never uses
        // our EM or RM values.)
        Inst = Inst->getPrevNode();
        if (Inst == EMVal && !IsConstantPred) {
          // Reached the def of the value. Stop scanning, unless the def is
          // constantpred, in which case we pretend it was live from the
          // ConstStop.
          Inst = nullptr;
          continue;
        }
        if (Inst == ConstStop && IsConstantPred) {
          // For a constantpred value, we have reached the point that we want
          // to treat as its definition point.  Stop scanning.
          Inst = nullptr;
          continue;
        }
        // Check if this is the def of some other EM value.
        if (auto *VT = dyn_cast<VectorType>(Inst->getType()))
          if (VT->getElementType()->isIntegerTy(1))
            if (Vals->count(Inst) && !ToRemove.count(Inst)) {
              // It is the def of some other EM value. Mark that one as
              // interfering. However do not mark it if both values are
              // constantpred, since we pretend all of those are defined at the
              // start of the function.
              if (!IsConstantPred || vc::getAnyIntrinsicID(Inst) !=
                                         GenXIntrinsic::genx_constantpred) {
                LLVM_DEBUG(
                    dbgs()
                    << "GenXSimdCFConformance::checkInterference: def of "
                    << Inst->getName() << " found in live range of "
                    << EMVal->getName() << "\n");
                auto *SVI = dyn_cast<ShuffleVectorInst>(Inst);
                if (SVI && SVI->getOperand(0) == EMVal) {
                  // Shuffle vector is baled as EM of another size: this check
                  // is to ensure that the EM in SVI is still actual
                  LLVM_DEBUG(
                      dbgs()
                      << "\tShuffle vector with correct arg, skipping it\n");
                } else if (canUseLoweredEM(User) && !FG) {
                  // Lower EM in Early Pass
                  replaceUseWithLoweredEM(User, OpNo, ToRemove);
                  LLVM_DEBUG(dbgs() << "\tSucceded to lower EM for that use\n");
                } else {
                  LLVM_DEBUG(dbgs() << "\t!!! Failed to lower EM for that use: "
                                       "def will be lowered\n");
                  BadDefs->insert(Inst);
                }
                // Done for that use
                break;
              }
            }
      }
    }
  }

  for (auto Inst : ToRemove) {
    removeFromEMRMVals(Inst);
  }
}

/***********************************************************************
 * insertCond : insert a vector of i1 value into the start of another one
 *
 * Enter:   OldVal = value to insert into
 *          NewVal = value to insert, at index 0
 *          Name = name for any new instruction
 *          InsertBefore = where to insert any new instruction
 *          DL = debug loc to give any new instruction
 *
 * Return:  value, possibly the same as the input value
 */
Value *GenXSimdCFConformance::insertCond(Value *OldVal, Value *NewVal,
                                         const Twine &Name,
                                         Instruction *InsertBefore,
                                         const DebugLoc &DL) {
  unsigned OldWidth =
      cast<IGCLLVM::FixedVectorType>(OldVal->getType())->getNumElements();
  unsigned NewWidth =
      cast<IGCLLVM::FixedVectorType>(NewVal->getType())->getNumElements();
  if (OldWidth == NewWidth)
    return NewVal;
  // Do the insert with shufflevector. We need two shufflevectors, one to extend
  // NewVal to OldVal's width, and one to combine them.
  // GenXLowering decides whether this is suitable to lower to wrpredregion, or
  // needs to be lowered to something less efficient.
  SmallVector<Constant *, 32> Indices;
  Type *I32Ty = Type::getInt32Ty(InsertBefore->getContext());
  unsigned i;
  for (i = 0; i != NewWidth; ++i)
    Indices.push_back(ConstantInt::get(I32Ty, i));
  auto UndefIndex = UndefValue::get(I32Ty);
  for (; i != OldWidth; ++i)
    Indices.push_back(UndefIndex);
  auto SV1 = new ShuffleVectorInst(NewVal, UndefValue::get(NewVal->getType()),
                                   ConstantVector::get(Indices),
                                   NewVal->getName() + ".extend", InsertBefore);
  SV1->setDebugLoc(DL);
  if (isa<UndefValue>(OldVal))
    return SV1;
  Indices.clear();
  for (i = 0; i != NewWidth; ++i)
    Indices.push_back(ConstantInt::get(I32Ty, i + OldWidth));
  for (; i != OldWidth; ++i)
    Indices.push_back(ConstantInt::get(I32Ty, i));
  auto SV2 = new ShuffleVectorInst(OldVal, SV1, ConstantVector::get(Indices),
                                   Name, InsertBefore);
  SV2->setDebugLoc(DL);
  return SV2;
}

/***********************************************************************
 * truncateCond : truncate a vector of i1 value
 *
 * Enter:   In = input value
 *          Ty = type to truncate to
 *          Name = name for any new instruction
 *          InsertBefore = where to insert any new instruction
 *          DL = debug loc to give any new instruction
 *
 * Return:  value, possibly the same as the input value
 */
Value *GenXSimdCFConformance::truncateCond(Value *In, Type *Ty,
                                           const Twine &Name,
                                           Instruction *InsertBefore,
                                           const DebugLoc &DL) {
  unsigned InWidth =
      cast<IGCLLVM::FixedVectorType>(In->getType())->getNumElements();
  unsigned TruncWidth = cast<IGCLLVM::FixedVectorType>(Ty)->getNumElements();
  if (InWidth == TruncWidth)
    return In;
  // Do the truncate with shufflevector. GenXLowering lowers it to rdpredregion.
  SmallVector<Constant *, 32> Indices;
  Type *I32Ty = Type::getInt32Ty(InsertBefore->getContext());
  unsigned i;
  for (i = 0; i != TruncWidth; ++i)
    Indices.push_back(ConstantInt::get(I32Ty, i));
  auto SV =
      new ShuffleVectorInst(In, UndefValue::get(In->getType()),
                            ConstantVector::get(Indices), Name, InsertBefore);
  SV->setDebugLoc(DL);
  return SV;
}

/***********************************************************************
 * lowerGoto : lower a llvm.genx.simdcf.goto
 *
 * This also outputs a warning that we failed to optimize a SIMD branch.
 * We always output it, rather than including it in the -rpass mechanism
 * to enable or disable the warning, as it is an unexpected situation that
 * we want our users to report.
 */
void GenXSimdCFConformance::lowerGoto(CallInst *Goto) {
  LLVM_DEBUG(dbgs() << "lowerGoto: " << *Goto << "\n");
  const DebugLoc &DL = Goto->getDebugLoc();
  if (EnableGenXGotoJoin && !lowerSimdCF)
    DiagnosticInfoSimdCF::emit(Goto, "failed to optimize SIMD branch",
                               DS_Warning);
  Value *Results[3];
  auto EM = Goto->getOperand(0);
  auto Cond = Goto->getOperand(2);
  // EM is always 32 bit. Extract SubEM, of the same width as Cond, from it.
  auto OldSubEM =
      truncateCond(EM, Cond->getType(), EM->getName() + ".sub", Goto, DL);
  // Result 1: NewRM = OldRM | (SubEM & ~Cond)
  auto NotCond = BinaryOperator::Create(
      Instruction::Xor, Cond, Constant::getAllOnesValue(Cond->getType()),
      Goto->getName() + ".notcond", Goto);
  NotCond->setDebugLoc(DL);
  auto NotCondAndSubEM =
      BinaryOperator::Create(Instruction::And, NotCond, OldSubEM,
                             Goto->getName() + ".disabling", Goto);
  NotCondAndSubEM->setDebugLoc(DL);
  Value *OldRM = Goto->getArgOperand(1);
  auto NewRM = BinaryOperator::Create(Instruction::Or, OldRM, NotCondAndSubEM,
                                      Goto->getName() + ".newRM", Goto);
  NewRM->setDebugLoc(DL);
  Results[1] = NewRM;
  // And SubEM with Cond.
  auto SubEM = BinaryOperator::Create(Instruction::And, OldSubEM, Cond,
                                      Goto->getName() + ".subEM", Goto);
  SubEM->setDebugLoc(DL);
  // Insert that back into EM. That is result 0.
  Results[0] = EM = insertCond(EM, SubEM, Goto->getName() + ".EM", Goto, DL);
  // Result 2: BranchCond = !any(SubEM)
  Function *AnyFunc = vc::getAnyDeclaration(
      M, GenXIntrinsic::genx_any, SubEM->getType());
  auto Any = CallInst::Create(AnyFunc, SubEM, SubEM->getName() + ".any", Goto);
  Any->setDebugLoc(DL);
  auto Not = BinaryOperator::Create(Instruction::Xor, Any,
                                    Constant::getAllOnesValue(Any->getType()),
                                    Any->getName() + ".not", Goto);
  Not->setDebugLoc(DL);
  Results[2] = Not;
  // Replace uses.
  replaceGotoJoinUses(Goto, Results);
  Goto->eraseFromParent();
  Modified = true;
}

/***********************************************************************
 * lowerJoin : lower a llvm.genx.simdcf.join
 */
void GenXSimdCFConformance::lowerJoin(CallInst *Join) {
  LLVM_DEBUG(dbgs() << "lowerJoin: " << *Join << "\n");
  const DebugLoc &DL = Join->getDebugLoc();
  Value *Results[2];
  auto EM = Join->getOperand(0);
  auto RM = Join->getOperand(1);
  // EM is always 32 bit. Extract SubEM, of the same width as RM, from it.
  auto OldSubEM =
      truncateCond(EM, RM->getType(), EM->getName() + ".sub", Join, DL);
  // Or it with RM.
  auto SubEM = BinaryOperator::Create(Instruction::Or, OldSubEM, RM,
                                      Join->getName() + ".subEM", Join);
  SubEM->setDebugLoc(DL);
  // Insert that back into EM. That is result 0.
  Results[0] = EM = insertCond(EM, SubEM, Join->getName() + ".EM", Join, DL);
  // Result 1: BranchCond = !any(SubEM)
  Function *AnyFunc = vc::getAnyDeclaration(
      M, GenXIntrinsic::genx_any, SubEM->getType());
  auto Any = CallInst::Create(AnyFunc, SubEM, SubEM->getName() + ".any", Join);
  Any->setDebugLoc(DL);
  auto Not = BinaryOperator::Create(Instruction::Xor, Any,
                                    Constant::getAllOnesValue(Any->getType()),
                                    Any->getName() + ".not", Join);
  Not->setDebugLoc(DL);
  Results[1] = Not;
  // Replace uses.
  replaceGotoJoinUses(Join, Results);
  Join->eraseFromParent();
  Modified = true;
}

/***********************************************************************
 * replaceGotoJoinUses : replace uses of goto/join
 *
 * The goto and join intrinsics have multiple return values in a struct.
 * This attempts to find the extractvalues and replace those directly.
 * It also spots where a value is unused.
 */
void GenXSimdCFConformance::replaceGotoJoinUses(CallInst *GotoJoin,
                                                ArrayRef<Value *> Vals) {
  SmallVector<ExtractValueInst *, 4> Extracts;
  for (auto ui = GotoJoin->use_begin(), ue = GotoJoin->use_end(); ui != ue;
       ++ui) {
    auto *Extract = dyn_cast<ExtractValueInst>(ui->getUser());
    if (Extract)
      Extracts.push_back(Extract);
  }
  for (auto ei = Extracts.begin(), ee = Extracts.end(); ei != ee; ++ei) {
    auto Extract = *ei;
    unsigned Index = Extract->getIndices()[0];
    if (Index >= Vals.size())
      continue;
    Extract->replaceAllUsesWith(Vals[Index]);
    Extract->eraseFromParent();
  }
  if (!GotoJoin->use_empty()) {
    // There are still some uses of the original goto/join. We need to
    // aggregate the result values into a struct.
    Value *StructVal = UndefValue::get(GotoJoin->getType());
    Instruction *InsertBefore = GotoJoin->getNextNode();
    for (unsigned Index = 0,
                  End = cast<StructType>(GotoJoin->getType())->getNumElements();
         Index != End; ++Index)
      StructVal = InsertValueInst::Create(StructVal, Vals[Index], Index, "",
                                          InsertBefore);
    GotoJoin->replaceAllUsesWith(StructVal);
  } else {
    // Remove code for unused value. This is particularly useful at an outer
    // join, where the !any(NewEM) is unused, so we don't need to compute it.
    for (unsigned vi = 0; vi != Vals.size(); ++vi) {
      Value *V = Vals[vi];
      while (V && V->use_empty()) {
        auto *I = dyn_cast<Instruction>(V);
        if (I == nullptr)
          continue;
        unsigned NumOperands = I->getNumOperands();
        if (auto *CI = dyn_cast<CallInst>(I))
          NumOperands = IGCLLVM::getNumArgOperands(CI);
        V = nullptr;
        if (NumOperands == 1)
          V = I->getOperand(0);
        I->eraseFromParent();
      }
    }
  }
}

/***********************************************************************
 * fixBlockDataBeforeRemoval : clear redundant phi-s and dbg
 *    instruction, before erase basic block
 */
static void fixBlockDataBeforeRemoval(BasicBlock *BB, BasicBlock *SuccBB) {
  while (auto *PN = dyn_cast<PHINode>(BB->begin()))
    PN->eraseFromParent();

  IGC_ASSERT_MESSAGE(BB->getSingleSuccessor() == SuccBB,
                     "Awaiting only one successor");
  bool HasOnePred = SuccBB->hasNPredecessors(1);
  Instruction *InsertBefore = SuccBB->getFirstNonPHI();
  while (auto *DBG = dyn_cast<DbgVariableIntrinsic>(BB->begin())) {
    DBG->moveBefore(InsertBefore);
    if (!HasOnePred)
      IGCLLVM::setKillLocation(DBG);
  }
}

/***********************************************************************
 * setCategories : set webs of EM and RM values to category EM or RM
 *
 * This also modifies EM uses as needed.
 */
void GenXLateSimdCFConformance::setCategories() {
  // First the EM values.
  for (auto ei = EMVals.begin(); ei != EMVals.end(); /* empty */) {
    SimpleValue EMVal = *ei;
    ei++;
    // For this EM value, set its category and modify its uses.
    Liveness->getOrCreateLiveRange(EMVal)->setCategory(vc::RegCategory::EM);
    LLVM_DEBUG(dbgs() << "Set category for:\n" << *EMVal.getValue() << "\n");
    if (!isa<StructType>(EMVal.getValue()->getType()))
      modifyEMUses(EMVal.getValue());
    switch (vc::getAnyIntrinsicID(EMVal.getValue())) {
    case GenXIntrinsic::genx_simdcf_join: {
      // For a join, set the category of each RM value.
      auto RMValsEntry = &RMVals[cast<CallInst>(EMVal.getValue())];
      for (auto vi = RMValsEntry->begin(), ve = RMValsEntry->end(); vi != ve;
           ++vi) {
        SimpleValue RMVal = *vi;
        // For this RM value, set its category.
        Liveness->getOrCreateLiveRange(RMVal)->setCategory(vc::RegCategory::RM);
      }
    }
    // Fall through...
    case GenXIntrinsic::genx_simdcf_goto: {
      // See if this is a branching goto/join where the "true" successor is
      // an empty critical edge splitter block.
      auto CI = cast<CallInst>(EMVal.getValue());
      BasicBlock *BB = CI->getParent();
      if (GotoJoin::isBranchingGotoJoinBlock(BB) == CI) {
        BasicBlock *TrueSucc = BB->getTerminator()->getSuccessor(0);
        if (BasicBlock *TrueSuccSucc =
                getEmptyCriticalEdgeSplitterSuccessor(TrueSucc)) {
          for (PHINode &Phi : TrueSucc->phis()) {
            if (Phi.getNumIncomingValues() == 1) {
              auto *PredInst = Phi.getIncomingValue(0);
              Phi.replaceAllUsesWith(PredInst);
              Liveness->eraseLiveRange(&Phi);
              removeFromEMRMVals(&Phi);
            } else {
              IGC_ASSERT_MESSAGE(true, "BB has unremovable phi");
            }
          }
          // Remove phi and move dbg-info
          fixBlockDataBeforeRemoval(TrueSucc, TrueSuccSucc);
          IGC_ASSERT_MESSAGE(TrueSucc->front().isTerminator(),
                             "BB is not empty for removal");
          // For a branching goto/join where the "true" successor is an empty
          // critical edge splitter block, remove the empty block, to ensure
          // that the "true" successor is a join label.
          // Adjust phi nodes in TrueSuccSucc.
          adjustPhiNodesForBlockRemoval(TrueSuccSucc, TrueSucc);
          // Replace the use (we know there is only the one).
          BB->getTerminator()->setSuccessor(0, TrueSuccSucc);
          // Erase the critical edge splitter block.
          TrueSucc->eraseFromParent();
          Modified = true;
        }
      }
      break;
    }
    default:
      break;
    }
  }
}

/***********************************************************************
 * modifyEMUses : modify EM uses as needed
 */
void GenXLateSimdCFConformance::modifyEMUses(Value *EM) {
  LLVM_DEBUG(dbgs() << "modifyEMUses: " << EM->getName() << "\n");
  // Gather the selects we need to modify, at the same time as handling other
  // uses of the EM values.
  SmallVector<SelectInst *, 4> Selects;
  SmallVector<Value *, 4> EMs;
  EMs.push_back(EM);
  for (unsigned ei = 0; ei != EMs.size(); ++ei) {
    EM = EMs[ei];
    // Scan EM's uses.
    for (auto ui = EM->use_begin(), ue = EM->use_end(); ui != ue; ++ui) {
      auto User = cast<Instruction>(ui->getUser());
      if (auto *Sel = dyn_cast<SelectInst>(User)) {
        IGC_ASSERT(!ui->getOperandNo());
        Selects.push_back(Sel);
      } else {
        IGC_ASSERT(testIsValidEMUse(User, ui));
        if (vc::getAnyIntrinsicID(User) == GenXIntrinsic::genx_rdpredregion) {
          // An rdpredregion of the EM. Find its uses in select too.
          EMs.push_back(User);
        }
      }
    }
  }
  // Modify each select into a predicated wrregion.
  const GenXSubtarget &Subtarget = getAnalysis<TargetPassConfig>()
                                       .getTM<GenXTargetMachine>()
                                       .getGenXSubtarget();
  const DataLayout &DL = M->getDataLayout();
  for (auto si = Selects.begin(), se = Selects.end(); si != se; ++si) {
    auto Sel = *si;
    Value *FalseVal = Sel->getFalseValue();

    // This code removes redundancy introduced by
    // select & phi lying within the same goto-join region
    // and effectively duplicating the work.
    bool LoadFalseVal = true;
    Instruction *Goto = nullptr;
    if (auto *ExtrCond = dyn_cast<ExtractValueInst>(Sel->getCondition()))
      Goto = dyn_cast<Instruction>(ExtrCond->getAggregateOperand());
    for (auto *U : Sel->users()) {
      auto *SelPhiUser = dyn_cast<PHINode>(U);
      if (!SelPhiUser)
        continue;
      DominatorTree *DomTree = getDomTree(Sel->getFunction());
      // NOTE: we should expect exactly 2 incoming blocks,
      // but sometimes we may have more due to both
      // NoCondEV cases and critical edge splitting,
      // and that should not affect correctness of the transformation
      // IGC_ASSERT(PhiUser->getNumIncomingValues() == 2);
      if (auto *DomBB = DomTree->findNearestCommonDominator(
              SelPhiUser->getIncomingBlock(0),
              SelPhiUser->getIncomingBlock(1))) {
        auto *Term = dyn_cast<BranchInst>(DomBB->getTerminator());
        if (Term && Term->isConditional()) {
          auto *ExtrCond = dyn_cast<ExtractValueInst>(Term->getCondition());
          if (ExtrCond && ExtrCond->getAggregateOperand() == Goto) {
            LoadFalseVal = false;
            break;
          }
        }
      }
    }

    if (auto *C = dyn_cast<Constant>(FalseVal)) {
      if (!isa<UndefValue>(C)) {
        if (LoadFalseVal) {
          // The false value needs loading if it is a constant other than
          // undef.
          SmallVector<Instruction *, 4> AddedInstructions;
          FalseVal =
              ConstantLoader(C, Subtarget, DL, nullptr, &AddedInstructions)
                  .loadBig(Sel);
          // ConstantLoader generated at least one instruction.  Ensure that
          // each one has debug loc and category.
          for (auto aii = AddedInstructions.begin(),
                    aie = AddedInstructions.end();
               aii != aie; ++aii) {
            Instruction *I = *aii;
            I->setDebugLoc(Sel->getDebugLoc());
          }
        } else
          // As mentioned above, we're trying to eliminate
          // redundancy with select+phi in a goto/join region.
          // So we convert select to a wrr with an undef source
          // for it to effectively become a simple mov
          FalseVal = UndefValue::get(C->getType());
      }
    }
    Region R(Sel);
    R.Mask = Sel->getCondition();
    IGC_ASSERT(FalseVal);
    Value *Wr = R.createWrRegion(FalseVal, Sel->getTrueValue(), Sel->getName(),
                                 Sel, Sel->getDebugLoc());
    Sel->replaceAllUsesWith(Wr);
    Liveness->eraseLiveRange(Sel);
    Sel->eraseFromParent();
    Modified = true;
  }
}

/***********************************************************************
 * hoistExtractEMInstructions : hoist EM-related extractvalue
 * instructions and remove duplicates if there are such to comply
 * EM conformance
 *
 * Duplicates may be insterted by LICM before late pass. This method
 * is not needed in early pass, because redundant extractvalues are
 * deleted by preceding EarlyCSE there.
 *
 * Currently, RM duplicates are not handled, because no case where it
 * might be needed was found as of yet
 */
void GenXLateSimdCFConformance::hoistExtractEMInstructions() {
  DenseMap<SimpleValue, ExtractValueInst *> EVs;
  SmallVector<ExtractValueInst *, 8> ToRemove;

  for (auto &&EMVal : EMVals)
    if (auto *V = dyn_cast<ExtractValueInst>(EMVal.getValue())) {
      auto *StructVal = cast<Instruction>(V->getAggregateOperand());
      auto [It, IsInserted] =
          EVs.try_emplace(SimpleValue{StructVal, V->getIndices()}, V);
      if (IsInserted)
        V->moveAfter(StructVal);
      else {
        ToRemove.push_back(V);
        V->replaceAllUsesWith(It->second);
      }
    }
  for (auto &&V : ToRemove) {
    removeFromEMRMVals(V);
    V->eraseFromParent();
  }
}

/***********************************************************************
 * optimizeRestoredSIMDCF : perform optimization on restored SIMD CF
 *
 * Restored SIMD CF is built from linear code blocks that came from
 * llvm transformations. Some code could be moved from SIMD CF after
 * join point during this transformations. This function tries to
 * put such code back.
 *
 * TODO: some other transformations could be applied after SIMD CF was
 * linearized. Maybe this function should be updated in future.
 */
void GenXSimdCFConformance::optimizeRestoredSIMDCF() {
  for (auto &Data : BlocksToOptimize) {
    // Skip blocks with lowered EM values
    if (!EMVals.count(SimpleValue(Data.second.getRealEM(), 0))) {
      LLVM_DEBUG(dbgs() << "optimizeRestoredSIMDCF: skipping "
                        << Data.first->getName() << "\n");
      continue;
    }
    optimizeLinearization(Data.first, Data.second);
  }
}

/***********************************************************************
 * isActualStoredEM : check if Inst is a actual stored EM
 *
 * This function is called during linear fragment optimization.
 * The actual stored EM is a PHI node with const/getEM inputs here.
 * Actuallity is checked via EM-getEM map.
 */
bool GenXSimdCFConformance::isActualStoredEM(Instruction *Inst,
                                             JoinPointOptData &JPData) {
  LLVM_DEBUG(dbgs() << "isActualStoredEM: visiting\n" << *Inst << "\n");
  PHINode *PN = dyn_cast<PHINode>(Inst);

  // Linearized block should be turned into a hammock: stored EM
  // must come via PHI with two preds. Go through shufflevector
  // in case of truncated EM.
  if (auto *SVI = dyn_cast<ShuffleVectorInst>(Inst)) {
    LLVM_DEBUG(dbgs() << "Truncated EM detected\n");

    // Check SVI trunc correctness
    if (!canUseRealEM(Inst, 0)) {
      LLVM_DEBUG(dbgs() << "Bad trunc via SVI: not an actual EM\n");
      return false;
    }

    PN = dyn_cast<PHINode>(SVI->getOperand(0));
  }
  if (!PN || PN->getNumIncomingValues() != 2) {
    LLVM_DEBUG(dbgs() << "Incompatable inst: not an actual EM\n");
    return false;
  }

  Value *ExpectedGetEM = PN->getIncomingValueForBlock(JPData.getFalsePred());
  Value *ExpectedConstEM = PN->getIncomingValueForBlock(JPData.getTruePred());

  IGC_ASSERT_MESSAGE(ExpectedGetEM, "Bad phi in hammock!");
  IGC_ASSERT_MESSAGE(ExpectedConstEM, "Bad phi in hammock!");

  // Find stored value
  auto It = LoweredEMValsMap.find(JPData.getRealEM());
  if (It == LoweredEMValsMap.end()) {
    LLVM_DEBUG(dbgs() << "No EM was stored: not an actual EM\n");
    return false;
  }

  // Check if the val from SIMD BB is a stored via get.em EM
  if (ExpectedGetEM != It->second) {
    LLVM_DEBUG(
        dbgs() << "SIMD BB value is not a correct get.em: not an actual EM\n");
    return false;
  }

  // Check if the val from True BB is an all null constant
  if (ExpectedConstEM != Constant::getNullValue(ExpectedConstEM->getType())) {
    LLVM_DEBUG(
        dbgs()
        << "True BB value is not a correct constant: not an actual EM\n");
    return false;
  }

  LLVM_DEBUG(dbgs() << "All checks passed\n");
  return true;
}

/***********************************************************************
 * canBeMovedUnderSIMDCF : check if Instruction can be moved under
 * SIMD CF
 *
 * This function is called during linear fragment optimization.
 * We can move instruction if such movement does not corrupt
 * dominance. Sometimes we can meet several instruction that
 * should be moved. There is a recursive call, all instructions
 * in chain are stored in Visited set.
 */
bool GenXSimdCFConformance::canBeMovedUnderSIMDCF(
    Value *Val, BasicBlock *CurrBB, JoinPointOptData &JPData,
    std::set<Instruction *> &Visited) {
  Instruction *Inst = dyn_cast<Instruction>(Val);

  // Can be non-inst. In this case we have nothing to check.
  if (!Inst)
    return true;

  LLVM_DEBUG(dbgs() << "canBeMovedUnderSIMDCF: visiting\n" << *Inst << "\n");

  // Mark instruction as visited. Return if it was already added to set:
  // we don't expect it to be here.
  if (!Visited.insert(Inst).second) {
    LLVM_DEBUG(dbgs() << "Instruction was already visited: do not move\n");
    return false;
  }

  // Instruction is not located in linearized fragment
  if (Inst->getParent() != CurrBB) {
    LLVM_DEBUG(dbgs() << "Out of linearized fragment: do not move\n");
    return false;
  }

  // Do not move join instruction
  if (vc::getAnyIntrinsicID(Inst) == GenXIntrinsic::genx_simdcf_join) {
    LLVM_DEBUG(dbgs() << "Join instruction: do not move\n");
    return false;
  }

  // TODO: current assumption is that nothing except linearization was applied
  // Skip instruction that has more than one user
  if (!Inst->hasOneUse()) {
    LLVM_DEBUG(dbgs() << "More than one user: do not move\n");
    return false;
  }

  // Check operands
  for (unsigned i = 0, e = Inst->getNumOperands(); i < e; ++i) {
    Instruction *Pred = dyn_cast<Instruction>(Inst->getOperand(i));

    // Not an instruction: not blocking moving
    if (!Pred)
      continue;

    // Check for dominance
    DominatorTree *DomTree = getDomTree(CurrBB->getParent());
    if (DomTree->dominates(Pred, JPData.getFalsePred()))
      continue;

    // Check for actual saved EM: it is a phi located in current BB,
    // so the dominance check failed
    if (isActualStoredEM(Pred, JPData))
      continue;

    // Dominance check and EM check failed: instruction is inside this block
    LLVM_DEBUG(dbgs() << "Recursive call for operand #" << i << "\n");
    if (canBeMovedUnderSIMDCF(Pred, CurrBB, JPData, Visited))
      continue;

    // Recursive call failed: do not move
    LLVM_DEBUG(dbgs() << "canBeMovedUnderSIMDCF: bad operand: do not move\n");
    return false;
  }

  LLVM_DEBUG(dbgs() << "canBeMovedUnderSIMDCF: move\n" << *Inst << "\n");
  return true;
}

/***********************************************************************
 * isSelectConditionCondEV : check if Select's condition is a stored
 * Cond EV value.
 *
 * This function is called during linear fragment optimization.
 * Linear fragment optimization bases on the fact that LLVM performed
 * code movement with PHI -> select transformation. This function
 * checks if the select condition is a CondEV from previous SIMD
 * branching instruction.
 *
 * This function also can handle constant vectorization if it
 * was applied: it doesn't break SIMD CF CondEV semantics.
 */
bool GenXSimdCFConformance::isSelectConditionCondEV(SelectInst *Sel,
                                                    JoinPointOptData &JPData) {
  PHINode *PN = dyn_cast<PHINode>(Sel->getCondition());
  if (!PN)
    return false;

  // CondEV Phi must be in the same BB
  if (PN->getParent() != Sel->getParent())
    return false;

  Value *TrueBlockValue = PN->getIncomingValueForBlock(JPData.getTruePred());
  Value *FalseBlockValue = PN->getIncomingValueForBlock(JPData.getFalsePred());

  IGC_ASSERT_MESSAGE(TrueBlockValue, "Bad phi in hammock!");
  IGC_ASSERT_MESSAGE(FalseBlockValue, "Bad phi in hammock!");

  Constant *TrueBlockConst = dyn_cast<Constant>(TrueBlockValue);
  Constant *FalseBlockConst = dyn_cast<Constant>(FalseBlockValue);

  if (!TrueBlockConst || !FalseBlockConst)
    return false;

  // It is not necessary to check constant type due CondEV semantics
  if (!TrueBlockConst->isOneValue() || !FalseBlockConst->isNullValue())
    return false;

  return true;
}

/***********************************************************************
 * replaceGetEMUse : find and replace GetEM uses to fix dominance.
 *
 * This function is called during linear fragment optimization.
 * After we moved Inst to SIMD BB, we need to update EM connections
 * according to updated DF.
 *
 * In many cases we can place real EM instead of lowered one.
 * GetEM may become redundant - it will be removed later in this pass.
 *
 * Note: SIMD CF will be non-conformant if Inst is left at the point
 * where new EM was generated. Inst is moved after that replacement
 * in linearized fragment optimization so conformance is not broken.
 */
void GenXSimdCFConformance::replaceGetEMUse(Instruction *Inst,
                                            JoinPointOptData &JPData) {
  for (unsigned i = 0, e = Inst->getNumOperands(); i < e; ++i) {
    Instruction *Pred = dyn_cast<Instruction>(Inst->getOperand(i));

    if (!Pred)
      continue;

    // EM must be in the same BB
    if (Pred->getParent() != Inst->getParent())
      continue;

    if (!isActualStoredEM(Pred, JPData))
      continue;

    if (canUseRealEM(Inst, i)) {
      // Replace with real EM
      Instruction *NewOp = JPData.getRealEM();
      Instruction *FullEM = nullptr;
      if (isa<ShuffleVectorInst>(Pred)) {
        // Copy truncation via SVI
        NewOp = Pred->clone();
        NewOp->insertBefore(JPData.getFalsePred()->getTerminator());
        NewOp->setOperand(0, JPData.getRealEM());
        FullEM = cast<Instruction>(Pred->getOperand(0));
      }
      Inst->setOperand(i, NewOp);

      // Remove Pred if it is not needed anymore.
      // Do the same for FullEM.
      // GetEM that was used here will be handled later.
      if (Pred->use_empty()) {
        Pred->eraseFromParent();
      }
      if (FullEM && FullEM->use_empty()) {
        FullEM->eraseFromParent();
      }
    } else {
      // Replace with lowered EM
      auto it = LoweredEMValsMap.find(JPData.getRealEM());
      IGC_ASSERT_EXIT_MESSAGE(it != LoweredEMValsMap.end(),
                         "Should be checked earlier");
      Instruction *LoweredEM = cast<Instruction>(it->second);
      Inst->setOperand(i, LoweredEM);

      if (Pred->use_empty())
        Pred->eraseFromParent();
    }
  }
}

/***********************************************************************
 * optimizeLinearization : optimize linearized fragment
 *
 * This optimization restores SIMD CF for linearized fragment.
 * To detect code that can be moved under SIMD CF, we need to find
 * a the following select inst:
 *    Val = select CondEV, OldVal, NewVal
 * Details can be found below.
 */
void GenXSimdCFConformance::optimizeLinearization(BasicBlock *BB,
                                                  JoinPointOptData &JPData) {
  std::set<Instruction *> InstsToMove;
  std::vector<SelectInst *> SelectsToOptimize;
  for (Instruction *Inst = BB->getTerminator()->getPrevNode();
       Inst && !dyn_cast<PHINode>(Inst); Inst = Inst->getPrevNode()) {
    // We are looking for "Val = select CondEV, OldVal, NewVal" instruction.
    //
    // Linearization put NewVal calculations after JP. OldVal came from True BB
    // via PHI instruction. We can move NewVal calculations under SIMD CF and
    // place a PHINode instead of this select. Val, OldVal and NewVal will be
    // coalesced and allocated on the same register later.
    SelectInst *Select = dyn_cast<SelectInst>(Inst);
    if (!Select || !isSelectConditionCondEV(Select, JPData))
      continue;

    // Check if OldVal came from outside. Also it can be a constant.
    // TODO: current assumption is that nothing except linearization was
    // applied. It is possible that OldVal was moved down after it. We also can
    // move it back but some analysis is required to avoid possible overhead.
    // Not done now.
    Value *OldVal = Select->getTrueValue();
    if (Instruction *OldValInst = dyn_cast<Instruction>(OldVal)) {
      DominatorTree *DomTree = getDomTree(BB->getParent());
      // Must dominate this BB and SIMD BB
      if (!DomTree->dominates(OldValInst, BB) ||
          !DomTree->dominates(OldValInst, JPData.getFalsePred()))
        continue;
    }

    // Check NewVal
    Value *NewVal = Select->getFalseValue();
    std::set<Instruction *> Visited;
    if (!canBeMovedUnderSIMDCF(NewVal, BB, JPData, Visited))
      continue;

    // We can optimize this select
    InstsToMove.insert(Visited.begin(), Visited.end());
    SelectsToOptimize.push_back(Select);
  }

  // Move instructions
  // FIXME: there must be a way to do it in a better manner
  // The idea of this is to save the instructions' order so we don't brake
  // dominance when movement is performed.
  std::vector<Instruction *> OrderedInstsToMove;
  for (Instruction *Inst = BB->getFirstNonPHI(); Inst;
       Inst = Inst->getNextNode()) {
    if (InstsToMove.find(Inst) == InstsToMove.end())
      continue;
    OrderedInstsToMove.push_back(Inst);
  }
  for (auto *Inst : OrderedInstsToMove) {
    replaceGetEMUse(Inst, JPData);
    Inst->moveBefore(JPData.getFalsePred()->getTerminator());
  }

  // Handle selects
  for (auto *Select : SelectsToOptimize) {
    PHINode *PN = PHINode::Create(Select->getType(), 2, "optimized_sel",
                                  BB->getFirstNonPHI());
    PN->addIncoming(Select->getTrueValue(), JPData.getTruePred());
    PN->addIncoming(Select->getFalseValue(), JPData.getFalsePred());
    Select->replaceAllUsesWith(PN);
    Select->eraseFromParent();
  }
}

/***********************************************************************
 * GotoJoinEVs::GotoJoinEVs : collects and handle EVs. See CollectEVs
 * for more info.
 */
GenXSimdCFConformance::GotoJoinEVs::GotoJoinEVs(Value *GJ)
    : GotoJoin(GJ), IsGoto(false) {
  if (!GotoJoin)
    return;

  auto IID = vc::getAnyIntrinsicID(GotoJoin);
  switch (IID) {
  case GenXIntrinsic::genx_simdcf_goto:
    IsGoto = true;
    break;
  case GenXIntrinsic::genx_simdcf_join:
    IsGoto = false;
    break;
  default:
    IGC_ASSERT_MESSAGE(0, "Expected goto or join!");
    break;
  }

  CollectEVs();
}

/***********************************************************************
 * GotoJoinEVs::getEMEV : get EV for goto/join Execution Mask
 */
ExtractValueInst *GenXSimdCFConformance::GotoJoinEVs::getEMEV() const {
  IGC_ASSERT_MESSAGE(GotoJoin, "Uninitialized GotoJoinEVs Data!");
  static_assert(EMPos < (sizeof(EVs) / sizeof(*EVs)));
  return EVs[EMPos];
}

/***********************************************************************
 * GotoJoinEVs::getRMEV : get EV for goto/join Resume Mask
 */
ExtractValueInst *GenXSimdCFConformance::GotoJoinEVs::getRMEV() const {
  IGC_ASSERT_MESSAGE(GotoJoin, "Uninitialized GotoJoinEVs Data!");
  IGC_ASSERT_MESSAGE(IsGoto, "Only goto returns RM!");
  static_assert(RMPos < (sizeof(EVs) / sizeof(*EVs)));
  return EVs[RMPos];
}

/***********************************************************************
 * GotoJoinEVs::getCondEV : get EV for goto/join condition
 */
ExtractValueInst *GenXSimdCFConformance::GotoJoinEVs::getCondEV() const {
  ExtractValueInst *Result = nullptr;
  IGC_ASSERT_MESSAGE(GotoJoin, "Uninitialized GotoJoinEVs Data!");
  if (IsGoto) {
    static_assert(GotoCondPos < (sizeof(EVs) / sizeof(*EVs)));
    Result = EVs[GotoCondPos];
  } else {
    static_assert(JoinCondPos < (sizeof(EVs) / sizeof(*EVs)));
    Result = EVs[JoinCondPos];
  }
  return Result;
}

Value *GenXSimdCFConformance::GotoJoinEVs::getGotoJoin() const {
  IGC_ASSERT_MESSAGE(GotoJoin, "Uninitialized GotoJoinEVs Data!");
  return GotoJoin;
}

/***********************************************************************
 * GotoJoinEVs::getSplitPoint : find first instruction that is not
 * a EV or doesn't use Goto/Join. Such instruction always exists
 * in a correct IR - BB terminator is a such instruction.
 */
Instruction *GenXSimdCFConformance::GotoJoinEVs::getSplitPoint() const {
  IGC_ASSERT_MESSAGE(GotoJoin, "Uninitialized GotoJoinEVs Data!");
  Instruction *SplitPoint = cast<Instruction>(GotoJoin)->getNextNode();
  for (; isa<ExtractValueInst>(SplitPoint) &&
         SplitPoint->getOperand(0) == GotoJoin;
       SplitPoint = SplitPoint->getNextNode())
    ;
  return SplitPoint;
}

/***********************************************************************
 * GotoJoinEVs::setCondEV : set EV for goto/join condition. It is
 * needed on basic block splitting to handle bad Cond EV user.
 */
void GenXSimdCFConformance::GotoJoinEVs::setCondEV(ExtractValueInst *CondEV) {
  IGC_ASSERT_MESSAGE(GotoJoin, "Uninitialized GotoJoinEVs Data!");
  IGC_ASSERT_MESSAGE(!getCondEV(), "CondEV is already set!");
  if (IsGoto) {
    static_assert(GotoCondPos < (sizeof(EVs) / sizeof(*EVs)));
    EVs[GotoCondPos] = CondEV;
  } else {
    static_assert(JoinCondPos < (sizeof(EVs) / sizeof(*EVs)));
    EVs[JoinCondPos] = CondEV;
  }
}

/***********************************************************************
 * GotoJoinEVs::isGoto : check wether this EVs info belongs to goto
 */
bool GenXSimdCFConformance::GotoJoinEVs::isGoto() const {
  IGC_ASSERT_MESSAGE(GotoJoin, "Uninitialized GotoJoinEVs Data!");
  return IsGoto;
}

/***********************************************************************
 * GotoJoinEVs::isJoin : check wether this EVs info belongs to join
 */
bool GenXSimdCFConformance::GotoJoinEVs::isJoin() const {
  IGC_ASSERT_MESSAGE(GotoJoin, "Uninitialized GotoJoinEVs Data!");
  return !IsGoto;
}

/***********************************************************************
 * GotoJoinEVs::getNameForMissingEV : helper for generating variable
 * names while creating missing EV instructions. Bases on the value
 * extraction index.
 */
std::string GenXSimdCFConformance::GotoJoinEVs::getNameForMissingEV(
    unsigned EVIndex) const {
  std::string Name = "missing";
  switch (EVIndex) {
  case EMPos:
    Name += "EMEV";
    break;
  case RMPos:
    Name += "RMEV";
    break;
  case GotoCondPos:
    Name += "CondEV";
    break;
  }
  return Name;
}

/***********************************************************************
 * GotoJoindEVs::CanonicalizePHIs : make goto/join -> EV -> PHI chains
 * conformant to the same pattern.
 *
 * Check if goto/join is used by any PHI nodes directly. In such cases,
 * unify the CFG through the following transformation:
 *
 * -----------------------------------------------------------------------
 * | BB.phi:                                                             |
 * | %phinode = phi %structtype [%gotojoin.1, %BB1], [%gotojoin.2, BB2]  |
 * | %ev = extractvalue %structtype %phinode, <n>                        |
 * | ; use %ev                                                           |
 * -----------------------------------------------------------------------
 *                          |       |       |
 *                          V       V       V
 * -----------------------------------------------------------------------
 * | BB1:                                                                |
 * | %ev.1 = extractvalue %structtype %gotojoin.1, <n>                   |
 * | ...                                                                 |
 * | BB2:                                                                |
 * | %ev.2 = extractvalue %structtype %gotojoin.2, <n>                   |
 * | ...                                                                 |
 * | BB.phi:                                                             |
 * | %phinode = phi %<extracted.type> [%ev.1, %BB1], [%ev.2, %BB.2]      |
 * | ; use %phinode                                                      |
 * -----------------------------------------------------------------------
 */
void GenXSimdCFConformance::GotoJoinEVs::CanonicalizePHIs() {
  // Collect all PHIs from GotoJoin users.
  SmallVector<PHINode *, 4> PHINodes;
  for (auto *U : GotoJoin->users())
    if (auto *PN = dyn_cast<PHINode>(U))
      PHINodes.push_back(PN);

  for (auto *PN : PHINodes) {
    SmallVector<ExtractValueInst *, 4> EVsToReplace;
    // Iterate over EVs for each PHI node.
    for (auto *PHIUser : PN->users()) {
      auto *EV = dyn_cast<ExtractValueInst>(PHIUser);
      IGC_ASSERT_MESSAGE(EV, "Bad user of goto/join!");
      IGC_ASSERT_MESSAGE(EV->getNumIndices() == 1,
                         "Expected 1 index in Extract Value for goto/join!");
      // Extract index info from the EV - we need to create similar EVs for the
      // goto/join calls themselves.
      ArrayRef<unsigned> idxArray = EV->getIndices();
      // Create a corresponding EV for each of PHI's incoming goto/join calls,
      // then reset the PHI node to reference these new EVs.
      PN->mutateType(EV->getType());
      for (unsigned i = 0; i < PN->getNumIncomingValues(); ++i) {
        auto *newEV = ExtractValueInst::Create(
            PN->getIncomingValue(i), idxArray, getNameForMissingEV(idxArray[0]),
            PN->getIncomingBlock(i));
        PN->setIncomingValue(i, newEV);
      }
      // Now that the PHI points to the goto/join EVs, it should be used
      // instead of the pre-existing EV from PHI's basic block. Mark the
      // EV for replacement and deletion.
      EVsToReplace.push_back(EV);
    }
    for (auto *EV : EVsToReplace) {
      EV->replaceAllUsesWith(PN);
      EV->eraseFromParent();
    }
  }
}

/***********************************************************************
 * GotoJoindEVs::CollectEVs : handle and store goto/join EVs
 *
 * This does the following steps:
 *  - Locate EVs. If we found a duplicate, just replace users.
 *  - Move EVs right after the goto/join
 *  - Add missing EM and RM. This is needed for correct liverange
 *    interference analysis.
 */
void GenXSimdCFConformance::GotoJoinEVs::CollectEVs() {
  IGC_ASSERT_MESSAGE(GotoJoin, "Uninitialized GotoJoinEVs Data!");
  IGC_ASSERT_MESSAGE(testIsGotoJoin(GotoJoin), "Expected goto or join!");

  auto *GotoJoinInst = dyn_cast<Instruction>(GotoJoin);

  // Before handling EVs, ensure that PHIs always use extractvalue results
  // instead of referencing goto/join results directly.
  CanonicalizePHIs();
  // Now that EV/PHI graph is canonical, collect EVs, hoist them, resolve
  // duplications.
  for (auto ui = GotoJoin->use_begin(), ue = GotoJoin->use_end(); ui != ue;) {

    auto *EV = dyn_cast<ExtractValueInst>(ui->getUser());
    ++ui;

    IGC_ASSERT_EXIT_MESSAGE(EV, "Bad user of goto/join!");
    IGC_ASSERT_EXIT_MESSAGE(EV->getNumIndices() == 1,
                            "Expected 1 index in Extract Value for goto/join!");

    const unsigned idx = EV->getIndices()[0];
    IGC_ASSERT(testPosCorrectness(idx));

    LLVM_DEBUG(dbgs() << "Found EV:\n" << *EV << "\n");
    IGC_ASSERT(idx < (sizeof(EVs) / sizeof(*EVs)));
    if (EVs[idx]) {
      LLVM_DEBUG(dbgs() << "Duplication: replacing users with:\n"
                        << *EVs[idx] << "\n");
      EV->replaceAllUsesWith(EVs[idx]);
      EV->eraseFromParent();
    } else {
      LLVM_DEBUG(dbgs() << "Saving it.\n");
      EVs[idx] = EV;
    }
  }

  // Add missing EVs for masks
  for (unsigned idx = 0, end = IsGoto ? RMPos : EMPos; idx <= end; ++idx) {
    IGC_ASSERT(idx < (sizeof(EVs) / sizeof(*EVs)));
    if (EVs[idx])
      continue;

    auto EV = ExtractValueInst::Create(
        GotoJoin, {idx}, getNameForMissingEV(idx), GotoJoinInst->getParent());
    EVs[idx] = EV;
  }

  hoistEVs();
}

/***********************************************************************
 * GotoJoinEVs::hoistEVs : move EVs right after goto/join
 */
void GenXSimdCFConformance::GotoJoinEVs::hoistEVs() const {
  IGC_ASSERT_MESSAGE(GotoJoin, "Uninitialized GotoJoinEVs Data!");

  LLVM_DEBUG(dbgs() << "Moving EV users after:\n" << *GotoJoin << "\n");

  const size_t count = (sizeof(EVs) / sizeof(*EVs));
  for (size_t idx = 0; idx < count; ++idx) {
    if (EVs[idx])
      EVs[idx]->moveAfter(dyn_cast<Instruction>(GotoJoin));
  }
}

/***********************************************************************
 * DiagnosticInfoSimdCF::emit : emit an error or warning
 */
void DiagnosticInfoSimdCF::emit(Instruction *Inst, StringRef Msg,
                                DiagnosticSeverity Severity) {
  DiagnosticInfoSimdCF Err(Severity, *Inst->getFunction(), Inst->getDebugLoc(),
                           Msg);
  Inst->getContext().diagnose(Err);
}
