Skip to content

Commit

Permalink
PR #17544: Fix ConvNeXt classifier activation bug
Browse files Browse the repository at this point in the history
Imported from GitHub PR #17544

Fixes the issue #17386.

Please find the bug, fix, and test [gist here](https://colab.research.google.com/gist/Frightera/93a784fc0e03a19f5804e7e14bfbecba/-17386.ipynb).
Copybara import of the project:

--
eedabb6 by Kaan <[email protected]>:

Pass classifier_activation arg to "Head"
--
e226091 by Kaan <[email protected]>:

Add classifier_activation to the docstring
--
3856fed by Kaan <[email protected]>:

Add classifier_activation unit test
--
c3dfc34 by Kaan <[email protected]>:

Move classifier_activation validation before head creation
--
71eaa69 by Kaan <[email protected]>:

Update test_application_classifier_activation
--
ebd6940 by Kaan Bıçakcı <[email protected]>:

Reformatting using format.sh

--
bce4ac9 by Kaan Bıçakcı <[email protected]>:

Fix test_application_classifier_activation test

--
00d4889 by Kaan Bıçakcı <[email protected]>:

Fix Head params to accept classifier_activation

--
2c22d37 by Kaan Bıçakcı <[email protected]>:

Revert "Fix Head params to accept classifier_activation"

This reverts commit 00d4889.

--
3abd441 by Kaan Bıçakcı <[email protected]>:

Exclude RegNet in test_application_classifier_activation

Merging this change closes #17544

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17544 from Frightera:frightera_fix_17386 3abd441
PiperOrigin-RevId: 509550999
  • Loading branch information
tensorflower-gardener committed Feb 14, 2023
1 parent 5bc61d2 commit 34feca3
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
10 changes: 10 additions & 0 deletions keras/applications/applications_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,16 @@ def test_application_pooling(self, app, last_dim):
)
self.assertShapeEqual(output_shape, (None, last_dim))

@parameterized.parameters(MODEL_LIST)
def test_application_classifier_activation(self, app, _):
if "RegNet" in app.__name__:
self.skipTest("RegNet models do not support classifier activation")
model = app(
weights=None, include_top=True, classifier_activation="softmax"
)
last_layer_act = model.layers[-1].activation.__name__
self.assertEqual(last_layer_act, "softmax")

@parameterized.parameters(*MODEL_LIST_NO_NASNET)
def test_application_variable_input_channels(self, app, last_dim):
if backend.image_data_format() == "channels_first":
Expand Down
15 changes: 12 additions & 3 deletions keras/applications/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,12 @@ def apply(x):
return apply


def Head(num_classes=1000, name=None):
def Head(num_classes=1000, classifier_activation=None, name=None):
"""Implementation of classification head of RegNet.
Args:
num_classes: number of classes for Dense layer
classifier_activation: activation function for the Dense layer
name: name prefix
Returns:
Expand All @@ -342,7 +343,11 @@ def apply(x):
x = layers.LayerNormalization(
epsilon=1e-6, name=name + "_head_layernorm"
)(x)
x = layers.Dense(num_classes, name=name + "_head_dense")(x)
x = layers.Dense(
num_classes,
activation=classifier_activation,
name=name + "_head_dense",
)(x)
return x

return apply
Expand Down Expand Up @@ -522,8 +527,12 @@ def ConvNeXt(
cur += depths[i]

if include_top:
x = Head(num_classes=classes, name=model_name)(x)
imagenet_utils.validate_activation(classifier_activation, weights)
x = Head(
num_classes=classes,
classifier_activation=classifier_activation,
name=model_name,
)(x)

else:
if pooling == "avg":
Expand Down

0 comments on commit 34feca3

Please sign in to comment.