Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic TopK implementation #744

Merged
merged 11 commits into from
Apr 25, 2017
Merged

Generic TopK implementation #744

merged 11 commits into from
Apr 25, 2017

Conversation

killeent
Copy link
Contributor

@killeent killeent commented Apr 7, 2017

This PR makes TopK generic. The main changes required to do this were:

  1. Extend the float --> unsigned integer bijective function to all other Tensor types
  2. Make it possible to act on data types whose underlying representation requires 64bits

@killeent
Copy link
Contributor Author

@soumith @wickedfoo This is ready.

@killeent killeent changed the title WIP: Generic TopK implementation Generic TopK implementation Apr 11, 2017
Copy link
Contributor

@wickedfoo wickedfoo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks good to me.

typedef unsigned int RadixType;

static inline __device__ RadixType convert(short v) {
return 32768u + v;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a static assert that sizeof(short) == 2 to ensure that this is a correct constant

typedef unsigned int RadixType;

static inline __device__ RadixType convert(int v) {
return 2147483648u + v;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a static assert that sizeof(int) == 4 to ensure that this is a correct constant

typedef unsigned long long int RadixType;

static inline __device__ RadixType convert(long v) {
return 9223372036854775808ull + v;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a static assert that sizeof(long) == 8 to ensure that this is a correct constant

@soumith soumith merged commit 93a6864 into torch:master Apr 25, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants