Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add embedding mode with arg flag. Currently working #282

Merged
merged 11 commits into from
Mar 24, 2023
32 changes: 29 additions & 3 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,18 @@ static bool llama_model_load(
return true;
}

// Prints the provided embedding vector to stdout
// in a neat format
void display_embedding(const std::vector<float> & embedding_representation){
fprintf(stdout, "\n[\n");
for (int j = 0; j < embedding_representation.size()-1 ; j++){
fprintf(stdout, "%f, ", embedding_representation[j]);
}
fprintf(stdout, "%f", embedding_representation[embedding_representation.size()-1]);
fprintf(stdout, "\n]\n");
}


// evaluate the transformer
//
// - lctx: llama context
Expand All @@ -610,7 +622,8 @@ static bool llama_eval_internal(
const llama_token * tokens,
const int n_tokens,
const int n_past,
const int n_threads) {
const int n_threads,
const bool embedding_mode = false) {
const int64_t t_start_us = ggml_time_us();

const int N = n_tokens;
Expand Down Expand Up @@ -798,6 +811,18 @@ static bool llama_eval_internal(
inpL);
}

if(embedding_mode){
// capture input sentence embedding
ggml_build_forward_expand(&gf, inpL);
ggml_graph_compute (ctx0, &gf);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to update and compute the graph in embedding mode?

std::vector<float> embedding_representation;
embedding_representation.resize(n_embd);
memcpy(embedding_representation.data(), (float *) ggml_get_data(inpL) + (n_embd * (N - 1)), sizeof(float) * n_embd);
display_embedding(embedding_representation);
ggml_free(ctx0);
return true;
}

// lm_head
{
inpL = ggml_mul_mat(ctx0, model.output, inpL);
Expand Down Expand Up @@ -1440,8 +1465,9 @@ int llama_eval(
const llama_token * tokens,
int n_tokens,
int n_past,
int n_threads) {
if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads)) {
int n_threads,
bool embedding_mode = false) {
if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, embedding_mode)) {
fprintf(stderr, "%s: failed to eval\n", __func__);
return 1;
}
Expand Down
3 changes: 2 additions & 1 deletion llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ extern "C" {
const llama_token * tokens,
int n_tokens,
int n_past,
int n_threads);
int n_threads,
bool embedding_mode);

// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
Expand Down
22 changes: 19 additions & 3 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
int end = start + params.n_ctx - 1;
std::vector<llama_token> embd(tokens.begin() + start, tokens.begin() + end);
auto start_t = std::chrono::high_resolution_clock::now();
if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) {
if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads, false)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
}
Expand Down Expand Up @@ -219,7 +219,7 @@ int main(int argc, char ** argv) {
// TODO: better way to do that
{
const std::vector<llama_token> tmp = { 0, 1, 2, 3 };
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads, false);
}

if (params.perplexity) {
Expand Down Expand Up @@ -285,6 +285,7 @@ int main(int argc, char ** argv) {

std::vector<llama_token> embd;


int last_n_size = params.repeat_last_n;
std::vector<llama_token> last_n_tokens(last_n_size);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
Expand Down Expand Up @@ -317,10 +318,25 @@ int main(int argc, char ** argv) {
// the first thing we will do is to output the prompt, so set color accordingly
set_console_state(CONSOLE_STATE_PROMPT);

if (params.embedding){
embd = embd_inp;
if (embd.size() > 0) {
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads, params.embedding)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
}

if (params.use_color) {
printf(ANSI_COLOR_RESET);
}
return 0;
}

while (remaining_tokens > 0 || params.interactive) {
// predict
if (embd.size() > 0) {
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads, params.embedding)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
Expand Down
5 changes: 5 additions & 0 deletions utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.model = argv[++i];
} else if (arg == "-i" || arg == "--interactive") {
params.interactive = true;
} else if (arg == "--embedding") {
params.embedding = true;
} else if (arg == "--interactive-start") {
params.interactive = true;
params.interactive_start = true;
} else if (arg == "-ins" || arg == "--instruct") {
params.instruct = true;
} else if (arg == "--color") {
Expand Down
2 changes: 2 additions & 0 deletions utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@ struct gpt_params {
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
std::string prompt = "";


std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted

bool memory_f16 = false; // use f16 instead of f32 for memory kv
bool random_prompt = false; // do not randomize prompt if none provided
bool use_color = false; // use color to distinguish generations and inputs
bool interactive = false; // interactive mode
bool embedding = false; // get only sentence embedding
bool interactive_start = false; // reverse prompt immediately
bool instruct = false; // instruction mode (used for Alpaca models)
bool ignore_eos = false; // do not stop generating after eos
Expand Down