Commit 7a3c3b40 authored by pbethge's avatar pbethge
Browse files

add frozen graph support

parent 3037843a
......@@ -22,7 +22,8 @@
namespace ofxTF2 {
Model::Model(const std::string & modelPath) {
Model::Model(const std::string & modelPath, const cppflow::model::TYPE type) {
this->type = type;
Model::load(modelPath);
}
......@@ -30,15 +31,32 @@ Model::~Model(){
Model::clear();
}
void Model::setModelType(const cppflow::model::TYPE type) {
this->type = type;
}
bool Model::load(const std::string & modelPath) {
Model::clear();
std::string path = ofToDataPath(modelPath);
if(!ofDirectory::doesDirectoryExist(path)) {
ofLogError("ofxTensorFlow2") << "Model: model path not found: "
<< modelPath;
return false;
if (this->type == cppflow::model::SAVED_MODEL){
if(!ofDirectory::doesDirectoryExist(path)) {
ofLogError("ofxTensorFlow2") << "Model: model path not found: "
<< modelPath;
return false;
}
}
else if (this->type == cppflow::model::FROZEN_GRAPH){
if(!ofFile::doesFileExist(path)) {
ofLogError("ofxTensorFlow2") << "Model: model path not found: "
<< modelPath;
return false;
}
}
else {
ofLogError("ofxTensorFlow2") << "Model: model type unknown";
return false;
}
auto model = new cppflow::model(path);
auto model = new cppflow::model(path, this->type);
if(!model) {
modelPath_ = "";
ofLogError("ofxTensorFlow2") << "Model: model could not be initialized!";
......
......@@ -64,14 +64,20 @@ class Model {
public:
Model() = default;
Model(const std::string & modelPath);
Model(const std::string & modelPath, const cppflow::model::TYPE type=cppflow::model::SAVED_MODEL);
virtual ~Model();
/// load a SavedModel directory relative to bin/data
/// load a model from directory or file relative to bin/data
/// type maybe either cppflow::model::TYPE::SAVED_MODEL or cppflow::model::TYPE::FROZEN_GRAPH
/// directories for SavedModels include assets, variables, and a .pb file
/// file for frozen graph is a .pb file
/// returns true on success
virtual bool load(const std::string & modelPath);
/// set the model type
/// type maybe either cppflow::model::TYPE::SAVED_MODEL or cppflow::model::TYPE::FROZEN_GRAPH
virtual void setModelType(const cppflow::model::TYPE type);
/// set input and output names or reset to defaults (call without args)
/// use the CLI tool "saved_model_cli" to inspect the SavedModel e.g.
/// saved_model_cli show --dir path/to/model/ --tag_set serve
......@@ -101,7 +107,7 @@ public:
virtual void printOperations();
protected:
cppflow::model::TYPE type = cppflow::model::TYPE::SAVED_MODEL;
std::string modelPath_ = "";
std::vector<std::string> inputNames_ = {{"serving_default_input_1"}};
std::vector<std::string> outputNames_ = {{"StatefulPartitionedCall"}};
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment