// // ShapeInference.h // MNN // // Created by MNN on 2020/04/05. // Copyright © 2018, Alibaba Group Holding Limited // #ifndef MNN_PLUGIN_PLUGIN_KERNEL_HPP_ #define MNN_PLUGIN_PLUGIN_KERNEL_HPP_ #include <functional> #include <string> #include <unordered_map> #include <MNN/plugin/PluginContext.hpp> namespace MNN { namespace plugin { template <typename KernelContextT> class MNN_PUBLIC ComputeKernel { public: ComputeKernel() = default; virtual ~ComputeKernel() = default; virtual bool compute(KernelContextT* ctx) = 0; }; class MNN_PUBLIC CPUComputeKernel : public ComputeKernel<CPUKernelContext> { public: using ContextT = CPUKernelContext; using KernelT = CPUComputeKernel; CPUComputeKernel() = default; virtual ~CPUComputeKernel() = default; virtual bool init(CPUKernelContext* ctx) = 0; virtual bool compute(CPUKernelContext* ctx) = 0; }; template <typename PluginKernelT> class MNN_PUBLIC ComputeKernelRegistry { public: typedef std::function<PluginKernelT*()> Factory; static std::unordered_map<std::string, Factory>* getFactoryMap(); static bool add(const std::string& name, Factory factory); static PluginKernelT* get(const std::string& name); }; template <typename PluginKernelT> struct ComputeKernelRegistrar { ComputeKernelRegistrar(const std::string& name) { ComputeKernelRegistry<typename PluginKernelT::KernelT>::add(name, []() { // NOLINT return new PluginKernelT; // NOLINT }); } }; #define REGISTER_PLUGIN_COMPUTE_KERNEL(name, computeKernel) \ namespace { \ static auto _plugin_compute_kernel_##name##_ __attribute__((unused)) = \ ComputeKernelRegistrar<computeKernel>(#name); \ } // namespace } // namespace plugin } // namespace MNN #endif // MNN_PLUGIN_PLUGIN_KERNEL_HPP_