From 6cd7d8dafd6a5cb438c5bd595e126f3cd863814e Mon Sep 17 00:00:00 2001 From: mingmingl Date: Thu, 1 May 2025 09:37:59 -0700 Subject: [PATCH 1/8] Add classes to construct and use data access profiles --- llvm/include/llvm/ADT/MapVector.h | 2 + .../include/llvm/ProfileData/DataAccessProf.h | 156 +++++++++++ llvm/include/llvm/ProfileData/InstrProf.h | 16 +- llvm/lib/ProfileData/CMakeLists.txt | 1 + llvm/lib/ProfileData/DataAccessProf.cpp | 246 ++++++++++++++++++ llvm/lib/ProfileData/InstrProf.cpp | 8 +- llvm/unittests/ProfileData/MemProfTest.cpp | 161 ++++++++++++ 7 files changed, 579 insertions(+), 11 deletions(-) create mode 100644 llvm/include/llvm/ProfileData/DataAccessProf.h create mode 100644 llvm/lib/ProfileData/DataAccessProf.cpp diff --git a/llvm/include/llvm/ADT/MapVector.h b/llvm/include/llvm/ADT/MapVector.h index c11617a81c97d..fe0d106795c34 100644 --- a/llvm/include/llvm/ADT/MapVector.h +++ b/llvm/include/llvm/ADT/MapVector.h @@ -57,6 +57,8 @@ class MapVector { return std::move(Vector); } + ArrayRef getArrayRef() const { return Vector; } + size_type size() const { return Vector.size(); } /// Grow the MapVector so that it can contain at least \p NumEntries items diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h new file mode 100644 index 0000000000000..2cce4945fddd5 --- /dev/null +++ b/llvm/include/llvm/ProfileData/DataAccessProf.h @@ -0,0 +1,156 @@ +//===- DataAccessProf.h - Data access profile format support ---------*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains support to construct and use data access profiles. +// +// For the original RFC of this pass please see +// https://discourse.llvm.org/t/rfc-profile-guided-static-data-partitioning/83744 +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_PROFILEDATA_DATAACCESSPROF_H_ +#define LLVM_PROFILEDATA_DATAACCESSPROF_H_ + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseMapInfoVariant.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ProfileData/InstrProf.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/StringSaver.h" + +#include +#include + +namespace llvm { + +namespace data_access_prof { +// The location of data in the source code. +struct DataLocation { + // The filename where the data is located. + StringRef FileName; + // The line number in the source code. + uint32_t Line; +}; + +// The data access profiles for a symbol. +struct DataAccessProfRecord { + // Represents a data symbol. The semantic comes in two forms: a symbol index + // for symbol name if `IsStringLiteral` is false, or the hash of a string + // content if `IsStringLiteral` is true. Required. + uint64_t SymbolID; + + // The access count of symbol. Required. + uint64_t AccessCount; + + // True iff this is a record for string literal (symbols with name pattern + // `.str.*` in the symbol table). Required. + bool IsStringLiteral; + + // The locations of data in the source code. Optional. + llvm::SmallVector Locations; +}; + +/// Encapsulates the data access profile data and the methods to operate on it. +/// This class provides profile look-up, serialization and deserialization. +class DataAccessProfData { +public: + // SymbolID is either a string representing symbol name, or a uint64_t + // representing the content hash of a string literal. + using SymbolID = std::variant; + using StringToIndexMap = llvm::MapVector; + + DataAccessProfData() : saver(Allocator) {} + + /// Serialize profile data to the output stream. + /// Storage layout: + /// - Serialized strings. + /// - The encoded hashes. + /// - Records. + Error serialize(ProfOStream &OS) const; + + /// Deserialize this class from the given buffer. + Error deserialize(const unsigned char *&Ptr); + + /// Returns a pointer of profile record for \p SymbolID, or nullptr if there + /// isn't a record. Internally, this function will canonicalize the symbol + /// name before the lookup. + const DataAccessProfRecord *getProfileRecord(const SymbolID SymID) const; + + /// Returns true if \p SymID is seen in profiled binaries and cold. + bool isKnownColdSymbol(const SymbolID SymID) const; + + /// Methods to add symbolized data access profile. Returns error if duplicated + /// symbol names or content hashes are seen. The user of this class should + /// aggregate counters that corresponds to the same symbol name or with the + /// same string literal hash before calling 'add*' methods. + Error addSymbolizedDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount); + Error addSymbolizedDataAccessProfile( + SymbolID SymbolID, uint64_t AccessCount, + const llvm::SmallVector &Locations); + Error addKnownSymbolWithoutSamples(SymbolID SymbolID); + + /// Returns a iterable StringRef for strings in the order they are added. + auto getStrings() const { + ArrayRef> RefSymbolNames( + StrToIndexMap.begin(), StrToIndexMap.end()); + return llvm::make_first_range(RefSymbolNames); + } + + /// Returns array reference for various internal data structures. + inline ArrayRef< + std::pair, DataAccessProfRecord>> + getRecords() const { + return Records.getArrayRef(); + } + inline ArrayRef getKnownColdSymbols() const { + return KnownColdSymbols.getArrayRef(); + } + inline ArrayRef getKnownColdHashes() const { + return KnownColdHashes.getArrayRef(); + } + +private: + /// Serialize the symbol strings into the output stream. + Error serializeStrings(ProfOStream &OS) const; + + /// Deserialize the symbol strings from \p Ptr and increment \p Ptr to the + /// start of the next payload. + Error deserializeStrings(const unsigned char *&Ptr, + const uint64_t NumSampledSymbols, + uint64_t NumColdKnownSymbols); + + /// Decode the records and increment \p Ptr to the start of the next payload. + Error deserializeRecords(const unsigned char *&Ptr); + + /// A helper function to compute a storage index for \p SymbolID. + uint64_t getEncodedIndex(const SymbolID SymbolID) const; + + // `Records` stores the records and `SymbolToRecordIndex` maps a symbol ID to + // its record index. + MapVector Records; + + // Use MapVector to keep input order of strings for serialization and + // deserialization. + StringToIndexMap StrToIndexMap; + llvm::SetVector KnownColdHashes; + llvm::SetVector KnownColdSymbols; + // Keeps owned copies of the input strings. + llvm::BumpPtrAllocator Allocator; + llvm::UniqueStringSaver saver; +}; + +} // namespace data_access_prof +} // namespace llvm + +#endif // LLVM_PROFILEDATA_DATAACCESSPROF_H_ diff --git a/llvm/include/llvm/ProfileData/InstrProf.h b/llvm/include/llvm/ProfileData/InstrProf.h index 2d011c89f27cb..8a6be22bdb1a4 100644 --- a/llvm/include/llvm/ProfileData/InstrProf.h +++ b/llvm/include/llvm/ProfileData/InstrProf.h @@ -357,6 +357,12 @@ void createPGONameMetadata(GlobalObject &GO, StringRef PGOName); /// the duplicated profile variables for Comdat functions. bool needsComdatForCounter(const GlobalObject &GV, const Module &M); +/// \c NameStrings is a string composed of one of more possibly encoded +/// sub-strings. The substrings are separated by 0 or more zero bytes. This +/// method decodes the string and calls `NameCallback` for each substring. +Error readAndDecodeStrings(StringRef NameStrings, + std::function NameCallback); + /// An enum describing the attributes of an instrumented profile. enum class InstrProfKind { Unknown = 0x0, @@ -493,6 +499,11 @@ class InstrProfSymtab { public: using AddrHashMap = std::vector>; + // Returns the canonial name of the given PGOName. In a canonical name, all + // suffixes that begins with "." except ".__uniq." are stripped. + // FIXME: Unify this with `FunctionSamples::getCanonicalFnName`. + static StringRef getCanonicalName(StringRef PGOName); + private: using AddrIntervalMap = IntervalMap>; @@ -528,11 +539,6 @@ class InstrProfSymtab { static StringRef getExternalSymbol() { return "** External Symbol **"; } - // Returns the canonial name of the given PGOName. In a canonical name, all - // suffixes that begins with "." except ".__uniq." are stripped. - // FIXME: Unify this with `FunctionSamples::getCanonicalFnName`. - static StringRef getCanonicalName(StringRef PGOName); - // Add the function into the symbol table, by creating the following // map entries: // name-set = {PGOFuncName} union {getCanonicalName(PGOFuncName)} diff --git a/llvm/lib/ProfileData/CMakeLists.txt b/llvm/lib/ProfileData/CMakeLists.txt index eb7c2a3c1a28a..67a69d7761b2c 100644 --- a/llvm/lib/ProfileData/CMakeLists.txt +++ b/llvm/lib/ProfileData/CMakeLists.txt @@ -1,4 +1,5 @@ add_llvm_component_library(LLVMProfileData + DataAccessProf.cpp GCOV.cpp IndexedMemProfData.cpp InstrProf.cpp diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp new file mode 100644 index 0000000000000..cf538d6a1b28e --- /dev/null +++ b/llvm/lib/ProfileData/DataAccessProf.cpp @@ -0,0 +1,246 @@ +#include "llvm/ProfileData/DataAccessProf.h" +#include "llvm/ADT/DenseMapInfoVariant.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ProfileData/InstrProf.h" +#include "llvm/Support/Compression.h" +#include "llvm/Support/Endian.h" +#include "llvm/Support/Errc.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/StringSaver.h" +#include "llvm/Support/raw_ostream.h" +#include + +namespace llvm { +namespace data_access_prof { + +// If `Map` has an entry keyed by `Str`, returns the entry iterator. Otherwise, +// creates an owned copy of `Str`, adds a map entry for it and returns the +// iterator. +static MapVector::iterator +saveStringToMap(MapVector &Map, + llvm::UniqueStringSaver &saver, StringRef Str) { + auto [Iter, Inserted] = Map.try_emplace(saver.save(Str), Map.size()); + return Iter; +} + +const DataAccessProfRecord * +DataAccessProfData::getProfileRecord(const SymbolID SymbolID) const { + auto Key = SymbolID; + if (std::holds_alternative(SymbolID)) + Key = InstrProfSymtab::getCanonicalName(std::get(SymbolID)); + + auto It = Records.find(Key); + if (It != Records.end()) + return &It->second; + + return nullptr; +} + +bool DataAccessProfData::isKnownColdSymbol(const SymbolID SymID) const { + if (std::holds_alternative(SymID)) + return KnownColdHashes.count(std::get(SymID)); + return KnownColdSymbols.count(std::get(SymID)); +} + +Error DataAccessProfData::addSymbolizedDataAccessProfile(SymbolID Symbol, + uint64_t AccessCount) { + uint64_t RecordID = -1; + bool IsStringLiteral = false; + SymbolID Key; + if (std::holds_alternative(Symbol)) { + RecordID = std::get(Symbol); + Key = RecordID; + IsStringLiteral = true; + } else { + StringRef SymbolName = std::get(Symbol); + if (SymbolName.empty()) + return make_error("Empty symbol name", + llvm::errc::invalid_argument); + + StringRef CanonicalName = InstrProfSymtab::getCanonicalName(SymbolName); + Key = CanonicalName; + RecordID = saveStringToMap(StrToIndexMap, saver, CanonicalName)->second; + IsStringLiteral = false; + } + + auto [Iter, Inserted] = Records.try_emplace( + Key, DataAccessProfRecord{RecordID, AccessCount, IsStringLiteral}); + if (!Inserted) + return make_error("Duplicate symbol or string literal added. " + "User of DataAccessProfData should " + "aggregate count for the same symbol. ", + llvm::errc::invalid_argument); + + return Error::success(); +} + +Error DataAccessProfData::addSymbolizedDataAccessProfile( + SymbolID SymbolID, uint64_t AccessCount, + const llvm::SmallVector &Locations) { + if (Error E = addSymbolizedDataAccessProfile(SymbolID, AccessCount)) + return E; + + auto &Record = Records.back().second; + for (const auto &Location : Locations) + Record.Locations.push_back( + {saveStringToMap(StrToIndexMap, saver, Location.FileName)->first, + Location.Line}); + + return Error::success(); +} + +Error DataAccessProfData::addKnownSymbolWithoutSamples(SymbolID SymbolID) { + if (std::holds_alternative(SymbolID)) { + KnownColdHashes.insert(std::get(SymbolID)); + return Error::success(); + } + StringRef SymbolName = std::get(SymbolID); + if (SymbolName.empty()) + return make_error("Empty symbol name", + llvm::errc::invalid_argument); + StringRef CanonicalSymName = InstrProfSymtab::getCanonicalName(SymbolName); + KnownColdSymbols.insert(CanonicalSymName); + return Error::success(); +} + +Error DataAccessProfData::deserialize(const unsigned char *&Ptr) { + uint64_t NumSampledSymbols = + support::endian::readNext(Ptr); + uint64_t NumColdKnownSymbols = + support::endian::readNext(Ptr); + if (Error E = deserializeStrings(Ptr, NumSampledSymbols, NumColdKnownSymbols)) + return E; + + uint64_t Num = + support::endian::readNext(Ptr); + for (uint64_t I = 0; I < Num; ++I) + KnownColdHashes.insert( + support::endian::readNext(Ptr)); + + return deserializeRecords(Ptr); +} + +Error DataAccessProfData::serializeStrings(ProfOStream &OS) const { + OS.write(StrToIndexMap.size()); + OS.write(KnownColdSymbols.size()); + + std::vector Strs; + Strs.reserve(StrToIndexMap.size() + KnownColdSymbols.size()); + for (const auto &Str : StrToIndexMap) + Strs.push_back(Str.first.str()); + for (const auto &Str : KnownColdSymbols) + Strs.push_back(Str.str()); + + std::string CompressedStrings; + if (!Strs.empty()) + if (Error E = collectGlobalObjectNameStrings( + Strs, compression::zlib::isAvailable(), CompressedStrings)) + return E; + const uint64_t CompressedStringLen = CompressedStrings.length(); + // Record the length of compressed string. + OS.write(CompressedStringLen); + // Write the chars in compressed strings. + for (auto &c : CompressedStrings) + OS.writeByte(static_cast(c)); + // Pad up to a multiple of 8. + // InstrProfReader could read bytes according to 'CompressedStringLen'. + const uint64_t PaddedLength = alignTo(CompressedStringLen, 8); + for (uint64_t K = CompressedStringLen; K < PaddedLength; K++) + OS.writeByte(0); + return Error::success(); +} + +uint64_t DataAccessProfData::getEncodedIndex(const SymbolID SymbolID) const { + if (std::holds_alternative(SymbolID)) + return std::get(SymbolID); + + return StrToIndexMap.find(std::get(SymbolID))->second; +} + +Error DataAccessProfData::serialize(ProfOStream &OS) const { + if (Error E = serializeStrings(OS)) + return E; + OS.write(KnownColdHashes.size()); + for (const auto &Hash : KnownColdHashes) + OS.write(Hash); + OS.write((uint64_t)(Records.size())); + for (const auto &[Key, Rec] : Records) { + OS.write(getEncodedIndex(Rec.SymbolID)); + OS.writeByte(Rec.IsStringLiteral); + OS.write(Rec.AccessCount); + OS.write(Rec.Locations.size()); + for (const auto &Loc : Rec.Locations) { + OS.write(getEncodedIndex(Loc.FileName)); + OS.write32(Loc.Line); + } + } + return Error::success(); +} + +Error DataAccessProfData::deserializeStrings(const unsigned char *&Ptr, + uint64_t NumSampledSymbols, + uint64_t NumColdKnownSymbols) { + uint64_t Len = + support::endian::readNext(Ptr); + + // With M=NumSampledSymbols and N=NumColdKnownSymbols, the first M strings are + // symbols with samples, and next N strings are known cold symbols. + uint64_t StringCnt = 0; + std::function addName = [&](StringRef Name) { + if (StringCnt < NumSampledSymbols) + saveStringToMap(StrToIndexMap, saver, Name); + else + KnownColdSymbols.insert(saver.save(Name)); + ++StringCnt; + return Error::success(); + }; + if (Error E = + readAndDecodeStrings(StringRef((const char *)Ptr, Len), addName)) + return E; + + Ptr += alignTo(Len, 8); + return Error::success(); +} + +Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) { + SmallVector Strings = llvm::to_vector(getStrings()); + + uint64_t NumRecords = + support::endian::readNext(Ptr); + + for (uint64_t I = 0; I < NumRecords; ++I) { + uint64_t ID = + support::endian::readNext(Ptr); + + bool IsStringLiteral = + support::endian::readNext(Ptr); + + uint64_t AccessCount = + support::endian::readNext(Ptr); + + SymbolID SymbolID; + if (IsStringLiteral) + SymbolID = ID; + else + SymbolID = Strings[ID]; + if (Error E = addSymbolizedDataAccessProfile(SymbolID, AccessCount)) + return E; + + auto &Record = Records.back().second; + + uint64_t NumLocations = + support::endian::readNext(Ptr); + + Record.Locations.reserve(NumLocations); + for (uint64_t J = 0; J < NumLocations; ++J) { + uint64_t FileNameIndex = + support::endian::readNext(Ptr); + uint32_t Line = + support::endian::readNext(Ptr); + Record.Locations.push_back({Strings[FileNameIndex], Line}); + } + } + return Error::success(); +} +} // namespace data_access_prof +} // namespace llvm diff --git a/llvm/lib/ProfileData/InstrProf.cpp b/llvm/lib/ProfileData/InstrProf.cpp index 88621787c1dd9..254f941acde82 100644 --- a/llvm/lib/ProfileData/InstrProf.cpp +++ b/llvm/lib/ProfileData/InstrProf.cpp @@ -573,12 +573,8 @@ Error InstrProfSymtab::addVTableWithName(GlobalVariable &VTable, return Error::success(); } -/// \c NameStrings is a string composed of one of more possibly encoded -/// sub-strings. The substrings are separated by 0 or more zero bytes. This -/// method decodes the string and calls `NameCallback` for each substring. -static Error -readAndDecodeStrings(StringRef NameStrings, - std::function NameCallback) { +Error readAndDecodeStrings(StringRef NameStrings, + std::function NameCallback) { const uint8_t *P = NameStrings.bytes_begin(); const uint8_t *EndP = NameStrings.bytes_end(); while (P < EndP) { diff --git a/llvm/unittests/ProfileData/MemProfTest.cpp b/llvm/unittests/ProfileData/MemProfTest.cpp index 3e430aa4eae58..f6362448a5734 100644 --- a/llvm/unittests/ProfileData/MemProfTest.cpp +++ b/llvm/unittests/ProfileData/MemProfTest.cpp @@ -10,14 +10,17 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLForwardCompat.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/DebugInfo/DIContext.h" #include "llvm/DebugInfo/Symbolize/SymbolizableModule.h" #include "llvm/IR/Value.h" #include "llvm/Object/ObjectFile.h" +#include "llvm/ProfileData/DataAccessProf.h" #include "llvm/ProfileData/MemProfData.inc" #include "llvm/ProfileData/MemProfReader.h" #include "llvm/ProfileData/MemProfYAML.h" #include "llvm/Support/raw_ostream.h" +#include "gmock/gmock-more-matchers.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -36,6 +39,8 @@ using ::llvm::StringRef; using ::llvm::object::SectionedAddress; using ::llvm::symbolize::SymbolizableModule; using ::testing::ElementsAre; +using ::testing::ElementsAreArray; +using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::Pair; using ::testing::Return; @@ -747,6 +752,162 @@ TEST(MemProf, YAMLParser) { ElementsAre(0x3000))))); } +static std::string ErrorToString(Error E) { + std::string ErrMsg; + llvm::raw_string_ostream OS(ErrMsg); + llvm::logAllUnhandledErrors(std::move(E), OS); + return ErrMsg; +} + +// Test the various scenarios when DataAccessProfData should return error on +// invalid input. +TEST(MemProf, DataAccessProfileError) { + // Returns error if the input symbol name is empty. + llvm::data_access_prof::DataAccessProfData Data; + EXPECT_THAT(ErrorToString(Data.addSymbolizedDataAccessProfile("", 100)), + HasSubstr("Empty symbol name")); + + // Returns error when the same symbol gets added twice. + ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("foo", 100)); + EXPECT_THAT(ErrorToString(Data.addSymbolizedDataAccessProfile("foo", 100)), + HasSubstr("Duplicate symbol or string literal added")); + + // Returns error when the same string content hash gets added twice. + ASSERT_FALSE(Data.addSymbolizedDataAccessProfile((uint64_t)135246, 1000)); + EXPECT_THAT(ErrorToString( + Data.addSymbolizedDataAccessProfile((uint64_t)135246, 1000)), + HasSubstr("Duplicate symbol or string literal added")); +} + +// Test the following operations on DataAccessProfData: +// - Profile record look up. +// - Serialization and de-serialization. +TEST(MemProf, DataAccessProfile) { + using namespace llvm::data_access_prof; + llvm::data_access_prof::DataAccessProfData Data; + + // In the bool conversion, Error is true if it's in a failure state and false + // if it's in an accept state. Use ASSERT_FALSE or EXPECT_FALSE for no error. + ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("foo.llvm.123", 100)); + ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)789)); + ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym2")); + ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("bar.__uniq.321", 123, + { + DataLocation{"file2", 3}, + })); + ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym1")); + ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)678)); + ASSERT_FALSE(Data.addSymbolizedDataAccessProfile( + (uint64_t)135246, 1000, + {DataLocation{"file1", 1}, DataLocation{"file2", 2}})); + + { + // Test that symbol names and file names are stored in the input order. + EXPECT_THAT(llvm::to_vector(Data.getStrings()), + ElementsAre("foo", "bar.__uniq.321", "file2", "file1")); + EXPECT_THAT(Data.getKnownColdSymbols(), ElementsAre("sym2", "sym1")); + EXPECT_THAT(Data.getKnownColdHashes(), ElementsAre(789, 678)); + + // Look up profiles. + EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)789)); + EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)678)); + EXPECT_TRUE(Data.isKnownColdSymbol("sym2")); + EXPECT_TRUE(Data.isKnownColdSymbol("sym1")); + + EXPECT_EQ(Data.getProfileRecord("non-existence"), nullptr); + EXPECT_EQ(Data.getProfileRecord((uint64_t)789987), nullptr); + + EXPECT_THAT( + Data.getProfileRecord("foo.llvm.123"), + AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0), + testing::Field(&DataAccessProfRecord::AccessCount, 100), + testing::Field(&DataAccessProfRecord::IsStringLiteral, false), + testing::Field(&DataAccessProfRecord::Locations, + testing::IsEmpty()))); + EXPECT_THAT( + *Data.getProfileRecord("bar.__uniq.321"), + AllOf( + testing::Field(&DataAccessProfRecord::SymbolID, 1), + testing::Field(&DataAccessProfRecord::AccessCount, 123), + testing::Field(&DataAccessProfRecord::IsStringLiteral, false), + testing::Field(&DataAccessProfRecord::Locations, + ElementsAre(AllOf( + testing::Field(&DataLocation::FileName, "file2"), + testing::Field(&DataLocation::Line, 3)))))); + EXPECT_THAT( + *Data.getProfileRecord((uint64_t)135246), + AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 135246), + testing::Field(&DataAccessProfRecord::AccessCount, 1000), + testing::Field(&DataAccessProfRecord::IsStringLiteral, true), + testing::Field( + &DataAccessProfRecord::Locations, + ElementsAre( + AllOf(testing::Field(&DataLocation::FileName, "file1"), + testing::Field(&DataLocation::Line, 1)), + AllOf(testing::Field(&DataLocation::FileName, "file2"), + testing::Field(&DataLocation::Line, 2)))))); + } + + // Tests serialization and de-serialization. + llvm::data_access_prof::DataAccessProfData deserializedData; + { + std::string serializedData; + llvm::raw_string_ostream OS(serializedData); + llvm::ProfOStream POS(OS); + + EXPECT_FALSE(Data.serialize(POS)); + + const unsigned char *p = + reinterpret_cast(serializedData.data()); + ASSERT_THAT(llvm::to_vector(deserializedData.getStrings()), + testing::IsEmpty()); + EXPECT_FALSE(deserializedData.deserialize(p)); + + EXPECT_THAT(llvm::to_vector(deserializedData.getStrings()), + ElementsAre("foo", "bar.__uniq.321", "file2", "file1")); + EXPECT_THAT(deserializedData.getKnownColdSymbols(), + ElementsAre("sym2", "sym1")); + EXPECT_THAT(deserializedData.getKnownColdHashes(), ElementsAre(789, 678)); + + // Look up profiles after deserialization. + EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)789)); + EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)678)); + EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym2")); + EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym1")); + + auto Records = + llvm::to_vector(llvm::make_second_range(deserializedData.getRecords())); + + EXPECT_THAT( + Records, + ElementsAre( + AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0), + testing::Field(&DataAccessProfRecord::AccessCount, 100), + testing::Field(&DataAccessProfRecord::IsStringLiteral, false), + testing::Field(&DataAccessProfRecord::Locations, + testing::IsEmpty())), + AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 1), + testing::Field(&DataAccessProfRecord::AccessCount, 123), + testing::Field(&DataAccessProfRecord::IsStringLiteral, false), + testing::Field( + &DataAccessProfRecord::Locations, + ElementsAre(AllOf( + testing::Field(&DataLocation::FileName, "file2"), + testing::Field(&DataLocation::Line, 3))))), + AllOf( + testing::Field(&DataAccessProfRecord::SymbolID, 135246), + testing::Field(&DataAccessProfRecord::AccessCount, 1000), + testing::Field(&DataAccessProfRecord::IsStringLiteral, true), + testing::Field( + &DataAccessProfRecord::Locations, + ElementsAre( + AllOf(testing::Field(&DataLocation::FileName, "file1"), + testing::Field(&DataLocation::Line, 1)), + AllOf(testing::Field(&DataLocation::FileName, "file2"), + testing::Field(&DataLocation::Line, 2))))))); + } +} + // Verify that the YAML parser accepts a GUID expressed as a function name. TEST(MemProf, YAMLParserGUID) { StringRef YAMLData = R"YAML( From 47275299fd3e9ef242f108040df8ffe46f9cd7b0 Mon Sep 17 00:00:00 2001 From: mingmingl Date: Mon, 5 May 2025 16:57:56 -0700 Subject: [PATCH 2/8] resolve review feedback --- .../include/llvm/ProfileData/DataAccessProf.h | 37 ++++++++++------ llvm/lib/ProfileData/DataAccessProf.cpp | 43 ++++++++++--------- llvm/unittests/ProfileData/MemProfTest.cpp | 23 +++++----- 3 files changed, 57 insertions(+), 46 deletions(-) diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h index 2cce4945fddd5..36648d6298ee5 100644 --- a/llvm/include/llvm/ProfileData/DataAccessProf.h +++ b/llvm/include/llvm/ProfileData/DataAccessProf.h @@ -45,6 +45,11 @@ struct DataLocation { // The data access profiles for a symbol. struct DataAccessProfRecord { + DataAccessProfRecord(uint64_t SymbolID, uint64_t AccessCount, + bool IsStringLiteral) + : SymbolID(SymbolID), AccessCount(AccessCount), + IsStringLiteral(IsStringLiteral) {} + // Represents a data symbol. The semantic comes in two forms: a symbol index // for symbol name if `IsStringLiteral` is false, or the hash of a string // content if `IsStringLiteral` is true. Required. @@ -58,7 +63,7 @@ struct DataAccessProfRecord { bool IsStringLiteral; // The locations of data in the source code. Optional. - llvm::SmallVector Locations; + llvm::SmallVector Locations; }; /// Encapsulates the data access profile data and the methods to operate on it. @@ -70,7 +75,7 @@ class DataAccessProfData { using SymbolID = std::variant; using StringToIndexMap = llvm::MapVector; - DataAccessProfData() : saver(Allocator) {} + DataAccessProfData() : Saver(Allocator) {} /// Serialize profile data to the output stream. /// Storage layout: @@ -94,10 +99,13 @@ class DataAccessProfData { /// symbol names or content hashes are seen. The user of this class should /// aggregate counters that corresponds to the same symbol name or with the /// same string literal hash before calling 'add*' methods. - Error addSymbolizedDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount); - Error addSymbolizedDataAccessProfile( - SymbolID SymbolID, uint64_t AccessCount, - const llvm::SmallVector &Locations); + Error setDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount); + /// Similar to the method above, for records with \p Locations representing + /// the `filename:line` where this symbol shows up. Note because of linker's + /// merge of identical symbols (e.g., unnamed_addr string literals), one + /// symbol is likely to have multiple locations. + Error setDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount, + const llvm::SmallVector &Locations); Error addKnownSymbolWithoutSamples(SymbolID SymbolID); /// Returns a iterable StringRef for strings in the order they are added. @@ -122,13 +130,13 @@ class DataAccessProfData { private: /// Serialize the symbol strings into the output stream. - Error serializeStrings(ProfOStream &OS) const; + Error serializeSymbolsAndFilenames(ProfOStream &OS) const; /// Deserialize the symbol strings from \p Ptr and increment \p Ptr to the /// start of the next payload. - Error deserializeStrings(const unsigned char *&Ptr, - const uint64_t NumSampledSymbols, - uint64_t NumColdKnownSymbols); + Error deserializeSymbolsAndFilenames(const unsigned char *&Ptr, + const uint64_t NumSampledSymbols, + uint64_t NumColdKnownSymbols); /// Decode the records and increment \p Ptr to the start of the next payload. Error deserializeRecords(const unsigned char *&Ptr); @@ -136,6 +144,12 @@ class DataAccessProfData { /// A helper function to compute a storage index for \p SymbolID. uint64_t getEncodedIndex(const SymbolID SymbolID) const; + // Keeps owned copies of the input strings. + // NOTE: Keep `Saver` initialized before other class members that reference + // its string copies and destructed after they are destructed. + llvm::BumpPtrAllocator Allocator; + llvm::UniqueStringSaver Saver; + // `Records` stores the records and `SymbolToRecordIndex` maps a symbol ID to // its record index. MapVector Records; @@ -145,9 +159,6 @@ class DataAccessProfData { StringToIndexMap StrToIndexMap; llvm::SetVector KnownColdHashes; llvm::SetVector KnownColdSymbols; - // Keeps owned copies of the input strings. - llvm::BumpPtrAllocator Allocator; - llvm::UniqueStringSaver saver; }; } // namespace data_access_prof diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp index cf538d6a1b28e..c52533c13919c 100644 --- a/llvm/lib/ProfileData/DataAccessProf.cpp +++ b/llvm/lib/ProfileData/DataAccessProf.cpp @@ -18,8 +18,8 @@ namespace data_access_prof { // iterator. static MapVector::iterator saveStringToMap(MapVector &Map, - llvm::UniqueStringSaver &saver, StringRef Str) { - auto [Iter, Inserted] = Map.try_emplace(saver.save(Str), Map.size()); + llvm::UniqueStringSaver &Saver, StringRef Str) { + auto [Iter, Inserted] = Map.try_emplace(Saver.save(Str), Map.size()); return Iter; } @@ -38,12 +38,12 @@ DataAccessProfData::getProfileRecord(const SymbolID SymbolID) const { bool DataAccessProfData::isKnownColdSymbol(const SymbolID SymID) const { if (std::holds_alternative(SymID)) - return KnownColdHashes.count(std::get(SymID)); - return KnownColdSymbols.count(std::get(SymID)); + return KnownColdHashes.contains(std::get(SymID)); + return KnownColdSymbols.contains(std::get(SymID)); } -Error DataAccessProfData::addSymbolizedDataAccessProfile(SymbolID Symbol, - uint64_t AccessCount) { +Error DataAccessProfData::setDataAccessProfile(SymbolID Symbol, + uint64_t AccessCount) { uint64_t RecordID = -1; bool IsStringLiteral = false; SymbolID Key; @@ -59,12 +59,12 @@ Error DataAccessProfData::addSymbolizedDataAccessProfile(SymbolID Symbol, StringRef CanonicalName = InstrProfSymtab::getCanonicalName(SymbolName); Key = CanonicalName; - RecordID = saveStringToMap(StrToIndexMap, saver, CanonicalName)->second; + RecordID = saveStringToMap(StrToIndexMap, Saver, CanonicalName)->second; IsStringLiteral = false; } - auto [Iter, Inserted] = Records.try_emplace( - Key, DataAccessProfRecord{RecordID, AccessCount, IsStringLiteral}); + auto [Iter, Inserted] = + Records.try_emplace(Key, RecordID, AccessCount, IsStringLiteral); if (!Inserted) return make_error("Duplicate symbol or string literal added. " "User of DataAccessProfData should " @@ -74,16 +74,16 @@ Error DataAccessProfData::addSymbolizedDataAccessProfile(SymbolID Symbol, return Error::success(); } -Error DataAccessProfData::addSymbolizedDataAccessProfile( +Error DataAccessProfData::setDataAccessProfile( SymbolID SymbolID, uint64_t AccessCount, const llvm::SmallVector &Locations) { - if (Error E = addSymbolizedDataAccessProfile(SymbolID, AccessCount)) + if (Error E = setDataAccessProfile(SymbolID, AccessCount)) return E; auto &Record = Records.back().second; for (const auto &Location : Locations) Record.Locations.push_back( - {saveStringToMap(StrToIndexMap, saver, Location.FileName)->first, + {saveStringToMap(StrToIndexMap, Saver, Location.FileName)->first, Location.Line}); return Error::success(); @@ -108,7 +108,8 @@ Error DataAccessProfData::deserialize(const unsigned char *&Ptr) { support::endian::readNext(Ptr); uint64_t NumColdKnownSymbols = support::endian::readNext(Ptr); - if (Error E = deserializeStrings(Ptr, NumSampledSymbols, NumColdKnownSymbols)) + if (Error E = deserializeSymbolsAndFilenames(Ptr, NumSampledSymbols, + NumColdKnownSymbols)) return E; uint64_t Num = @@ -120,7 +121,7 @@ Error DataAccessProfData::deserialize(const unsigned char *&Ptr) { return deserializeRecords(Ptr); } -Error DataAccessProfData::serializeStrings(ProfOStream &OS) const { +Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const { OS.write(StrToIndexMap.size()); OS.write(KnownColdSymbols.size()); @@ -158,7 +159,7 @@ uint64_t DataAccessProfData::getEncodedIndex(const SymbolID SymbolID) const { } Error DataAccessProfData::serialize(ProfOStream &OS) const { - if (Error E = serializeStrings(OS)) + if (Error E = serializeSymbolsAndFilenames(OS)) return E; OS.write(KnownColdHashes.size()); for (const auto &Hash : KnownColdHashes) @@ -177,9 +178,9 @@ Error DataAccessProfData::serialize(ProfOStream &OS) const { return Error::success(); } -Error DataAccessProfData::deserializeStrings(const unsigned char *&Ptr, - uint64_t NumSampledSymbols, - uint64_t NumColdKnownSymbols) { +Error DataAccessProfData::deserializeSymbolsAndFilenames( + const unsigned char *&Ptr, uint64_t NumSampledSymbols, + uint64_t NumColdKnownSymbols) { uint64_t Len = support::endian::readNext(Ptr); @@ -188,9 +189,9 @@ Error DataAccessProfData::deserializeStrings(const unsigned char *&Ptr, uint64_t StringCnt = 0; std::function addName = [&](StringRef Name) { if (StringCnt < NumSampledSymbols) - saveStringToMap(StrToIndexMap, saver, Name); + saveStringToMap(StrToIndexMap, Saver, Name); else - KnownColdSymbols.insert(saver.save(Name)); + KnownColdSymbols.insert(Saver.save(Name)); ++StringCnt; return Error::success(); }; @@ -223,7 +224,7 @@ Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) { SymbolID = ID; else SymbolID = Strings[ID]; - if (Error E = addSymbolizedDataAccessProfile(SymbolID, AccessCount)) + if (Error E = setDataAccessProfile(SymbolID, AccessCount)) return E; auto &Record = Records.back().second; diff --git a/llvm/unittests/ProfileData/MemProfTest.cpp b/llvm/unittests/ProfileData/MemProfTest.cpp index f6362448a5734..b7b8d642ad930 100644 --- a/llvm/unittests/ProfileData/MemProfTest.cpp +++ b/llvm/unittests/ProfileData/MemProfTest.cpp @@ -764,18 +764,17 @@ static std::string ErrorToString(Error E) { TEST(MemProf, DataAccessProfileError) { // Returns error if the input symbol name is empty. llvm::data_access_prof::DataAccessProfData Data; - EXPECT_THAT(ErrorToString(Data.addSymbolizedDataAccessProfile("", 100)), + EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("", 100)), HasSubstr("Empty symbol name")); // Returns error when the same symbol gets added twice. - ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("foo", 100)); - EXPECT_THAT(ErrorToString(Data.addSymbolizedDataAccessProfile("foo", 100)), + ASSERT_FALSE(Data.setDataAccessProfile("foo", 100)); + EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("foo", 100)), HasSubstr("Duplicate symbol or string literal added")); // Returns error when the same string content hash gets added twice. - ASSERT_FALSE(Data.addSymbolizedDataAccessProfile((uint64_t)135246, 1000)); - EXPECT_THAT(ErrorToString( - Data.addSymbolizedDataAccessProfile((uint64_t)135246, 1000)), + ASSERT_FALSE(Data.setDataAccessProfile((uint64_t)135246, 1000)); + EXPECT_THAT(ErrorToString(Data.setDataAccessProfile((uint64_t)135246, 1000)), HasSubstr("Duplicate symbol or string literal added")); } @@ -788,16 +787,16 @@ TEST(MemProf, DataAccessProfile) { // In the bool conversion, Error is true if it's in a failure state and false // if it's in an accept state. Use ASSERT_FALSE or EXPECT_FALSE for no error. - ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("foo.llvm.123", 100)); + ASSERT_FALSE(Data.setDataAccessProfile("foo.llvm.123", 100)); ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)789)); ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym2")); - ASSERT_FALSE(Data.addSymbolizedDataAccessProfile("bar.__uniq.321", 123, - { - DataLocation{"file2", 3}, - })); + ASSERT_FALSE(Data.setDataAccessProfile("bar.__uniq.321", 123, + { + DataLocation{"file2", 3}, + })); ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym1")); ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)678)); - ASSERT_FALSE(Data.addSymbolizedDataAccessProfile( + ASSERT_FALSE(Data.setDataAccessProfile( (uint64_t)135246, 1000, {DataLocation{"file1", 1}, DataLocation{"file2", 2}})); From 80249bce82307ffef1817747d1c2fca100a9ea53 Mon Sep 17 00:00:00 2001 From: Mingming Liu Date: Tue, 6 May 2025 10:39:53 -0700 Subject: [PATCH 3/8] Apply suggestions from code review Co-authored-by: Kazu Hirata --- llvm/include/llvm/ProfileData/DataAccessProf.h | 11 ++++++----- llvm/include/llvm/ProfileData/InstrProf.h | 2 +- llvm/lib/ProfileData/DataAccessProf.cpp | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h index 36648d6298ee5..81289b60b96ed 100644 --- a/llvm/include/llvm/ProfileData/DataAccessProf.h +++ b/llvm/include/llvm/ProfileData/DataAccessProf.h @@ -97,7 +97,7 @@ class DataAccessProfData { /// Methods to add symbolized data access profile. Returns error if duplicated /// symbol names or content hashes are seen. The user of this class should - /// aggregate counters that corresponds to the same symbol name or with the + /// aggregate counters that correspond to the same symbol name or with the /// same string literal hash before calling 'add*' methods. Error setDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount); /// Similar to the method above, for records with \p Locations representing @@ -108,7 +108,8 @@ class DataAccessProfData { const llvm::SmallVector &Locations); Error addKnownSymbolWithoutSamples(SymbolID SymbolID); - /// Returns a iterable StringRef for strings in the order they are added. + /// Returns an iterable StringRef for strings in the order they are added. + /// Each string may be a symbol name or a file name. auto getStrings() const { ArrayRef> RefSymbolNames( StrToIndexMap.begin(), StrToIndexMap.end()); @@ -116,15 +117,15 @@ class DataAccessProfData { } /// Returns array reference for various internal data structures. - inline ArrayRef< + ArrayRef< std::pair, DataAccessProfRecord>> getRecords() const { return Records.getArrayRef(); } - inline ArrayRef getKnownColdSymbols() const { + ArrayRef getKnownColdSymbols() const { return KnownColdSymbols.getArrayRef(); } - inline ArrayRef getKnownColdHashes() const { + ArrayRef getKnownColdHashes() const { return KnownColdHashes.getArrayRef(); } diff --git a/llvm/include/llvm/ProfileData/InstrProf.h b/llvm/include/llvm/ProfileData/InstrProf.h index 8a6be22bdb1a4..710b2f6836064 100644 --- a/llvm/include/llvm/ProfileData/InstrProf.h +++ b/llvm/include/llvm/ProfileData/InstrProf.h @@ -357,7 +357,7 @@ void createPGONameMetadata(GlobalObject &GO, StringRef PGOName); /// the duplicated profile variables for Comdat functions. bool needsComdatForCounter(const GlobalObject &GV, const Module &M); -/// \c NameStrings is a string composed of one of more possibly encoded +/// \c NameStrings is a string composed of one or more possibly encoded /// sub-strings. The substrings are separated by 0 or more zero bytes. This /// method decodes the string and calls `NameCallback` for each substring. Error readAndDecodeStrings(StringRef NameStrings, diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp index c52533c13919c..a42ee41b24358 100644 --- a/llvm/lib/ProfileData/DataAccessProf.cpp +++ b/llvm/lib/ProfileData/DataAccessProf.cpp @@ -141,7 +141,7 @@ Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const { // Record the length of compressed string. OS.write(CompressedStringLen); // Write the chars in compressed strings. - for (auto &c : CompressedStrings) + for (char C : CompressedStrings) OS.writeByte(static_cast(c)); // Pad up to a multiple of 8. // InstrProfReader could read bytes according to 'CompressedStringLen'. From b69c9930b9584a8dede96b0a9605c03f4d504bb8 Mon Sep 17 00:00:00 2001 From: mingmingl Date: Tue, 6 May 2025 10:50:59 -0700 Subject: [PATCH 4/8] resolve review feedback --- .../include/llvm/ProfileData/DataAccessProf.h | 43 ++--- llvm/include/llvm/ProfileData/InstrProf.h | 5 +- llvm/lib/ProfileData/DataAccessProf.cpp | 72 ++++--- llvm/unittests/ProfileData/CMakeLists.txt | 1 + .../ProfileData/DataAccessProfTest.cpp | 181 ++++++++++++++++++ llvm/unittests/ProfileData/MemProfTest.cpp | 160 ---------------- 6 files changed, 245 insertions(+), 217 deletions(-) create mode 100644 llvm/unittests/ProfileData/DataAccessProfTest.cpp diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h index 81289b60b96ed..a85a3320ae57c 100644 --- a/llvm/include/llvm/ProfileData/DataAccessProf.h +++ b/llvm/include/llvm/ProfileData/DataAccessProf.h @@ -70,9 +70,11 @@ struct DataAccessProfRecord { /// This class provides profile look-up, serialization and deserialization. class DataAccessProfData { public: - // SymbolID is either a string representing symbol name, or a uint64_t - // representing the content hash of a string literal. - using SymbolID = std::variant; + // SymbolID is either a string representing symbol name if the symbol has + // stable mangled name relative to source code, or a uint64_t representing the + // content hash of a string literal (with unstable name patterns like + // `.str.N[.llvm.hash]`). The StringRef is owned by the class's saver object. + using SymbolHandle = std::variant; using StringToIndexMap = llvm::MapVector; DataAccessProfData() : Saver(Allocator) {} @@ -90,38 +92,32 @@ class DataAccessProfData { /// Returns a pointer of profile record for \p SymbolID, or nullptr if there /// isn't a record. Internally, this function will canonicalize the symbol /// name before the lookup. - const DataAccessProfRecord *getProfileRecord(const SymbolID SymID) const; + const DataAccessProfRecord *getProfileRecord(const SymbolHandle SymID) const; /// Returns true if \p SymID is seen in profiled binaries and cold. - bool isKnownColdSymbol(const SymbolID SymID) const; + bool isKnownColdSymbol(const SymbolHandle SymID) const; - /// Methods to add symbolized data access profile. Returns error if duplicated + /// Methods to set symbolized data access profile. Returns error if duplicated /// symbol names or content hashes are seen. The user of this class should /// aggregate counters that correspond to the same symbol name or with the - /// same string literal hash before calling 'add*' methods. - Error setDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount); + /// same string literal hash before calling 'set*' methods. + Error setDataAccessProfile(SymbolHandle SymbolID, uint64_t AccessCount); /// Similar to the method above, for records with \p Locations representing /// the `filename:line` where this symbol shows up. Note because of linker's /// merge of identical symbols (e.g., unnamed_addr string literals), one /// symbol is likely to have multiple locations. - Error setDataAccessProfile(SymbolID SymbolID, uint64_t AccessCount, - const llvm::SmallVector &Locations); - Error addKnownSymbolWithoutSamples(SymbolID SymbolID); + Error setDataAccessProfile(SymbolHandle SymbolID, uint64_t AccessCount, + ArrayRef Locations); + Error addKnownSymbolWithoutSamples(SymbolHandle SymbolID); /// Returns an iterable StringRef for strings in the order they are added. /// Each string may be a symbol name or a file name. auto getStrings() const { - ArrayRef> RefSymbolNames( - StrToIndexMap.begin(), StrToIndexMap.end()); - return llvm::make_first_range(RefSymbolNames); + return llvm::make_first_range(StrToIndexMap.getArrayRef()); } /// Returns array reference for various internal data structures. - ArrayRef< - std::pair, DataAccessProfRecord>> - getRecords() const { - return Records.getArrayRef(); - } + auto getRecords() const { return Records.getArrayRef(); } ArrayRef getKnownColdSymbols() const { return KnownColdSymbols.getArrayRef(); } @@ -137,13 +133,13 @@ class DataAccessProfData { /// start of the next payload. Error deserializeSymbolsAndFilenames(const unsigned char *&Ptr, const uint64_t NumSampledSymbols, - uint64_t NumColdKnownSymbols); + const uint64_t NumColdKnownSymbols); /// Decode the records and increment \p Ptr to the start of the next payload. Error deserializeRecords(const unsigned char *&Ptr); /// A helper function to compute a storage index for \p SymbolID. - uint64_t getEncodedIndex(const SymbolID SymbolID) const; + uint64_t getEncodedIndex(const SymbolHandle SymbolID) const; // Keeps owned copies of the input strings. // NOTE: Keep `Saver` initialized before other class members that reference @@ -151,9 +147,8 @@ class DataAccessProfData { llvm::BumpPtrAllocator Allocator; llvm::UniqueStringSaver Saver; - // `Records` stores the records and `SymbolToRecordIndex` maps a symbol ID to - // its record index. - MapVector Records; + // `Records` stores the records. + MapVector Records; // Use MapVector to keep input order of strings for serialization and // deserialization. diff --git a/llvm/include/llvm/ProfileData/InstrProf.h b/llvm/include/llvm/ProfileData/InstrProf.h index 710b2f6836064..33b93ea0a558a 100644 --- a/llvm/include/llvm/ProfileData/InstrProf.h +++ b/llvm/include/llvm/ProfileData/InstrProf.h @@ -358,8 +358,9 @@ void createPGONameMetadata(GlobalObject &GO, StringRef PGOName); bool needsComdatForCounter(const GlobalObject &GV, const Module &M); /// \c NameStrings is a string composed of one or more possibly encoded -/// sub-strings. The substrings are separated by 0 or more zero bytes. This -/// method decodes the string and calls `NameCallback` for each substring. +/// sub-strings. The substrings are separated by `\01` (returned by +/// InstrProf.h:getInstrProfNameSeparator). This method decodes the string and +/// calls `NameCallback` for each substring. Error readAndDecodeStrings(StringRef NameStrings, std::function NameCallback); diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp index a42ee41b24358..d1c034f639347 100644 --- a/llvm/lib/ProfileData/DataAccessProf.cpp +++ b/llvm/lib/ProfileData/DataAccessProf.cpp @@ -23,11 +23,22 @@ saveStringToMap(MapVector &Map, return Iter; } +// Returns the canonical name or error. +static Expected getCanonicalName(StringRef Name) { + if (Name.empty()) + return make_error("Empty symbol name", + llvm::errc::invalid_argument); + return InstrProfSymtab::getCanonicalName(Name); +} + const DataAccessProfRecord * -DataAccessProfData::getProfileRecord(const SymbolID SymbolID) const { +DataAccessProfData::getProfileRecord(const SymbolHandle SymbolID) const { auto Key = SymbolID; - if (std::holds_alternative(SymbolID)) - Key = InstrProfSymtab::getCanonicalName(std::get(SymbolID)); + if (std::holds_alternative(SymbolID)) { + StringRef Name = std::get(SymbolID); + assert(!Name.empty() && "Empty symbol name"); + Key = InstrProfSymtab::getCanonicalName(Name); + } auto It = Records.find(Key); if (It != Records.end()) @@ -36,30 +47,27 @@ DataAccessProfData::getProfileRecord(const SymbolID SymbolID) const { return nullptr; } -bool DataAccessProfData::isKnownColdSymbol(const SymbolID SymID) const { +bool DataAccessProfData::isKnownColdSymbol(const SymbolHandle SymID) const { if (std::holds_alternative(SymID)) return KnownColdHashes.contains(std::get(SymID)); return KnownColdSymbols.contains(std::get(SymID)); } -Error DataAccessProfData::setDataAccessProfile(SymbolID Symbol, +Error DataAccessProfData::setDataAccessProfile(SymbolHandle Symbol, uint64_t AccessCount) { uint64_t RecordID = -1; bool IsStringLiteral = false; - SymbolID Key; + SymbolHandle Key; if (std::holds_alternative(Symbol)) { RecordID = std::get(Symbol); Key = RecordID; IsStringLiteral = true; } else { - StringRef SymbolName = std::get(Symbol); - if (SymbolName.empty()) - return make_error("Empty symbol name", - llvm::errc::invalid_argument); - - StringRef CanonicalName = InstrProfSymtab::getCanonicalName(SymbolName); - Key = CanonicalName; - RecordID = saveStringToMap(StrToIndexMap, Saver, CanonicalName)->second; + auto CanonicalName = getCanonicalName(std::get(Symbol)); + if (!CanonicalName) + return CanonicalName.takeError(); + std::tie(Key, RecordID) = + *saveStringToMap(StrToIndexMap, Saver, *CanonicalName); IsStringLiteral = false; } @@ -75,8 +83,8 @@ Error DataAccessProfData::setDataAccessProfile(SymbolID Symbol, } Error DataAccessProfData::setDataAccessProfile( - SymbolID SymbolID, uint64_t AccessCount, - const llvm::SmallVector &Locations) { + SymbolHandle SymbolID, uint64_t AccessCount, + ArrayRef Locations) { if (Error E = setDataAccessProfile(SymbolID, AccessCount)) return E; @@ -89,17 +97,15 @@ Error DataAccessProfData::setDataAccessProfile( return Error::success(); } -Error DataAccessProfData::addKnownSymbolWithoutSamples(SymbolID SymbolID) { +Error DataAccessProfData::addKnownSymbolWithoutSamples(SymbolHandle SymbolID) { if (std::holds_alternative(SymbolID)) { KnownColdHashes.insert(std::get(SymbolID)); return Error::success(); } - StringRef SymbolName = std::get(SymbolID); - if (SymbolName.empty()) - return make_error("Empty symbol name", - llvm::errc::invalid_argument); - StringRef CanonicalSymName = InstrProfSymtab::getCanonicalName(SymbolName); - KnownColdSymbols.insert(CanonicalSymName); + auto CanonicalName = getCanonicalName(std::get(SymbolID)); + if (!CanonicalName) + return CanonicalName.takeError(); + KnownColdSymbols.insert(*CanonicalName); return Error::success(); } @@ -142,7 +148,7 @@ Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const { OS.write(CompressedStringLen); // Write the chars in compressed strings. for (char C : CompressedStrings) - OS.writeByte(static_cast(c)); + OS.writeByte(static_cast(C)); // Pad up to a multiple of 8. // InstrProfReader could read bytes according to 'CompressedStringLen'. const uint64_t PaddedLength = alignTo(CompressedStringLen, 8); @@ -151,11 +157,15 @@ Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const { return Error::success(); } -uint64_t DataAccessProfData::getEncodedIndex(const SymbolID SymbolID) const { +uint64_t +DataAccessProfData::getEncodedIndex(const SymbolHandle SymbolID) const { if (std::holds_alternative(SymbolID)) return std::get(SymbolID); - return StrToIndexMap.find(std::get(SymbolID))->second; + auto Iter = StrToIndexMap.find(std::get(SymbolID)); + assert(Iter != StrToIndexMap.end() && + "String literals not found in StrToIndexMap"); + return Iter->second; } Error DataAccessProfData::serialize(ProfOStream &OS) const { @@ -179,13 +189,13 @@ Error DataAccessProfData::serialize(ProfOStream &OS) const { } Error DataAccessProfData::deserializeSymbolsAndFilenames( - const unsigned char *&Ptr, uint64_t NumSampledSymbols, - uint64_t NumColdKnownSymbols) { + const unsigned char *&Ptr, const uint64_t NumSampledSymbols, + const uint64_t NumColdKnownSymbols) { uint64_t Len = support::endian::readNext(Ptr); - // With M=NumSampledSymbols and N=NumColdKnownSymbols, the first M strings are - // symbols with samples, and next N strings are known cold symbols. + // The first NumSampledSymbols strings are symbols with samples, and next + // NumColdKnownSymbols strings are known cold symbols. uint64_t StringCnt = 0; std::function addName = [&](StringRef Name) { if (StringCnt < NumSampledSymbols) @@ -219,7 +229,7 @@ Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) { uint64_t AccessCount = support::endian::readNext(Ptr); - SymbolID SymbolID; + SymbolHandle SymbolID; if (IsStringLiteral) SymbolID = ID; else diff --git a/llvm/unittests/ProfileData/CMakeLists.txt b/llvm/unittests/ProfileData/CMakeLists.txt index 0a7f7da085950..29b9cb751dabe 100644 --- a/llvm/unittests/ProfileData/CMakeLists.txt +++ b/llvm/unittests/ProfileData/CMakeLists.txt @@ -10,6 +10,7 @@ set(LLVM_LINK_COMPONENTS add_llvm_unittest(ProfileDataTests BPFunctionNodeTest.cpp CoverageMappingTest.cpp + DataAccessProfTest.cpp InstrProfDataTest.cpp InstrProfTest.cpp ItaniumManglingCanonicalizerTest.cpp diff --git a/llvm/unittests/ProfileData/DataAccessProfTest.cpp b/llvm/unittests/ProfileData/DataAccessProfTest.cpp new file mode 100644 index 0000000000000..50c4af49fe76b --- /dev/null +++ b/llvm/unittests/ProfileData/DataAccessProfTest.cpp @@ -0,0 +1,181 @@ + +//===- unittests/Support/DataAccessProfTest.cpp +//----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ProfileData/DataAccessProf.h" +#include "llvm/Support/raw_ostream.h" +#include "gmock/gmock-more-matchers.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace llvm { +namespace data_access_prof { +namespace { + +using ::llvm::StringRef; +using ::testing::ElementsAre; +using ::testing::HasSubstr; +using ::testing::IsEmpty; + +static std::string ErrorToString(Error E) { + std::string ErrMsg; + llvm::raw_string_ostream OS(ErrMsg); + llvm::logAllUnhandledErrors(std::move(E), OS); + return ErrMsg; +} + +// Test the various scenarios when DataAccessProfData should return error on +// invalid input. +TEST(MemProf, DataAccessProfileError) { + // Returns error if the input symbol name is empty. + DataAccessProfData Data; + EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("", 100)), + HasSubstr("Empty symbol name")); + + // Returns error when the same symbol gets added twice. + ASSERT_FALSE(Data.setDataAccessProfile("foo", 100)); + EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("foo", 100)), + HasSubstr("Duplicate symbol or string literal added")); + + // Returns error when the same string content hash gets added twice. + ASSERT_FALSE(Data.setDataAccessProfile((uint64_t)135246, 1000)); + EXPECT_THAT(ErrorToString(Data.setDataAccessProfile((uint64_t)135246, 1000)), + HasSubstr("Duplicate symbol or string literal added")); +} + +// Test the following operations on DataAccessProfData: +// - Profile record look up. +// - Serialization and de-serialization. +TEST(MemProf, DataAccessProfile) { + DataAccessProfData Data; + + // In the bool conversion, Error is true if it's in a failure state and false + // if it's in an accept state. Use ASSERT_FALSE or EXPECT_FALSE for no error. + ASSERT_FALSE(Data.setDataAccessProfile("foo.llvm.123", 100)); + ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)789)); + ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym2")); + ASSERT_FALSE(Data.setDataAccessProfile("bar.__uniq.321", 123, + { + DataLocation{"file2", 3}, + })); + ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym1")); + ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)678)); + ASSERT_FALSE(Data.setDataAccessProfile( + (uint64_t)135246, 1000, + {DataLocation{"file1", 1}, DataLocation{"file2", 2}})); + + { + // Test that symbol names and file names are stored in the input order. + EXPECT_THAT(llvm::to_vector(Data.getStrings()), + ElementsAre("foo", "bar.__uniq.321", "file2", "file1")); + EXPECT_THAT(Data.getKnownColdSymbols(), ElementsAre("sym2", "sym1")); + EXPECT_THAT(Data.getKnownColdHashes(), ElementsAre(789, 678)); + + // Look up profiles. + EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)789)); + EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)678)); + EXPECT_TRUE(Data.isKnownColdSymbol("sym2")); + EXPECT_TRUE(Data.isKnownColdSymbol("sym1")); + + EXPECT_EQ(Data.getProfileRecord("non-existence"), nullptr); + EXPECT_EQ(Data.getProfileRecord((uint64_t)789987), nullptr); + + EXPECT_THAT( + *Data.getProfileRecord("foo.llvm.123"), + AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0), + testing::Field(&DataAccessProfRecord::AccessCount, 100), + testing::Field(&DataAccessProfRecord::IsStringLiteral, false), + testing::Field(&DataAccessProfRecord::Locations, + testing::IsEmpty()))); + EXPECT_THAT( + *Data.getProfileRecord("bar.__uniq.321"), + AllOf( + testing::Field(&DataAccessProfRecord::SymbolID, 1), + testing::Field(&DataAccessProfRecord::AccessCount, 123), + testing::Field(&DataAccessProfRecord::IsStringLiteral, false), + testing::Field(&DataAccessProfRecord::Locations, + ElementsAre(AllOf( + testing::Field(&DataLocation::FileName, "file2"), + testing::Field(&DataLocation::Line, 3)))))); + EXPECT_THAT( + *Data.getProfileRecord((uint64_t)135246), + AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 135246), + testing::Field(&DataAccessProfRecord::AccessCount, 1000), + testing::Field(&DataAccessProfRecord::IsStringLiteral, true), + testing::Field( + &DataAccessProfRecord::Locations, + ElementsAre( + AllOf(testing::Field(&DataLocation::FileName, "file1"), + testing::Field(&DataLocation::Line, 1)), + AllOf(testing::Field(&DataLocation::FileName, "file2"), + testing::Field(&DataLocation::Line, 2)))))); + } + + // Tests serialization and de-serialization. + DataAccessProfData deserializedData; + { + std::string serializedData; + llvm::raw_string_ostream OS(serializedData); + llvm::ProfOStream POS(OS); + + EXPECT_FALSE(Data.serialize(POS)); + + const unsigned char *p = + reinterpret_cast(serializedData.data()); + ASSERT_THAT(llvm::to_vector(deserializedData.getStrings()), + testing::IsEmpty()); + EXPECT_FALSE(deserializedData.deserialize(p)); + + EXPECT_THAT(llvm::to_vector(deserializedData.getStrings()), + ElementsAre("foo", "bar.__uniq.321", "file2", "file1")); + EXPECT_THAT(deserializedData.getKnownColdSymbols(), + ElementsAre("sym2", "sym1")); + EXPECT_THAT(deserializedData.getKnownColdHashes(), ElementsAre(789, 678)); + + // Look up profiles after deserialization. + EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)789)); + EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)678)); + EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym2")); + EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym1")); + + auto Records = + llvm::to_vector(llvm::make_second_range(deserializedData.getRecords())); + + EXPECT_THAT( + Records, + ElementsAre( + AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0), + testing::Field(&DataAccessProfRecord::AccessCount, 100), + testing::Field(&DataAccessProfRecord::IsStringLiteral, false), + testing::Field(&DataAccessProfRecord::Locations, + testing::IsEmpty())), + AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 1), + testing::Field(&DataAccessProfRecord::AccessCount, 123), + testing::Field(&DataAccessProfRecord::IsStringLiteral, false), + testing::Field( + &DataAccessProfRecord::Locations, + ElementsAre(AllOf( + testing::Field(&DataLocation::FileName, "file2"), + testing::Field(&DataLocation::Line, 3))))), + AllOf( + testing::Field(&DataAccessProfRecord::SymbolID, 135246), + testing::Field(&DataAccessProfRecord::AccessCount, 1000), + testing::Field(&DataAccessProfRecord::IsStringLiteral, true), + testing::Field( + &DataAccessProfRecord::Locations, + ElementsAre( + AllOf(testing::Field(&DataLocation::FileName, "file1"), + testing::Field(&DataLocation::Line, 1)), + AllOf(testing::Field(&DataLocation::FileName, "file2"), + testing::Field(&DataLocation::Line, 2))))))); + } +} +} // namespace +} // namespace data_access_prof +} // namespace llvm diff --git a/llvm/unittests/ProfileData/MemProfTest.cpp b/llvm/unittests/ProfileData/MemProfTest.cpp index b7b8d642ad930..3e430aa4eae58 100644 --- a/llvm/unittests/ProfileData/MemProfTest.cpp +++ b/llvm/unittests/ProfileData/MemProfTest.cpp @@ -10,17 +10,14 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLForwardCompat.h" -#include "llvm/ADT/SmallVector.h" #include "llvm/DebugInfo/DIContext.h" #include "llvm/DebugInfo/Symbolize/SymbolizableModule.h" #include "llvm/IR/Value.h" #include "llvm/Object/ObjectFile.h" -#include "llvm/ProfileData/DataAccessProf.h" #include "llvm/ProfileData/MemProfData.inc" #include "llvm/ProfileData/MemProfReader.h" #include "llvm/ProfileData/MemProfYAML.h" #include "llvm/Support/raw_ostream.h" -#include "gmock/gmock-more-matchers.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -39,8 +36,6 @@ using ::llvm::StringRef; using ::llvm::object::SectionedAddress; using ::llvm::symbolize::SymbolizableModule; using ::testing::ElementsAre; -using ::testing::ElementsAreArray; -using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::Pair; using ::testing::Return; @@ -752,161 +747,6 @@ TEST(MemProf, YAMLParser) { ElementsAre(0x3000))))); } -static std::string ErrorToString(Error E) { - std::string ErrMsg; - llvm::raw_string_ostream OS(ErrMsg); - llvm::logAllUnhandledErrors(std::move(E), OS); - return ErrMsg; -} - -// Test the various scenarios when DataAccessProfData should return error on -// invalid input. -TEST(MemProf, DataAccessProfileError) { - // Returns error if the input symbol name is empty. - llvm::data_access_prof::DataAccessProfData Data; - EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("", 100)), - HasSubstr("Empty symbol name")); - - // Returns error when the same symbol gets added twice. - ASSERT_FALSE(Data.setDataAccessProfile("foo", 100)); - EXPECT_THAT(ErrorToString(Data.setDataAccessProfile("foo", 100)), - HasSubstr("Duplicate symbol or string literal added")); - - // Returns error when the same string content hash gets added twice. - ASSERT_FALSE(Data.setDataAccessProfile((uint64_t)135246, 1000)); - EXPECT_THAT(ErrorToString(Data.setDataAccessProfile((uint64_t)135246, 1000)), - HasSubstr("Duplicate symbol or string literal added")); -} - -// Test the following operations on DataAccessProfData: -// - Profile record look up. -// - Serialization and de-serialization. -TEST(MemProf, DataAccessProfile) { - using namespace llvm::data_access_prof; - llvm::data_access_prof::DataAccessProfData Data; - - // In the bool conversion, Error is true if it's in a failure state and false - // if it's in an accept state. Use ASSERT_FALSE or EXPECT_FALSE for no error. - ASSERT_FALSE(Data.setDataAccessProfile("foo.llvm.123", 100)); - ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)789)); - ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym2")); - ASSERT_FALSE(Data.setDataAccessProfile("bar.__uniq.321", 123, - { - DataLocation{"file2", 3}, - })); - ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym1")); - ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)678)); - ASSERT_FALSE(Data.setDataAccessProfile( - (uint64_t)135246, 1000, - {DataLocation{"file1", 1}, DataLocation{"file2", 2}})); - - { - // Test that symbol names and file names are stored in the input order. - EXPECT_THAT(llvm::to_vector(Data.getStrings()), - ElementsAre("foo", "bar.__uniq.321", "file2", "file1")); - EXPECT_THAT(Data.getKnownColdSymbols(), ElementsAre("sym2", "sym1")); - EXPECT_THAT(Data.getKnownColdHashes(), ElementsAre(789, 678)); - - // Look up profiles. - EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)789)); - EXPECT_TRUE(Data.isKnownColdSymbol((uint64_t)678)); - EXPECT_TRUE(Data.isKnownColdSymbol("sym2")); - EXPECT_TRUE(Data.isKnownColdSymbol("sym1")); - - EXPECT_EQ(Data.getProfileRecord("non-existence"), nullptr); - EXPECT_EQ(Data.getProfileRecord((uint64_t)789987), nullptr); - - EXPECT_THAT( - Data.getProfileRecord("foo.llvm.123"), - AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0), - testing::Field(&DataAccessProfRecord::AccessCount, 100), - testing::Field(&DataAccessProfRecord::IsStringLiteral, false), - testing::Field(&DataAccessProfRecord::Locations, - testing::IsEmpty()))); - EXPECT_THAT( - *Data.getProfileRecord("bar.__uniq.321"), - AllOf( - testing::Field(&DataAccessProfRecord::SymbolID, 1), - testing::Field(&DataAccessProfRecord::AccessCount, 123), - testing::Field(&DataAccessProfRecord::IsStringLiteral, false), - testing::Field(&DataAccessProfRecord::Locations, - ElementsAre(AllOf( - testing::Field(&DataLocation::FileName, "file2"), - testing::Field(&DataLocation::Line, 3)))))); - EXPECT_THAT( - *Data.getProfileRecord((uint64_t)135246), - AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 135246), - testing::Field(&DataAccessProfRecord::AccessCount, 1000), - testing::Field(&DataAccessProfRecord::IsStringLiteral, true), - testing::Field( - &DataAccessProfRecord::Locations, - ElementsAre( - AllOf(testing::Field(&DataLocation::FileName, "file1"), - testing::Field(&DataLocation::Line, 1)), - AllOf(testing::Field(&DataLocation::FileName, "file2"), - testing::Field(&DataLocation::Line, 2)))))); - } - - // Tests serialization and de-serialization. - llvm::data_access_prof::DataAccessProfData deserializedData; - { - std::string serializedData; - llvm::raw_string_ostream OS(serializedData); - llvm::ProfOStream POS(OS); - - EXPECT_FALSE(Data.serialize(POS)); - - const unsigned char *p = - reinterpret_cast(serializedData.data()); - ASSERT_THAT(llvm::to_vector(deserializedData.getStrings()), - testing::IsEmpty()); - EXPECT_FALSE(deserializedData.deserialize(p)); - - EXPECT_THAT(llvm::to_vector(deserializedData.getStrings()), - ElementsAre("foo", "bar.__uniq.321", "file2", "file1")); - EXPECT_THAT(deserializedData.getKnownColdSymbols(), - ElementsAre("sym2", "sym1")); - EXPECT_THAT(deserializedData.getKnownColdHashes(), ElementsAre(789, 678)); - - // Look up profiles after deserialization. - EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)789)); - EXPECT_TRUE(deserializedData.isKnownColdSymbol((uint64_t)678)); - EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym2")); - EXPECT_TRUE(deserializedData.isKnownColdSymbol("sym1")); - - auto Records = - llvm::to_vector(llvm::make_second_range(deserializedData.getRecords())); - - EXPECT_THAT( - Records, - ElementsAre( - AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0), - testing::Field(&DataAccessProfRecord::AccessCount, 100), - testing::Field(&DataAccessProfRecord::IsStringLiteral, false), - testing::Field(&DataAccessProfRecord::Locations, - testing::IsEmpty())), - AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 1), - testing::Field(&DataAccessProfRecord::AccessCount, 123), - testing::Field(&DataAccessProfRecord::IsStringLiteral, false), - testing::Field( - &DataAccessProfRecord::Locations, - ElementsAre(AllOf( - testing::Field(&DataLocation::FileName, "file2"), - testing::Field(&DataLocation::Line, 3))))), - AllOf( - testing::Field(&DataAccessProfRecord::SymbolID, 135246), - testing::Field(&DataAccessProfRecord::AccessCount, 1000), - testing::Field(&DataAccessProfRecord::IsStringLiteral, true), - testing::Field( - &DataAccessProfRecord::Locations, - ElementsAre( - AllOf(testing::Field(&DataLocation::FileName, "file1"), - testing::Field(&DataLocation::Line, 1)), - AllOf(testing::Field(&DataLocation::FileName, "file2"), - testing::Field(&DataLocation::Line, 2))))))); - } -} - // Verify that the YAML parser accepts a GUID expressed as a function name. TEST(MemProf, YAMLParserGUID) { StringRef YAMLData = R"YAML( From df080949a1ccc75fceb386b0232005f96397752d Mon Sep 17 00:00:00 2001 From: mingmingl Date: Tue, 6 May 2025 16:35:14 -0700 Subject: [PATCH 5/8] resolve feedback --- llvm/include/llvm/ProfileData/DataAccessProf.h | 6 +++++- llvm/lib/ProfileData/DataAccessProf.cpp | 12 +++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h index a85a3320ae57c..91de43fdf60ca 100644 --- a/llvm/include/llvm/ProfileData/DataAccessProf.h +++ b/llvm/include/llvm/ProfileData/DataAccessProf.h @@ -52,7 +52,11 @@ struct DataAccessProfRecord { // Represents a data symbol. The semantic comes in two forms: a symbol index // for symbol name if `IsStringLiteral` is false, or the hash of a string - // content if `IsStringLiteral` is true. Required. + // content if `IsStringLiteral` is true. For most of the symbolizable static + // data, the mangled symbol names remain stable relative to the source code + // and therefore used to identify symbols across binary releases. String + // literals have unstable name patterns like `.str.N[.llvm.hash]`, so we use + // the content hash instead. This is a required field. uint64_t SymbolID; // The access count of symbol. Required. diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp index d1c034f639347..d7e67f5f09cbe 100644 --- a/llvm/lib/ProfileData/DataAccessProf.cpp +++ b/llvm/lib/ProfileData/DataAccessProf.cpp @@ -35,9 +35,15 @@ const DataAccessProfRecord * DataAccessProfData::getProfileRecord(const SymbolHandle SymbolID) const { auto Key = SymbolID; if (std::holds_alternative(SymbolID)) { - StringRef Name = std::get(SymbolID); - assert(!Name.empty() && "Empty symbol name"); - Key = InstrProfSymtab::getCanonicalName(Name); + auto NameOrErr = getCanonicalName(std::get(SymbolID)); + // If name canonicalization fails, suppress the error inside. + if (!NameOrErr) { + assert( + std::get(SymbolID).empty() && + "Name canonicalization only fails when stringified string is empty."); + return nullptr; + } + Key = *NameOrErr; } auto It = Records.find(Key); From 6dd04e46542851b84bf26cd95245399204072085 Mon Sep 17 00:00:00 2001 From: mingmingl Date: Mon, 12 May 2025 17:05:19 -0700 Subject: [PATCH 6/8] resolve comments --- .../include/llvm/ProfileData/DataAccessProf.h | 119 ++++++++++++------ llvm/include/llvm/ProfileData/InstrProf.h | 2 +- llvm/lib/ProfileData/DataAccessProf.cpp | 49 ++++---- .../ProfileData/DataAccessProfTest.cpp | 116 ++++++++--------- 4 files changed, 167 insertions(+), 119 deletions(-) diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h index 91de43fdf60ca..e8504102238d1 100644 --- a/llvm/include/llvm/ProfileData/DataAccessProf.h +++ b/llvm/include/llvm/ProfileData/DataAccessProf.h @@ -30,23 +30,40 @@ #include "llvm/Support/StringSaver.h" #include +#include #include namespace llvm { namespace data_access_prof { -// The location of data in the source code. -struct DataLocation { + +/// The location of data in the source code. Used by profile lookup API. +struct SourceLocation { + SourceLocation(StringRef FileNameRef, uint32_t Line) + : FileName(FileNameRef.str()), Line(Line) {} + /// The filename where the data is located. + std::string FileName; + /// The line number in the source code. + uint32_t Line; +}; + +namespace internal { + +// Conceptually similar to SourceLocation except that FileNames are StringRef of +// which strings are owned by `DataAccessProfData`. Used by `DataAccessProfData` +// to represent data locations internally. +struct SourceLocationRef { // The filename where the data is located. StringRef FileName; // The line number in the source code. uint32_t Line; }; -// The data access profiles for a symbol. -struct DataAccessProfRecord { - DataAccessProfRecord(uint64_t SymbolID, uint64_t AccessCount, - bool IsStringLiteral) +// The data access profiles for a symbol. Used by `DataAccessProfData` +// to represent records internally. +struct DataAccessProfRecordRef { + DataAccessProfRecordRef(uint64_t SymbolID, uint64_t AccessCount, + bool IsStringLiteral) : SymbolID(SymbolID), AccessCount(AccessCount), IsStringLiteral(IsStringLiteral) {} @@ -67,18 +84,43 @@ struct DataAccessProfRecord { bool IsStringLiteral; // The locations of data in the source code. Optional. - llvm::SmallVector Locations; + llvm::SmallVector Locations; }; +} // namespace internal + +// SymbolID is either a string representing symbol name if the symbol has +// stable mangled name relative to source code, or a uint64_t representing the +// content hash of a string literal (with unstable name patterns like +// `.str.N[.llvm.hash]`). The StringRef is owned by the class's saver object. +using SymbolHandleRef = std::variant; -/// Encapsulates the data access profile data and the methods to operate on it. -/// This class provides profile look-up, serialization and deserialization. +// The senamtic is the same as `SymbolHandleRef` above. The strings are owned. +using SymbolHandle = std::variant; + +/// The data access profiles for a symbol. +struct DataAccessProfRecord { +public: + DataAccessProfRecord(SymbolHandleRef SymHandleRef, + ArrayRef LocRefs) { + if (std::holds_alternative(SymHandleRef)) { + SymHandle = std::get(SymHandleRef).str(); + } else + SymHandle = std::get(SymHandleRef); + + for (auto Loc : LocRefs) + Locations.push_back(SourceLocation(Loc.FileName, Loc.Line)); + } + SymbolHandle SymHandle; + + // The locations of data in the source code. Optional. + SmallVector Locations; +}; + +/// Encapsulates the data access profile data and the methods to operate on +/// it. This class provides profile look-up, serialization and +/// deserialization. class DataAccessProfData { public: - // SymbolID is either a string representing symbol name if the symbol has - // stable mangled name relative to source code, or a uint64_t representing the - // content hash of a string literal (with unstable name patterns like - // `.str.N[.llvm.hash]`). The StringRef is owned by the class's saver object. - using SymbolHandle = std::variant; using StringToIndexMap = llvm::MapVector; DataAccessProfData() : Saver(Allocator) {} @@ -93,35 +135,39 @@ class DataAccessProfData { /// Deserialize this class from the given buffer. Error deserialize(const unsigned char *&Ptr); - /// Returns a pointer of profile record for \p SymbolID, or nullptr if there + /// Returns a profile record for \p SymbolID, or std::nullopt if there /// isn't a record. Internally, this function will canonicalize the symbol /// name before the lookup. - const DataAccessProfRecord *getProfileRecord(const SymbolHandle SymID) const; + std::optional + getProfileRecord(const SymbolHandleRef SymID) const; /// Returns true if \p SymID is seen in profiled binaries and cold. - bool isKnownColdSymbol(const SymbolHandle SymID) const; + bool isKnownColdSymbol(const SymbolHandleRef SymID) const; - /// Methods to set symbolized data access profile. Returns error if duplicated - /// symbol names or content hashes are seen. The user of this class should - /// aggregate counters that correspond to the same symbol name or with the - /// same string literal hash before calling 'set*' methods. - Error setDataAccessProfile(SymbolHandle SymbolID, uint64_t AccessCount); + /// Methods to set symbolized data access profile. Returns error if + /// duplicated symbol names or content hashes are seen. The user of this + /// class should aggregate counters that correspond to the same symbol name + /// or with the same string literal hash before calling 'set*' methods. + Error setDataAccessProfile(SymbolHandleRef SymbolID, uint64_t AccessCount); /// Similar to the method above, for records with \p Locations representing /// the `filename:line` where this symbol shows up. Note because of linker's /// merge of identical symbols (e.g., unnamed_addr string literals), one /// symbol is likely to have multiple locations. - Error setDataAccessProfile(SymbolHandle SymbolID, uint64_t AccessCount, - ArrayRef Locations); - Error addKnownSymbolWithoutSamples(SymbolHandle SymbolID); - - /// Returns an iterable StringRef for strings in the order they are added. - /// Each string may be a symbol name or a file name. - auto getStrings() const { - return llvm::make_first_range(StrToIndexMap.getArrayRef()); + Error setDataAccessProfile(SymbolHandleRef SymbolID, uint64_t AccessCount, + ArrayRef Locations); + /// Add a symbol that's seen in the profiled binary without samples. + Error addKnownSymbolWithoutSamples(SymbolHandleRef SymbolID); + + /// The following methods return array reference for various internal data + /// structures. + ArrayRef getStrToIndexMapRef() const { + return StrToIndexMap.getArrayRef(); + } + ArrayRef< + MapVector::value_type> + getRecords() const { + return Records.getArrayRef(); } - - /// Returns array reference for various internal data structures. - auto getRecords() const { return Records.getArrayRef(); } ArrayRef getKnownColdSymbols() const { return KnownColdSymbols.getArrayRef(); } @@ -139,11 +185,12 @@ class DataAccessProfData { const uint64_t NumSampledSymbols, const uint64_t NumColdKnownSymbols); - /// Decode the records and increment \p Ptr to the start of the next payload. + /// Decode the records and increment \p Ptr to the start of the next + /// payload. Error deserializeRecords(const unsigned char *&Ptr); /// A helper function to compute a storage index for \p SymbolID. - uint64_t getEncodedIndex(const SymbolHandle SymbolID) const; + uint64_t getEncodedIndex(const SymbolHandleRef SymbolID) const; // Keeps owned copies of the input strings. // NOTE: Keep `Saver` initialized before other class members that reference @@ -152,7 +199,7 @@ class DataAccessProfData { llvm::UniqueStringSaver Saver; // `Records` stores the records. - MapVector Records; + MapVector Records; // Use MapVector to keep input order of strings for serialization and // deserialization. diff --git a/llvm/include/llvm/ProfileData/InstrProf.h b/llvm/include/llvm/ProfileData/InstrProf.h index 33b93ea0a558a..544a59df43ed3 100644 --- a/llvm/include/llvm/ProfileData/InstrProf.h +++ b/llvm/include/llvm/ProfileData/InstrProf.h @@ -500,7 +500,7 @@ class InstrProfSymtab { public: using AddrHashMap = std::vector>; - // Returns the canonial name of the given PGOName. In a canonical name, all + // Returns the canonical name of the given PGOName. In a canonical name, all // suffixes that begins with "." except ".__uniq." are stripped. // FIXME: Unify this with `FunctionSamples::getCanonicalFnName`. static StringRef getCanonicalName(StringRef PGOName); diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp index d7e67f5f09cbe..c5d0099977cfa 100644 --- a/llvm/lib/ProfileData/DataAccessProf.cpp +++ b/llvm/lib/ProfileData/DataAccessProf.cpp @@ -16,11 +16,11 @@ namespace data_access_prof { // If `Map` has an entry keyed by `Str`, returns the entry iterator. Otherwise, // creates an owned copy of `Str`, adds a map entry for it and returns the // iterator. -static MapVector::iterator -saveStringToMap(MapVector &Map, +static std::pair +saveStringToMap(DataAccessProfData::StringToIndexMap &Map, llvm::UniqueStringSaver &Saver, StringRef Str) { auto [Iter, Inserted] = Map.try_emplace(Saver.save(Str), Map.size()); - return Iter; + return *Iter; } // Returns the canonical name or error. @@ -31,8 +31,8 @@ static Expected getCanonicalName(StringRef Name) { return InstrProfSymtab::getCanonicalName(Name); } -const DataAccessProfRecord * -DataAccessProfData::getProfileRecord(const SymbolHandle SymbolID) const { +std::optional +DataAccessProfData::getProfileRecord(const SymbolHandleRef SymbolID) const { auto Key = SymbolID; if (std::holds_alternative(SymbolID)) { auto NameOrErr = getCanonicalName(std::get(SymbolID)); @@ -41,40 +41,39 @@ DataAccessProfData::getProfileRecord(const SymbolHandle SymbolID) const { assert( std::get(SymbolID).empty() && "Name canonicalization only fails when stringified string is empty."); - return nullptr; + return std::nullopt; } Key = *NameOrErr; } auto It = Records.find(Key); - if (It != Records.end()) - return &It->second; + if (It != Records.end()) { + return DataAccessProfRecord(Key, It->second.Locations); + } - return nullptr; + return std::nullopt; } -bool DataAccessProfData::isKnownColdSymbol(const SymbolHandle SymID) const { +bool DataAccessProfData::isKnownColdSymbol(const SymbolHandleRef SymID) const { if (std::holds_alternative(SymID)) return KnownColdHashes.contains(std::get(SymID)); return KnownColdSymbols.contains(std::get(SymID)); } -Error DataAccessProfData::setDataAccessProfile(SymbolHandle Symbol, +Error DataAccessProfData::setDataAccessProfile(SymbolHandleRef Symbol, uint64_t AccessCount) { uint64_t RecordID = -1; - bool IsStringLiteral = false; - SymbolHandle Key; - if (std::holds_alternative(Symbol)) { + const bool IsStringLiteral = std::holds_alternative(Symbol); + SymbolHandleRef Key; + if (IsStringLiteral) { RecordID = std::get(Symbol); Key = RecordID; - IsStringLiteral = true; } else { auto CanonicalName = getCanonicalName(std::get(Symbol)); if (!CanonicalName) return CanonicalName.takeError(); std::tie(Key, RecordID) = - *saveStringToMap(StrToIndexMap, Saver, *CanonicalName); - IsStringLiteral = false; + saveStringToMap(StrToIndexMap, Saver, *CanonicalName); } auto [Iter, Inserted] = @@ -89,21 +88,22 @@ Error DataAccessProfData::setDataAccessProfile(SymbolHandle Symbol, } Error DataAccessProfData::setDataAccessProfile( - SymbolHandle SymbolID, uint64_t AccessCount, - ArrayRef Locations) { + SymbolHandleRef SymbolID, uint64_t AccessCount, + ArrayRef Locations) { if (Error E = setDataAccessProfile(SymbolID, AccessCount)) return E; auto &Record = Records.back().second; for (const auto &Location : Locations) Record.Locations.push_back( - {saveStringToMap(StrToIndexMap, Saver, Location.FileName)->first, + {saveStringToMap(StrToIndexMap, Saver, Location.FileName).first, Location.Line}); return Error::success(); } -Error DataAccessProfData::addKnownSymbolWithoutSamples(SymbolHandle SymbolID) { +Error DataAccessProfData::addKnownSymbolWithoutSamples( + SymbolHandleRef SymbolID) { if (std::holds_alternative(SymbolID)) { KnownColdHashes.insert(std::get(SymbolID)); return Error::success(); @@ -164,7 +164,7 @@ Error DataAccessProfData::serializeSymbolsAndFilenames(ProfOStream &OS) const { } uint64_t -DataAccessProfData::getEncodedIndex(const SymbolHandle SymbolID) const { +DataAccessProfData::getEncodedIndex(const SymbolHandleRef SymbolID) const { if (std::holds_alternative(SymbolID)) return std::get(SymbolID); @@ -220,7 +220,8 @@ Error DataAccessProfData::deserializeSymbolsAndFilenames( } Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) { - SmallVector Strings = llvm::to_vector(getStrings()); + SmallVector Strings = + llvm::to_vector(llvm::make_first_range(getStrToIndexMapRef())); uint64_t NumRecords = support::endian::readNext(Ptr); @@ -235,7 +236,7 @@ Error DataAccessProfData::deserializeRecords(const unsigned char *&Ptr) { uint64_t AccessCount = support::endian::readNext(Ptr); - SymbolHandle SymbolID; + SymbolHandleRef SymbolID; if (IsStringLiteral) SymbolID = ID; else diff --git a/llvm/unittests/ProfileData/DataAccessProfTest.cpp b/llvm/unittests/ProfileData/DataAccessProfTest.cpp index 50c4af49fe76b..127230d4805e7 100644 --- a/llvm/unittests/ProfileData/DataAccessProfTest.cpp +++ b/llvm/unittests/ProfileData/DataAccessProfTest.cpp @@ -1,4 +1,3 @@ - //===- unittests/Support/DataAccessProfTest.cpp //----------------------------------===// // @@ -10,6 +9,7 @@ #include "llvm/ProfileData/DataAccessProf.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Testing/Support/SupportHelpers.h" #include "gmock/gmock-more-matchers.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -19,7 +19,9 @@ namespace data_access_prof { namespace { using ::llvm::StringRef; +using llvm::ValueIs; using ::testing::ElementsAre; +using ::testing::Field; using ::testing::HasSubstr; using ::testing::IsEmpty; @@ -53,6 +55,8 @@ TEST(MemProf, DataAccessProfileError) { // - Profile record look up. // - Serialization and de-serialization. TEST(MemProf, DataAccessProfile) { + using internal::DataAccessProfRecordRef; + using internal::SourceLocationRef; DataAccessProfData Data; // In the bool conversion, Error is true if it's in a failure state and false @@ -62,18 +66,19 @@ TEST(MemProf, DataAccessProfile) { ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym2")); ASSERT_FALSE(Data.setDataAccessProfile("bar.__uniq.321", 123, { - DataLocation{"file2", 3}, + SourceLocation{"file2", 3}, })); ASSERT_FALSE(Data.addKnownSymbolWithoutSamples("sym1")); ASSERT_FALSE(Data.addKnownSymbolWithoutSamples((uint64_t)678)); ASSERT_FALSE(Data.setDataAccessProfile( (uint64_t)135246, 1000, - {DataLocation{"file1", 1}, DataLocation{"file2", 2}})); + {SourceLocation{"file1", 1}, SourceLocation{"file2", 2}})); { // Test that symbol names and file names are stored in the input order. - EXPECT_THAT(llvm::to_vector(Data.getStrings()), - ElementsAre("foo", "bar.__uniq.321", "file2", "file1")); + EXPECT_THAT( + llvm::to_vector(llvm::make_first_range(Data.getStrToIndexMapRef())), + ElementsAre("foo", "bar.__uniq.321", "file2", "file1")); EXPECT_THAT(Data.getKnownColdSymbols(), ElementsAre("sym2", "sym1")); EXPECT_THAT(Data.getKnownColdHashes(), ElementsAre(789, 678)); @@ -83,38 +88,34 @@ TEST(MemProf, DataAccessProfile) { EXPECT_TRUE(Data.isKnownColdSymbol("sym2")); EXPECT_TRUE(Data.isKnownColdSymbol("sym1")); - EXPECT_EQ(Data.getProfileRecord("non-existence"), nullptr); - EXPECT_EQ(Data.getProfileRecord((uint64_t)789987), nullptr); + EXPECT_EQ(Data.getProfileRecord("non-existence"), std::nullopt); + EXPECT_EQ(Data.getProfileRecord((uint64_t)789987), std::nullopt); EXPECT_THAT( - *Data.getProfileRecord("foo.llvm.123"), - AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0), - testing::Field(&DataAccessProfRecord::AccessCount, 100), - testing::Field(&DataAccessProfRecord::IsStringLiteral, false), - testing::Field(&DataAccessProfRecord::Locations, - testing::IsEmpty()))); + Data.getProfileRecord("foo.llvm.123"), + ValueIs(AllOf( + Field(&DataAccessProfRecord::SymHandle, + testing::VariantWith(testing::Eq("foo"))), + Field(&DataAccessProfRecord::Locations, testing::IsEmpty())))); EXPECT_THAT( - *Data.getProfileRecord("bar.__uniq.321"), - AllOf( - testing::Field(&DataAccessProfRecord::SymbolID, 1), - testing::Field(&DataAccessProfRecord::AccessCount, 123), - testing::Field(&DataAccessProfRecord::IsStringLiteral, false), - testing::Field(&DataAccessProfRecord::Locations, - ElementsAre(AllOf( - testing::Field(&DataLocation::FileName, "file2"), - testing::Field(&DataLocation::Line, 3)))))); + Data.getProfileRecord("bar.__uniq.321"), + ValueIs(AllOf( + Field(&DataAccessProfRecord::SymHandle, + testing::VariantWith( + testing::Eq("bar.__uniq.321"))), + Field(&DataAccessProfRecord::Locations, + ElementsAre(AllOf(Field(&SourceLocation::FileName, "file2"), + Field(&SourceLocation::Line, 3))))))); EXPECT_THAT( - *Data.getProfileRecord((uint64_t)135246), - AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 135246), - testing::Field(&DataAccessProfRecord::AccessCount, 1000), - testing::Field(&DataAccessProfRecord::IsStringLiteral, true), - testing::Field( - &DataAccessProfRecord::Locations, - ElementsAre( - AllOf(testing::Field(&DataLocation::FileName, "file1"), - testing::Field(&DataLocation::Line, 1)), - AllOf(testing::Field(&DataLocation::FileName, "file2"), - testing::Field(&DataLocation::Line, 2)))))); + Data.getProfileRecord((uint64_t)135246), + ValueIs(AllOf( + Field(&DataAccessProfRecord::SymHandle, + testing::VariantWith(testing::Eq(135246))), + Field(&DataAccessProfRecord::Locations, + ElementsAre(AllOf(Field(&SourceLocation::FileName, "file1"), + Field(&SourceLocation::Line, 1)), + AllOf(Field(&SourceLocation::FileName, "file2"), + Field(&SourceLocation::Line, 2))))))); } // Tests serialization and de-serialization. @@ -128,11 +129,13 @@ TEST(MemProf, DataAccessProfile) { const unsigned char *p = reinterpret_cast(serializedData.data()); - ASSERT_THAT(llvm::to_vector(deserializedData.getStrings()), + ASSERT_THAT(llvm::to_vector(llvm::make_first_range( + deserializedData.getStrToIndexMapRef())), testing::IsEmpty()); EXPECT_FALSE(deserializedData.deserialize(p)); - EXPECT_THAT(llvm::to_vector(deserializedData.getStrings()), + EXPECT_THAT(llvm::to_vector(llvm::make_first_range( + deserializedData.getStrToIndexMapRef())), ElementsAre("foo", "bar.__uniq.321", "file2", "file1")); EXPECT_THAT(deserializedData.getKnownColdSymbols(), ElementsAre("sym2", "sym1")); @@ -150,30 +153,27 @@ TEST(MemProf, DataAccessProfile) { EXPECT_THAT( Records, ElementsAre( - AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 0), - testing::Field(&DataAccessProfRecord::AccessCount, 100), - testing::Field(&DataAccessProfRecord::IsStringLiteral, false), - testing::Field(&DataAccessProfRecord::Locations, - testing::IsEmpty())), - AllOf(testing::Field(&DataAccessProfRecord::SymbolID, 1), - testing::Field(&DataAccessProfRecord::AccessCount, 123), - testing::Field(&DataAccessProfRecord::IsStringLiteral, false), - testing::Field( - &DataAccessProfRecord::Locations, - ElementsAre(AllOf( - testing::Field(&DataLocation::FileName, "file2"), - testing::Field(&DataLocation::Line, 3))))), AllOf( - testing::Field(&DataAccessProfRecord::SymbolID, 135246), - testing::Field(&DataAccessProfRecord::AccessCount, 1000), - testing::Field(&DataAccessProfRecord::IsStringLiteral, true), - testing::Field( - &DataAccessProfRecord::Locations, - ElementsAre( - AllOf(testing::Field(&DataLocation::FileName, "file1"), - testing::Field(&DataLocation::Line, 1)), - AllOf(testing::Field(&DataLocation::FileName, "file2"), - testing::Field(&DataLocation::Line, 2))))))); + Field(&DataAccessProfRecordRef::SymbolID, 0), + Field(&DataAccessProfRecordRef::AccessCount, 100), + Field(&DataAccessProfRecordRef::IsStringLiteral, false), + Field(&DataAccessProfRecordRef::Locations, testing::IsEmpty())), + AllOf(Field(&DataAccessProfRecordRef::SymbolID, 1), + Field(&DataAccessProfRecordRef::AccessCount, 123), + Field(&DataAccessProfRecordRef::IsStringLiteral, false), + Field(&DataAccessProfRecordRef::Locations, + ElementsAre( + AllOf(Field(&SourceLocationRef::FileName, "file2"), + Field(&SourceLocationRef::Line, 3))))), + AllOf(Field(&DataAccessProfRecordRef::SymbolID, 135246), + Field(&DataAccessProfRecordRef::AccessCount, 1000), + Field(&DataAccessProfRecordRef::IsStringLiteral, true), + Field(&DataAccessProfRecordRef::Locations, + ElementsAre( + AllOf(Field(&SourceLocationRef::FileName, "file1"), + Field(&SourceLocationRef::Line, 1)), + AllOf(Field(&SourceLocationRef::FileName, "file2"), + Field(&SourceLocationRef::Line, 2))))))); } } } // namespace From 4b25d6706b6fc0092a566dddbbe937d325ac7dd8 Mon Sep 17 00:00:00 2001 From: mingmingl Date: Wed, 14 May 2025 21:06:56 -0700 Subject: [PATCH 7/8] Keep copies of owned strings for cold symbol names --- llvm/lib/ProfileData/DataAccessProf.cpp | 3 ++- llvm/unittests/ProfileData/DataAccessProfTest.cpp | 11 ++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/llvm/lib/ProfileData/DataAccessProf.cpp b/llvm/lib/ProfileData/DataAccessProf.cpp index c5d0099977cfa..a31f3db0621fb 100644 --- a/llvm/lib/ProfileData/DataAccessProf.cpp +++ b/llvm/lib/ProfileData/DataAccessProf.cpp @@ -111,7 +111,8 @@ Error DataAccessProfData::addKnownSymbolWithoutSamples( auto CanonicalName = getCanonicalName(std::get(SymbolID)); if (!CanonicalName) return CanonicalName.takeError(); - KnownColdSymbols.insert(*CanonicalName); + KnownColdSymbols.insert( + saveStringToMap(StrToIndexMap, Saver, *CanonicalName).first); return Error::success(); } diff --git a/llvm/unittests/ProfileData/DataAccessProfTest.cpp b/llvm/unittests/ProfileData/DataAccessProfTest.cpp index 127230d4805e7..084a8e96cdafe 100644 --- a/llvm/unittests/ProfileData/DataAccessProfTest.cpp +++ b/llvm/unittests/ProfileData/DataAccessProfTest.cpp @@ -78,7 +78,7 @@ TEST(MemProf, DataAccessProfile) { // Test that symbol names and file names are stored in the input order. EXPECT_THAT( llvm::to_vector(llvm::make_first_range(Data.getStrToIndexMapRef())), - ElementsAre("foo", "bar.__uniq.321", "file2", "file1")); + ElementsAre("foo", "sym2", "bar.__uniq.321", "file2", "sym1", "file1")); EXPECT_THAT(Data.getKnownColdSymbols(), ElementsAre("sym2", "sym1")); EXPECT_THAT(Data.getKnownColdHashes(), ElementsAre(789, 678)); @@ -134,9 +134,10 @@ TEST(MemProf, DataAccessProfile) { testing::IsEmpty()); EXPECT_FALSE(deserializedData.deserialize(p)); - EXPECT_THAT(llvm::to_vector(llvm::make_first_range( - deserializedData.getStrToIndexMapRef())), - ElementsAre("foo", "bar.__uniq.321", "file2", "file1")); + EXPECT_THAT( + llvm::to_vector( + llvm::make_first_range(deserializedData.getStrToIndexMapRef())), + ElementsAre("foo", "sym2", "bar.__uniq.321", "file2", "sym1", "file1")); EXPECT_THAT(deserializedData.getKnownColdSymbols(), ElementsAre("sym2", "sym1")); EXPECT_THAT(deserializedData.getKnownColdHashes(), ElementsAre(789, 678)); @@ -158,7 +159,7 @@ TEST(MemProf, DataAccessProfile) { Field(&DataAccessProfRecordRef::AccessCount, 100), Field(&DataAccessProfRecordRef::IsStringLiteral, false), Field(&DataAccessProfRecordRef::Locations, testing::IsEmpty())), - AllOf(Field(&DataAccessProfRecordRef::SymbolID, 1), + AllOf(Field(&DataAccessProfRecordRef::SymbolID, 2), Field(&DataAccessProfRecordRef::AccessCount, 123), Field(&DataAccessProfRecordRef::IsStringLiteral, false), Field(&DataAccessProfRecordRef::Locations, From 2ecc621c9310b9ea975cba06321fd7623442f343 Mon Sep 17 00:00:00 2001 From: mingmingl Date: Thu, 15 May 2025 16:48:10 -0700 Subject: [PATCH 8/8] move type comment before statement --- llvm/include/llvm/ProfileData/DataAccessProf.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llvm/include/llvm/ProfileData/DataAccessProf.h b/llvm/include/llvm/ProfileData/DataAccessProf.h index e8504102238d1..3cc8835a776dd 100644 --- a/llvm/include/llvm/ProfileData/DataAccessProf.h +++ b/llvm/include/llvm/ProfileData/DataAccessProf.h @@ -121,6 +121,8 @@ struct DataAccessProfRecord { /// deserialization. class DataAccessProfData { public: + // Use MapVector to keep input order of strings for serialization and + // deserialization. using StringToIndexMap = llvm::MapVector; DataAccessProfData() : Saver(Allocator) {} @@ -201,8 +203,6 @@ class DataAccessProfData { // `Records` stores the records. MapVector Records; - // Use MapVector to keep input order of strings for serialization and - // deserialization. StringToIndexMap StrToIndexMap; llvm::SetVector KnownColdHashes; llvm::SetVector KnownColdSymbols;