Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

MXNetError: unknown type for MKLDNN :2 when training Mask RCNN with mxnet-cu101==1.7.0 #19631

Closed
karan6181 opened this issue Dec 4, 2020 · 8 comments
Labels
Bug MKLDNN v1.x Targeting v1.x branch

Comments

@karan6181
Copy link
Contributor

Description

  • The GluonCV Mask RCNN script with and without horovod fails with MXNetError: unknown type for MKLDNN :2 issue using mxnet-cu101==1.7.0

Error Message

Traceback (most recent call last):
  File "/shared/mx_170_mkl_env/lib/python3.8/multiprocessing/pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
  File "/shared/mx_170_mkl_env/lib/python3.8/site-packages/mxnet/gluon/data/dataloader.py", line 429, in _worker_fn
    batch = batchify_fn([_worker_dataset[i] for i in samples])
  File "/shared/mx_170_mkl_env/lib/python3.8/site-packages/mxnet/gluon/data/dataloader.py", line 429, in <listcomp>
    batch = batchify_fn([_worker_dataset[i] for i in samples])
  File "/shared/mx_170_mkl_env/lib/python3.8/site-packages/mxnet/gluon/data/dataset.py", line 219, in __getitem__
    return self._fn(*item)
  File "/shared/mx_170_mkl_env/lib/python3.8/site-packages/gluoncv-0.8.0-py3.8-linux-x86_64.egg/gluoncv/data/transforms/presets/rcnn.py", line 407, in __call__
    cls_target, box_target, box_mask = self._target_generator(
  File "/shared/mx_170_mkl_env/lib/python3.8/site-packages/mxnet/gluon/block.py", line 747, in __call__
    out = self.forward(*args)
  File "/shared/mx_170_mkl_env/lib/python3.8/site-packages/gluoncv-0.8.0-py3.8-linux-x86_64.egg/gluoncv/model_zoo/rcnn/rpn/rpn_target.py", line 157, in forward
    ious = mx.nd.contrib.box_iou(anchor, bbox, format='corner').asnumpy()
  File "/shared/mx_170_mkl_env/lib/python3.8/site-packages/mxnet/ndarray/ndarray.py", line 2563, in asnumpy
    check_call(_LIB.MXNDArraySyncCopyToCPU(
  File "/shared/mx_170_mkl_env/lib/python3.8/site-packages/mxnet/base.py", line 246, in check_call
    raise get_last_ffi_error()
mxnet.base.MXNetError: Traceback (most recent call last):
  File "src/ndarray/./../operator/tensor/.././../common/../operator/nn/mkldnn/mkldnn_base-inl.h", line 246
MXNetError: unknown type for MKLDNN :2
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/shared/gluoncv_master/scripts/instance/mask_rcnn/train_mask_rcnn.py", line 737, in <module>
    train(net, train_data, val_data, eval_metric, batch_size, ctx, logger, args)
  File "/shared/gluoncv_master/scripts/instance/mask_rcnn/train_mask_rcnn.py", line 559, in train
    next_data_batch = next(train_data_iter)
  File "/shared/mx_170_mkl_env/lib/python3.8/site-packages/mxnet/gluon/data/dataloader.py", line 484, in __next__
    batch = pickle.loads(ret.get(self._timeout))
  File "/shared/mx_170_mkl_env/lib/python3.8/multiprocessing/pool.py", line 771, in get
    raise self._value
mxnet.base.MXNetError: Traceback (most recent call last):
  File "src/ndarray/./../operator/tensor/.././../common/../operator/nn/mkldnn/mkldnn_base-inl.h", line 246
MXNetError: unknown type for MKLDNN :2

GluonCV: v0.9.0
Horovod: v0.21.0

To Reproduce

Without Horovod

python gluon-cv/scripts/instance/mask_rcnn/train_mask_rcnn.py --gpus 0,1,2,3,4,5,6,7 --num-workers 4 --amp --lr-decay-epoch 8,10 --epochs 6 --log-interval 10 --val-interval 12 --batch-size 8 --use-fpn --lr 0.01 --lr-warmup-factor 0.001 --lr-warmup 1600 --static-alloc --clip-gradient 1.5 --use-ext --seed 987

Full log: https://gist.github.com/karan6181/efa4ad8f61c3e21cbee9c55fea98b2f0

With Horovod

horovodrun -np 8 -H localhost:8 python gluon-cv/scripts/instance/mask_rcnn/train_mask_rcnn.py --horovod --num-workers 4 --amp --lr-decay-epoch 8,10 --epochs 6 --log-interval 10 --val-interval 12 --batch-size 8 --use-fpn --lr 0.01 --lr-warmup-factor 0.001 --lr-warmup 1600 --static-alloc --clip-gradient 1.5 --use-ext --seed 987

Environment

We recommend using our script for collecting the diagnostic information with the following command
curl --retry 10 -s https://raw.githubusercontent.com/apache/incubator-mxnet/master/tools/diagnose.py | python3

----------Python Info----------
Version      : 3.8.5
Compiler     : GCC 7.3.0
Build        : ('default', 'Sep  4 2020 07:30:14')
Arch         : ('64bit', 'ELF')
------------Pip Info-----------
Version      : 20.2.4
Directory    : /shared/mx_170_mkl_env/lib/python3.8/site-packages/pip
----------MXNet Info-----------
Version      : 1.7.0
Directory    : /shared/mx_170_mkl_env/lib/python3.8/site-packages/mxnet
Commit Hash   : 64f737cdd59fe88d2c5b479f25d011c5156b6a8a
64f737cdd59fe88d2c5b479f25d011c5156b6a8a
64f737cdd59fe88d2c5b479f25d011c5156b6a8a
64f737cdd59fe88d2c5b479f25d011c5156b6a8a
64f737cdd59fe88d2c5b479f25d011c5156b6a8a
64f737cdd59fe88d2c5b479f25d011c5156b6a8a
64f737cdd59fe88d2c5b479f25d011c5156b6a8a
64f737cdd59fe88d2c5b479f25d011c5156b6a8a
64f737cdd59fe88d2c5b479f25d011c5156b6a8a
64f737cdd59fe88d2c5b479f25d011c5156b6a8a
Library      : ['/shared/mx_170_mkl_env/lib/python3.8/site-packages/mxnet/libmxnet.so']
Build features:
✔ CUDA
✔ CUDNN
✔ NCCL
✔ CUDA_RTC
✖ TENSORRT
✔ CPU_SSE
✔ CPU_SSE2
✔ CPU_SSE3
✔ CPU_SSE4_1
✔ CPU_SSE4_2
✖ CPU_SSE4A
✔ CPU_AVX
✖ CPU_AVX2
✔ OPENMP
✖ SSE
✔ F16C
✖ JEMALLOC
✔ BLAS_OPEN
✖ BLAS_ATLAS
✖ BLAS_MKL
✖ BLAS_APPLE
✔ LAPACK
✔ MKLDNN
✔ OPENCV
✖ CAFFE
✖ PROFILER
✔ DIST_KVSTORE
✖ CXX14
✖ INT64_TENSOR_SIZE
✔ SIGNAL_HANDLER
✖ DEBUG
✖ TVM_OP
----------System Info----------
Platform     : Linux-4.15.0-1060-aws-x86_64-with-glibc2.10
system       : Linux
node         : ip-192-168-70-159
release      : 4.15.0-1060-aws
version      : #62-Ubuntu SMP Tue Feb 11 21:23:22 UTC 2020
----------Hardware Info----------
machine      : x86_64
processor    : x86_64
Architecture:        x86_64
CPU op-mode(s):      32-bit, 64-bit
Byte Order:          Little Endian
CPU(s):              96
On-line CPU(s) list: 0-95
Thread(s) per core:  2
Core(s) per socket:  24
Socket(s):           2
NUMA node(s):        2
Vendor ID:           GenuineIntel
CPU family:          6
Model:               85
Model name:          Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz
Stepping:            4
CPU MHz:             1200.134
BogoMIPS:            4999.99
Hypervisor vendor:   KVM
Virtualization type: full
L1d cache:           32K
L1i cache:           32K
L2 cache:            1024K
L3 cache:            33792K
NUMA node0 CPU(s):   0-23,48-71
NUMA node1 CPU(s):   24-47,72-95
Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke
----------Network Test----------
Setting timeout: 10
Timing for MXNet: https://github.com/apache/incubator-mxnet, DNS: 0.0025 sec, LOAD: 0.4890 sec.
Timing for Gluon Tutorial(en): http://gluon.mxnet.io, DNS: 0.0145 sec, LOAD: 0.0717 sec.
Error open Gluon Tutorial(cn): https://zh.gluon.ai, <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1123)>, DNS finished in 0.20147395133972168 sec.
Timing for FashionMNIST: https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/fashion-mnist/train-labels-idx1-ubyte.gz, DNS: 0.0093 sec, LOAD: 0.4630 sec.
Timing for PYPI: https://pypi.python.org/pypi/pip, DNS: 0.0033 sec, LOAD: 0.0823 sec.
Error open Conda: https://repo.continuum.io/pkgs/free/, HTTP Error 403: Forbidden, DNS finished in 0.0021767616271972656 sec.
----------Environment----------
KMP_DUPLICATE_LIB_OK="True"
@samskalicky
Copy link
Contributor

File "src/ndarray/./../operator/tensor/.././../common/../operator/nn/mkldnn/mkldnn_base-inl.h", line 246
MXNetError: unknown type for MKLDNN :2

If you look at this file:
https://github.com/apache/incubator-mxnet/blob/64f737cdd59fe88d2c5b479f25d011c5156b6a8a/src/operator/nn/mkldnn/mkldnn_base-inl.h#L234-L247

You'll see its running dtype checking and doesnt find a case in the switch for value 2. If you go look at what value 2 is:
https://github.com/apache/incubator-mxnet/blob/64f737cdd59fe88d2c5b479f25d011c5156b6a8a/3rdparty/mshadow/mshadow/base.h#L351-L354

You'll see that its kFloat16. So you're script is somehow using FP16 and MKLDNN doesnt support that type.

@karan6181
Copy link
Contributor Author

Thank You @samskalicky . Mask RCNN does use AMP and it casts the weights and gradients to FP16 here https://github.com/dmlc/gluon-cv/blob/master/scripts/instance/mask_rcnn/train_mask_rcnn.py#L705

@karan6181
Copy link
Contributor Author

I also tried running Mask RCNN script on single node using mxnet-cu101mkl==1.6.0.post0 with gluon-cv==0.8.0 and I was able to run it successfully without any issue.

Below is the output from the run:

python gluon-cv/scripts/instance/mask_rcnn/train_mask_rcnn.py --gpus 0,1,2,3,4,5,6,7 --num-workers 4 --amp --lr-decay-epoch 8,10 --epochs 6 --log-interval 10 --val-interval 12 --batch-size 8 --use-fpn --lr 0.01 --lr-warmup-factor 0.001 --lr-warmup 1600 --static-alloc --clip-gradient 1.5 --use-ext --seed 987
/shared/mx_oob_env/lib/python3.8/site-packages/mxnet/gluon/block.py:1389: UserWarning: Cannot decide type for the following arguments. Consider providing them as input:
	data: None
  input_sym_arg_type = in_param.infer_type()[0]
[23:15:28] src/storage/storage.cc:110: Using GPUPooledRoundedStorageManager.
[23:15:31] src/storage/storage.cc:110: Using GPUPooledRoundedStorageManager.
[23:15:34] src/storage/storage.cc:110: Using GPUPooledRoundedStorageManager.
[23:15:36] src/storage/storage.cc:110: Using GPUPooledRoundedStorageManager.
[23:15:39] src/storage/storage.cc:110: Using GPUPooledRoundedStorageManager.
[23:15:41] src/storage/storage.cc:110: Using GPUPooledRoundedStorageManager.
[23:15:44] src/storage/storage.cc:110: Using GPUPooledRoundedStorageManager.
[23:15:46] src/storage/storage.cc:110: Using GPUPooledRoundedStorageManager.
loading annotations into memory...
Done (t=14.20s)
creating index...
index created!
loading annotations into memory...
Done (t=0.39s)
creating index...
index created!
creating index...
/shared/mx_oob_env/lib/python3.8/site-packages/mxnet/gluon/parameter.py:701: UserWarning: Constant parameter "maskrcnn0_rpn0_rpnanchorgenerator0_anchor_" does not support grad_req other than "null", and new value "write" is ignored.
  warnings.warn('Constant parameter "{}" does not support '
/shared/mx_oob_env/lib/python3.8/site-packages/mxnet/gluon/parameter.py:701: UserWarning: Constant parameter "maskrcnn0_rpn0_rpnanchorgenerator1_anchor_" does not support grad_req other than "null", and new value "write" is ignored.
  warnings.warn('Constant parameter "{}" does not support '
/shared/mx_oob_env/lib/python3.8/site-packages/mxnet/gluon/parameter.py:701: UserWarning: Constant parameter "maskrcnn0_rpn0_rpnanchorgenerator2_anchor_" does not support grad_req other than "null", and new value "write" is ignored.
  warnings.warn('Constant parameter "{}" does not support '
/shared/mx_oob_env/lib/python3.8/site-packages/mxnet/gluon/parameter.py:701: UserWarning: Constant parameter "maskrcnn0_rpn0_rpnanchorgenerator3_anchor_" does not support grad_req other than "null", and new value "write" is ignored.
  warnings.warn('Constant parameter "{}" does not support '
/shared/mx_oob_env/lib/python3.8/site-packages/mxnet/gluon/parameter.py:701: UserWarning: Constant parameter "maskrcnn0_rpn0_rpnanchorgenerator4_anchor_" does not support grad_req other than "null", and new value "write" is ignored.
  warnings.warn('Constant parameter "{}" does not support '
INFO:root:Namespace(amp=True, batch_size=8, clip_gradient=1.5, custom_model=None, dataset='coco', disable_hybridization=False, epochs=6, executor_threads=1, gpus='0,1,2,3,4,5,6,7', horovod=False, kv_store='device', log_interval=10, lr=0.01, lr_decay=0.1, lr_decay_epoch='8,10', lr_warmup='1600', lr_warmup_factor=0.001, momentum=0.9, network='resnet50_v1b', norm_layer=None, num_workers=4, rcnn_smoothl1_rho=1.0, resume='', rpn_smoothl1_rho=0.1111111111111111, save_interval=1, save_prefix='mask_rcnn_fpn_resnet50_v1b_coco', seed=987, start_epoch=0, static_alloc=True, train_datapath='/scratch/data/mask_rcnn/mxnet/', use_ext=True, use_fpn=True, val_datapath='/scratch/data/mask_rcnn/mxnet/', val_interval=12, verbose=False, wd=0.0001)
INFO:root:Start training from [Epoch 0]
INFO:root:[Epoch 0 Iteration 0] Set learning rate to 1e-05
[23:16:40] src/imperative/cached_op.cc:192: Disabling fusion due to altered topological order of inputs.
[23:16:40] src/imperative/cached_op.cc:192: Disabling fusion due to altered topological order of inputs.
[23:16:41] src/imperative/cached_op.cc:192: Disabling fusion due to altered topological order of inputs.
[23:16:42] src/imperative/cached_op.cc:192: Disabling fusion due to altered topological order of inputs.
[23:16:42] src/imperative/cached_op.cc:192: Disabling fusion due to altered topological order of inputs.
[23:16:43] src/imperative/cached_op.cc:192: Disabling fusion due to altered topological order of inputs.
[23:16:43] src/imperative/cached_op.cc:192: Disabling fusion due to altered topological order of inputs.
[23:16:44] src/imperative/cached_op.cc:192: Disabling fusion due to altered topological order of inputs.
[23:16:47] src/kvstore/././comm.h:744: only 32 out of 56 GPU pairs are enabled direct access. It may affect the performance. You can set MXNET_ENABLE_GPU_P2P=0 to turn it off
[23:16:47] src/kvstore/././comm.h:753: .vvvv...
[23:16:47] src/kvstore/././comm.h:753: v.vv.v..
[23:16:47] src/kvstore/././comm.h:753: vv.v..v.
[23:16:47] src/kvstore/././comm.h:753: vvv....v
[23:16:47] src/kvstore/././comm.h:753: v....vvv
[23:16:47] src/kvstore/././comm.h:753: .v..v.vv
[23:16:47] src/kvstore/././comm.h:753: ..v.vv.v
[23:16:47] src/kvstore/././comm.h:753: ...vvvv.
INFO:root:AMP: decreasing loss scale to 32768.000000
INFO:root:AMP: decreasing loss scale to 16384.000000
INFO:root:AMP: decreasing loss scale to 8192.000000
INFO:root:AMP: decreasing loss scale to 4096.000000
INFO:root:AMP: decreasing loss scale to 2048.000000
INFO:root:AMP: decreasing loss scale to 1024.000000
INFO:root:[Epoch 0][Batch 9], Speed: 4.880 samples/sec, RPN_Conf=0.606,RPN_SmoothL1=0.156,RCNN_CrossEntropy=4.487,RCNN_SmoothL1=0.021,RCNN_Mask=1.882,RPNAcc=0.751,RPNL1Loss=1.384,RCNNAcc=0.004,RCNNL1Loss=0.947,MaskAcc=0.518,MaskFGAcc=0.522
INFO:root:[Epoch 0 Iteration 10] Set learning rate to 7.24375e-05
INFO:root:[Epoch 0][Batch 19], Speed: 15.230 samples/sec, RPN_Conf=0.577,RPN_SmoothL1=0.151,RCNN_CrossEntropy=4.018,RCNN_SmoothL1=0.019,RCNN_Mask=1.803,RPNAcc=0.797,RPNL1Loss=1.402,RCNNAcc=0.314,RCNNL1Loss=0.879,MaskAcc=0.513,MaskFGAcc=0.523
INFO:root:[Epoch 0 Iteration 20] Set learning rate to 0.000134875
INFO:root:[Epoch 0][Batch 29], Speed: 11.308 samples/sec, RPN_Conf=0.526,RPN_SmoothL1=0.143,RCNN_CrossEntropy=3.149,RCNN_SmoothL1=0.020,RCNN_Mask=1.635,RPNAcc=0.828,RPNL1Loss=1.323,RCNNAcc=0.535,RCNNL1Loss=0.927,MaskAcc=0.514,MaskFGAcc=0.525
INFO:root:[Epoch 0 Iteration 30] Set learning rate to 0.0001973125
INFO:root:[Epoch 0][Batch 39], Speed: 16.033 samples/sec, RPN_Conf=0.479,RPN_SmoothL1=0.139,RCNN_CrossEntropy=2.477,RCNN_SmoothL1=0.022,RCNN_Mask=1.504,RPNAcc=0.842,RPNL1Loss=1.277,RCNNAcc=0.645,RCNNL1Loss=0.998,MaskAcc=0.514,MaskFGAcc=0.527
INFO:root:[Epoch 0 Iteration 40] Set learning rate to 0.00025975
INFO:root:[Epoch 0][Batch 49], Speed: 13.002 samples/sec, RPN_Conf=0.435,RPN_SmoothL1=0.129,RCNN_CrossEntropy=2.047,RCNN_SmoothL1=0.026,RCNN_Mask=1.390,RPNAcc=0.856,RPNL1Loss=1.229,RCNNAcc=0.711,RCNNL1Loss=1.112,MaskAcc=0.518,MaskFGAcc=0.529
INFO:root:[Epoch 0 Iteration 50] Set learning rate to 0.0003221875
INFO:root:[Epoch 0][Batch 59], Speed: 13.902 samples/sec, RPN_Conf=0.423,RPN_SmoothL1=0.127,RCNN_CrossEntropy=1.780,RCNN_SmoothL1=0.032,RCNN_Mask=1.303,RPNAcc=0.858,RPNL1Loss=1.156,RCNNAcc=0.753,RCNNL1Loss=1.227,MaskAcc=0.516,MaskFGAcc=0.532
INFO:root:[Epoch 0 Iteration 60] Set learning rate to 0.000384625
INFO:root:[Epoch 0][Batch 69], Speed: 13.738 samples/sec, RPN_Conf=0.402,RPN_SmoothL1=0.122,RCNN_CrossEntropy=1.582,RCNN_SmoothL1=0.039,RCNN_Mask=1.230,RPNAcc=0.862,RPNL1Loss=1.104,RCNNAcc=0.782,RCNNL1Loss=1.369,MaskAcc=0.517,MaskFGAcc=0.534
INFO:root:[Epoch 0 Iteration 70] Set learning rate to 0.0004470625
INFO:root:[Epoch 0][Batch 79], Speed: 12.123 samples/sec, RPN_Conf=0.385,RPN_SmoothL1=0.116,RCNN_CrossEntropy=1.440,RCNN_SmoothL1=0.048,RCNN_Mask=1.172,RPNAcc=0.865,RPNL1Loss=1.055,RCNNAcc=0.802,RCNNL1Loss=1.537,MaskAcc=0.517,MaskFGAcc=0.536

However, running it with mxnet-cu101==1.7.0 and gluoncv==0.8.0 fails with:

Traceback (most recent call last):
  File "/shared/mx_oob_env/lib/python3.8/multiprocessing/pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
  File "/shared/mx_oob_env/lib/python3.8/site-packages/mxnet/gluon/data/dataloader.py", line 429, in _worker_fn
    batch = batchify_fn([_worker_dataset[i] for i in samples])
  File "/shared/mx_oob_env/lib/python3.8/site-packages/mxnet/gluon/data/dataloader.py", line 429, in <listcomp>
    batch = batchify_fn([_worker_dataset[i] for i in samples])
  File "/shared/mx_oob_env/lib/python3.8/site-packages/mxnet/gluon/data/dataset.py", line 219, in __getitem__
    return self._fn(*item)
  File "/shared/mx_oob_env/lib/python3.8/site-packages/gluoncv/data/transforms/presets/rcnn.py", line 407, in __call__
    cls_target, box_target, box_mask = self._target_generator(
  File "/shared/mx_oob_env/lib/python3.8/site-packages/mxnet/gluon/block.py", line 682, in __call__
    out = self.forward(*args)
  File "/shared/mx_oob_env/lib/python3.8/site-packages/gluoncv/model_zoo/rcnn/rpn/rpn_target.py", line 157, in forward
    ious = mx.nd.contrib.box_iou(anchor, bbox, format='corner').asnumpy()
  File "/shared/mx_oob_env/lib/python3.8/site-packages/mxnet/ndarray/ndarray.py", line 2563, in asnumpy
    check_call(_LIB.MXNDArraySyncCopyToCPU(
  File "/shared/mx_oob_env/lib/python3.8/site-packages/mxnet/base.py", line 246, in check_call
    raise get_last_ffi_error()
mxnet.base.MXNetError: Traceback (most recent call last):
  File "src/ndarray/./../operator/tensor/.././../common/../operator/nn/mkldnn/mkldnn_base-inl.h", line 246
MXNetError: unknown type for MKLDNN :2

@samskalicky
Copy link
Contributor

@bgawrych @bartekkuncer @grygielski FYI, looks like something changed from 1.6.0 to 1.7.0 that is causing this issue when running on CPU with MKLDNN

@zhreshold
Copy link
Member

I would suspect that the merging of mkldnn as default caused some issue in the contrib operators.

@karan6181
Copy link
Contributor Author

Update: Commenting out this line of code (https://github.com/dmlc/gluon-cv/blob/master/scripts/instance/mask_rcnn/train_mask_rcnn.py#L705-L710) seems to work with Horovod v0.21.0, mxnet-cu101==1.7.0 and gluoncv==0.8.0. However, running the same script without horovod fails with different issue which is mentioned below:

INFO:root:[Epoch 0 Iteration 0] Set learning rate to 1e-05
[00:26:11] src/imperative/./cached_op.h:257: Disabling fusion due to altered topological order of inputs.
[00:26:12] src/imperative/./cached_op.h:257: Disabling fusion due to altered topological order of inputs.
Exception in thread Thread-7:
Traceback (most recent call last):
  File "/shared/mx_oob_env/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/shared/mx_oob_env/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/shared/mx_oob_env/lib/python3.8/site-packages/gluoncv/utils/parallel.py", line 105, in _worker
    out = parallel.forward_backward(x)
  File "/shared/mx_oob_env/lib/python3.8/site-packages/gluoncv/model_zoo/rcnn/mask_rcnn/data_parallel.py", line 48, in forward_backward
    cls_targets, box_targets, box_masks, indices = self.net(data, gt_box, gt_label)
  File "/shared/mx_oob_env/lib/python3.8/site-packages/mxnet/gluon/block.py", line 747, in __call__
    out = self.forward(*args)
  File "/shared/mx_oob_env/lib/python3.8/site-packages/mxnet/gluon/block.py", line 1309, in forward
    return self._call_cached_op(x, *args)
  File "/shared/mx_oob_env/lib/python3.8/site-packages/mxnet/gluon/block.py", line 1093, in _call_cached_op
    out = self._cached_op(*cargs)
  File "/shared/mx_oob_env/lib/python3.8/site-packages/mxnet/_ctypes/ndarray.py", line 148, in __call__
    check_call(_LIB.MXInvokeCachedOpEx(
  File "/shared/mx_oob_env/lib/python3.8/site-packages/mxnet/base.py", line 246, in check_call
    raise get_last_ffi_error()
mxnet.base.MXNetError: Traceback (most recent call last):
  File "src/imperative/cached_op.cc", line 777
MXNetError: Check failed: inputs[i]->ctx() == default_ctx (gpu(0) vs. gpu(1)) : CachedOp requires all inputs to live on the same context. But data0 is on gpu(1) while maskrcnn0_normalizedperclassboxcenterencoder0_means is on gpu(0)
  • Conclusion: Manually casting the model to FP16 doesn't work in MXNet-cu101 1.7.0, however, it is working with MXNet-cu101mkl 1.6.0.

@anko-intel
Copy link
Contributor

I will try to analyze the issue.

@szha szha added MKLDNN v1.x Targeting v1.x branch and removed needs triage labels Feb 8, 2021
@szha
Copy link
Member

szha commented Feb 8, 2021

@anko-intel thanks for the fix!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Bug MKLDNN v1.x Targeting v1.x branch
Projects
None yet
Development

No branches or pull requests

5 participants