Skip to content

Commit

Permalink
add option to save prediction to normalized or original data scale
Browse files Browse the repository at this point in the history
  • Loading branch information
zimenglyu committed Jul 29, 2024
1 parent b2c8a7b commit 61750c2
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 11 deletions.
24 changes: 17 additions & 7 deletions rnn/rnn.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ vector<double> RNN::get_predictions(
void RNN::write_predictions(
string output_filename, const vector<string>& input_parameter_names, const vector<string>& output_parameter_names,
const vector<vector<double> >& series_data, const vector<vector<double> >& expected_outputs,
TimeSeriesSets* time_series_sets, bool using_dropout, double dropout_probability
TimeSeriesSets* time_series_sets, bool using_dropout, double dropout_probability, bool normalize_predictions
) {
forward_pass(series_data, using_dropout, false, dropout_probability);

Expand Down Expand Up @@ -629,20 +629,30 @@ void RNN::write_predictions(
if (i > 0) {
outfile << ",";
}
outfile << series_data[i][j];
// outfile << time_series_sets->denormalize(input_parameter_names[i], series_data[i][j]);
if (normalize_predictions) {
outfile << series_data[i][j];
} else {
outfile << time_series_sets->denormalize(input_parameter_names[i], series_data[i][j]);
}
}

for (int32_t i = 0; i < (int32_t) output_nodes.size(); i++) {
outfile << ",";
outfile << expected_outputs[i][j];
// outfile << time_series_sets->denormalize(output_parameter_names[i], expected_outputs[i][j]);
if (normalize_predictions) {
outfile << expected_outputs[i][j];
} else {
outfile << time_series_sets->denormalize(output_parameter_names[i], expected_outputs[i][j]);
}

}

for (int32_t i = 0; i < (int32_t) output_nodes.size(); i++) {
outfile << ",";
outfile << output_nodes[i]->output_values[j];
// outfile << time_series_sets->denormalize(output_parameter_names[i], output_nodes[i]->output_values[j]);
if (normalize_predictions) {
outfile << output_nodes[i]->output_values[j];
} else {
outfile << time_series_sets->denormalize(output_parameter_names[i], output_nodes[i]->output_values[j]);
}
}
outfile << endl;
}
Expand Down
2 changes: 1 addition & 1 deletion rnn/rnn.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class RNN {
string output_filename, const vector<string>& input_parameter_names,
const vector<string>& output_parameter_names, const vector<vector<double> >& series_data,
const vector<vector<double> >& expected_outputs, TimeSeriesSets* time_series_sets, bool using_dropout,
double dropout_probability
double dropout_probability, bool normalize_predictions
);

void initialize_randomly();
Expand Down
4 changes: 2 additions & 2 deletions rnn/rnn_genome.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -1315,7 +1315,7 @@ vector<vector<double> > RNN_Genome::get_predictions(
void RNN_Genome::write_predictions(
string output_directory, const vector<string>& input_filenames, const vector<double>& parameters,
const vector<vector<vector<double> > >& inputs, const vector<vector<vector<double> > >& outputs,
TimeSeriesSets* time_series_sets
TimeSeriesSets* time_series_sets, bool normalize_predictions
) {
RNN* rnn = get_rnn();
rnn->set_weights(parameters);
Expand All @@ -1335,7 +1335,7 @@ void RNN_Genome::write_predictions(

rnn->write_predictions(
output_filename, input_parameter_names, output_parameter_names, inputs[i], outputs[i], time_series_sets,
use_dropout, dropout_probability
use_dropout, dropout_probability, normalize_predictions
);
}

Expand Down
2 changes: 1 addition & 1 deletion rnn/rnn_genome.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ class RNN_Genome {
void write_predictions(
string output_directory, const vector<string>& input_filenames, const vector<double>& parameters,
const vector<vector<vector<double> > >& inputs, const vector<vector<vector<double> > >& outputs,
TimeSeriesSets* time_series_sets
TimeSeriesSets* time_series_sets, bool normalize_predictions
);
// void write_predictions(string output_directory, const vector<string> &input_filenames, const vector<double>
// &parameters, const vector< vector< vector<double> > > &inputs, const vector< vector< vector<double> > > &outputs,
Expand Down
3 changes: 3 additions & 0 deletions rnn_examples/evaluate_rnn.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ int main(int argc, char** argv) {
int32_t time_offset = 1;
get_argument(arguments, "--time_offset", true, time_offset);

bool normalize_predictions = false;
get_argument(arguments, "--normalize_predictions", true, normalize_predictions);

time_series_sets->export_test_series(time_offset, testing_inputs, testing_outputs);

vector<double> best_parameters = genome->get_best_parameters();
Expand Down

0 comments on commit 61750c2

Please sign in to comment.