Module.hpp 4.7 KB
Newer Older
姜天宇's avatar
姜天宇 committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
//
//  Module.hpp
//  MNN
//
//  Created by MNN on 2019/11/25.
//  Copyright © 2018, Alibaba Group Holding Limited
//

#ifndef MNN_Train_Module_hpp
#define MNN_Train_Module_hpp

#include <vector>
#include <unordered_map>

#include <MNN/expr/Expr.hpp>
#include <MNN/expr/Executor.hpp>
#include <MNN/MNNForwardType.h>

namespace MNN {
namespace Express {
struct SubGraph;
class MNN_PUBLIC Module {
public:
    Module()                                                                               = default;
    virtual ~Module()                                                                      = default;
    virtual std::vector<Express::VARP> onForward(const std::vector<Express::VARP>& inputs) = 0;
    Express::VARP forward(Express::VARP input);
    std::vector<Express::VARP> parameters() const;
    bool loadParameters(const std::vector<Express::VARP>& parameters);
    void setIsTraining(const bool isTraining);
    bool getIsTraining();
    void clearCache();

    const std::string& name() const {
        return mName;
    };
    void setName(std::string name) {
        mName = std::move(name);
    }
    const std::string type() const {
        return mType;
    }
    void setType(std::string type) {
        mType = std::move(type);
    }
    // Return the parameter index
    int addParameter(Express::VARP parameter);

    void setParameter(Express::VARP parameter, int index);
    static Module* createEmpty(const std::vector<Express::VARP>& parameters);
    
    struct BackendInfo {
        MNNForwardType type = MNN_FORWARD_CPU;
        BackendConfig* config = nullptr;
    };
    
    struct Config {
        // Load module as dynamic, default static
        bool dynamic = false;

        // for static mode, if the shape is mutable, set true, otherwise set false to avoid resizeSession freqencily
        bool shapeMutable = true;
        // Pre-rearrange weights or not. Disabled by default.
        // The weights will be rearranged in a general way, so the best implementation
        // may not be adopted if `rearrange` is enabled.
        bool rearrange = false;
        
        BackendInfo* backend = nullptr;
    };
    static Module* load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const Config* config = nullptr);
    static Module* load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const char* fileName, const Config* config = nullptr);
    // Shared RuntimeManager
    static Module* load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const char* fileName, const std::shared_ptr<MNN::Express::Executor::RuntimeManager> rtMgr, const Config* config = nullptr);
    static Module* load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const std::shared_ptr<MNN::Express::Executor::RuntimeManager> rtMgr, const Config* config = nullptr);
    
    static Module* extract(std::vector<Express::VARP> inputs, std::vector<Express::VARP> outputs, bool fortrain, const std::map<std::string, SubGraph>& subGraph = {});

    static Module* clone(const Module* module, const bool shareParams = false);

    struct Info {
        // Input info load from model
        // If the ith input has no info, it will be nullptr
        std::vector<Variable::Info> inputs;
        Dimensionformat defaultFormat;
        std::shared_ptr<MNN::Express::Executor::RuntimeManager> runTimeManager;
    };
    const Info* getInfo() const;
    class CloneContext {
    public:
        CloneContext() = default;
        explicit CloneContext(const bool shareParams)
            : mShareParams(shareParams) {}
        virtual ~CloneContext() = default;

        const bool shareParams() const { return mShareParams; }

        EXPRP getOrClone(const EXPRP expr);
        VARP getOrClone(const VARP var);

    private:
        bool mShareParams = false;
        std::unordered_map<const Expr*, EXPRP> mExprMap;
        std::unordered_map<const Variable*, VARP> mVarMap;
    };

    virtual Module* clone(CloneContext* ctx) const {
        return nullptr;
    }
    void registerModel(const std::vector<std::shared_ptr<Module>>& children);

protected:
    virtual void onClearCache() {
    }

    Module* cloneBaseTo(CloneContext* ctx, Module* module) const;

private:
    void _collectParameters(std::vector<Express::VARP>& result) const;
    std::vector<std::shared_ptr<Module>> mChildren;
    std::vector<Express::VARP> mParameters;
    bool mIsTraining = true;
    std::string mName;
    std::string mType;
};

struct SubGraph {
    std::vector<std::string> inputs;
    std::vector<std::string> outputs;
    std::shared_ptr<Module> m;
};

} // namespace Train
} // namespace MNN

#endif