// // Interpreter.hpp // MNN // // Created by MNN on 2018/07/23. // Copyright © 2018, Alibaba Group Holding Limited // #ifndef Interpreter_hpp #define Interpreter_hpp #include <functional> #include <map> #include <memory> #include <string> #include <MNN/ErrorCode.hpp> #include <MNN/MNNForwardType.h> #include <MNN/Tensor.hpp> namespace MNN { /** session schedule config */ struct ScheduleConfig { /** which tensor should be kept */ std::vector<std::string> saveTensors; /** forward type */ MNNForwardType type = MNN_FORWARD_CPU; /** CPU:number of threads in parallel , Or GPU: mode setting*/ union { int numThread = 4; int mode; }; /** subpath to run */ struct Path { std::vector<std::string> inputs; std::vector<std::string> outputs; enum Mode { /** * Op Mode * - inputs means the source op, can NOT be empty. * - outputs means the sink op, can be empty. * The path will start from source op, then flow when encounter the sink op. * The sink op will not be compute in this path. */ Op = 0, /** * Tensor Mode * - inputs means the inputs tensors, can NOT be empty. * - outputs means the outputs tensors, can NOT be empty. * It will find the pipeline that compute outputs from inputs. */ Tensor = 1 }; /** running mode */ Mode mode = Op; }; Path path; /** backup backend used to create execution when desinated backend do NOT support any op */ MNNForwardType backupType = MNN_FORWARD_CPU; /** extra backend config */ BackendConfig* backendConfig = nullptr; }; class Session; struct Content; class Tensor; class Backend; class Runtime; class MNN_PUBLIC OperatorInfo { struct Info; public: /** Operator's name*/ const std::string& name() const; /** Operator's type*/ const std::string& type() const; /** Operator's flops, in M*/ float flops() const; protected: OperatorInfo(); ~OperatorInfo(); Info* mContent; }; typedef std::function<bool(const std::vector<Tensor*>&, const std::string& /*opName*/)> TensorCallBack; typedef std::function<bool(const std::vector<Tensor*>&, const OperatorInfo*)> TensorCallBackWithInfo; typedef std::pair<std::map<MNNForwardType, std::shared_ptr<Runtime>>, std::shared_ptr<Runtime>> RuntimeInfo; /** net data holder. multiple sessions could share same net. */ class MNN_PUBLIC Interpreter { public: /** * @brief create net from file. * @param file given file. * @return created net if success, NULL otherwise. */ static Interpreter* createFromFile(const char* file); /** * @brief create net from buffer. * @param buffer given data buffer. * @param size size of data buffer. * @return created net if success, NULL otherwise. */ static Interpreter* createFromBuffer(const void* buffer, size_t size); ~Interpreter(); enum SessionMode { /** About CallBack, Default Session_Debug*/ /** runSessionWithCallBack is allowed and can get internal op info*/ Session_Debug = 0, /** runSessionWithCallBack is not valid and can't get any info of op in session*/ Session_Release = 1, /** About input tenosr, Default Session_Input_Inside*/ /** The input tensor is alloced by session, input data after session resized*/ Session_Input_Inside = 2, /** The input tensor is alloced by user, set input data before session resize*/ Session_Input_User = 3, /** The output tensor depends on session, and can't be seperate used*/ Session_Output_Inside = 4, /** The output tensor can be seperated from session*/ Session_Output_User = 5, /** Try Resize Session when create Session or not, default direct: */ Session_Resize_Direct = 6, Session_Resize_Defer = 7, /** Determine the Execution's forward type is determine by user or auto determine */ Session_Backend_Fix = 8, // Use the backend user set, when not support use default backend Session_Backend_Auto = 9, // Auto Determine the Op type by MNN }; /** * @brief The API shoud be called before create session. * @param mode session mode */ void setSessionMode(SessionMode mode); /** * @brief The API shoud be called before create session. * If the cache exist, try to load cache from file. * After createSession, try to save cache to file. * @param cacheFile cache file name * @param keySize depercerate, for future use. */ void setCacheFile(const char* cacheFile, size_t keySize = 128); /** * @brief The API shoud be called after last resize session. * If resize session generate new cache info, try to rewrite cache file. * If resize session do not generate any new cache info, just do nothing. * @param session giveb session * @param flag Protected param, not used now */ ErrorCode updateCacheFile(Session *session, int flag = 0); enum HintMode { // Max Op number for async tuning MAX_TUNING_NUMBER = 0, }; /** * @brief The API shoud be called before create session. * @param mode Hint type * @param value Hint value */ void setSessionHint(HintMode mode, int value); public: /** * @brief create runtimeInfo seperately with schedule config. * @param configs session schedule configs. */ static RuntimeInfo createRuntime(const std::vector<ScheduleConfig>& configs); /** * @brief create session with schedule config. created session will be managed in net. * @param config session schedule config. * @return created session if success, NULL otherwise. */ Session* createSession(const ScheduleConfig& config); /** * @brief create session with schedule config and user-specified runtime. * @param config session schedule config, runtime runtimeInfo used by the created session. * @return created session if success, NULL otherwise. */ Session* createSession(const ScheduleConfig& config, const RuntimeInfo& runtime); /** * @brief create multi-path session with schedule configs. created session will be managed in net. * @param configs session schedule configs. * @return created session if success, NULL otherwise. */ Session* createMultiPathSession(const std::vector<ScheduleConfig>& configs); /** * @brief create multi-path session with schedule configs and user-specified runtime. created session will be managed in net. * @param configs session schedule configs. * @return created session if success, NULL otherwise. */ Session* createMultiPathSession(const std::vector<ScheduleConfig>& configs, const RuntimeInfo& runtime); /** * @brief release session. * @param session given session. * @return true if given session is held by net and is freed. */ bool releaseSession(Session* session); /** * @brief call this function to get tensors ready. output tensor buffer (host or deviceId) should be retrieved * after resize of any input tensor. * @param session given session. */ void resizeSession(Session* session); /** * @brief call this function if don't need resize or create session any more, it will save a few memory that equal * to the size of model buffer */ void releaseModel(); /** * @brief Get the model buffer for user to save * @return std::make_pair(modleBuffer, modelSize). * @example: * std::ofstream output("trainResult.alinn") * auto buffer = net->getModelBuffer(); * output.write((const char*)buffer.first, buffer.second); */ std::pair<const void*, size_t> getModelBuffer() const; /** * @brief update Session's Tensor to model's Const Op * @param session given session. * @return result of running. */ ErrorCode updateSessionToModel(Session* session); /** * @brief run session. * @param session given session. * @return result of running. */ ErrorCode runSession(Session* session) const; /* * @brief run session. * @param session given session. * @param before callback before each op. return true to run the op; return false to skip the op. * @param after callback after each op. return true to continue running; return false to interrupt the session. * @param sync synchronously wait for finish of execution or not. * @return result of running. */ ErrorCode runSessionWithCallBack(const Session* session, const TensorCallBack& before, const TensorCallBack& end, bool sync = false) const; /* * @brief run session. * @param session given session. * @param before callback before each op. return true to run the op; return false to skip the op. * @param after callback after each op. return true to continue running; return false to interrupt the session. * @param sync synchronously wait for finish of execution or not. * @return result of running. */ ErrorCode runSessionWithCallBackInfo(const Session* session, const TensorCallBackWithInfo& before, const TensorCallBackWithInfo& end, bool sync = false) const; /** * @brief get input tensor for given name. * @param session given session. * @param name given name. if NULL, return first input. * @return tensor if found, NULL otherwise. */ Tensor* getSessionInput(const Session* session, const char* name); /** * @brief get output tensor for given name. * @param session given session. * @param name given name. if NULL, return first output. * @return tensor if found, NULL otherwise. */ Tensor* getSessionOutput(const Session* session, const char* name); enum SessionInfoCode { /** memory session used in MB, float* */ MEMORY = 0, /** float operation needed in session in M, float* */ FLOPS = 1, /** Backends in session in M, int*, length >= 1 + number of configs when create session */ BACKENDS = 2, ALL }; /** * @brief get session info * @param session given session. * @param code given info code. * @param ptr given info ptr, see SessionInfoCode for detail * @return true if support the code, false otherwise. */ bool getSessionInfo(const Session* session, SessionInfoCode code, void* ptr); /** * @brief get all output tensors. * @param session given session. * @return all output tensors mapped with name. */ const std::map<std::string, Tensor*>& getSessionOutputAll(const Session* session) const; /** * @brief get all input tensors. * @param session given session. * @return all input tensors mapped with name. */ const std::map<std::string, Tensor*>& getSessionInputAll(const Session* session) const; public: /** * @brief resize given tensor. * @param tensor given tensor. * @param dims new dims. at most 6 dims. */ void resizeTensor(Tensor* tensor, const std::vector<int>& dims); /** * @brief resize given tensor by nchw. * @param batch / N. * @param channel / C. * @param height / H. * @param width / W */ void resizeTensor(Tensor* tensor, int batch, int channel, int height, int width); /** * @brief get backend used to create given tensor. * @param session given session. * @param tensor given tensor. * @return backend used to create given tensor, may be NULL. */ const Backend* getBackend(const Session* session, const Tensor* tensor) const; /** * @brief get business code (model identifier). * @return business code. */ const char* bizCode() const; /** * @brief get model UUID * @return Model UUID. */ const char* uuid() const; private: static Interpreter* createFromBufferInternal(Content* net); Content* mNet = nullptr; Interpreter(Content* net); Interpreter(const Interpreter&) = delete; Interpreter(const Interpreter&&) = delete; Interpreter& operator=(const Interpreter&) = delete; Interpreter& operator=(const Interpreter&&) = delete; }; } // namespace MNN #endif /* Interpreter_hpp */