// // ShapeInference.h // MNN // // Created by MNN on 2020/04/05. // Copyright © 2018, Alibaba Group Holding Limited // #ifndef MNN_PLUGIN_PLUGIN_SHAPE_INFERENCE_HPP_ #define MNN_PLUGIN_PLUGIN_SHAPE_INFERENCE_HPP_ #include <functional> #include <string> #include <unordered_map> #include <MNN/plugin/PluginContext.hpp> namespace MNN { namespace plugin { class MNN_PUBLIC InferShapeKernel { public: virtual ~InferShapeKernel() = default; virtual bool compute(InferShapeContext* ctx) = 0; }; class MNN_PUBLIC InferShapeKernelRegister { public: // typedef InferShapeKernel* (*Factory)(); typedef std::function<InferShapeKernel*()> Factory; static std::unordered_map<std::string, Factory>* getFactoryMap(); static bool add(const std::string& name, Factory factory); static InferShapeKernel* get(const std::string& name); }; template <typename PluginKernel> struct InferShapeKernelRegistrar { InferShapeKernelRegistrar(const std::string& name) { InferShapeKernelRegister::add(name, []() { // NOLINT return new PluginKernel; // NOLINT }); } }; #define REGISTER_PLUGIN_OP(name, inferShapeKernel) \ namespace { \ static auto _plugin_infer_shape_##name##_ __attribute__((unused)) = \ InferShapeKernelRegistrar<inferShapeKernel>(#name); \ } // namespace } // namespace plugin } // namespace MNN #endif // MNN_PLUGIN_PLUGIN_SHAPE_INFERENCE_HPP_