Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add "stop" keywords as alternative to eot token #769

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.antiprompt.push_back(argv[i]);
} else if (arg == "--stop") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.stop_keywords.push_back(argv[i]);
} else if (arg == "--perplexity") {
params.perplexity = true;
} else if (arg == "--ignore-eos") {
Expand Down Expand Up @@ -209,8 +215,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " --interactive-first run in interactive mode and wait for input right away\n");
fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n");
fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n");
fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n");
fprintf(stderr, " specified more than once for multiple prompts).\n");
fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT\n");
fprintf(stderr, " (can be specified more than once for multiple reverse prompts).\n");
fprintf(stderr, " --stop KEYWORD a string that, when output by the model, will stop generation\n");
fprintf(stderr, " (can be specified more than once for multiple keywords).\n");
fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for <= 0)\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
Expand Down
1 change: 1 addition & 0 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct gpt_params {


std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::vector<std::string> stop_keywords; // string upon seeing which the model will stop

bool memory_f16 = true; // use f16 instead of f32 for memory kv
bool random_prompt = false; // do not randomize prompt if none provided
Expand Down
46 changes: 43 additions & 3 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,13 @@ int main(int argc, char ** argv) {
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
}
}

if (params.stop_keywords.size()) {
for (auto stop_keyword : params.stop_keywords) {
fprintf(stderr, "Stop keyword: '%s'\n", stop_keyword.c_str());
}
}

fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
Expand Down Expand Up @@ -344,13 +351,28 @@ int main(int argc, char ** argv) {
// check if we should prompt the user for more
if (params.interactive && (int) embd_inp.size() <= n_consumed) {

// check for reverse prompt
if (params.antiprompt.size()) {
std::string last_output;
std::string last_output;
if (params.antiprompt.size() || params.stop_keywords.size()) {
for (auto id : last_n_tokens) {
last_output += llama_token_to_str(ctx, id);
}
}

// Check for stop keywords, a configurable alternative to the end-of-text token
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should check the token id instead of the string for stop.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For context, the precedent set by #330 is to check for the string (in reverse prompts). I think checking for tokens caused a bug, or at least unintuitive behavior (#292).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, this should work in the same way as "antiprompts" to provide a better UX to users. Users should be able to add to the CLI parameters --stop "### Assistant" (for example, in the spirit of the trained vicuna model) or --stop PAUSE (for example, to implement Simon Willinson's ReAct Python example), even though these are multi-token markers.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we know why? It will be a good learning for me to understand why token will not work? Because one stop string can be generated by the different tokens??

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, see #292 (comment)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also noticed that the stop words weren't always consistently formatted. Sometimes it would do STOP instead of stop. I'd appreciate a normalize function to down case before comparing.

// This should stop also the interactive mode, useful to stop interactive mode without SIGTERM
bool stop = false;
for (std::string stop_keyword : params.stop_keywords) {
if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) {
stop = true;
break;
}
}
if (stop) {
break;
}

// check for reverse prompt
if (params.antiprompt.size()) {
is_antiprompt = false;
// Check if each of the reverse prompts appears at the end of the output.
for (std::string & antiprompt : params.antiprompt) {
Expand Down Expand Up @@ -430,6 +452,24 @@ int main(int argc, char ** argv) {
}
}

// Check for stop keywords, a configurable alternative to the end-of-text token
if (!params.interactive && params.stop_keywords.size() && !is_interacting) {
std::string last_output;
for (auto id : last_n_tokens) {
last_output += llama_token_to_str(ctx, id);
}
bool stop = false;
for (std::string stop_keyword : params.stop_keywords) {
if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) {
stop = true;
break;
}
}
if (stop) {
break;
}
}

// end of text token
if (!embd.empty() && embd.back() == llama_token_eos()) {
if (params.instruct) {
Expand Down