Skip to content

Commit

Permalink
speculative : fix handling of some input params (ggml-org#9963)
Browse files Browse the repository at this point in the history
* speculative : fix batch sizes at initialization

ggml-ci

* speculative : handle params.n_predict == -1

* speculative : limit batch size to llama_n_batch
  • Loading branch information
ggerganov authored and dsx1986 committed Oct 29, 2024
1 parent ec2a378 commit 44c0943
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ int main(int argc, char ** argv) {
return 1;
}

if (params.n_predict < -1) {
LOG_ERR("%s: --n-predict must be >= -1\n", __func__);
return 1;
}

common_init();

if (params.model_draft.empty()) {
Expand Down Expand Up @@ -190,8 +195,8 @@ int main(int argc, char ** argv) {
drafts[s].smpl = common_sampler_init(model_dft, params.sparams);
}

llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft);
llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft);

const auto t_dec_start = ggml_time_us();

Expand Down Expand Up @@ -441,7 +446,7 @@ int main(int argc, char ** argv) {
++n_past_dft;
}

if (n_predict > params.n_predict || has_eos) {
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
break;
}

Expand Down

0 comments on commit 44c0943

Please sign in to comment.