Skip to content

Commit

Permalink
Fix wrong n_trial number in autotvm tutorials' progress bar (apache#4070
Browse files Browse the repository at this point in the history
)

if n_trial is larger then config space.
  • Loading branch information
dati91 authored and Animesh Jain committed Oct 17, 2019
1 parent 7d29745 commit 66a5c6a
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 7 deletions.
3 changes: 2 additions & 1 deletion nnvm/tutorials/tune_nnvm_arm.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,8 @@ def tune_tasks(tasks,
tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))

# do tuning
tuner_obj.tune(n_trial=min(n_trial, len(tsk.config_space)),
n_trial = min(n_trial, len(tsk.config_space))
tuner_obj.tune(n_trial=n_trial,
early_stopping=early_stopping,
measure_option=measure_option,
callbacks=[
Expand Down
3 changes: 2 additions & 1 deletion nnvm/tutorials/tune_nnvm_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ def tune_tasks(tasks,
tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))

# do tuning
tuner_obj.tune(n_trial=min(n_trial, len(tsk.config_space)),
n_trial = min(n_trial, len(tsk.config_space))
tuner_obj.tune(n_trial=n_trial,
early_stopping=early_stopping,
measure_option=measure_option,
callbacks=[
Expand Down
3 changes: 2 additions & 1 deletion nnvm/tutorials/tune_nnvm_mobile_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ def tune_tasks(tasks,
tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))

# do tuning
tuner_obj.tune(n_trial=min(n_trial, len(tsk.config_space)),
n_trial = min(n_trial, len(tsk.config_space))
tuner_obj.tune(n_trial=n_trial,
early_stopping=early_stopping,
measure_option=measure_option,
callbacks=[
Expand Down
3 changes: 2 additions & 1 deletion tutorials/autotvm/tune_relay_arm.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,8 @@ def tune_tasks(tasks,
tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))

# do tuning
tuner_obj.tune(n_trial=min(n_trial, len(tsk.config_space)),
n_trial = min(n_trial, len(tsk.config_space))
tuner_obj.tune(n_trial=n_trial,
early_stopping=early_stopping,
measure_option=measure_option,
callbacks=[
Expand Down
3 changes: 2 additions & 1 deletion tutorials/autotvm/tune_relay_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ def tune_tasks(tasks,
tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))

# do tuning
tuner_obj.tune(n_trial=min(n_trial, len(tsk.config_space)),
n_trial = min(n_trial, len(tsk.config_space))
tuner_obj.tune(n_trial=n_trial,
early_stopping=early_stopping,
measure_option=measure_option,
callbacks=[
Expand Down
3 changes: 2 additions & 1 deletion tutorials/autotvm/tune_relay_mobile_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ def tune_tasks(tasks,
tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))

# do tuning
tuner_obj.tune(n_trial=min(n_trial, len(tsk.config_space)),
n_trial = min(n_trial, len(tsk.config_space))
tuner_obj.tune(n_trial=n_trial,
early_stopping=early_stopping,
measure_option=measure_option,
callbacks=[
Expand Down
3 changes: 2 additions & 1 deletion vta/tutorials/autotvm/tune_relay_vta.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ def tune_tasks(tasks,
tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))

# do tuning
tuner_obj.tune(n_trial=min(n_trial, len(tsk.config_space)),
n_trial = min(n_trial, len(tsk.config_space))
tuner_obj.tune(n_trial=n_trial,
early_stopping=early_stopping,
measure_option=measure_option,
callbacks=[
Expand Down

0 comments on commit 66a5c6a

Please sign in to comment.