Commit 0ced7564 authored by Dan Wilcox's avatar Dan Wilcox
Browse files

char_rnn format editing

parent 49e6967d
......@@ -8,7 +8,7 @@
# The location of your root openFrameworks installation
# (default) OF_ROOT = ../../..
################################################################################
# OF_ROOT = ../../..
OF_ROOT = /Users/wilcox/src/of/of_v0.11.2_osx_release
################################################################################
# PROJECT ROOT
......
......@@ -12,50 +12,46 @@
* This code has been developed at ZKM | Hertz-Lab as part of „The Intelligent
* Museum“ generously funded by the German Federal Cultural Foundation.
*
* This code is based on Memo Akten's ofxMSATensorFlow example.
* This code is based on Memo Akten's ofxMSATensorFlow example-char-rnn:
* https://github.com/memo/ofxMSATensorFlow
*/
#include "ofApp.h"
//--------------------------------------------------------------
// from msa::tf:: utilities
// from msa::tf::utilities
template<typename T> vector<T> adjustProbsWithTemp(const vector<T>& p_in, float t) {
if(t>0) {
if(t > 0) {
vector<T> p_out(p_in.size());
T sum = 0;
for(size_t i=0; i<p_in.size(); i++) {
p_out[i] = exp( log((double)p_in[i]) / (double)t );
for(size_t i = 0; i < p_in.size(); i++) {
p_out[i] = exp(log((double)p_in[i]) / (double)t);
sum += p_out[i];
}
if(sum > 0)
for(size_t i=0; i<p_out.size(); i++) p_out[i] /= sum;
if(sum > 0) {
for(size_t i = 0; i < p_out.size(); i++) {p_out[i] /= sum;}
}
return p_out;
}
return p_in;
}
//--------------------------------------------------------------
// from msa::tf:: utilities
// from msa::tf::utilities
template<typename T> int sample_from_prob(std::default_random_engine& rng, const vector<T>& p) {
std::discrete_distribution<int> rdist (p.begin(),p.end());
std::discrete_distribution<int> rdist(p.begin(), p.end());
int r = rdist(rng);
return r;
}
//--------------------------------------------------------------
void ofApp::setup() {
ofSetColor(255);
ofBackground(0);
ofSetVerticalSync(true);
ofSetLogLevel(OF_LOG_VERBOSE);
ofSetFrameRate(20); // generating a character per frame at 60fps is too fast to read in realtime
ofSetWindowTitle("example_frozen_graph_char_rnn");
ofSetFrameRate(20); // generating a character per frame at 60fps is too fast to read in realtime
ofSetColor(255);
ofBackground(0);
ofSetLogLevel(OF_LOG_VERBOSE);
// set model type and i/o names
model.setModelType(cppflow::model::TYPE::FROZEN_GRAPH);
......@@ -71,7 +67,7 @@ void ofApp::setup() {
// scan models dir
modelsDir.listDir("models");
if(modelsDir.size()==0) {
if(modelsDir.size() == 0) {
ofLogError() << "Couldn't find models folder.";
assert(false);
ofExit(1);
......@@ -86,14 +82,12 @@ void ofApp::setup() {
font.load(OF_TTF_SANS, 14);
}
//--------------------------------------------------------------
void ofApp::loadModelIndex(int index) {
curModelIndex = ofClamp(index, 0, modelsDir.size()-1);
loadModel(modelsDir.getPath(curModelIndex));
}
//--------------------------------------------------------------
void ofApp::loadModel(std::string dir) {
......@@ -116,7 +110,6 @@ void ofApp::loadModel(std::string dir) {
primeModel(textFull, primeLength);
}
//--------------------------------------------------------------
void ofApp::loadChars(string path) {
ofLogVerbose() << "load_chars : " << path;
......@@ -133,16 +126,14 @@ void ofApp::loadChars(string path) {
}
}
//--------------------------------------------------------------
void ofApp::primeModel(string primeData, int primeLength) {
outputReady = false;
for(unsigned int i=MAX(0, primeData.size()-primeLength); i<primeData.size(); i++) {
for(unsigned int i = MAX(0, primeData.size()-primeLength); i < primeData.size(); i++) {
runModel(primeData[i]);
}
}
//--------------------------------------------------------------
void ofApp::runModel(char ch) {
......@@ -168,13 +159,13 @@ void ofApp::runModel(char ch) {
outputReady = true;
}
//--------------------------------------------------------------
void ofApp::addChar(char ch) {
// add sampled char to text
if(ch == '\n') {
textLines.push_back("");
} else {
}
else {
textLines.back() += ch;
}
......@@ -188,11 +179,11 @@ void ofApp::addChar(char ch) {
}
// ghetto scroll
while(textLines.size() > maxLineNum) textLines.pop_front();
while(textLines.size() > maxLineNum) {textLines.pop_front();}
// rebuild text
textFull.clear();
for(auto&& text_line : textLines) {
for(auto &&text_line : textLines) {
textFull += "\n" + text_line;
}
......@@ -200,7 +191,6 @@ void ofApp::addChar(char ch) {
runModel(ch);
}
//--------------------------------------------------------------
void ofApp::draw() {
stringstream str;
......@@ -212,8 +202,8 @@ void ofApp::draw() {
str << endl;
str << "Press number key to load model: " << endl;
for(unsigned int i=0; i<modelsDir.size(); i++) {
auto marker = (i==curModelIndex) ? ">" : " ";
for(unsigned int i = 0; i < modelsDir.size(); i++) {
auto marker = (i == curModelIndex ? ">" : " ");
str << " " << (i+1) << " : " << marker << " " << modelsDir.getName(i) << endl;
}
......@@ -222,17 +212,13 @@ void ofApp::draw() {
str << "(and prime the model accordingly)" << endl;
str << endl;
if(outputReady) {
// sample one character from probability distribution
int curCharIndex = sample_from_prob(rng, lastModelOutput);
char curChar = intToChar[curCharIndex];
str << "Next char : " << curCharIndex << " | " << curChar << endl;
if(doAutoRun || doRunOnce) {
if(doRunOnce) doRunOnce = false;
if(doRunOnce) {doRunOnce = false;}
addChar(curChar);
}
}
......@@ -245,44 +231,34 @@ void ofApp::draw() {
ofDrawBitmapString(textFull + "_", 320, 10);
}
//--------------------------------------------------------------
void ofApp::keyPressed(int key) {
switch(key) {
case '1':
case '2':
case '3':
case '4':
case '5':
case '6':
case '7':
case '8':
case '9':
loadModelIndex(key-'1');
break;
case OF_KEY_DEL:
textLines = { "The" };
break;
case OF_KEY_RETURN:
doAutoRun ^= true;
break;
case OF_KEY_RIGHT:
doRunOnce = true;
doAutoRun = false;
break;
default:
doAutoRun = false;
if(charToInt.count(key) > 0) addChar(key);
break;
case '1': case '2': case '3': case '4': case '5':
case '6': case '7': case '8': case '9':
loadModelIndex(key-'1');
break;
case OF_KEY_DEL:
textLines = { "The" };
break;
case OF_KEY_RETURN:
doAutoRun ^= true;
break;
case OF_KEY_RIGHT:
doRunOnce = true;
doAutoRun = false;
break;
default:
doAutoRun = false;
if(charToInt.count(key) > 0) {addChar(key);}
break;
}
}
//--------------------------------------------------------------
void ofApp::update() {
......
......@@ -12,12 +12,12 @@
* This code has been developed at ZKM | Hertz-Lab as part of „The Intelligent
* Museum“ generously funded by the German Federal Cultural Foundation.
*
* This code is based on Memo Akten's ofxMSATensorFlow example.
* This code is based on Memo Akten's ofxMSATensorFlow example-char-rnn:
* https://github.com/memo/ofxMSATensorFlow
*/
/* Memo Akten writes:
/* From Memo Akten's ofxMSATensorFlow example */
/*
Generative character based Long Short-Term Memory (LSTM) Recurrent Neural Network (RNN) demo,
ala Karpathy's char-rnn(http://karpathy.github.io/2015/05/21/rnn-effectiveness/)
and Graves2013(https://arxiv.org/abs/1308.0850).
......@@ -39,14 +39,12 @@ So they're not great. A bit of hyperparameter tuning would give much better resu
The openframeworks code won't change at all, it'll just load the better model.
*/
#pragma once
#include "ofMain.h"
#include "ofxTensorFlow2.h"
#include <random>
//--------------------------------------------------------------
class ofApp : public ofBaseApp {
public:
......@@ -89,9 +87,9 @@ public:
std::map<int, char> charToInt;
// tensors in and out of model
cppflow::tensor t_dataIn; // data in
cppflow::tensor t_state; // current lstm state
std::vector<float> lastModelOutput; // probabilities
cppflow::tensor t_dataIn; // data in
cppflow::tensor t_state; // current lstm state
std::vector<float> lastModelOutput; // probabilities
// generated text
std::string textFull;
......@@ -103,7 +101,7 @@ public:
float sampleTemp = 0.5f;
// model file management
ofDirectory modelsDir; // data/models folder which contains subfolders for each model
ofDirectory modelsDir; // data/models folder which contains subfolders for each model
unsigned int curModelIndex = 0; // which model (i.e. folder) we're currently using
// random generator for sampling
......@@ -111,7 +109,6 @@ public:
// other vars
bool outputReady = false;
bool doAutoRun = true; // auto run every frame
bool doRunOnce = false; // only run one character
bool doAutoRun = true; // auto run every frame
bool doRunOnce = false; // only run one character
};
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment