Skip to content

[llvm] Support multiple save/restore points in mir #119357

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 93 additions & 7 deletions llvm/include/llvm/CodeGen/MIRYamlMapping.h
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,79 @@ LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::yaml::CalledGlobal)
namespace llvm {
namespace yaml {

// Struct representing one save/restore point in the
// 'savePoint' / 'restorePoint' list. One point consists of machine basic block
// name and list of saved/restored in this basic block registers. There are
// two forms of Save/Restore point representation:
// 1. Without explicit register enumeration:
// savePoint: '%bb.n'
// restorePoint: '%bb.n'
// in this case we assume that all CalleeSavedRegisters
// are splilled/restored in these points
// 2. With explicit register:
// savePoint:
// - point: '%bb.1'
// registers:
// - '$rbx'
// - '$r12'
// ...
// restorePoint:
// - point: '%bb.1'
// registers:
// - '$rbx'
// - '$r12'
// If this representation form is used and no register is saved/restored in the
// selected BB, the empty list of register should be specified ( i.e. registers:
// [])
struct SaveRestorePointEntry {
StringValue Point;
std::vector<StringValue> Registers;

bool operator==(const SaveRestorePointEntry &Other) const {
return Point == Other.Point && Registers == Other.Registers;
}
};

using SaveRestorePoints =
std::variant<std::vector<SaveRestorePointEntry>, StringValue>;
Comment on lines +667 to +668
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why this is a variant, and not just the vector of points

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several comments above (#119357 (comment)) @preames suggested to make support for multiple save/restore points backward compatible with single save/restore point approach.
It helped to "cut out a huge amount of spurious diff". So, StringValue in this variant needed for backward compatibility.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make no effort to maintain MIR backwards compatibility. As a diff reduction strategy, this can be part 1. But we should absolutely rip out the old form


template <> struct PolymorphicTraits<SaveRestorePoints> {

static NodeKind getKind(const SaveRestorePoints &SRPoints) {
if (std::holds_alternative<std::vector<SaveRestorePointEntry>>(SRPoints))
return NodeKind::Sequence;
if (std::holds_alternative<StringValue>(SRPoints))
return NodeKind::Scalar;
llvm_unreachable("Unsupported NodeKind of SaveRestorePoints");
}

static SaveRestorePointEntry &getAsMap(SaveRestorePoints &SRPoints) {
llvm_unreachable("SaveRestorePoints can't be represented as Map");
}

static std::vector<SaveRestorePointEntry> &
getAsSequence(SaveRestorePoints &SRPoints) {
if (!std::holds_alternative<std::vector<SaveRestorePointEntry>>(SRPoints))
SRPoints = std::vector<SaveRestorePointEntry>();

return std::get<std::vector<SaveRestorePointEntry>>(SRPoints);
}

static StringValue &getAsScalar(SaveRestorePoints &SRPoints) {
if (!std::holds_alternative<StringValue>(SRPoints))
SRPoints = StringValue();

return std::get<StringValue>(SRPoints);
}
};

template <> struct MappingTraits<SaveRestorePointEntry> {
static void mapping(IO &YamlIO, SaveRestorePointEntry &Entry) {
YamlIO.mapRequired("point", Entry.Point);
YamlIO.mapRequired("registers", Entry.Registers);
}
};

template <> struct MappingTraits<MachineJumpTable> {
static void mapping(IO &YamlIO, MachineJumpTable &JT) {
YamlIO.mapRequired("kind", JT.Kind);
Expand All @@ -639,6 +712,14 @@ template <> struct MappingTraits<MachineJumpTable> {
}
};

} // namespace yaml
} // namespace llvm

LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::yaml::SaveRestorePointEntry)

namespace llvm {
namespace yaml {

/// Serializable representation of MachineFrameInfo.
///
/// Doesn't serialize attributes like 'StackAlignment', 'IsStackRealignable' and
Expand Down Expand Up @@ -666,8 +747,8 @@ struct MachineFrameInfo {
bool HasTailCall = false;
bool IsCalleeSavedInfoValid = false;
unsigned LocalFrameSize = 0;
StringValue SavePoint;
StringValue RestorePoint;
SaveRestorePoints SavePoints;
SaveRestorePoints RestorePoints;

bool operator==(const MachineFrameInfo &Other) const {
return IsFrameAddressTaken == Other.IsFrameAddressTaken &&
Expand All @@ -688,7 +769,8 @@ struct MachineFrameInfo {
HasMustTailInVarArgFunc == Other.HasMustTailInVarArgFunc &&
HasTailCall == Other.HasTailCall &&
LocalFrameSize == Other.LocalFrameSize &&
SavePoint == Other.SavePoint && RestorePoint == Other.RestorePoint &&
SavePoints == Other.SavePoints &&
RestorePoints == Other.RestorePoints &&
IsCalleeSavedInfoValid == Other.IsCalleeSavedInfoValid;
}
};
Expand Down Expand Up @@ -720,10 +802,14 @@ template <> struct MappingTraits<MachineFrameInfo> {
YamlIO.mapOptional("isCalleeSavedInfoValid", MFI.IsCalleeSavedInfoValid,
false);
YamlIO.mapOptional("localFrameSize", MFI.LocalFrameSize, (unsigned)0);
YamlIO.mapOptional("savePoint", MFI.SavePoint,
StringValue()); // Don't print it out when it's empty.
YamlIO.mapOptional("restorePoint", MFI.RestorePoint,
StringValue()); // Don't print it out when it's empty.
YamlIO.mapOptional(
"savePoint", MFI.SavePoints,
SaveRestorePoints(
StringValue())); // Don't print it out when it's empty.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty list?

YamlIO.mapOptional(
"restorePoint", MFI.RestorePoints,
SaveRestorePoints(
StringValue())); // Don't print it out when it's empty.
}
};

Expand Down
105 changes: 96 additions & 9 deletions llvm/include/llvm/CodeGen/MachineFrameInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ class CalleeSavedInfo {
explicit CalleeSavedInfo(unsigned R, int FI = 0) : Reg(R), FrameIdx(FI) {}

// Accessors.
Register getReg() const { return Reg; }
int getFrameIdx() const { return FrameIdx; }
unsigned getDstReg() const { return DstReg; }
Register getReg() const { return Reg; }
int getFrameIdx() const { return FrameIdx; }
unsigned getDstReg() const { return DstReg; }
void setFrameIdx(int FI) {
FrameIdx = FI;
SpilledToReg = false;
Expand All @@ -74,6 +74,36 @@ class CalleeSavedInfo {
bool isSpilledToReg() const { return SpilledToReg; }
};

class SaveRestorePoints {
public:
using PointsMap = DenseMap<MachineBasicBlock *, std::vector<CalleeSavedInfo>>;

private:
PointsMap Map;

public:
const PointsMap &get() const { return Map; }

const std::vector<CalleeSavedInfo> getCSInfo(MachineBasicBlock *MBB) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid vector copy, just return ArrayRef?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lookup here returns a default constructed vector, if MBB isn't found, ref on this vector is dangling.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace with at? there should be no default construction

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at here is not a good solution, because sometimes MBB really not present in map and it is not an error.
For example, not every Prolog is a Save BB (we don't save registers in it), but in emitPrologue we call getCSInfo() for Prolog.

return Map.lookup(MBB);
}

void set(PointsMap &&CSI) { Map = std::move(CSI); }

MachineBasicBlock *findAny(const CalleeSavedInfo &Match) const {
for (auto [BB, CSIV] : Map) {
for (auto &CSI : CSIV) {
if (CSI.getReg() == Match.getReg())
return BB;
}
}
return nullptr;
}

void clear() { Map.clear(); }
bool empty() const { return Map.empty(); }
};

/// The MachineFrameInfo class represents an abstract stack frame until
/// prolog/epilog code is inserted. This class is key to allowing stack frame
/// representation optimizations, such as frame pointer elimination. It also
Expand Down Expand Up @@ -331,9 +361,16 @@ class MachineFrameInfo {
bool HasTailCall = false;

/// Not null, if shrink-wrapping found a better place for the prologue.
MachineBasicBlock *Save = nullptr;
MachineBasicBlock *Prolog = nullptr;
/// Not null, if shrink-wrapping found a better place for the epilogue.
MachineBasicBlock *Restore = nullptr;
MachineBasicBlock *Epilog = nullptr;

/// Not empty, if shrink-wrapping found a better place for saving callee
/// saves.
SaveRestorePoints SavePoints;
/// Not empty, if shrink-wrapping found a better place for restoring callee
/// saves.
SaveRestorePoints RestorePoints;

/// Size of the UnsafeStack Frame
uint64_t UnsafeStackSize = 0;
Expand Down Expand Up @@ -809,6 +846,28 @@ class MachineFrameInfo {
/// \copydoc getCalleeSavedInfo()
std::vector<CalleeSavedInfo> &getCalleeSavedInfo() { return CSInfo; }

/// Returns callee saved info vector for provided save point in
/// the current function.
const std::vector<CalleeSavedInfo>
getSaveCSInfo(MachineBasicBlock *MBB) const {
return SavePoints.getCSInfo(MBB);
}

/// Returns callee saved info vector for provided restore point
/// in the current function.
const std::vector<CalleeSavedInfo>
getRestoreCSInfo(MachineBasicBlock *MBB) const {
return RestorePoints.getCSInfo(MBB);
}

MachineBasicBlock *findSpilledIn(const CalleeSavedInfo &CSI) const {
return SavePoints.findAny(CSI);
}

MachineBasicBlock *findRestoredIn(const CalleeSavedInfo &CSI) const {
return RestorePoints.findAny(CSI);
}

/// Used by prolog/epilog inserter to set the function's callee saved
/// information.
void setCalleeSavedInfo(std::vector<CalleeSavedInfo> CSI) {
Expand All @@ -820,10 +879,38 @@ class MachineFrameInfo {

void setCalleeSavedInfoValid(bool v) { CSIValid = v; }

MachineBasicBlock *getSavePoint() const { return Save; }
void setSavePoint(MachineBasicBlock *NewSave) { Save = NewSave; }
MachineBasicBlock *getRestorePoint() const { return Restore; }
void setRestorePoint(MachineBasicBlock *NewRestore) { Restore = NewRestore; }
const SaveRestorePoints::PointsMap &getRestorePoints() const {
return RestorePoints.get();
}

const SaveRestorePoints::PointsMap &getSavePoints() const {
return SavePoints.get();
}

void setSavePoints(SaveRestorePoints::PointsMap NewSavePoints) {
SavePoints.set(std::move(NewSavePoints));
}

void setRestorePoints(SaveRestorePoints::PointsMap NewRestorePoints) {
RestorePoints.set(std::move(NewRestorePoints));
}

static const SaveRestorePoints::PointsMap constructSaveRestorePoints(
const SaveRestorePoints::PointsMap &SRPoints,
const DenseMap<MachineBasicBlock *, MachineBasicBlock *> &BBMap) {
SaveRestorePoints::PointsMap Pts{};
for (auto &Src : SRPoints)
Pts.insert({BBMap.find(Src.first)->second, Src.second});
return Pts;
}

MachineBasicBlock *getProlog() const { return Prolog; }
void setProlog(MachineBasicBlock *BB) { Prolog = BB; }
MachineBasicBlock *getEpilog() const { return Epilog; }
void setEpilog(MachineBasicBlock *BB) { Epilog = BB; }

void clearSavePoints() { SavePoints.clear(); }
void clearRestorePoints() { RestorePoints.clear(); }

uint64_t getUnsafeStackSize() const { return UnsafeStackSize; }
void setUnsafeStackSize(uint64_t Size) { UnsafeStackSize = Size; }
Expand Down
68 changes: 56 additions & 12 deletions llvm/lib/CodeGen/MIRParser/MIRParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ class MIRParserImpl {
bool initializeFrameInfo(PerFunctionMIParsingState &PFS,
const yaml::MachineFunction &YamlMF);

bool initializeSaveRestorePoints(PerFunctionMIParsingState &PFS,
const yaml::SaveRestorePoints &YamlSRPoints,
bool IsSavePoints);

bool initializeCallSiteInfo(PerFunctionMIParsingState &PFS,
const yaml::MachineFunction &YamlMF);

Expand Down Expand Up @@ -851,18 +855,12 @@ bool MIRParserImpl::initializeFrameInfo(PerFunctionMIParsingState &PFS,
MFI.setHasTailCall(YamlMFI.HasTailCall);
MFI.setCalleeSavedInfoValid(YamlMFI.IsCalleeSavedInfoValid);
MFI.setLocalFrameSize(YamlMFI.LocalFrameSize);
if (!YamlMFI.SavePoint.Value.empty()) {
MachineBasicBlock *MBB = nullptr;
if (parseMBBReference(PFS, MBB, YamlMFI.SavePoint))
return true;
MFI.setSavePoint(MBB);
}
if (!YamlMFI.RestorePoint.Value.empty()) {
MachineBasicBlock *MBB = nullptr;
if (parseMBBReference(PFS, MBB, YamlMFI.RestorePoint))
return true;
MFI.setRestorePoint(MBB);
}
if (initializeSaveRestorePoints(PFS, YamlMFI.SavePoints,
/*IsSavePoints=*/true))
return true;
if (initializeSaveRestorePoints(PFS, YamlMFI.RestorePoints,
/*IsSavePoints=*/false))
return true;

std::vector<CalleeSavedInfo> CSIInfo;
// Initialize the fixed frame objects.
Expand Down Expand Up @@ -1077,6 +1075,52 @@ bool MIRParserImpl::initializeConstantPool(PerFunctionMIParsingState &PFS,
return false;
}

// Return true if basic block was incorrectly specified in MIR
bool MIRParserImpl::initializeSaveRestorePoints(
PerFunctionMIParsingState &PFS, const yaml::SaveRestorePoints &YamlSRPoints,
bool IsSavePoints) {
SMDiagnostic Error;
MachineBasicBlock *MBB = nullptr;
llvm::SaveRestorePoints::PointsMap SRPoints;
MachineFunction &MF = PFS.MF;
MachineFrameInfo &MFI = MF.getFrameInfo();

if (std::holds_alternative<std::vector<yaml::SaveRestorePointEntry>>(
YamlSRPoints)) {
const auto &VectorRepr =
std::get<std::vector<yaml::SaveRestorePointEntry>>(YamlSRPoints);
if (VectorRepr.empty())
return false;
std::vector<CalleeSavedInfo> Registers;
for (const auto &Entry : VectorRepr) {
const auto &MBBSource = Entry.Point;
if (parseMBBReference(PFS, MBB, MBBSource.Value))
return true;
Registers.clear();
for (auto &RegStr : Entry.Registers) {
Register Reg;
if (parseNamedRegisterReference(PFS, Reg, RegStr.Value, Error))
return error(Error, RegStr.SourceRange);
Registers.push_back(CalleeSavedInfo(Reg));
}
SRPoints.insert(std::make_pair(MBB, Registers));
}
} else {
yaml::StringValue StringRepr = std::get<yaml::StringValue>(YamlSRPoints);
if (StringRepr.Value.empty())
return false;
if (parseMBBReference(PFS, MBB, StringRepr))
return true;
SRPoints.insert(std::make_pair(MBB, MFI.getCalleeSavedInfo()));
}

if (IsSavePoints)
MFI.setSavePoints(SRPoints);
else
MFI.setRestorePoints(SRPoints);
return false;
}

bool MIRParserImpl::initializeJumpTableInfo(PerFunctionMIParsingState &PFS,
const yaml::MachineJumpTable &YamlJTI) {
MachineJumpTableInfo *JTI = PFS.MF.getOrCreateJumpTableInfo(YamlJTI.Kind);
Expand Down
Loading