Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add embedding mode with arg flag. Currently working #282

Merged
merged 11 commits into from
Mar 24, 2023
48 changes: 43 additions & 5 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ struct llama_context {

// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
// input embedding (1-dimensional array: [n_embd])
std::vector<float> embedding;
bool logits_all = false;
};

Expand All @@ -112,6 +114,7 @@ struct llama_context_params llama_context_default_params() {
/*.f16_kv =*/ false,
/*.logits_all =*/ false,
/*.vocab_only =*/ false,
/*.embedding =*/ false,
};

return result;
Expand All @@ -127,7 +130,8 @@ static bool llama_model_load(
int n_ctx,
int n_parts,
ggml_type memory_type,
bool vocab_only) {
bool vocab_only,
bool embedding) {
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());

const int64_t t_start_us = ggml_time_us();
Expand Down Expand Up @@ -594,11 +598,27 @@ static bool llama_model_load(

lctx.logits.reserve(lctx.model.hparams.n_ctx);

if (embedding){
lctx.embedding.reserve(lctx.model.hparams.n_embd);
}

lctx.t_load_us = ggml_time_us() - t_start_us;

return true;
}

// Prints the provided embedding vector to stdout
// in a neat format
void display_embedding(const std::vector<float> & embedding_representation){
fprintf(stdout, "\n[\n");
for (int j = 0; j < embedding_representation.size()-1 ; j++){
fprintf(stdout, "%f, ", embedding_representation[j]);
}
fprintf(stdout, "%f", embedding_representation[embedding_representation.size()-1]);
fprintf(stdout, "\n]\n");
}


// evaluate the transformer
//
// - lctx: llama context
Expand All @@ -611,7 +631,8 @@ static bool llama_eval_internal(
const llama_token * tokens,
const int n_tokens,
const int n_past,
const int n_threads) {
const int n_threads,
const bool embedding_mode = false) {
const int64_t t_start_us = ggml_time_us();

const int N = n_tokens;
Expand Down Expand Up @@ -799,6 +820,18 @@ static bool llama_eval_internal(
inpL);
}

if(embedding_mode){
// capture input sentence embedding
ggml_build_forward_expand(&gf, inpL);
ggml_graph_compute (ctx0, &gf);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to update and compute the graph in embedding mode?

std::vector<float> embedding_representation;
embedding_representation.resize(n_embd);
memcpy(embedding_representation.data(), (float *) ggml_get_data(inpL) + (n_embd * (N - 1)), sizeof(float) * n_embd);
display_embedding(embedding_representation);
ggml_free(ctx0);
return true;
}

// lm_head
{
inpL = ggml_mul_mat(ctx0, model.output, inpL);
Expand Down Expand Up @@ -1408,7 +1441,7 @@ struct llama_context * llama_init_from_file(

ggml_type type_memory = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;

if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only)) {
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only, params.embedding)) {
fprintf(stderr, "%s: failed to load model\n", __func__);
delete ctx;
return nullptr;
Expand Down Expand Up @@ -1441,8 +1474,9 @@ int llama_eval(
const llama_token * tokens,
int n_tokens,
int n_past,
int n_threads) {
if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads)) {
int n_threads,
bool embedding_mode = false) {
if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, embedding_mode)) {
fprintf(stderr, "%s: failed to eval\n", __func__);
return 1;
}
Expand Down Expand Up @@ -1482,6 +1516,10 @@ float * llama_get_logits(struct llama_context * ctx) {
return ctx->logits.data();
}

float * llama_get_embeddings(struct llama_context * ctx) {
return ctx->embedding.data();
}

const char * llama_token_to_str(struct llama_context * ctx, llama_token token) {
if (token >= llama_n_vocab(ctx)) {
return nullptr;
Expand Down
8 changes: 7 additions & 1 deletion llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ extern "C" {
bool f16_kv; // use fp16 for KV cache
bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool vocab_only; // only load the vocabulary, no weights
bool embedding; // embedding mode only
};

LLAMA_API struct llama_context_params llama_context_default_params();
Expand Down Expand Up @@ -84,7 +85,8 @@ extern "C" {
const llama_token * tokens,
int n_tokens,
int n_past,
int n_threads);
int n_threads,
bool embedding_mode);

// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
Expand All @@ -108,6 +110,10 @@ extern "C" {
// Cols: n_vocab
LLAMA_API float * llama_get_logits(struct llama_context * ctx);

// Get the embeddings for the input
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx)

// Token Id -> String. Uses the vocabulary in the provided context
LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);

Expand Down
23 changes: 20 additions & 3 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
int end = start + params.n_ctx - 1;
std::vector<llama_token> embd(tokens.begin() + start, tokens.begin() + end);
auto start_t = std::chrono::high_resolution_clock::now();
if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) {
if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads, false)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
}
Expand Down Expand Up @@ -199,6 +199,7 @@ int main(int argc, char ** argv) {
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.logits_all = params.perplexity;
lparams.embedding = params.embedding;

ctx = llama_init_from_file(params.model.c_str(), lparams);

Expand All @@ -219,7 +220,7 @@ int main(int argc, char ** argv) {
// TODO: better way to do that
{
const std::vector<llama_token> tmp = { 0, 1, 2, 3 };
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads, false);
}

if (params.perplexity) {
Expand Down Expand Up @@ -289,6 +290,7 @@ int main(int argc, char ** argv) {

std::vector<llama_token> embd;


int last_n_size = params.repeat_last_n;
std::vector<llama_token> last_n_tokens(last_n_size);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
Expand Down Expand Up @@ -321,10 +323,25 @@ int main(int argc, char ** argv) {
// the first thing we will do is to output the prompt, so set color accordingly
set_console_state(CONSOLE_STATE_PROMPT);

if (params.embedding){
embd = embd_inp;
if (embd.size() > 0) {
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads, params.embedding)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
}

if (params.use_color) {
printf(ANSI_COLOR_RESET);
}
return 0;
}

while (remaining_tokens > 0 || params.interactive) {
// predict
if (embd.size() > 0) {
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads, params.embedding)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
Expand Down
4 changes: 4 additions & 0 deletions utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.model = argv[++i];
} else if (arg == "-i" || arg == "--interactive") {
params.interactive = true;
} else if (arg == "--embedding") {
params.embedding = true;
} else if (arg == "--interactive-start") {
params.interactive = true;
} else if (arg == "--interactive-first") {
params.interactive_start = true;
} else if (arg == "-ins" || arg == "--instruct") {
Expand Down
4 changes: 4 additions & 0 deletions utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,17 @@ struct gpt_params {
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
std::string prompt = "";


std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted

bool memory_f16 = false; // use f16 instead of f32 for memory kv
bool random_prompt = false; // do not randomize prompt if none provided
bool use_color = false; // use color to distinguish generations and inputs
bool interactive = false; // interactive mode

bool embedding = false; // get only sentence embedding
bool interactive_start = false; // wait for user input immediately

bool instruct = false; // instruction mode (used for Alpaca models)
bool ignore_eos = false; // do not stop generating after eos
bool perplexity = false; // compute perplexity over the prompt
Expand Down