Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Offload] Only initialize a plugin if it is needed #92765

Merged
merged 1 commit into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions offload/plugins-nextgen/common/include/JIT.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,6 @@ struct JITEngine {
process(const __tgt_device_image &Image,
target::plugin::GenericDeviceTy &Device);

/// Return true if \p Image is a bitcode image that can be JITed for the given
/// architecture.
Expected<bool> checkBitcodeImage(StringRef Buffer) const;

private:
/// Compile the bitcode image \p Image and generate the binary image that can
/// be loaded to the target device of the triple \p Triple architecture \p
Expand Down
12 changes: 11 additions & 1 deletion offload/plugins-nextgen/common/include/PluginInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,10 @@ struct GenericPluginTy {
/// given target. Returns true if the \p Image is compatible with the plugin.
Expected<bool> checkELFImage(StringRef Image) const;

/// Return true if the \p Image can be compiled to run on the platform's
/// target architecture.
Expected<bool> checkBitcodeImage(StringRef Image) const;

/// Indicate if an image is compatible with the plugin devices. Notice that
/// this function may be called before actually initializing the devices. So
/// we could not move this function into GenericDeviceTy.
Expand All @@ -1066,8 +1070,11 @@ struct GenericPluginTy {
public:
// TODO: This plugin interface needs to be cleaned up.

/// Returns true if the plugin has been initialized.
int32_t is_initialized() const;

/// Returns non-zero if the provided \p Image can be executed by the runtime.
int32_t is_valid_binary(__tgt_device_image *Image);
int32_t is_valid_binary(__tgt_device_image *Image, bool Initialized = true);

/// Initialize the device inside of the plugin.
int32_t init_device(int32_t DeviceId);
Expand Down Expand Up @@ -1187,6 +1194,9 @@ struct GenericPluginTy {
void **KernelPtr);

private:
/// Indicates if the platform runtime has been fully initialized.
bool Initialized = false;

/// Number of devices available for the plugin.
int32_t NumDevices = 0;

Expand Down
16 changes: 0 additions & 16 deletions offload/plugins-nextgen/common/src/JIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,19 +323,3 @@ JITEngine::process(const __tgt_device_image &Image,

return &Image;
}

Expected<bool> JITEngine::checkBitcodeImage(StringRef Buffer) const {
TimeTraceScope TimeScope("Check bitcode image");

assert(identify_magic(Buffer) == file_magic::bitcode &&
"Input is not bitcode");

LLVMContext Context;
auto ModuleOrErr = getLazyBitcodeModule(MemoryBufferRef(Buffer, ""), Context,
/*ShouldLazyLoadMetadata=*/true);
if (!ModuleOrErr)
return ModuleOrErr.takeError();
Module &M = **ModuleOrErr;

return Triple(M.getTargetTriple()).getArch() == TT.getArch();
}
34 changes: 28 additions & 6 deletions offload/plugins-nextgen/common/src/PluginInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "omp-tools.h"
#endif

#include "llvm/Bitcode/BitcodeReader.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/JSON.h"
Expand Down Expand Up @@ -1495,6 +1496,7 @@ Error GenericPluginTy::init() {
if (!NumDevicesOrErr)
return NumDevicesOrErr.takeError();

Initialized = true;
NumDevices = *NumDevicesOrErr;
if (NumDevices == 0)
return Plugin::success();
Expand Down Expand Up @@ -1578,14 +1580,27 @@ Expected<bool> GenericPluginTy::checkELFImage(StringRef Image) const {
if (!MachineOrErr)
return MachineOrErr.takeError();

if (!*MachineOrErr)
return MachineOrErr;
}

Expected<bool> GenericPluginTy::checkBitcodeImage(StringRef Image) const {
if (identify_magic(Image) != file_magic::bitcode)
return false;

// Perform plugin-dependent checks for the specific architecture if needed.
return isELFCompatible(Image);
LLVMContext Context;
auto ModuleOrErr = getLazyBitcodeModule(MemoryBufferRef(Image, ""), Context,
/*ShouldLazyLoadMetadata=*/true);
if (!ModuleOrErr)
return ModuleOrErr.takeError();
Module &M = **ModuleOrErr;

return Triple(M.getTargetTriple()).getArch() == getTripleArch();
jplehr marked this conversation as resolved.
Show resolved Hide resolved
}

int32_t GenericPluginTy::is_valid_binary(__tgt_device_image *Image) {
int32_t GenericPluginTy::is_initialized() const { return Initialized; }

int32_t GenericPluginTy::is_valid_binary(__tgt_device_image *Image,
bool Initialized) {
StringRef Buffer(reinterpret_cast<const char *>(Image->ImageStart),
target::getPtrDiff(Image->ImageEnd, Image->ImageStart));

Expand All @@ -1603,10 +1618,17 @@ int32_t GenericPluginTy::is_valid_binary(__tgt_device_image *Image) {
auto MatchOrErr = checkELFImage(Buffer);
if (Error Err = MatchOrErr.takeError())
return HandleError(std::move(Err));
return *MatchOrErr;
if (!Initialized || !*MatchOrErr)
return *MatchOrErr;

// Perform plugin-dependent checks for the specific architecture if needed.
auto CompatibleOrErr = isELFCompatible(Buffer);
if (Error Err = CompatibleOrErr.takeError())
return HandleError(std::move(Err));
return *CompatibleOrErr;
}
case file_magic::bitcode: {
auto MatchOrErr = getJIT().checkBitcodeImage(Buffer);
auto MatchOrErr = checkBitcodeImage(Buffer);
if (Error Err = MatchOrErr.takeError())
return HandleError(std::move(Err));
return *MatchOrErr;
Expand Down
34 changes: 24 additions & 10 deletions offload/src/PluginManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,8 @@ void PluginManager::init() {
// Attempt to create an instance of each supported plugin.
#define PLUGIN_TARGET(Name) \
do { \
auto Plugin = std::unique_ptr<GenericPluginTy>(createPlugin_##Name()); \
if (auto Err = Plugin->init()) { \
[[maybe_unused]] std::string InfoMsg = toString(std::move(Err)); \
DP("Failed to init plugin: %s\n", InfoMsg.c_str()); \
} else { \
DP("Registered plugin %s with %d visible device(s)\n", \
Plugin->getName(), Plugin->number_of_devices()); \
Plugins.emplace_back(std::move(Plugin)); \
} \
Plugins.emplace_back( \
std::unique_ptr<GenericPluginTy>(createPlugin_##Name())); \
} while (false);
#include "Shared/Targets.def"

Expand Down Expand Up @@ -160,6 +153,27 @@ void PluginManager::registerLib(__tgt_bin_desc *Desc) {
if (Entry.flags == OMP_REGISTER_REQUIRES)
PM->addRequirements(Entry.data);

// Initialize all the plugins that have associated images.
for (auto &Plugin : Plugins) {
if (Plugin->is_initialized())
continue;

// Extract the exectuable image and extra information if availible.
for (int32_t i = 0; i < Desc->NumDeviceImages; ++i) {
if (!Plugin->is_valid_binary(&Desc->DeviceImages[i],
/*Initialized=*/false))
continue;

if (auto Err = Plugin->init()) {
[[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
DP("Failed to init plugin: %s\n", InfoMsg.c_str());
} else {
DP("Registered plugin %s with %d visible device(s)\n",
Plugin->getName(), Plugin->number_of_devices());
}
}
}

// Extract the exectuable image and extra information if availible.
for (int32_t i = 0; i < Desc->NumDeviceImages; ++i)
PM->addDeviceImage(*Desc, Desc->DeviceImages[i]);
Expand All @@ -177,7 +191,7 @@ void PluginManager::registerLib(__tgt_bin_desc *Desc) {
if (!R.number_of_devices())
continue;

if (!R.is_valid_binary(Img)) {
if (!R.is_valid_binary(Img, /*Initialized=*/true)) {
DP("Image " DPxMOD " is NOT compatible with RTL %s!\n",
DPxPTR(Img->ImageStart), R.getName());
continue;
Expand Down
Loading