Skip to content

Commit

Permalink
some quick hacks
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Dec 3, 2018
1 parent 9895787 commit d92f41e
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
1 change: 1 addition & 0 deletions python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def visit_func(e):
scale = power2_scale(arr)
else:
scale = cfg.global_scale

const_params[ndom_scale] = _scalar(scale / 2**valid_bit, 'float32')
const_params[nbit] = _scalar(field_bit, 'int32')
const_params[nclip_min] = _scalar(- (2**valid_bit - 1), 'float32')
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/quantize/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def conv2d_rewrite(ref_call, new_args, ctx):
if cfg.counter < cfg.skip_k_conv:
cfg.counter += 1
return None
if cfg.counter > cfg.skip_k_conv:
return None
#if cfg.counter > cfg.skip_k_conv:
#return None
cfg.counter += 1

lhs, rhs = map(_prepare, new_args)
Expand Down
6 changes: 4 additions & 2 deletions src/relay/pass/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ Expr QuantizeQStateRewrite(const Call& ref_call,
data = Cast(data, Float(32));
Expr scaled_data = Multiply(data, Divide(n->dom_scale, dom_scale));
Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
round_data = Cast(round_data, Int(32));
return QIntStateNode::make(round_data, dom_scale, bit_imm, Float(32));
}
} else if (const auto* n = new_args[0].as<QRealStateNode>()) {
Expand Down Expand Up @@ -255,8 +256,7 @@ Expr Conv2dQStateRewrite(const Call& ref_call,
const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 2);
if (!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>()) {
Expr ret = ForwardOp(ref_call, new_args);
return ret;
return Expr(nullptr);
}
CHECK(new_args[0].as<QIntStateNode>() && new_args[1].as<QIntStateNode>());

Expand Down Expand Up @@ -287,6 +287,7 @@ RELAY_REGISTER_OP("nn.conv2d")
Expr MulQStateRewrite(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
return Expr(nullptr);
CHECK_EQ(new_args.size(), 2);
bool int_domain = (new_args[0].as<QIntStateNode>() || new_args[1].as<QIntStateNode>());
if (int_domain) {
Expand Down Expand Up @@ -404,6 +405,7 @@ Expr AddQStateRewrite(const Call& ref_call,
Expr ReluQStateRewrite(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
return Expr(nullptr);
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QIntStateNode>()) {
Expr ret = ForwardOp(ref_call, {n->data});
Expand Down

0 comments on commit d92f41e

Please sign in to comment.