diff mbox series

[bug#59607,4/8] gnu: Add python-basicsr.

Message ID d06f0940078e449917736ca7f65e1411553ab42f.camel@gmail.com
State New
Headers show
Series Upscale your anime pictures, now with 99% less malware | expand

Commit Message

Liliana Marie Prikler Nov. 20, 2022, 4:25 p.m. UTC
* gnu/packages/patches/python-basicsr-fuck-nvidia.patch: New file.
* gnu/local.mk (dist_patch_DATA): Register it here.
* gnu/packages/machine-learning.scm (python-basicsr): New variable.
---
 gnu/local.mk                                  |    1 +
 gnu/packages/machine-learning.scm             |   66 +
 .../patches/python-basicsr-fuck-nvidia.patch  | 3233 +++++++++++++++++
 3 files changed, 3300 insertions(+)
 create mode 100644 gnu/packages/patches/python-basicsr-fuck-nvidia.patch
diff mbox series

Patch

diff --git a/gnu/local.mk b/gnu/local.mk
index 7278c50e4f..8dd1abe07a 100644
--- a/gnu/local.mk
+++ b/gnu/local.mk
@@ -1720,6 +1720,7 @@  dist_patch_DATA =						\
   %D%/packages/patches/python-apsw-3.39.2.1-test-fix.patch	\
   %D%/packages/patches/python-aionotify-0.2.0-py3.8.patch	\
   %D%/packages/patches/python-argcomplete-1.11.1-fish31.patch	\
+  %D%/packages/patches/python-basicsr-fuck-nvidia.patch	\
   %D%/packages/patches/python-cross-compile.patch		\
   %D%/packages/patches/python-configobj-setuptools.patch	\
   %D%/packages/patches/python-dateutil-pytest-compat.patch	\
diff --git a/gnu/packages/machine-learning.scm b/gnu/packages/machine-learning.scm
index 0566f4bd69..a5767a2c31 100644
--- a/gnu/packages/machine-learning.scm
+++ b/gnu/packages/machine-learning.scm
@@ -750,6 +750,72 @@  (define (delete-ifdefs file)
 in terms of new algorithms.")
     (license license:gpl3+)))
 
+(define-public python-basicsr
+  (package
+    (name "python-basicsr")
+    (version "1.4.2")
+    (source (origin
+              (method git-fetch)
+              (uri
+               (git-reference
+                (url "https://github.com/XPixelGroup/BasicSR")
+                (commit (string-append "v" version))))
+              (patches
+               (search-patches
+                "python-basicsr-fuck-nvidia.patch"))
+              (modules '((guix build utils)))
+              (snippet
+               #~(begin (substitute* (find-files "." "\\.py")
+                          (("\\.cuda\\(\\)") "")
+                          (("pretrained=True") "weights=None"))
+                        ;; Instead of images files, a custom lmdb is used
+                        (delete-file-recursively "tests/data")))
+              (sha256
+               (base32
+                "0qjk1hf1qjla3f6hb8fd6dv9w3b77568z8g17mlcxl91bp031z2i"))))
+    (build-system python-build-system)
+    (arguments
+     (list
+      #:phases
+      #~(modify-phases %standard-phases
+          (add-after 'unpack 'fix-requirements
+            (lambda _
+              (substitute* "requirements.txt"
+                (("opencv-python") "")  ; installed without egg-info
+                (("tb-nightly") ""))))
+          (add-before 'check 'pre-check
+            (lambda _
+              (setenv "HOME" (getcwd))
+              ;; Missing data...
+              (delete-file-recursively "tests/test_data")
+              ;; Model is fetched over the web
+              (delete-file-recursively "tests/test_models")))
+          (replace 'check
+            (lambda* (#:key tests? #:allow-other-keys)
+              (when tests?
+                (invoke "pytest" "-vv")))))))
+    (propagated-inputs (list opencv     ; used via python bindings
+                             python-addict
+                             python-future
+                             python-lmdb
+                             python-numpy
+                             python-pillow
+                             python-pyyaml
+                             python-requests
+                             python-scikit-image
+                             python-scipy
+                             python-pytorch
+                             python-torchvision
+                             python-tqdm
+                             python-yapf))
+    (native-inputs (list lmdb python-cython python-pytest))
+    (home-page "https://github.com/xinntao/BasicSR")
+    (synopsis "Image and Video Super-Resolution Toolbox")
+    (description "BasicSR is a pytorch-based toolbox to perform image restoration
+tasks such as super-scaling, denoising, deblurring, and removal of JPEG
+artifacts.")
+    (license license:asl2.0)))
+
 (define-public ncnn
   (package
     (name "ncnn")
diff --git a/gnu/packages/patches/python-basicsr-fuck-nvidia.patch b/gnu/packages/patches/python-basicsr-fuck-nvidia.patch
new file mode 100644
index 0000000000..30cc1cb9ad
--- /dev/null
+++ b/gnu/packages/patches/python-basicsr-fuck-nvidia.patch
@@ -0,0 +1,3233 @@ 
+diff --git a/basicsr/archs/arch_util.py b/basicsr/archs/arch_util.py
+index 11b82a7..875b2b6 100644
+--- a/basicsr/archs/arch_util.py
++++ b/basicsr/archs/arch_util.py
+@@ -10,7 +10,7 @@ from torch.nn import functional as F
+ from torch.nn import init as init
+ from torch.nn.modules.batchnorm import _BatchNorm
+ 
+-from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
++from basicsr.ops.dcn import ModulatedDeformConvPack
+ from basicsr.utils import get_root_logger
+ 
+ 
+@@ -228,12 +228,8 @@ class DCNv2Pack(ModulatedDeformConvPack):
+             logger = get_root_logger()
+             logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
+ 
+-        if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
+-            return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
+-                                                 self.dilation, mask)
+-        else:
+-            return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
+-                                         self.dilation, self.groups, self.deformable_groups)
++        return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
++                                             self.dilation, mask)
+ 
+ 
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+diff --git a/basicsr/archs/basicvsrpp_arch.py b/basicsr/archs/basicvsrpp_arch.py
+index d9699cb..e726b8b 100644
+--- a/basicsr/archs/basicvsrpp_arch.py
++++ b/basicsr/archs/basicvsrpp_arch.py
+@@ -69,14 +69,6 @@ class BasicVSRPlusPlus(nn.Module):
+         self.backbone = nn.ModuleDict()
+         modules = ['backward_1', 'forward_1', 'backward_2', 'forward_2']
+         for i, module in enumerate(modules):
+-            if torch.cuda.is_available():
+-                self.deform_align[module] = SecondOrderDeformableAlignment(
+-                    2 * mid_channels,
+-                    mid_channels,
+-                    3,
+-                    padding=1,
+-                    deformable_groups=16,
+-                    max_residue_magnitude=max_residue_magnitude)
+             self.backbone[module] = ConvResidualBlocks((2 + i) * mid_channels, mid_channels, num_blocks)
+ 
+         # upsampling module
+diff --git a/basicsr/archs/stylegan2_arch.py b/basicsr/archs/stylegan2_arch.py
+index 9ab37f5..42cb08c 100644
+--- a/basicsr/archs/stylegan2_arch.py
++++ b/basicsr/archs/stylegan2_arch.py
+@@ -4,7 +4,6 @@ import torch
+ from torch import nn
+ from torch.nn import functional as F
+ 
+-from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu
+ from basicsr.ops.upfirdn2d import upfirdn2d
+ from basicsr.utils.registry import ARCH_REGISTRY
+ 
+@@ -141,8 +140,7 @@ class EqualLinear(nn.Module):
+             bias. Default: ``True``.
+         bias_init_val (float): Bias initialized value. Default: 0.
+         lr_mul (float): Learning rate multiplier. Default: 1.
+-        activation (None | str): The activation after ``linear`` operation.
+-            Supported: 'fused_lrelu', None. Default: None.
++        activation (None | str): Ignored.
+     """
+ 
+     def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):
+@@ -150,10 +148,7 @@ class EqualLinear(nn.Module):
+         self.in_channels = in_channels
+         self.out_channels = out_channels
+         self.lr_mul = lr_mul
+-        self.activation = activation
+-        if self.activation not in ['fused_lrelu', None]:
+-            raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
+-                             "Supported ones are: ['fused_lrelu', None].")
++        self.activation = None
+         self.scale = (1 / math.sqrt(in_channels)) * lr_mul
+ 
+         self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
+@@ -167,12 +162,7 @@ class EqualLinear(nn.Module):
+             bias = None
+         else:
+             bias = self.bias * self.lr_mul
+-        if self.activation == 'fused_lrelu':
+-            out = F.linear(x, self.weight * self.scale)
+-            out = fused_leaky_relu(out, bias)
+-        else:
+-            out = F.linear(x, self.weight * self.scale, bias=bias)
+-        return out
++        return F.linear(x, self.weight * self.scale, bias=bias)
+ 
+     def __repr__(self):
+         return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
+@@ -318,7 +308,7 @@ class StyleConv(nn.Module):
+             sample_mode=sample_mode,
+             resample_kernel=resample_kernel)
+         self.weight = nn.Parameter(torch.zeros(1))  # for noise injection
+-        self.activate = FusedLeakyReLU(out_channels)
++        self.activate = ScaledLeakyReLU()
+ 
+     def forward(self, x, style, noise=None):
+         # modulate
+@@ -693,10 +683,7 @@ class ConvLayer(nn.Sequential):
+                 and not activate))
+         # activation
+         if activate:
+-            if bias:
+-                layers.append(FusedLeakyReLU(out_channels))
+-            else:
+-                layers.append(ScaledLeakyReLU(0.2))
++            layers.append(ScaledLeakyReLU(0.2))
+ 
+         super(ConvLayer, self).__init__(*layers)
+ 
+diff --git a/basicsr/data/prefetch_dataloader.py b/basicsr/data/prefetch_dataloader.py
+index 5088425..0cf35e6 100644
+--- a/basicsr/data/prefetch_dataloader.py
++++ b/basicsr/data/prefetch_dataloader.py
+@@ -99,7 +99,7 @@ class CUDAPrefetcher():
+         self.loader = iter(loader)
+         self.opt = opt
+         self.stream = torch.cuda.Stream()
+-        self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
++        self.device = torch.device('cpu')
+         self.preload()
+ 
+     def preload(self):
+diff --git a/basicsr/models/base_model.py b/basicsr/models/base_model.py
+index 05c8d2e..36442a2 100644
+--- a/basicsr/models/base_model.py
++++ b/basicsr/models/base_model.py
+@@ -15,7 +15,7 @@ class BaseModel():
+ 
+     def __init__(self, opt):
+         self.opt = opt
+-        self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
++        self.device = torch.device('cpu')
+         self.is_train = opt['is_train']
+         self.schedulers = []
+         self.optimizers = []
+@@ -91,14 +91,7 @@ class BaseModel():
+         Args:
+             net (nn.Module)
+         """
+-        net = net.to(self.device)
+-        if self.opt['dist']:
+-            find_unused_parameters = self.opt.get('find_unused_parameters', False)
+-            net = DistributedDataParallel(
+-                net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters)
+-        elif self.opt['num_gpu'] > 1:
+-            net = DataParallel(net)
+-        return net
++        return net.to(self.device)
+ 
+     def get_optimizer(self, optim_type, params, lr, **kwargs):
+         if optim_type == 'Adam':
+diff --git a/basicsr/ops/dcn/__init__.py b/basicsr/ops/dcn/__init__.py
+index 32e3592..68033e0 100644
+--- a/basicsr/ops/dcn/__init__.py
++++ b/basicsr/ops/dcn/__init__.py
+@@ -1,7 +1,4 @@
+-from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,
+-                          modulated_deform_conv)
++from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack)
+ 
+ __all__ = [
+-    'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
+-    'modulated_deform_conv'
+-]
++    'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack',]
+diff --git a/basicsr/ops/dcn/deform_conv.py b/basicsr/ops/dcn/deform_conv.py
+index 6268ca8..38ced57 100644
+--- a/basicsr/ops/dcn/deform_conv.py
++++ b/basicsr/ops/dcn/deform_conv.py
+@@ -2,191 +2,9 @@ import math
+ import os
+ import torch
+ from torch import nn as nn
+-from torch.autograd import Function
+-from torch.autograd.function import once_differentiable
+ from torch.nn import functional as F
+ from torch.nn.modules.utils import _pair, _single
+ 
+-BASICSR_JIT = os.getenv('BASICSR_JIT')
+-if BASICSR_JIT == 'True':
+-    from torch.utils.cpp_extension import load
+-    module_path = os.path.dirname(__file__)
+-    deform_conv_ext = load(
+-        'deform_conv',
+-        sources=[
+-            os.path.join(module_path, 'src', 'deform_conv_ext.cpp'),
+-            os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'),
+-            os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'),
+-        ],
+-    )
+-else:
+-    try:
+-        from . import deform_conv_ext
+-    except ImportError:
+-        pass
+-        # avoid annoying print output
+-        # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n '
+-        #       '1. compile with BASICSR_EXT=True. or\n '
+-        #       '2. set BASICSR_JIT=True during running')
+-
+-
+-class DeformConvFunction(Function):
+-
+-    @staticmethod
+-    def forward(ctx,
+-                input,
+-                offset,
+-                weight,
+-                stride=1,
+-                padding=0,
+-                dilation=1,
+-                groups=1,
+-                deformable_groups=1,
+-                im2col_step=64):
+-        if input is not None and input.dim() != 4:
+-            raise ValueError(f'Expected 4D tensor as input, got {input.dim()}D tensor instead.')
+-        ctx.stride = _pair(stride)
+-        ctx.padding = _pair(padding)
+-        ctx.dilation = _pair(dilation)
+-        ctx.groups = groups
+-        ctx.deformable_groups = deformable_groups
+-        ctx.im2col_step = im2col_step
+-
+-        ctx.save_for_backward(input, offset, weight)
+-
+-        output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
+-
+-        ctx.bufs_ = [input.new_empty(0), input.new_empty(0)]  # columns, ones
+-
+-        if not input.is_cuda:
+-            raise NotImplementedError
+-        else:
+-            cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+-            assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
+-            deform_conv_ext.deform_conv_forward(input, weight,
+-                                                offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
+-                                                weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
+-                                                ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
+-                                                ctx.deformable_groups, cur_im2col_step)
+-        return output
+-
+-    @staticmethod
+-    @once_differentiable
+-    def backward(ctx, grad_output):
+-        input, offset, weight = ctx.saved_tensors
+-
+-        grad_input = grad_offset = grad_weight = None
+-
+-        if not grad_output.is_cuda:
+-            raise NotImplementedError
+-        else:
+-            cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+-            assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
+-
+-            if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+-                grad_input = torch.zeros_like(input)
+-                grad_offset = torch.zeros_like(offset)
+-                deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input,
+-                                                           grad_offset, weight, ctx.bufs_[0], weight.size(3),
+-                                                           weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
+-                                                           ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
+-                                                           ctx.deformable_groups, cur_im2col_step)
+-
+-            if ctx.needs_input_grad[2]:
+-                grad_weight = torch.zeros_like(weight)
+-                deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight,
+-                                                                ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
+-                                                                weight.size(2), ctx.stride[1], ctx.stride[0],
+-                                                                ctx.padding[1], ctx.padding[0], ctx.dilation[1],
+-                                                                ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
+-                                                                cur_im2col_step)
+-
+-        return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
+-
+-    @staticmethod
+-    def _output_size(input, weight, padding, dilation, stride):
+-        channels = weight.size(0)
+-        output_size = (input.size(0), channels)
+-        for d in range(input.dim() - 2):
+-            in_size = input.size(d + 2)
+-            pad = padding[d]
+-            kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
+-            stride_ = stride[d]
+-            output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
+-        if not all(map(lambda s: s > 0, output_size)):
+-            raise ValueError(f'convolution input is too small (output would be {"x".join(map(str, output_size))})')
+-        return output_size
+-
+-
+-class ModulatedDeformConvFunction(Function):
+-
+-    @staticmethod
+-    def forward(ctx,
+-                input,
+-                offset,
+-                mask,
+-                weight,
+-                bias=None,
+-                stride=1,
+-                padding=0,
+-                dilation=1,
+-                groups=1,
+-                deformable_groups=1):
+-        ctx.stride = stride
+-        ctx.padding = padding
+-        ctx.dilation = dilation
+-        ctx.groups = groups
+-        ctx.deformable_groups = deformable_groups
+-        ctx.with_bias = bias is not None
+-        if not ctx.with_bias:
+-            bias = input.new_empty(1)  # fake tensor
+-        if not input.is_cuda:
+-            raise NotImplementedError
+-        if weight.requires_grad or mask.requires_grad or offset.requires_grad or input.requires_grad:
+-            ctx.save_for_backward(input, offset, mask, weight, bias)
+-        output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
+-        ctx._bufs = [input.new_empty(0), input.new_empty(0)]
+-        deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output,
+-                                                      ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
+-                                                      ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
+-                                                      ctx.groups, ctx.deformable_groups, ctx.with_bias)
+-        return output
+-
+-    @staticmethod
+-    @once_differentiable
+-    def backward(ctx, grad_output):
+-        if not grad_output.is_cuda:
+-            raise NotImplementedError
+-        input, offset, mask, weight, bias = ctx.saved_tensors
+-        grad_input = torch.zeros_like(input)
+-        grad_offset = torch.zeros_like(offset)
+-        grad_mask = torch.zeros_like(mask)
+-        grad_weight = torch.zeros_like(weight)
+-        grad_bias = torch.zeros_like(bias)
+-        deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
+-                                                       grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
+-                                                       grad_output, weight.shape[2], weight.shape[3], ctx.stride,
+-                                                       ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
+-                                                       ctx.groups, ctx.deformable_groups, ctx.with_bias)
+-        if not ctx.with_bias:
+-            grad_bias = None
+-
+-        return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None)
+-
+-    @staticmethod
+-    def _infer_shape(ctx, input, weight):
+-        n = input.size(0)
+-        channels_out = weight.size(0)
+-        height, width = input.shape[2:4]
+-        kernel_h, kernel_w = weight.shape[2:4]
+-        height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
+-        width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
+-        return n, channels_out, height_out, width_out
+-
+-
+-deform_conv = DeformConvFunction.apply
+-modulated_deform_conv = ModulatedDeformConvFunction.apply
+-
+ 
+ class DeformConv(nn.Module):
+ 
+@@ -230,19 +48,7 @@ class DeformConv(nn.Module):
+         self.weight.data.uniform_(-stdv, stdv)
+ 
+     def forward(self, x, offset):
+-        # To fix an assert error in deform_conv_cuda.cpp:128
+-        # input image is smaller than kernel
+-        input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1])
+-        if input_pad:
+-            pad_h = max(self.kernel_size[0] - x.size(2), 0)
+-            pad_w = max(self.kernel_size[1] - x.size(3), 0)
+-            x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
+-            offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
+-        out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
+-                          self.deformable_groups)
+-        if input_pad:
+-            out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous()
+-        return out
++        return NotImplemented
+ 
+ 
+ class DeformConvPack(DeformConv):
+@@ -281,9 +87,7 @@ class DeformConvPack(DeformConv):
+         self.conv_offset.bias.data.zero_()
+ 
+     def forward(self, x):
+-        offset = self.conv_offset(x)
+-        return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
+-                           self.deformable_groups)
++        return NotImplemented
+ 
+ 
+ class ModulatedDeformConv(nn.Module):
+@@ -329,8 +133,7 @@ class ModulatedDeformConv(nn.Module):
+             self.bias.data.zero_()
+ 
+     def forward(self, x, offset, mask):
+-        return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
+-                                     self.groups, self.deformable_groups)
++        return NotImplemented
+ 
+ 
+ class ModulatedDeformConvPack(ModulatedDeformConv):
+@@ -371,9 +174,4 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
+             self.conv_offset.bias.data.zero_()
+ 
+     def forward(self, x):
+-        out = self.conv_offset(x)
+-        o1, o2, mask = torch.chunk(out, 3, dim=1)
+-        offset = torch.cat((o1, o2), dim=1)
+-        mask = torch.sigmoid(mask)
+-        return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
+-                                     self.groups, self.deformable_groups)
++        return NotImplemented
+diff --git a/basicsr/ops/dcn/src/deform_conv_cuda.cpp b/basicsr/ops/dcn/src/deform_conv_cuda.cpp
+deleted file mode 100644
+index b465c49..0000000
+--- a/basicsr/ops/dcn/src/deform_conv_cuda.cpp
++++ /dev/null
+@@ -1,685 +0,0 @@
+-// modify from
+-// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
+-
+-#include <torch/extension.h>
+-#include <ATen/DeviceGuard.h>
+-
+-#include <cmath>
+-#include <vector>
+-
+-void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
+-                       const int channels, const int height, const int width,
+-                       const int ksize_h, const int ksize_w, const int pad_h,
+-                       const int pad_w, const int stride_h, const int stride_w,
+-                       const int dilation_h, const int dilation_w,
+-                       const int parallel_imgs, const int deformable_group,
+-                       at::Tensor data_col);
+-
+-void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
+-                       const int channels, const int height, const int width,
+-                       const int ksize_h, const int ksize_w, const int pad_h,
+-                       const int pad_w, const int stride_h, const int stride_w,
+-                       const int dilation_h, const int dilation_w,
+-                       const int parallel_imgs, const int deformable_group,
+-                       at::Tensor grad_im);
+-
+-void deformable_col2im_coord(
+-    const at::Tensor data_col, const at::Tensor data_im,
+-    const at::Tensor data_offset, const int channels, const int height,
+-    const int width, const int ksize_h, const int ksize_w, const int pad_h,
+-    const int pad_w, const int stride_h, const int stride_w,
+-    const int dilation_h, const int dilation_w, const int parallel_imgs,
+-    const int deformable_group, at::Tensor grad_offset);
+-
+-void modulated_deformable_im2col_cuda(
+-    const at::Tensor data_im, const at::Tensor data_offset,
+-    const at::Tensor data_mask, const int batch_size, const int channels,
+-    const int height_im, const int width_im, const int height_col,
+-    const int width_col, const int kernel_h, const int kenerl_w,
+-    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+-    const int dilation_h, const int dilation_w, const int deformable_group,
+-    at::Tensor data_col);
+-
+-void modulated_deformable_col2im_cuda(
+-    const at::Tensor data_col, const at::Tensor data_offset,
+-    const at::Tensor data_mask, const int batch_size, const int channels,
+-    const int height_im, const int width_im, const int height_col,
+-    const int width_col, const int kernel_h, const int kenerl_w,
+-    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+-    const int dilation_h, const int dilation_w, const int deformable_group,
+-    at::Tensor grad_im);
+-
+-void modulated_deformable_col2im_coord_cuda(
+-    const at::Tensor data_col, const at::Tensor data_im,
+-    const at::Tensor data_offset, const at::Tensor data_mask,
+-    const int batch_size, const int channels, const int height_im,
+-    const int width_im, const int height_col, const int width_col,
+-    const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
+-    const int stride_h, const int stride_w, const int dilation_h,
+-    const int dilation_w, const int deformable_group, at::Tensor grad_offset,
+-    at::Tensor grad_mask);
+-
+-void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
+-                 at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
+-                 int padW, int dilationH, int dilationW, int group,
+-                 int deformable_group) {
+-  TORCH_CHECK(weight.ndimension() == 4,
+-           "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
+-           "but got: %s",
+-           weight.ndimension());
+-
+-  TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+-
+-  TORCH_CHECK(kW > 0 && kH > 0,
+-           "kernel size should be greater than zero, but got kH: %d kW: %d", kH,
+-           kW);
+-
+-  TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
+-           "kernel size should be consistent with weight, ",
+-           "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
+-           kW, weight.size(2), weight.size(3));
+-
+-  TORCH_CHECK(dW > 0 && dH > 0,
+-           "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
+-
+-  TORCH_CHECK(
+-      dilationW > 0 && dilationH > 0,
+-      "dilation should be greater than 0, but got dilationH: %d dilationW: %d",
+-      dilationH, dilationW);
+-
+-  int ndim = input.ndimension();
+-  int dimf = 0;
+-  int dimh = 1;
+-  int dimw = 2;
+-
+-  if (ndim == 4) {
+-    dimf++;
+-    dimh++;
+-    dimw++;
+-  }
+-
+-  TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
+-           ndim);
+-
+-  long nInputPlane = weight.size(1) * group;
+-  long inputHeight = input.size(dimh);
+-  long inputWidth = input.size(dimw);
+-  long nOutputPlane = weight.size(0);
+-  long outputHeight =
+-      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+-  long outputWidth =
+-      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+-
+-  TORCH_CHECK(nInputPlane % deformable_group == 0,
+-           "input channels must divide deformable group size");
+-
+-  if (outputWidth < 1 || outputHeight < 1)
+-    AT_ERROR(
+-        "Given input size: (%ld x %ld x %ld). "
+-        "Calculated output size: (%ld x %ld x %ld). Output size is too small",
+-        nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
+-        outputWidth);
+-
+-  TORCH_CHECK(input.size(1) == nInputPlane,
+-           "invalid number of input planes, expected: %d, but got: %d",
+-           nInputPlane, input.size(1));
+-
+-  TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
+-           "input image is smaller than kernel");
+-
+-  TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
+-           "invalid spatial size of offset, expected height: %d width: %d, but "
+-           "got height: %d width: %d",
+-           outputHeight, outputWidth, offset.size(2), offset.size(3));
+-
+-  TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
+-           "invalid number of channels of offset");
+-
+-  if (gradOutput != NULL) {
+-    TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane,
+-             "invalid number of gradOutput planes, expected: %d, but got: %d",
+-             nOutputPlane, gradOutput->size(dimf));
+-
+-    TORCH_CHECK((gradOutput->size(dimh) == outputHeight &&
+-              gradOutput->size(dimw) == outputWidth),
+-             "invalid size of gradOutput, expected height: %d width: %d , but "
+-             "got height: %d width: %d",
+-             outputHeight, outputWidth, gradOutput->size(dimh),
+-             gradOutput->size(dimw));
+-  }
+-}
+-
+-int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
+-                             at::Tensor offset, at::Tensor output,
+-                             at::Tensor columns, at::Tensor ones, int kW,
+-                             int kH, int dW, int dH, int padW, int padH,
+-                             int dilationW, int dilationH, int group,
+-                             int deformable_group, int im2col_step) {
+-  // todo: resize columns to include im2col: done
+-  // todo: add im2col_step as input
+-  // todo: add new output buffer and transpose it to output (or directly
+-  // transpose output) todo: possibly change data indexing because of
+-  // parallel_imgs
+-
+-  shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
+-              dilationH, dilationW, group, deformable_group);
+-  at::DeviceGuard guard(input.device());
+-
+-  input = input.contiguous();
+-  offset = offset.contiguous();
+-  weight = weight.contiguous();
+-
+-  int batch = 1;
+-  if (input.ndimension() == 3) {
+-    // Force batch
+-    batch = 0;
+-    input.unsqueeze_(0);
+-    offset.unsqueeze_(0);
+-  }
+-
+-  // todo: assert batchsize dividable by im2col_step
+-
+-  long batchSize = input.size(0);
+-  long nInputPlane = input.size(1);
+-  long inputHeight = input.size(2);
+-  long inputWidth = input.size(3);
+-
+-  long nOutputPlane = weight.size(0);
+-
+-  long outputWidth =
+-      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+-  long outputHeight =
+-      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+-
+-  TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
+-
+-  output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
+-                        outputHeight, outputWidth});
+-  columns = at::zeros(
+-      {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+-      input.options());
+-
+-  if (ones.ndimension() != 2 ||
+-      ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
+-    ones = at::ones({outputHeight, outputWidth}, input.options());
+-  }
+-
+-  input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+-                      inputHeight, inputWidth});
+-  offset =
+-      offset.view({batchSize / im2col_step, im2col_step,
+-                   deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+-
+-  at::Tensor output_buffer =
+-      at::zeros({batchSize / im2col_step, nOutputPlane,
+-                 im2col_step * outputHeight, outputWidth},
+-                output.options());
+-
+-  output_buffer = output_buffer.view(
+-      {output_buffer.size(0), group, output_buffer.size(1) / group,
+-       output_buffer.size(2), output_buffer.size(3)});
+-
+-  for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+-    deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
+-                      inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+-                      dilationW, im2col_step, deformable_group, columns);
+-
+-    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+-    weight = weight.view({group, weight.size(0) / group, weight.size(1),
+-                          weight.size(2), weight.size(3)});
+-
+-    for (int g = 0; g < group; g++) {
+-      output_buffer[elt][g] = output_buffer[elt][g]
+-                                  .flatten(1)
+-                                  .addmm_(weight[g].flatten(1), columns[g])
+-                                  .view_as(output_buffer[elt][g]);
+-    }
+-  }
+-
+-  output_buffer = output_buffer.view(
+-      {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
+-       output_buffer.size(3), output_buffer.size(4)});
+-
+-  output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
+-                                      im2col_step, outputHeight, outputWidth});
+-  output_buffer.transpose_(1, 2);
+-  output.copy_(output_buffer);
+-  output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+-
+-  input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+-  offset = offset.view(
+-      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+-
+-  if (batch == 0) {
+-    output = output.view({nOutputPlane, outputHeight, outputWidth});
+-    input = input.view({nInputPlane, inputHeight, inputWidth});
+-    offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
+-  }
+-
+-  return 1;
+-}
+-
+-int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
+-                                    at::Tensor gradOutput, at::Tensor gradInput,
+-                                    at::Tensor gradOffset, at::Tensor weight,
+-                                    at::Tensor columns, int kW, int kH, int dW,
+-                                    int dH, int padW, int padH, int dilationW,
+-                                    int dilationH, int group,
+-                                    int deformable_group, int im2col_step) {
+-  shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
+-              dilationH, dilationW, group, deformable_group);
+-  at::DeviceGuard guard(input.device());
+-
+-  input = input.contiguous();
+-  offset = offset.contiguous();
+-  gradOutput = gradOutput.contiguous();
+-  weight = weight.contiguous();
+-
+-  int batch = 1;
+-
+-  if (input.ndimension() == 3) {
+-    // Force batch
+-    batch = 0;
+-    input = input.view({1, input.size(0), input.size(1), input.size(2)});
+-    offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
+-    gradOutput = gradOutput.view(
+-        {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
+-  }
+-
+-  long batchSize = input.size(0);
+-  long nInputPlane = input.size(1);
+-  long inputHeight = input.size(2);
+-  long inputWidth = input.size(3);
+-
+-  long nOutputPlane = weight.size(0);
+-
+-  long outputWidth =
+-      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+-  long outputHeight =
+-      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+-
+-  TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
+-  gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
+-  columns = at::zeros(
+-      {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+-      input.options());
+-
+-  // change order of grad output
+-  gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
+-                                nOutputPlane, outputHeight, outputWidth});
+-  gradOutput.transpose_(1, 2);
+-
+-  gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
+-                              inputHeight, inputWidth});
+-  input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+-                      inputHeight, inputWidth});
+-  gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
+-                                deformable_group * 2 * kH * kW, outputHeight,
+-                                outputWidth});
+-  offset =
+-      offset.view({batchSize / im2col_step, im2col_step,
+-                   deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+-
+-  for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+-    // divide into groups
+-    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+-    weight = weight.view({group, weight.size(0) / group, weight.size(1),
+-                          weight.size(2), weight.size(3)});
+-    gradOutput = gradOutput.view(
+-        {gradOutput.size(0), group, gradOutput.size(1) / group,
+-         gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
+-
+-    for (int g = 0; g < group; g++) {
+-      columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
+-                                     gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
+-    }
+-
+-    columns =
+-        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+-    gradOutput = gradOutput.view(
+-        {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
+-         gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
+-
+-    deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
+-                            inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
+-                            dilationH, dilationW, im2col_step, deformable_group,
+-                            gradOffset[elt]);
+-
+-    deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
+-                      inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+-                      dilationW, im2col_step, deformable_group, gradInput[elt]);
+-  }
+-
+-  gradOutput.transpose_(1, 2);
+-  gradOutput =
+-      gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+-
+-  gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
+-  input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+-  gradOffset = gradOffset.view(
+-      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+-  offset = offset.view(
+-      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+-
+-  if (batch == 0) {
+-    gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
+-    input = input.view({nInputPlane, inputHeight, inputWidth});
+-    gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
+-    offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
+-    gradOffset =
+-        gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
+-  }
+-
+-  return 1;
+-}
+-
+-int deform_conv_backward_parameters_cuda(
+-    at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+-    at::Tensor gradWeight,  // at::Tensor gradBias,
+-    at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+-    int padW, int padH, int dilationW, int dilationH, int group,
+-    int deformable_group, float scale, int im2col_step) {
+-  // todo: transpose and reshape outGrad
+-  // todo: reshape columns
+-  // todo: add im2col_step as input
+-
+-  shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
+-              padW, dilationH, dilationW, group, deformable_group);
+-  at::DeviceGuard guard(input.device());
+-
+-  input = input.contiguous();
+-  offset = offset.contiguous();
+-  gradOutput = gradOutput.contiguous();
+-
+-  int batch = 1;
+-
+-  if (input.ndimension() == 3) {
+-    // Force batch
+-    batch = 0;
+-    input = input.view(
+-        at::IntList({1, input.size(0), input.size(1), input.size(2)}));
+-    gradOutput = gradOutput.view(
+-        {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
+-  }
+-
+-  long batchSize = input.size(0);
+-  long nInputPlane = input.size(1);
+-  long inputHeight = input.size(2);
+-  long inputWidth = input.size(3);
+-
+-  long nOutputPlane = gradWeight.size(0);
+-
+-  long outputWidth =
+-      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+-  long outputHeight =
+-      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+-
+-  TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
+-
+-  columns = at::zeros(
+-      {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+-      input.options());
+-
+-  gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
+-                                nOutputPlane, outputHeight, outputWidth});
+-  gradOutput.transpose_(1, 2);
+-
+-  at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
+-  gradOutputBuffer =
+-      gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
+-                             outputHeight, outputWidth});
+-  gradOutputBuffer.copy_(gradOutput);
+-  gradOutputBuffer =
+-      gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
+-                             im2col_step * outputHeight, outputWidth});
+-
+-  gradOutput.transpose_(1, 2);
+-  gradOutput =
+-      gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+-
+-  input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+-                      inputHeight, inputWidth});
+-  offset =
+-      offset.view({batchSize / im2col_step, im2col_step,
+-                   deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+-
+-  for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+-    deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
+-                      inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+-                      dilationW, im2col_step, deformable_group, columns);
+-
+-    // divide into group
+-    gradOutputBuffer = gradOutputBuffer.view(
+-        {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
+-         gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
+-    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+-    gradWeight =
+-        gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
+-                         gradWeight.size(2), gradWeight.size(3)});
+-
+-    for (int g = 0; g < group; g++) {
+-      gradWeight[g] = gradWeight[g]
+-                          .flatten(1)
+-                          .addmm_(gradOutputBuffer[elt][g].flatten(1),
+-                                  columns[g].transpose(1, 0), 1.0, scale)
+-                          .view_as(gradWeight[g]);
+-    }
+-    gradOutputBuffer = gradOutputBuffer.view(
+-        {gradOutputBuffer.size(0),
+-         gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
+-         gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
+-    columns =
+-        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+-    gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
+-                                  gradWeight.size(2), gradWeight.size(3),
+-                                  gradWeight.size(4)});
+-  }
+-
+-  input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+-  offset = offset.view(
+-      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+-
+-  if (batch == 0) {
+-    gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
+-    input = input.view({nInputPlane, inputHeight, inputWidth});
+-  }
+-
+-  return 1;
+-}
+-
+-void modulated_deform_conv_cuda_forward(
+-    at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+-    at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+-    int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+-    const int pad_h, const int pad_w, const int dilation_h,
+-    const int dilation_w, const int group, const int deformable_group,
+-    const bool with_bias) {
+-  TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+-  TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+-  at::DeviceGuard guard(input.device());
+-
+-  const int batch = input.size(0);
+-  const int channels = input.size(1);
+-  const int height = input.size(2);
+-  const int width = input.size(3);
+-
+-  const int channels_out = weight.size(0);
+-  const int channels_kernel = weight.size(1);
+-  const int kernel_h_ = weight.size(2);
+-  const int kernel_w_ = weight.size(3);
+-
+-  if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
+-    AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).",
+-             kernel_h_, kernel_w, kernel_h_, kernel_w_);
+-  if (channels != channels_kernel * group)
+-    AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).",
+-             channels, channels_kernel * group);
+-
+-  const int height_out =
+-      (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+-  const int width_out =
+-      (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+-
+-  if (ones.ndimension() != 2 ||
+-      ones.size(0) * ones.size(1) < height_out * width_out) {
+-    // Resize plane and fill with ones...
+-    ones = at::ones({height_out, width_out}, input.options());
+-  }
+-
+-  // resize output
+-  output = output.view({batch, channels_out, height_out, width_out}).zero_();
+-  // resize temporary columns
+-  columns =
+-      at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
+-                input.options());
+-
+-  output = output.view({output.size(0), group, output.size(1) / group,
+-                        output.size(2), output.size(3)});
+-
+-  for (int b = 0; b < batch; b++) {
+-    modulated_deformable_im2col_cuda(
+-        input[b], offset[b], mask[b], 1, channels, height, width, height_out,
+-        width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+-        dilation_h, dilation_w, deformable_group, columns);
+-
+-    // divide into group
+-    weight = weight.view({group, weight.size(0) / group, weight.size(1),
+-                          weight.size(2), weight.size(3)});
+-    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+-
+-    for (int g = 0; g < group; g++) {
+-      output[b][g] = output[b][g]
+-                         .flatten(1)
+-                         .addmm_(weight[g].flatten(1), columns[g])
+-                         .view_as(output[b][g]);
+-    }
+-
+-    weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
+-                          weight.size(3), weight.size(4)});
+-    columns =
+-        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+-  }
+-
+-  output = output.view({output.size(0), output.size(1) * output.size(2),
+-                        output.size(3), output.size(4)});
+-
+-  if (with_bias) {
+-    output += bias.view({1, bias.size(0), 1, 1});
+-  }
+-}
+-
+-void modulated_deform_conv_cuda_backward(
+-    at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+-    at::Tensor offset, at::Tensor mask, at::Tensor columns,
+-    at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+-    at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+-    int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+-    int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+-    const bool with_bias) {
+-  TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+-  TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+-  at::DeviceGuard guard(input.device());
+-
+-  const int batch = input.size(0);
+-  const int channels = input.size(1);
+-  const int height = input.size(2);
+-  const int width = input.size(3);
+-
+-  const int channels_kernel = weight.size(1);
+-  const int kernel_h_ = weight.size(2);
+-  const int kernel_w_ = weight.size(3);
+-  if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
+-    AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).",
+-             kernel_h_, kernel_w, kernel_h_, kernel_w_);
+-  if (channels != channels_kernel * group)
+-    AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).",
+-             channels, channels_kernel * group);
+-
+-  const int height_out =
+-      (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+-  const int width_out =
+-      (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+-
+-  if (ones.ndimension() != 2 ||
+-      ones.size(0) * ones.size(1) < height_out * width_out) {
+-    // Resize plane and fill with ones...
+-    ones = at::ones({height_out, width_out}, input.options());
+-  }
+-
+-  grad_input = grad_input.view({batch, channels, height, width});
+-  columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
+-                      input.options());
+-
+-  grad_output =
+-      grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
+-                        grad_output.size(2), grad_output.size(3)});
+-
+-  for (int b = 0; b < batch; b++) {
+-    // divide int group
+-    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+-    weight = weight.view({group, weight.size(0) / group, weight.size(1),
+-                          weight.size(2), weight.size(3)});
+-
+-    for (int g = 0; g < group; g++) {
+-      columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
+-                        grad_output[b][g].flatten(1), 0.0f, 1.0f);
+-    }
+-
+-    columns =
+-        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+-    weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
+-                          weight.size(3), weight.size(4)});
+-
+-    // gradient w.r.t. input coordinate data
+-    modulated_deformable_col2im_coord_cuda(
+-        columns, input[b], offset[b], mask[b], 1, channels, height, width,
+-        height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
+-        stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
+-        grad_mask[b]);
+-    // gradient w.r.t. input data
+-    modulated_deformable_col2im_cuda(
+-        columns, offset[b], mask[b], 1, channels, height, width, height_out,
+-        width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+-        dilation_h, dilation_w, deformable_group, grad_input[b]);
+-
+-    // gradient w.r.t. weight, dWeight should accumulate across the batch and
+-    // group
+-    modulated_deformable_im2col_cuda(
+-        input[b], offset[b], mask[b], 1, channels, height, width, height_out,
+-        width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+-        dilation_h, dilation_w, deformable_group, columns);
+-
+-    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+-    grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
+-                                    grad_weight.size(1), grad_weight.size(2),
+-                                    grad_weight.size(3)});
+-    if (with_bias)
+-      grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
+-
+-    for (int g = 0; g < group; g++) {
+-      grad_weight[g] =
+-          grad_weight[g]
+-              .flatten(1)
+-              .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
+-              .view_as(grad_weight[g]);
+-      if (with_bias) {
+-        grad_bias[g] =
+-            grad_bias[g]
+-                .view({-1, 1})
+-                .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
+-                .view(-1);
+-      }
+-    }
+-
+-    columns =
+-        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+-    grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
+-                                    grad_weight.size(2), grad_weight.size(3),
+-                                    grad_weight.size(4)});
+-    if (with_bias)
+-      grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
+-  }
+-  grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
+-                                  grad_output.size(2), grad_output.size(3),
+-                                  grad_output.size(4)});
+-}
+diff --git a/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu b/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu
+deleted file mode 100644
+index 98752dc..0000000
+--- a/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu
++++ /dev/null
+@@ -1,867 +0,0 @@
+-/*!
+- ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
+- *
+- * COPYRIGHT
+- *
+- * All contributions by the University of California:
+- * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
+- * All rights reserved.
+- *
+- * All other contributions:
+- * Copyright (c) 2014-2017, the respective contributors
+- * All rights reserved.
+- *
+- * Caffe uses a shared copyright model: each contributor holds copyright over
+- * their contributions to Caffe. The project versioning records all such
+- * contribution and copyright details. If a contributor wants to further mark
+- * their specific copyright on a particular contribution, they should indicate
+- * their copyright solely in the commit message of the change when it is
+- * committed.
+- *
+- * LICENSE
+- *
+- * Redistribution and use in source and binary forms, with or without
+- * modification, are permitted provided that the following conditions are met:
+- *
+- * 1. Redistributions of source code must retain the above copyright notice, this
+- * list of conditions and the following disclaimer.
+- * 2. Redistributions in binary form must reproduce the above copyright notice,
+- * this list of conditions and the following disclaimer in the documentation
+- * and/or other materials provided with the distribution.
+- *
+- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+- * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+- * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+- * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+- * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+- * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+- * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+- * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+- * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+- *
+- * CONTRIBUTION AGREEMENT
+- *
+- * By contributing to the BVLC/caffe repository through pull-request, comment,
+- * or otherwise, the contributor releases their content to the
+- * license and copyright terms herein.
+- *
+- ***************** END Caffe Copyright Notice and Disclaimer ********************
+- *
+- * Copyright (c) 2018 Microsoft
+- * Licensed under The MIT License [see LICENSE for details]
+- * \file modulated_deformable_im2col.cuh
+- * \brief Function definitions of converting an image to
+- * column matrix based on kernel, padding, dilation, and offset.
+- * These functions are mainly used in deformable convolution operators.
+- * \ref: https://arxiv.org/abs/1703.06211
+- * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
+- */
+-
+-// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
+-
+-#include <ATen/ATen.h>
+-#include <ATen/cuda/CUDAContext.h>
+-#include <THC/THCAtomics.cuh>
+-#include <stdio.h>
+-#include <math.h>
+-#include <float.h>
+-
+-using namespace at;
+-
+-#define CUDA_KERNEL_LOOP(i, n)                                 \
+-  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
+-       i += blockDim.x * gridDim.x)
+-
+-const int CUDA_NUM_THREADS = 1024;
+-const int kMaxGridNum = 65535;
+-
+-inline int GET_BLOCKS(const int N)
+-{
+-  return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
+-}
+-
+-template <typename scalar_t>
+-__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
+-                                               const int height, const int width, scalar_t h, scalar_t w)
+-{
+-
+-  int h_low = floor(h);
+-  int w_low = floor(w);
+-  int h_high = h_low + 1;
+-  int w_high = w_low + 1;
+-
+-  scalar_t lh = h - h_low;
+-  scalar_t lw = w - w_low;
+-  scalar_t hh = 1 - lh, hw = 1 - lw;
+-
+-  scalar_t v1 = 0;
+-  if (h_low >= 0 && w_low >= 0)
+-    v1 = bottom_data[h_low * data_width + w_low];
+-  scalar_t v2 = 0;
+-  if (h_low >= 0 && w_high <= width - 1)
+-    v2 = bottom_data[h_low * data_width + w_high];
+-  scalar_t v3 = 0;
+-  if (h_high <= height - 1 && w_low >= 0)
+-    v3 = bottom_data[h_high * data_width + w_low];
+-  scalar_t v4 = 0;
+-  if (h_high <= height - 1 && w_high <= width - 1)
+-    v4 = bottom_data[h_high * data_width + w_high];
+-
+-  scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+-
+-  scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+-  return val;
+-}
+-
+-template <typename scalar_t>
+-__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
+-                                        const int h, const int w, const int height, const int width)
+-{
+-
+-  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+-  {
+-    //empty
+-    return 0;
+-  }
+-
+-  int argmax_h_low = floor(argmax_h);
+-  int argmax_w_low = floor(argmax_w);
+-  int argmax_h_high = argmax_h_low + 1;
+-  int argmax_w_high = argmax_w_low + 1;
+-
+-  scalar_t weight = 0;
+-  if (h == argmax_h_low && w == argmax_w_low)
+-    weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
+-  if (h == argmax_h_low && w == argmax_w_high)
+-    weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
+-  if (h == argmax_h_high && w == argmax_w_low)
+-    weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
+-  if (h == argmax_h_high && w == argmax_w_high)
+-    weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
+-  return weight;
+-}
+-
+-template <typename scalar_t>
+-__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
+-                                          const int height, const int width, const scalar_t *im_data,
+-                                          const int data_width, const int bp_dir)
+-{
+-
+-  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+-  {
+-    //empty
+-    return 0;
+-  }
+-
+-  int argmax_h_low = floor(argmax_h);
+-  int argmax_w_low = floor(argmax_w);
+-  int argmax_h_high = argmax_h_low + 1;
+-  int argmax_w_high = argmax_w_low + 1;
+-
+-  scalar_t weight = 0;
+-
+-  if (bp_dir == 0)
+-  {
+-    if (argmax_h_low >= 0 && argmax_w_low >= 0)
+-      weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
+-    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+-      weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
+-    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+-      weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
+-    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+-      weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+-  }
+-  else if (bp_dir == 1)
+-  {
+-    if (argmax_h_low >= 0 && argmax_w_low >= 0)
+-      weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
+-    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+-      weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
+-    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+-      weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
+-    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+-      weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+-  }
+-
+-  return weight;
+-}
+-
+-template <typename scalar_t>
+-__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
+-                                             const int height, const int width, const int kernel_h, const int kernel_w,
+-                                             const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+-                                             const int dilation_h, const int dilation_w, const int channel_per_deformable_group,
+-                                             const int batch_size, const int num_channels, const int deformable_group,
+-                                             const int height_col, const int width_col,
+-                                             scalar_t *data_col)
+-{
+-  CUDA_KERNEL_LOOP(index, n)
+-  {
+-    // index index of output matrix
+-    const int w_col = index % width_col;
+-    const int h_col = (index / width_col) % height_col;
+-    const int b_col = (index / width_col / height_col) % batch_size;
+-    const int c_im = (index / width_col / height_col) / batch_size;
+-    const int c_col = c_im * kernel_h * kernel_w;
+-
+-    // compute deformable group index
+-    const int deformable_group_index = c_im / channel_per_deformable_group;
+-
+-    const int h_in = h_col * stride_h - pad_h;
+-    const int w_in = w_col * stride_w - pad_w;
+-    scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
+-    //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
+-    const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
+-    const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+-
+-    for (int i = 0; i < kernel_h; ++i)
+-    {
+-      for (int j = 0; j < kernel_w; ++j)
+-      {
+-        const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
+-        const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
+-        const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+-        const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+-        scalar_t val = static_cast<scalar_t>(0);
+-        const scalar_t h_im = h_in + i * dilation_h + offset_h;
+-        const scalar_t w_im = w_in + j * dilation_w + offset_w;
+-        if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
+-        {
+-          //const scalar_t map_h = i * dilation_h + offset_h;
+-          //const scalar_t map_w = j * dilation_w + offset_w;
+-          //const int cur_height = height - h_in;
+-          //const int cur_width = width - w_in;
+-          //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
+-          val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
+-        }
+-        *data_col_ptr = val;
+-        data_col_ptr += batch_size * height_col * width_col;
+-      }
+-    }
+-  }
+-}
+-
+-void deformable_im2col(
+-    const at::Tensor data_im, const at::Tensor data_offset, const int channels,
+-    const int height, const int width, const int ksize_h, const int ksize_w,
+-    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+-    const int dilation_h, const int dilation_w, const int parallel_imgs,
+-    const int deformable_group, at::Tensor data_col)
+-{
+-  // num_axes should be smaller than block size
+-  // todo: check parallel_imgs is correctly passed in
+-  int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+-  int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+-  int num_kernels = channels * height_col * width_col * parallel_imgs;
+-  int channel_per_deformable_group = channels / deformable_group;
+-
+-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+-      data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
+-        const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
+-        const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
+-        scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
+-
+-        deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
+-            num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
+-            pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+-            channel_per_deformable_group, parallel_imgs, channels, deformable_group,
+-            height_col, width_col, data_col_);
+-      }));
+-
+-  cudaError_t err = cudaGetLastError();
+-  if (err != cudaSuccess)
+-  {
+-    printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
+-  }
+-}
+-
+-template <typename scalar_t>
+-__global__ void deformable_col2im_gpu_kernel(
+-    const int n, const scalar_t *data_col, const scalar_t *data_offset,
+-    const int channels, const int height, const int width,
+-    const int kernel_h, const int kernel_w,
+-    const int pad_h, const int pad_w,
+-    const int stride_h, const int stride_w,
+-    const int dilation_h, const int dilation_w,
+-    const int channel_per_deformable_group,
+-    const int batch_size, const int deformable_group,
+-    const int height_col, const int width_col,
+-    scalar_t *grad_im)
+-{
+-  CUDA_KERNEL_LOOP(index, n)
+-  {
+-    const int j = (index / width_col / height_col / batch_size) % kernel_w;
+-    const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
+-    const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
+-    // compute the start and end of the output
+-
+-    const int deformable_group_index = c / channel_per_deformable_group;
+-
+-    int w_out = index % width_col;
+-    int h_out = (index / width_col) % height_col;
+-    int b = (index / width_col / height_col) % batch_size;
+-    int w_in = w_out * stride_w - pad_w;
+-    int h_in = h_out * stride_h - pad_h;
+-
+-    const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
+-                                                        2 * kernel_h * kernel_w * height_col * width_col;
+-    const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
+-    const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
+-    const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+-    const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+-    const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
+-    const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
+-
+-    const scalar_t cur_top_grad = data_col[index];
+-    const int cur_h = (int)cur_inv_h_data;
+-    const int cur_w = (int)cur_inv_w_data;
+-    for (int dy = -2; dy <= 2; dy++)
+-    {
+-      for (int dx = -2; dx <= 2; dx++)
+-      {
+-        if (cur_h + dy >= 0 && cur_h + dy < height &&
+-            cur_w + dx >= 0 && cur_w + dx < width &&
+-            abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
+-            abs(cur_inv_w_data - (cur_w + dx)) < 1)
+-        {
+-          int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
+-          scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
+-          atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
+-        }
+-      }
+-    }
+-  }
+-}
+-
+-void deformable_col2im(
+-    const at::Tensor data_col, const at::Tensor data_offset, const int channels,
+-    const int height, const int width, const int ksize_h,
+-    const int ksize_w, const int pad_h, const int pad_w,
+-    const int stride_h, const int stride_w,
+-    const int dilation_h, const int dilation_w,
+-    const int parallel_imgs, const int deformable_group,
+-    at::Tensor grad_im)
+-{
+-
+-  // todo: make sure parallel_imgs is passed in correctly
+-  int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+-  int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+-  int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
+-  int channel_per_deformable_group = channels / deformable_group;
+-
+-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+-      data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
+-        const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
+-        const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
+-        scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
+-
+-        deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
+-            num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
+-            ksize_w, pad_h, pad_w, stride_h, stride_w,
+-            dilation_h, dilation_w, channel_per_deformable_group,
+-            parallel_imgs, deformable_group, height_col, width_col, grad_im_);
+-      }));
+-
+-  cudaError_t err = cudaGetLastError();
+-  if (err != cudaSuccess)
+-  {
+-    printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
+-  }
+-}
+-
+-template <typename scalar_t>
+-__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
+-                                                   const scalar_t *data_im, const scalar_t *data_offset,
+-                                                   const int channels, const int height, const int width,
+-                                                   const int kernel_h, const int kernel_w,
+-                                                   const int pad_h, const int pad_w,
+-                                                   const int stride_h, const int stride_w,
+-                                                   const int dilation_h, const int dilation_w,
+-                                                   const int channel_per_deformable_group,
+-                                                   const int batch_size, const int offset_channels, const int deformable_group,
+-                                                   const int height_col, const int width_col, scalar_t *grad_offset)
+-{
+-  CUDA_KERNEL_LOOP(index, n)
+-  {
+-    scalar_t val = 0;
+-    int w = index % width_col;
+-    int h = (index / width_col) % height_col;
+-    int c = (index / width_col / height_col) % offset_channels;
+-    int b = (index / width_col / height_col) / offset_channels;
+-    // compute the start and end of the output
+-
+-    const int deformable_group_index = c / (2 * kernel_h * kernel_w);
+-    const int col_step = kernel_h * kernel_w;
+-    int cnt = 0;
+-    const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
+-                                                  batch_size * width_col * height_col;
+-    const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
+-                                                channel_per_deformable_group / kernel_h / kernel_w * height * width;
+-    const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
+-                                                        kernel_h * kernel_w * height_col * width_col;
+-
+-    const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
+-
+-    for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
+-    {
+-      const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
+-      const int bp_dir = offset_c % 2;
+-
+-      int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
+-      int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
+-      int w_out = col_pos % width_col;
+-      int h_out = (col_pos / width_col) % height_col;
+-      int w_in = w_out * stride_w - pad_w;
+-      int h_in = h_out * stride_h - pad_h;
+-      const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
+-      const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
+-      const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+-      const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+-      scalar_t inv_h = h_in + i * dilation_h + offset_h;
+-      scalar_t inv_w = w_in + j * dilation_w + offset_w;
+-      if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
+-      {
+-        inv_h = inv_w = -2;
+-      }
+-      const scalar_t weight = get_coordinate_weight(
+-          inv_h, inv_w,
+-          height, width, data_im_ptr + cnt * height * width, width, bp_dir);
+-      val += weight * data_col_ptr[col_pos];
+-      cnt += 1;
+-    }
+-
+-    grad_offset[index] = val;
+-  }
+-}
+-
+-void deformable_col2im_coord(
+-    const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
+-    const int channels, const int height, const int width, const int ksize_h,
+-    const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
+-    const int stride_w, const int dilation_h, const int dilation_w,
+-    const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
+-{
+-
+-  int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+-  int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+-  int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;
+-  int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;
+-
+-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+-      data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
+-        const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
+-        const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
+-        const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
+-        scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
+-
+-        deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
+-            num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
+-            ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
+-            dilation_h, dilation_w, channel_per_deformable_group,
+-            parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,
+-            height_col, width_col, grad_offset_);
+-      }));
+-}
+-
+-template <typename scalar_t>
+-__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
+-                                         const int height, const int width, scalar_t h, scalar_t w)
+-{
+-  int h_low = floor(h);
+-  int w_low = floor(w);
+-  int h_high = h_low + 1;
+-  int w_high = w_low + 1;
+-
+-  scalar_t lh = h - h_low;
+-  scalar_t lw = w - w_low;
+-  scalar_t hh = 1 - lh, hw = 1 - lw;
+-
+-  scalar_t v1 = 0;
+-  if (h_low >= 0 && w_low >= 0)
+-    v1 = bottom_data[h_low * data_width + w_low];
+-  scalar_t v2 = 0;
+-  if (h_low >= 0 && w_high <= width - 1)
+-    v2 = bottom_data[h_low * data_width + w_high];
+-  scalar_t v3 = 0;
+-  if (h_high <= height - 1 && w_low >= 0)
+-    v3 = bottom_data[h_high * data_width + w_low];
+-  scalar_t v4 = 0;
+-  if (h_high <= height - 1 && w_high <= width - 1)
+-    v4 = bottom_data[h_high * data_width + w_high];
+-
+-  scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+-
+-  scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+-  return val;
+-}
+-
+-template <typename scalar_t>
+-__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
+-                                             const int h, const int w, const int height, const int width)
+-{
+-  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+-  {
+-    //empty
+-    return 0;
+-  }
+-
+-  int argmax_h_low = floor(argmax_h);
+-  int argmax_w_low = floor(argmax_w);
+-  int argmax_h_high = argmax_h_low + 1;
+-  int argmax_w_high = argmax_w_low + 1;
+-
+-  scalar_t weight = 0;
+-  if (h == argmax_h_low && w == argmax_w_low)
+-    weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
+-  if (h == argmax_h_low && w == argmax_w_high)
+-    weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
+-  if (h == argmax_h_high && w == argmax_w_low)
+-    weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
+-  if (h == argmax_h_high && w == argmax_w_high)
+-    weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
+-  return weight;
+-}
+-
+-template <typename scalar_t>
+-__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
+-                                               const int height, const int width, const scalar_t *im_data,
+-                                               const int data_width, const int bp_dir)
+-{
+-  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+-  {
+-    //empty
+-    return 0;
+-  }
+-
+-  int argmax_h_low = floor(argmax_h);
+-  int argmax_w_low = floor(argmax_w);
+-  int argmax_h_high = argmax_h_low + 1;
+-  int argmax_w_high = argmax_w_low + 1;
+-
+-  scalar_t weight = 0;
+-
+-  if (bp_dir == 0)
+-  {
+-    if (argmax_h_low >= 0 && argmax_w_low >= 0)
+-      weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
+-    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+-      weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
+-    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+-      weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
+-    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+-      weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+-  }
+-  else if (bp_dir == 1)
+-  {
+-    if (argmax_h_low >= 0 && argmax_w_low >= 0)
+-      weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
+-    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+-      weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
+-    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+-      weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
+-    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+-      weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+-  }
+-
+-  return weight;
+-}
+-
+-template <typename scalar_t>
+-__global__ void modulated_deformable_im2col_gpu_kernel(const int n,
+-                                                       const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
+-                                                       const int height, const int width, const int kernel_h, const int kernel_w,
+-                                                       const int pad_h, const int pad_w,
+-                                                       const int stride_h, const int stride_w,
+-                                                       const int dilation_h, const int dilation_w,
+-                                                       const int channel_per_deformable_group,
+-                                                       const int batch_size, const int num_channels, const int deformable_group,
+-                                                       const int height_col, const int width_col,
+-                                                       scalar_t *data_col)
+-{
+-  CUDA_KERNEL_LOOP(index, n)
+-  {
+-    // index index of output matrix
+-    const int w_col = index % width_col;
+-    const int h_col = (index / width_col) % height_col;
+-    const int b_col = (index / width_col / height_col) % batch_size;
+-    const int c_im = (index / width_col / height_col) / batch_size;
+-    const int c_col = c_im * kernel_h * kernel_w;
+-
+-    // compute deformable group index
+-    const int deformable_group_index = c_im / channel_per_deformable_group;
+-
+-    const int h_in = h_col * stride_h - pad_h;
+-    const int w_in = w_col * stride_w - pad_w;
+-
+-    scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
+-    //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
+-    const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
+-    const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+-
+-    const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+-
+-    for (int i = 0; i < kernel_h; ++i)
+-    {
+-      for (int j = 0; j < kernel_w; ++j)
+-      {
+-        const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
+-        const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
+-        const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
+-        const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+-        const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+-        const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+-        scalar_t val = static_cast<scalar_t>(0);
+-        const scalar_t h_im = h_in + i * dilation_h + offset_h;
+-        const scalar_t w_im = w_in + j * dilation_w + offset_w;
+-        //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
+-        if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
+-        {
+-          //const float map_h = i * dilation_h + offset_h;
+-          //const float map_w = j * dilation_w + offset_w;
+-          //const int cur_height = height - h_in;
+-          //const int cur_width = width - w_in;
+-          //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
+-          val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
+-        }
+-        *data_col_ptr = val * mask;
+-        data_col_ptr += batch_size * height_col * width_col;
+-        //data_col_ptr += height_col * width_col;
+-      }
+-    }
+-  }
+-}
+-
+-template <typename scalar_t>
+-__global__ void modulated_deformable_col2im_gpu_kernel(const int n,
+-                                                       const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
+-                                                       const int channels, const int height, const int width,
+-                                                       const int kernel_h, const int kernel_w,
+-                                                       const int pad_h, const int pad_w,
+-                                                       const int stride_h, const int stride_w,
+-                                                       const int dilation_h, const int dilation_w,
+-                                                       const int channel_per_deformable_group,
+-                                                       const int batch_size, const int deformable_group,
+-                                                       const int height_col, const int width_col,
+-                                                       scalar_t *grad_im)
+-{
+-  CUDA_KERNEL_LOOP(index, n)
+-  {
+-    const int j = (index / width_col / height_col / batch_size) % kernel_w;
+-    const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
+-    const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
+-    // compute the start and end of the output
+-
+-    const int deformable_group_index = c / channel_per_deformable_group;
+-
+-    int w_out = index % width_col;
+-    int h_out = (index / width_col) % height_col;
+-    int b = (index / width_col / height_col) % batch_size;
+-    int w_in = w_out * stride_w - pad_w;
+-    int h_in = h_out * stride_h - pad_h;
+-
+-    const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+-    const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+-    const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
+-    const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
+-    const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
+-    const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+-    const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+-    const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+-    const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
+-    const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
+-
+-    const scalar_t cur_top_grad = data_col[index] * mask;
+-    const int cur_h = (int)cur_inv_h_data;
+-    const int cur_w = (int)cur_inv_w_data;
+-    for (int dy = -2; dy <= 2; dy++)
+-    {
+-      for (int dx = -2; dx <= 2; dx++)
+-      {
+-        if (cur_h + dy >= 0 && cur_h + dy < height &&
+-            cur_w + dx >= 0 && cur_w + dx < width &&
+-            abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
+-            abs(cur_inv_w_data - (cur_w + dx)) < 1)
+-        {
+-          int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
+-          scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
+-          atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
+-        }
+-      }
+-    }
+-  }
+-}
+-
+-template <typename scalar_t>
+-__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
+-                                                             const scalar_t *data_col, const scalar_t *data_im,
+-                                                             const scalar_t *data_offset, const scalar_t *data_mask,
+-                                                             const int channels, const int height, const int width,
+-                                                             const int kernel_h, const int kernel_w,
+-                                                             const int pad_h, const int pad_w,
+-                                                             const int stride_h, const int stride_w,
+-                                                             const int dilation_h, const int dilation_w,
+-                                                             const int channel_per_deformable_group,
+-                                                             const int batch_size, const int offset_channels, const int deformable_group,
+-                                                             const int height_col, const int width_col,
+-                                                             scalar_t *grad_offset, scalar_t *grad_mask)
+-{
+-  CUDA_KERNEL_LOOP(index, n)
+-  {
+-    scalar_t val = 0, mval = 0;
+-    int w = index % width_col;
+-    int h = (index / width_col) % height_col;
+-    int c = (index / width_col / height_col) % offset_channels;
+-    int b = (index / width_col / height_col) / offset_channels;
+-    // compute the start and end of the output
+-
+-    const int deformable_group_index = c / (2 * kernel_h * kernel_w);
+-    const int col_step = kernel_h * kernel_w;
+-    int cnt = 0;
+-    const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
+-    const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
+-    const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+-    const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+-
+-    const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
+-
+-    for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
+-    {
+-      const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
+-      const int bp_dir = offset_c % 2;
+-
+-      int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
+-      int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
+-      int w_out = col_pos % width_col;
+-      int h_out = (col_pos / width_col) % height_col;
+-      int w_in = w_out * stride_w - pad_w;
+-      int h_in = h_out * stride_h - pad_h;
+-      const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
+-      const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
+-      const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
+-      const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+-      const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+-      const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+-      scalar_t inv_h = h_in + i * dilation_h + offset_h;
+-      scalar_t inv_w = w_in + j * dilation_w + offset_w;
+-      if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
+-      {
+-        inv_h = inv_w = -2;
+-      }
+-      else
+-      {
+-        mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
+-      }
+-      const scalar_t weight = dmcn_get_coordinate_weight(
+-          inv_h, inv_w,
+-          height, width, data_im_ptr + cnt * height * width, width, bp_dir);
+-      val += weight * data_col_ptr[col_pos] * mask;
+-      cnt += 1;
+-    }
+-    // KERNEL_ASSIGN(grad_offset[index], offset_req, val);
+-    grad_offset[index] = val;
+-    if (offset_c % 2 == 0)
+-      // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
+-      grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
+-  }
+-}
+-
+-void modulated_deformable_im2col_cuda(
+-    const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
+-    const int batch_size, const int channels, const int height_im, const int width_im,
+-    const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
+-    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+-    const int dilation_h, const int dilation_w,
+-    const int deformable_group, at::Tensor data_col)
+-{
+-  // num_axes should be smaller than block size
+-  const int channel_per_deformable_group = channels / deformable_group;
+-  const int num_kernels = channels * batch_size * height_col * width_col;
+-
+-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+-      data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] {
+-        const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
+-        const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
+-        const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
+-        scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
+-
+-        modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
+-            num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
+-            pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
+-            batch_size, channels, deformable_group, height_col, width_col, data_col_);
+-      }));
+-
+-  cudaError_t err = cudaGetLastError();
+-  if (err != cudaSuccess)
+-  {
+-    printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+-  }
+-}
+-
+-void modulated_deformable_col2im_cuda(
+-    const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
+-    const int batch_size, const int channels, const int height_im, const int width_im,
+-    const int height_col, const int width_col, const int kernel_h, const int kernel_w,
+-    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+-    const int dilation_h, const int dilation_w,
+-    const int deformable_group, at::Tensor grad_im)
+-{
+-
+-  const int channel_per_deformable_group = channels / deformable_group;
+-  const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
+-
+-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+-      data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] {
+-        const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
+-        const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
+-        const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
+-        scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
+-
+-        modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
+-            num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
+-            kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+-            dilation_h, dilation_w, channel_per_deformable_group,
+-            batch_size, deformable_group, height_col, width_col, grad_im_);
+-      }));
+-
+-  cudaError_t err = cudaGetLastError();
+-  if (err != cudaSuccess)
+-  {
+-    printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+-  }
+-}
+-
+-void modulated_deformable_col2im_coord_cuda(
+-    const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
+-    const int batch_size, const int channels, const int height_im, const int width_im,
+-    const int height_col, const int width_col, const int kernel_h, const int kernel_w,
+-    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+-    const int dilation_h, const int dilation_w,
+-    const int deformable_group,
+-    at::Tensor grad_offset, at::Tensor grad_mask)
+-{
+-  const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
+-  const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
+-
+-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+-      data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] {
+-        const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
+-        const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
+-        const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
+-        const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
+-        scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
+-        scalar_t *grad_mask_ = grad_mask.data_ptr<scalar_t>();
+-
+-        modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
+-            num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
+-            kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+-            dilation_h, dilation_w, channel_per_deformable_group,
+-            batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
+-            grad_offset_, grad_mask_);
+-      }));
+-  cudaError_t err = cudaGetLastError();
+-  if (err != cudaSuccess)
+-  {
+-    printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
+-  }
+-}
+diff --git a/basicsr/ops/dcn/src/deform_conv_ext.cpp b/basicsr/ops/dcn/src/deform_conv_ext.cpp
+deleted file mode 100644
+index 41c6df6..0000000
+--- a/basicsr/ops/dcn/src/deform_conv_ext.cpp
++++ /dev/null
+@@ -1,164 +0,0 @@
+-// modify from
+-// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
+-
+-#include <torch/extension.h>
+-#include <ATen/DeviceGuard.h>
+-
+-#include <cmath>
+-#include <vector>
+-
+-#define WITH_CUDA  // always use cuda
+-#ifdef WITH_CUDA
+-int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
+-                             at::Tensor offset, at::Tensor output,
+-                             at::Tensor columns, at::Tensor ones, int kW,
+-                             int kH, int dW, int dH, int padW, int padH,
+-                             int dilationW, int dilationH, int group,
+-                             int deformable_group, int im2col_step);
+-
+-int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
+-                                    at::Tensor gradOutput, at::Tensor gradInput,
+-                                    at::Tensor gradOffset, at::Tensor weight,
+-                                    at::Tensor columns, int kW, int kH, int dW,
+-                                    int dH, int padW, int padH, int dilationW,
+-                                    int dilationH, int group,
+-                                    int deformable_group, int im2col_step);
+-
+-int deform_conv_backward_parameters_cuda(
+-    at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+-    at::Tensor gradWeight,  // at::Tensor gradBias,
+-    at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+-    int padW, int padH, int dilationW, int dilationH, int group,
+-    int deformable_group, float scale, int im2col_step);
+-
+-void modulated_deform_conv_cuda_forward(
+-    at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+-    at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+-    int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+-    const int pad_h, const int pad_w, const int dilation_h,
+-    const int dilation_w, const int group, const int deformable_group,
+-    const bool with_bias);
+-
+-void modulated_deform_conv_cuda_backward(
+-    at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+-    at::Tensor offset, at::Tensor mask, at::Tensor columns,
+-    at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+-    at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+-    int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+-    int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+-    const bool with_bias);
+-#endif
+-
+-int deform_conv_forward(at::Tensor input, at::Tensor weight,
+-                             at::Tensor offset, at::Tensor output,
+-                             at::Tensor columns, at::Tensor ones, int kW,
+-                             int kH, int dW, int dH, int padW, int padH,
+-                             int dilationW, int dilationH, int group,
+-                             int deformable_group, int im2col_step) {
+-  if (input.device().is_cuda()) {
+-#ifdef WITH_CUDA
+-    return deform_conv_forward_cuda(input, weight, offset, output, columns,
+-        ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group,
+-        deformable_group, im2col_step);
+-#else
+-    AT_ERROR("deform conv is not compiled with GPU support");
+-#endif
+-  }
+-  AT_ERROR("deform conv is not implemented on CPU");
+-}
+-
+-int deform_conv_backward_input(at::Tensor input, at::Tensor offset,
+-                                    at::Tensor gradOutput, at::Tensor gradInput,
+-                                    at::Tensor gradOffset, at::Tensor weight,
+-                                    at::Tensor columns, int kW, int kH, int dW,
+-                                    int dH, int padW, int padH, int dilationW,
+-                                    int dilationH, int group,
+-                                    int deformable_group, int im2col_step) {
+-  if (input.device().is_cuda()) {
+-#ifdef WITH_CUDA
+-    return deform_conv_backward_input_cuda(input, offset, gradOutput,
+-        gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH,
+-        dilationW, dilationH, group, deformable_group, im2col_step);
+-#else
+-    AT_ERROR("deform conv is not compiled with GPU support");
+-#endif
+-  }
+-  AT_ERROR("deform conv is not implemented on CPU");
+-}
+-
+-int deform_conv_backward_parameters(
+-    at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+-    at::Tensor gradWeight,  // at::Tensor gradBias,
+-    at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+-    int padW, int padH, int dilationW, int dilationH, int group,
+-    int deformable_group, float scale, int im2col_step) {
+-  if (input.device().is_cuda()) {
+-#ifdef WITH_CUDA
+-    return deform_conv_backward_parameters_cuda(input, offset, gradOutput,
+-        gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW,
+-        dilationH, group, deformable_group, scale, im2col_step);
+-#else
+-    AT_ERROR("deform conv is not compiled with GPU support");
+-#endif
+-  }
+-  AT_ERROR("deform conv is not implemented on CPU");
+-}
+-
+-void modulated_deform_conv_forward(
+-    at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+-    at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+-    int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+-    const int pad_h, const int pad_w, const int dilation_h,
+-    const int dilation_w, const int group, const int deformable_group,
+-    const bool with_bias) {
+-  if (input.device().is_cuda()) {
+-#ifdef WITH_CUDA
+-    return modulated_deform_conv_cuda_forward(input, weight, bias, ones,
+-        offset, mask, output, columns, kernel_h, kernel_w, stride_h,
+-        stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
+-        deformable_group, with_bias);
+-#else
+-    AT_ERROR("modulated deform conv is not compiled with GPU support");
+-#endif
+-  }
+-  AT_ERROR("modulated deform conv is not implemented on CPU");
+-}
+-
+-void modulated_deform_conv_backward(
+-    at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+-    at::Tensor offset, at::Tensor mask, at::Tensor columns,
+-    at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+-    at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+-    int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+-    int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+-    const bool with_bias) {
+-  if (input.device().is_cuda()) {
+-#ifdef WITH_CUDA
+-    return modulated_deform_conv_cuda_backward(input, weight, bias, ones,
+-        offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset,
+-        grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w,
+-        pad_h, pad_w, dilation_h, dilation_w, group, deformable_group,
+-        with_bias);
+-#else
+-    AT_ERROR("modulated deform conv is not compiled with GPU support");
+-#endif
+-  }
+-  AT_ERROR("modulated deform conv is not implemented on CPU");
+-}
+-
+-
+-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+-  m.def("deform_conv_forward", &deform_conv_forward,
+-        "deform forward");
+-  m.def("deform_conv_backward_input", &deform_conv_backward_input,
+-        "deform_conv_backward_input");
+-  m.def("deform_conv_backward_parameters",
+-        &deform_conv_backward_parameters,
+-        "deform_conv_backward_parameters");
+-  m.def("modulated_deform_conv_forward",
+-        &modulated_deform_conv_forward,
+-        "modulated deform conv forward");
+-  m.def("modulated_deform_conv_backward",
+-        &modulated_deform_conv_backward,
+-        "modulated deform conv backward");
+-}
+diff --git a/basicsr/ops/fused_act/__init__.py b/basicsr/ops/fused_act/__init__.py
+deleted file mode 100644
+index 241dc07..0000000
+--- a/basicsr/ops/fused_act/__init__.py
++++ /dev/null
+@@ -1,3 +0,0 @@
+-from .fused_act import FusedLeakyReLU, fused_leaky_relu
+-
+-__all__ = ['FusedLeakyReLU', 'fused_leaky_relu']
+diff --git a/basicsr/ops/fused_act/fused_act.py b/basicsr/ops/fused_act/fused_act.py
+deleted file mode 100644
+index 88edc44..0000000
+--- a/basicsr/ops/fused_act/fused_act.py
++++ /dev/null
+@@ -1,95 +0,0 @@
+-# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
+-
+-import os
+-import torch
+-from torch import nn
+-from torch.autograd import Function
+-
+-BASICSR_JIT = os.getenv('BASICSR_JIT')
+-if BASICSR_JIT == 'True':
+-    from torch.utils.cpp_extension import load
+-    module_path = os.path.dirname(__file__)
+-    fused_act_ext = load(
+-        'fused',
+-        sources=[
+-            os.path.join(module_path, 'src', 'fused_bias_act.cpp'),
+-            os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'),
+-        ],
+-    )
+-else:
+-    try:
+-        from . import fused_act_ext
+-    except ImportError:
+-        pass
+-        # avoid annoying print output
+-        # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n '
+-        #       '1. compile with BASICSR_EXT=True. or\n '
+-        #       '2. set BASICSR_JIT=True during running')
+-
+-
+-class FusedLeakyReLUFunctionBackward(Function):
+-
+-    @staticmethod
+-    def forward(ctx, grad_output, out, negative_slope, scale):
+-        ctx.save_for_backward(out)
+-        ctx.negative_slope = negative_slope
+-        ctx.scale = scale
+-
+-        empty = grad_output.new_empty(0)
+-
+-        grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale)
+-
+-        dim = [0]
+-
+-        if grad_input.ndim > 2:
+-            dim += list(range(2, grad_input.ndim))
+-
+-        grad_bias = grad_input.sum(dim).detach()
+-
+-        return grad_input, grad_bias
+-
+-    @staticmethod
+-    def backward(ctx, gradgrad_input, gradgrad_bias):
+-        out, = ctx.saved_tensors
+-        gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope,
+-                                                    ctx.scale)
+-
+-        return gradgrad_out, None, None, None
+-
+-
+-class FusedLeakyReLUFunction(Function):
+-
+-    @staticmethod
+-    def forward(ctx, input, bias, negative_slope, scale):
+-        empty = input.new_empty(0)
+-        out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
+-        ctx.save_for_backward(out)
+-        ctx.negative_slope = negative_slope
+-        ctx.scale = scale
+-
+-        return out
+-
+-    @staticmethod
+-    def backward(ctx, grad_output):
+-        out, = ctx.saved_tensors
+-
+-        grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale)
+-
+-        return grad_input, grad_bias, None, None
+-
+-
+-class FusedLeakyReLU(nn.Module):
+-
+-    def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
+-        super().__init__()
+-
+-        self.bias = nn.Parameter(torch.zeros(channel))
+-        self.negative_slope = negative_slope
+-        self.scale = scale
+-
+-    def forward(self, input):
+-        return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
+-
+-
+-def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
+-    return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
+diff --git a/basicsr/ops/fused_act/src/fused_bias_act.cpp b/basicsr/ops/fused_act/src/fused_bias_act.cpp
+deleted file mode 100644
+index 85ed0a7..0000000
+--- a/basicsr/ops/fused_act/src/fused_bias_act.cpp
++++ /dev/null
+@@ -1,26 +0,0 @@
+-// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp
+-#include <torch/extension.h>
+-
+-
+-torch::Tensor fused_bias_act_op(const torch::Tensor& input,
+-                                const torch::Tensor& bias,
+-                                const torch::Tensor& refer,
+-                                int act, int grad, float alpha, float scale);
+-
+-#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+-#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+-#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+-
+-torch::Tensor fused_bias_act(const torch::Tensor& input,
+-                             const torch::Tensor& bias,
+-                             const torch::Tensor& refer,
+-                             int act, int grad, float alpha, float scale) {
+-    CHECK_CUDA(input);
+-    CHECK_CUDA(bias);
+-
+-    return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
+-}
+-
+-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+-    m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
+-}
+diff --git a/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu b/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu
+deleted file mode 100644
+index 54c7ff5..0000000
+--- a/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu
++++ /dev/null
+@@ -1,100 +0,0 @@
+-// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu
+-// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
+-//
+-// This work is made available under the Nvidia Source Code License-NC.
+-// To view a copy of this license, visit
+-// https://nvlabs.github.io/stylegan2/license.html
+-
+-#include <torch/types.h>
+-
+-#include <ATen/ATen.h>
+-#include <ATen/AccumulateType.h>
+-#include <ATen/cuda/CUDAContext.h>
+-#include <ATen/cuda/CUDAApplyUtils.cuh>
+-
+-#include <cuda.h>
+-#include <cuda_runtime.h>
+-
+-
+-template <typename scalar_t>
+-static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
+-    int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
+-    int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
+-
+-    scalar_t zero = 0.0;
+-
+-    for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
+-        scalar_t x = p_x[xi];
+-
+-        if (use_bias) {
+-            x += p_b[(xi / step_b) % size_b];
+-        }
+-
+-        scalar_t ref = use_ref ? p_ref[xi] : zero;
+-
+-        scalar_t y;
+-
+-        switch (act * 10 + grad) {
+-            default:
+-            case 10: y = x; break;
+-            case 11: y = x; break;
+-            case 12: y = 0.0; break;
+-
+-            case 30: y = (x > 0.0) ? x : x * alpha; break;
+-            case 31: y = (ref > 0.0) ? x : x * alpha; break;
+-            case 32: y = 0.0; break;
+-        }
+-
+-        out[xi] = y * scale;
+-    }
+-}
+-
+-
+-torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
+-    int act, int grad, float alpha, float scale) {
+-    int curDevice = -1;
+-    cudaGetDevice(&curDevice);
+-    cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+-
+-    auto x = input.contiguous();
+-    auto b = bias.contiguous();
+-    auto ref = refer.contiguous();
+-
+-    int use_bias = b.numel() ? 1 : 0;
+-    int use_ref = ref.numel() ? 1 : 0;
+-
+-    int size_x = x.numel();
+-    int size_b = b.numel();
+-    int step_b = 1;
+-
+-    for (int i = 1 + 1; i < x.dim(); i++) {
+-        step_b *= x.size(i);
+-    }
+-
+-    int loop_x = 4;
+-    int block_size = 4 * 32;
+-    int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
+-
+-    auto y = torch::empty_like(x);
+-
+-    AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
+-        fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
+-            y.data_ptr<scalar_t>(),
+-            x.data_ptr<scalar_t>(),
+-            b.data_ptr<scalar_t>(),
+-            ref.data_ptr<scalar_t>(),
+-            act,
+-            grad,
+-            alpha,
+-            scale,
+-            loop_x,
+-            size_x,
+-            step_b,
+-            size_b,
+-            use_bias,
+-            use_ref
+-        );
+-    });
+-
+-    return y;
+-}
+diff --git a/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp b/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp
+deleted file mode 100644
+index 43d0b67..0000000
+--- a/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp
++++ /dev/null
+@@ -1,24 +0,0 @@
+-// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp
+-#include <torch/extension.h>
+-
+-
+-torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
+-                            int up_x, int up_y, int down_x, int down_y,
+-                            int pad_x0, int pad_x1, int pad_y0, int pad_y1);
+-
+-#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+-#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+-#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+-
+-torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
+-                        int up_x, int up_y, int down_x, int down_y,
+-                        int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
+-    CHECK_CUDA(input);
+-    CHECK_CUDA(kernel);
+-
+-    return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
+-}
+-
+-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+-    m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
+-}
+diff --git a/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu b/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu
+deleted file mode 100644
+index 8870063..0000000
+--- a/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu
++++ /dev/null
+@@ -1,370 +0,0 @@
+-// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu
+-// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
+-//
+-// This work is made available under the Nvidia Source Code License-NC.
+-// To view a copy of this license, visit
+-// https://nvlabs.github.io/stylegan2/license.html
+-
+-#include <torch/types.h>
+-
+-#include <ATen/ATen.h>
+-#include <ATen/AccumulateType.h>
+-#include <ATen/cuda/CUDAApplyUtils.cuh>
+-#include <ATen/cuda/CUDAContext.h>
+-
+-#include <cuda.h>
+-#include <cuda_runtime.h>
+-
+-static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
+-  int c = a / b;
+-
+-  if (c * b > a) {
+-    c--;
+-  }
+-
+-  return c;
+-}
+-
+-struct UpFirDn2DKernelParams {
+-  int up_x;
+-  int up_y;
+-  int down_x;
+-  int down_y;
+-  int pad_x0;
+-  int pad_x1;
+-  int pad_y0;
+-  int pad_y1;
+-
+-  int major_dim;
+-  int in_h;
+-  int in_w;
+-  int minor_dim;
+-  int kernel_h;
+-  int kernel_w;
+-  int out_h;
+-  int out_w;
+-  int loop_major;
+-  int loop_x;
+-};
+-
+-template <typename scalar_t>
+-__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
+-                                       const scalar_t *kernel,
+-                                       const UpFirDn2DKernelParams p) {
+-  int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
+-  int out_y = minor_idx / p.minor_dim;
+-  minor_idx -= out_y * p.minor_dim;
+-  int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
+-  int major_idx_base = blockIdx.z * p.loop_major;
+-
+-  if (out_x_base >= p.out_w || out_y >= p.out_h ||
+-      major_idx_base >= p.major_dim) {
+-    return;
+-  }
+-
+-  int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
+-  int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
+-  int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
+-  int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
+-
+-  for (int loop_major = 0, major_idx = major_idx_base;
+-       loop_major < p.loop_major && major_idx < p.major_dim;
+-       loop_major++, major_idx++) {
+-    for (int loop_x = 0, out_x = out_x_base;
+-         loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
+-      int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
+-      int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
+-      int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
+-      int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
+-
+-      const scalar_t *x_p =
+-          &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
+-                 minor_idx];
+-      const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
+-      int x_px = p.minor_dim;
+-      int k_px = -p.up_x;
+-      int x_py = p.in_w * p.minor_dim;
+-      int k_py = -p.up_y * p.kernel_w;
+-
+-      scalar_t v = 0.0f;
+-
+-      for (int y = 0; y < h; y++) {
+-        for (int x = 0; x < w; x++) {
+-          v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
+-          x_p += x_px;
+-          k_p += k_px;
+-        }
+-
+-        x_p += x_py - w * x_px;
+-        k_p += k_py - w * k_px;
+-      }
+-
+-      out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
+-          minor_idx] = v;
+-    }
+-  }
+-}
+-
+-template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
+-          int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
+-__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
+-                                 const scalar_t *kernel,
+-                                 const UpFirDn2DKernelParams p) {
+-  const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
+-  const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
+-
+-  __shared__ volatile float sk[kernel_h][kernel_w];
+-  __shared__ volatile float sx[tile_in_h][tile_in_w];
+-
+-  int minor_idx = blockIdx.x;
+-  int tile_out_y = minor_idx / p.minor_dim;
+-  minor_idx -= tile_out_y * p.minor_dim;
+-  tile_out_y *= tile_out_h;
+-  int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
+-  int major_idx_base = blockIdx.z * p.loop_major;
+-
+-  if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
+-      major_idx_base >= p.major_dim) {
+-    return;
+-  }
+-
+-  for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
+-       tap_idx += blockDim.x) {
+-    int ky = tap_idx / kernel_w;
+-    int kx = tap_idx - ky * kernel_w;
+-    scalar_t v = 0.0;
+-
+-    if (kx < p.kernel_w & ky < p.kernel_h) {
+-      v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
+-    }
+-
+-    sk[ky][kx] = v;
+-  }
+-
+-  for (int loop_major = 0, major_idx = major_idx_base;
+-       loop_major < p.loop_major & major_idx < p.major_dim;
+-       loop_major++, major_idx++) {
+-    for (int loop_x = 0, tile_out_x = tile_out_x_base;
+-         loop_x < p.loop_x & tile_out_x < p.out_w;
+-         loop_x++, tile_out_x += tile_out_w) {
+-      int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
+-      int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
+-      int tile_in_x = floor_div(tile_mid_x, up_x);
+-      int tile_in_y = floor_div(tile_mid_y, up_y);
+-
+-      __syncthreads();
+-
+-      for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
+-           in_idx += blockDim.x) {
+-        int rel_in_y = in_idx / tile_in_w;
+-        int rel_in_x = in_idx - rel_in_y * tile_in_w;
+-        int in_x = rel_in_x + tile_in_x;
+-        int in_y = rel_in_y + tile_in_y;
+-
+-        scalar_t v = 0.0;
+-
+-        if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
+-          v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
+-                        p.minor_dim +
+-                    minor_idx];
+-        }
+-
+-        sx[rel_in_y][rel_in_x] = v;
+-      }
+-
+-      __syncthreads();
+-      for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
+-           out_idx += blockDim.x) {
+-        int rel_out_y = out_idx / tile_out_w;
+-        int rel_out_x = out_idx - rel_out_y * tile_out_w;
+-        int out_x = rel_out_x + tile_out_x;
+-        int out_y = rel_out_y + tile_out_y;
+-
+-        int mid_x = tile_mid_x + rel_out_x * down_x;
+-        int mid_y = tile_mid_y + rel_out_y * down_y;
+-        int in_x = floor_div(mid_x, up_x);
+-        int in_y = floor_div(mid_y, up_y);
+-        int rel_in_x = in_x - tile_in_x;
+-        int rel_in_y = in_y - tile_in_y;
+-        int kernel_x = (in_x + 1) * up_x - mid_x - 1;
+-        int kernel_y = (in_y + 1) * up_y - mid_y - 1;
+-
+-        scalar_t v = 0.0;
+-
+-#pragma unroll
+-        for (int y = 0; y < kernel_h / up_y; y++)
+-#pragma unroll
+-          for (int x = 0; x < kernel_w / up_x; x++)
+-            v += sx[rel_in_y + y][rel_in_x + x] *
+-                 sk[kernel_y + y * up_y][kernel_x + x * up_x];
+-
+-        if (out_x < p.out_w & out_y < p.out_h) {
+-          out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
+-              minor_idx] = v;
+-        }
+-      }
+-    }
+-  }
+-}
+-
+-torch::Tensor upfirdn2d_op(const torch::Tensor &input,
+-                           const torch::Tensor &kernel, int up_x, int up_y,
+-                           int down_x, int down_y, int pad_x0, int pad_x1,
+-                           int pad_y0, int pad_y1) {
+-  int curDevice = -1;
+-  cudaGetDevice(&curDevice);
+-  cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+-
+-  UpFirDn2DKernelParams p;
+-
+-  auto x = input.contiguous();
+-  auto k = kernel.contiguous();
+-
+-  p.major_dim = x.size(0);
+-  p.in_h = x.size(1);
+-  p.in_w = x.size(2);
+-  p.minor_dim = x.size(3);
+-  p.kernel_h = k.size(0);
+-  p.kernel_w = k.size(1);
+-  p.up_x = up_x;
+-  p.up_y = up_y;
+-  p.down_x = down_x;
+-  p.down_y = down_y;
+-  p.pad_x0 = pad_x0;
+-  p.pad_x1 = pad_x1;
+-  p.pad_y0 = pad_y0;
+-  p.pad_y1 = pad_y1;
+-
+-  p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
+-            p.down_y;
+-  p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
+-            p.down_x;
+-
+-  auto out =
+-      at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
+-
+-  int mode = -1;
+-
+-  int tile_out_h = -1;
+-  int tile_out_w = -1;
+-
+-  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
+-      p.kernel_h <= 4 && p.kernel_w <= 4) {
+-    mode = 1;
+-    tile_out_h = 16;
+-    tile_out_w = 64;
+-  }
+-
+-  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
+-      p.kernel_h <= 3 && p.kernel_w <= 3) {
+-    mode = 2;
+-    tile_out_h = 16;
+-    tile_out_w = 64;
+-  }
+-
+-  if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
+-      p.kernel_h <= 4 && p.kernel_w <= 4) {
+-    mode = 3;
+-    tile_out_h = 16;
+-    tile_out_w = 64;
+-  }
+-
+-  if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
+-      p.kernel_h <= 2 && p.kernel_w <= 2) {
+-    mode = 4;
+-    tile_out_h = 16;
+-    tile_out_w = 64;
+-  }
+-
+-  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
+-      p.kernel_h <= 4 && p.kernel_w <= 4) {
+-    mode = 5;
+-    tile_out_h = 8;
+-    tile_out_w = 32;
+-  }
+-
+-  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
+-      p.kernel_h <= 2 && p.kernel_w <= 2) {
+-    mode = 6;
+-    tile_out_h = 8;
+-    tile_out_w = 32;
+-  }
+-
+-  dim3 block_size;
+-  dim3 grid_size;
+-
+-  if (tile_out_h > 0 && tile_out_w > 0) {
+-    p.loop_major = (p.major_dim - 1) / 16384 + 1;
+-    p.loop_x = 1;
+-    block_size = dim3(32 * 8, 1, 1);
+-    grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
+-                     (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
+-                     (p.major_dim - 1) / p.loop_major + 1);
+-  } else {
+-    p.loop_major = (p.major_dim - 1) / 16384 + 1;
+-    p.loop_x = 4;
+-    block_size = dim3(4, 32, 1);
+-    grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
+-                     (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
+-                     (p.major_dim - 1) / p.loop_major + 1);
+-  }
+-
+-  AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
+-    switch (mode) {
+-    case 1:
+-      upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
+-          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
+-                                                 x.data_ptr<scalar_t>(),
+-                                                 k.data_ptr<scalar_t>(), p);
+-
+-      break;
+-
+-    case 2:
+-      upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
+-          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
+-                                                 x.data_ptr<scalar_t>(),
+-                                                 k.data_ptr<scalar_t>(), p);
+-
+-      break;
+-
+-    case 3:
+-      upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
+-          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
+-                                                 x.data_ptr<scalar_t>(),
+-                                                 k.data_ptr<scalar_t>(), p);
+-
+-      break;
+-
+-    case 4:
+-      upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
+-          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
+-                                                 x.data_ptr<scalar_t>(),
+-                                                 k.data_ptr<scalar_t>(), p);
+-
+-      break;
+-
+-    case 5:
+-      upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
+-          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
+-                                                 x.data_ptr<scalar_t>(),
+-                                                 k.data_ptr<scalar_t>(), p);
+-
+-      break;
+-
+-    case 6:
+-      upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
+-          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
+-                                                 x.data_ptr<scalar_t>(),
+-                                                 k.data_ptr<scalar_t>(), p);
+-
+-      break;
+-
+-    default:
+-      upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
+-          out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
+-          k.data_ptr<scalar_t>(), p);
+-    }
+-  });
+-
+-  return out;
+-}
+diff --git a/basicsr/ops/upfirdn2d/upfirdn2d.py b/basicsr/ops/upfirdn2d/upfirdn2d.py
+index d6122d5..4768ec9 100644
+--- a/basicsr/ops/upfirdn2d/upfirdn2d.py
++++ b/basicsr/ops/upfirdn2d/upfirdn2d.py
+@@ -2,161 +2,11 @@
+ 
+ import os
+ import torch
+-from torch.autograd import Function
+ from torch.nn import functional as F
+ 
+-BASICSR_JIT = os.getenv('BASICSR_JIT')
+-if BASICSR_JIT == 'True':
+-    from torch.utils.cpp_extension import load
+-    module_path = os.path.dirname(__file__)
+-    upfirdn2d_ext = load(
+-        'upfirdn2d',
+-        sources=[
+-            os.path.join(module_path, 'src', 'upfirdn2d.cpp'),
+-            os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'),
+-        ],
+-    )
+-else:
+-    try:
+-        from . import upfirdn2d_ext
+-    except ImportError:
+-        pass
+-        # avoid annoying print output
+-        # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n '
+-        #       '1. compile with BASICSR_EXT=True. or\n '
+-        #       '2. set BASICSR_JIT=True during running')
+-
+-
+-class UpFirDn2dBackward(Function):
+-
+-    @staticmethod
+-    def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size):
+-
+-        up_x, up_y = up
+-        down_x, down_y = down
+-        g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
+-
+-        grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
+-
+-        grad_input = upfirdn2d_ext.upfirdn2d(
+-            grad_output,
+-            grad_kernel,
+-            down_x,
+-            down_y,
+-            up_x,
+-            up_y,
+-            g_pad_x0,
+-            g_pad_x1,
+-            g_pad_y0,
+-            g_pad_y1,
+-        )
+-        grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
+-
+-        ctx.save_for_backward(kernel)
+-
+-        pad_x0, pad_x1, pad_y0, pad_y1 = pad
+-
+-        ctx.up_x = up_x
+-        ctx.up_y = up_y
+-        ctx.down_x = down_x
+-        ctx.down_y = down_y
+-        ctx.pad_x0 = pad_x0
+-        ctx.pad_x1 = pad_x1
+-        ctx.pad_y0 = pad_y0
+-        ctx.pad_y1 = pad_y1
+-        ctx.in_size = in_size
+-        ctx.out_size = out_size
+-
+-        return grad_input
+-
+-    @staticmethod
+-    def backward(ctx, gradgrad_input):
+-        kernel, = ctx.saved_tensors
+-
+-        gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
+-
+-        gradgrad_out = upfirdn2d_ext.upfirdn2d(
+-            gradgrad_input,
+-            kernel,
+-            ctx.up_x,
+-            ctx.up_y,
+-            ctx.down_x,
+-            ctx.down_y,
+-            ctx.pad_x0,
+-            ctx.pad_x1,
+-            ctx.pad_y0,
+-            ctx.pad_y1,
+-        )
+-        # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
+-        #                                  ctx.out_size[1], ctx.in_size[3])
+-        gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1])
+-
+-        return gradgrad_out, None, None, None, None, None, None, None, None
+-
+-
+-class UpFirDn2d(Function):
+-
+-    @staticmethod
+-    def forward(ctx, input, kernel, up, down, pad):
+-        up_x, up_y = up
+-        down_x, down_y = down
+-        pad_x0, pad_x1, pad_y0, pad_y1 = pad
+-
+-        kernel_h, kernel_w = kernel.shape
+-        _, channel, in_h, in_w = input.shape
+-        ctx.in_size = input.shape
+-
+-        input = input.reshape(-1, in_h, in_w, 1)
+-
+-        ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
+-
+-        out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+-        out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+-        ctx.out_size = (out_h, out_w)
+-
+-        ctx.up = (up_x, up_y)
+-        ctx.down = (down_x, down_y)
+-        ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
+-
+-        g_pad_x0 = kernel_w - pad_x0 - 1
+-        g_pad_y0 = kernel_h - pad_y0 - 1
+-        g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
+-        g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
+-
+-        ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
+-
+-        out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1)
+-        # out = out.view(major, out_h, out_w, minor)
+-        out = out.view(-1, channel, out_h, out_w)
+-
+-        return out
+-
+-    @staticmethod
+-    def backward(ctx, grad_output):
+-        kernel, grad_kernel = ctx.saved_tensors
+-
+-        grad_input = UpFirDn2dBackward.apply(
+-            grad_output,
+-            kernel,
+-            grad_kernel,
+-            ctx.up,
+-            ctx.down,
+-            ctx.pad,
+-            ctx.g_pad,
+-            ctx.in_size,
+-            ctx.out_size,
+-        )
+-
+-        return grad_input, None, None, None, None
+-
+ 
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
+-    if input.device.type == 'cpu':
+-        out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
+-    else:
+-        out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]))
+-
+-    return out
++    return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
+ 
+ 
+ def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
+diff --git a/basicsr/utils/dist_util.py b/basicsr/utils/dist_util.py
+index 0fab887..2d0ae71 100644
+--- a/basicsr/utils/dist_util.py
++++ b/basicsr/utils/dist_util.py
+@@ -20,8 +20,6 @@ def init_dist(launcher, backend='nccl', **kwargs):
+ 
+ def _init_dist_pytorch(backend, **kwargs):
+     rank = int(os.environ['RANK'])
+-    num_gpus = torch.cuda.device_count()
+-    torch.cuda.set_device(rank % num_gpus)
+     dist.init_process_group(backend=backend, **kwargs)
+ 
+ 
+@@ -39,8 +37,6 @@ def _init_dist_slurm(backend, port=None):
+     proc_id = int(os.environ['SLURM_PROCID'])
+     ntasks = int(os.environ['SLURM_NTASKS'])
+     node_list = os.environ['SLURM_NODELIST']
+-    num_gpus = torch.cuda.device_count()
+-    torch.cuda.set_device(proc_id % num_gpus)
+     addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
+     # specify master port
+     if port is not None:
+@@ -52,8 +48,6 @@ def _init_dist_slurm(backend, port=None):
+         os.environ['MASTER_PORT'] = '29500'
+     os.environ['MASTER_ADDR'] = addr
+     os.environ['WORLD_SIZE'] = str(ntasks)
+-    os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+-    os.environ['RANK'] = str(proc_id)
+     dist.init_process_group(backend=backend)
+ 
+ 
+diff --git a/basicsr/utils/options.py b/basicsr/utils/options.py
+index 09bfa5a..f4333e9 100644
+--- a/basicsr/utils/options.py
++++ b/basicsr/utils/options.py
+@@ -134,8 +134,7 @@ def parse_options(root_path, is_train=True):
+     if args.debug and not opt['name'].startswith('debug'):
+         opt['name'] = 'debug_' + opt['name']
+ 
+-    if opt['num_gpu'] == 'auto':
+-        opt['num_gpu'] = torch.cuda.device_count()
++    opt['num_gpu'] = 0
+ 
+     # datasets
+     for phase, dataset in opt['datasets'].items():
+diff --git a/inference/inference_basicvsr.py b/inference/inference_basicvsr.py
+index 7b5e4b9..a55a182 100644
+--- a/inference/inference_basicvsr.py
++++ b/inference/inference_basicvsr.py
+@@ -30,7 +30,7 @@ def main():
+     parser.add_argument('--interval', type=int, default=15, help='interval size')
+     args = parser.parse_args()
+ 
+-    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
++    device = torch.device('cpu')
+ 
+     # set up model
+     model = BasicVSR(num_feat=64, num_block=30)
+diff --git a/inference/inference_basicvsrpp.py b/inference/inference_basicvsrpp.py
+index b44aaa4..9cbb988 100644
+--- a/inference/inference_basicvsrpp.py
++++ b/inference/inference_basicvsrpp.py
+@@ -30,7 +30,7 @@ def main():
+     parser.add_argument('--interval', type=int, default=100, help='interval size')
+     args = parser.parse_args()
+ 
+-    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
++    device = torch.device('cpu')
+ 
+     # set up model
+     model = BasicVSRPlusPlus(mid_channels=64, num_blocks=7)
+diff --git a/inference/inference_dfdnet.py b/inference/inference_dfdnet.py
+index 64a7a64..3a594ad 100644
+--- a/inference/inference_dfdnet.py
++++ b/inference/inference_dfdnet.py
+@@ -60,7 +60,7 @@ if __name__ == '__main__':
+     differences: 1) we use dlib for 68 landmark detection; 2) the used image
+     package are different (especially for reading and writing.)
+     """
+-    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
++    device = torch.device('cpu')
+     parser = argparse.ArgumentParser()
+ 
+     parser.add_argument('--upscale_factor', type=int, default=2)
+diff --git a/inference/inference_esrgan.py b/inference/inference_esrgan.py
+index e425b13..4dd3e6c 100644
+--- a/inference/inference_esrgan.py
++++ b/inference/inference_esrgan.py
+@@ -20,7 +20,7 @@ def main():
+     parser.add_argument('--output', type=str, default='results/ESRGAN', help='output folder')
+     args = parser.parse_args()
+ 
+-    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
++    device = torch.device('cpu')
+     # set up model
+     model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32)
+     model.load_state_dict(torch.load(args.model_path)['params'], strict=True)
+diff --git a/inference/inference_ridnet.py b/inference/inference_ridnet.py
+index 9825ba8..608efe2 100644
+--- a/inference/inference_ridnet.py
++++ b/inference/inference_ridnet.py
+@@ -10,7 +10,7 @@ from basicsr.archs.ridnet_arch import RIDNet
+ from basicsr.utils.img_util import img2tensor, tensor2img
+ 
+ if __name__ == '__main__':
+-    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
++    device = torch.device('cpu')
+     parser = argparse.ArgumentParser()
+     parser.add_argument('--test_path', type=str, default='datasets/denoise/RNI15')
+     parser.add_argument('--noise_g', type=int, default=25)
+diff --git a/inference/inference_stylegan2.py b/inference/inference_stylegan2.py
+index 52545ac..9348cab 100644
+--- a/inference/inference_stylegan2.py
++++ b/inference/inference_stylegan2.py
+@@ -30,7 +30,7 @@ def generate(args, g_ema, device, mean_latent, randomize_noise):
+ 
+ 
+ if __name__ == '__main__':
+-    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
++    device = torch.device('cpu')
+ 
+     parser = argparse.ArgumentParser()
+ 
+diff --git a/inference/inference_swinir.py b/inference/inference_swinir.py
+index 28e9bde..cfa59b0 100644
+--- a/inference/inference_swinir.py
++++ b/inference/inference_swinir.py
+@@ -33,7 +33,7 @@ def main():
+     args = parser.parse_args()
+ 
+     os.makedirs(args.output, exist_ok=True)
+-    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
++    device = torch.device('cpu')
+     # set up model
+     model = define_model(args)
+     model.eval()
+diff --git a/scripts/metrics/calculate_fid_folder.py b/scripts/metrics/calculate_fid_folder.py
+index 71b02e1..33bfe92 100644
+--- a/scripts/metrics/calculate_fid_folder.py
++++ b/scripts/metrics/calculate_fid_folder.py
+@@ -9,7 +9,7 @@ from basicsr.metrics.fid import calculate_fid, extract_inception_features, load_
+ 
+ 
+ def calculate_fid_folder():
+-    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
++    device = torch.device('cpu')
+ 
+     parser = argparse.ArgumentParser()
+     parser.add_argument('folder', type=str, help='Path to the folder.')
+diff --git a/scripts/metrics/calculate_fid_stats_from_datasets.py b/scripts/metrics/calculate_fid_stats_from_datasets.py
+index 56e3529..d5858d7 100644
+--- a/scripts/metrics/calculate_fid_stats_from_datasets.py
++++ b/scripts/metrics/calculate_fid_stats_from_datasets.py
+@@ -9,7 +9,7 @@ from basicsr.metrics.fid import extract_inception_features, load_patched_incepti
+ 
+ 
+ def calculate_stats_from_dataset():
+-    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
++    device = torch.device('cpu')
+ 
+     parser = argparse.ArgumentParser()
+     parser.add_argument('--num_sample', type=int, default=50000)
+diff --git a/scripts/metrics/calculate_stylegan2_fid.py b/scripts/metrics/calculate_stylegan2_fid.py
+index c5564b8..1723374 100644
+--- a/scripts/metrics/calculate_stylegan2_fid.py
++++ b/scripts/metrics/calculate_stylegan2_fid.py
+@@ -9,7 +9,7 @@ from basicsr.metrics.fid import calculate_fid, extract_inception_features, load_
+ 
+ 
+ def calculate_stylegan2_fid():
+-    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
++    device = torch.device('cpu')
+ 
+     parser = argparse.ArgumentParser()
+     parser.add_argument('ckpt', type=str, help='Path to the stylegan2 checkpoint.')
+diff --git a/setup.py b/setup.py
+index bc228e4..e8999c2 100644
+--- a/setup.py
++++ b/setup.py
+@@ -79,32 +79,6 @@ def get_version():
+     return locals()['__version__']
+ 
+ 
+-def make_cuda_ext(name, module, sources, sources_cuda=None):
+-    if sources_cuda is None:
+-        sources_cuda = []
+-    define_macros = []
+-    extra_compile_args = {'cxx': []}
+-
+-    if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
+-        define_macros += [('WITH_CUDA', None)]
+-        extension = CUDAExtension
+-        extra_compile_args['nvcc'] = [
+-            '-D__CUDA_NO_HALF_OPERATORS__',
+-            '-D__CUDA_NO_HALF_CONVERSIONS__',
+-            '-D__CUDA_NO_HALF2_OPERATORS__',
+-        ]
+-        sources += sources_cuda
+-    else:
+-        print(f'Compiling {name} without CUDA')
+-        extension = CppExtension
+-
+-    return extension(
+-        name=f'{module}.{name}',
+-        sources=[os.path.join(*module.split('.'), p) for p in sources],
+-        define_macros=define_macros,
+-        extra_compile_args=extra_compile_args)
+-
+-
+ def get_requirements(filename='requirements.txt'):
+     here = os.path.dirname(os.path.realpath(__file__))
+     with open(os.path.join(here, filename), 'r') as f:
+@@ -113,36 +87,6 @@ def get_requirements(filename='requirements.txt'):
+ 
+ 
+ if __name__ == '__main__':
+-    cuda_ext = os.getenv('BASICSR_EXT')  # whether compile cuda ext
+-    if cuda_ext == 'True':
+-        try:
+-            import torch
+-            from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
+-        except ImportError:
+-            raise ImportError('Unable to import torch - torch is needed to build cuda extensions')
+-
+-        ext_modules = [
+-            make_cuda_ext(
+-                name='deform_conv_ext',
+-                module='basicsr.ops.dcn',
+-                sources=['src/deform_conv_ext.cpp'],
+-                sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']),
+-            make_cuda_ext(
+-                name='fused_act_ext',
+-                module='basicsr.ops.fused_act',
+-                sources=['src/fused_bias_act.cpp'],
+-                sources_cuda=['src/fused_bias_act_kernel.cu']),
+-            make_cuda_ext(
+-                name='upfirdn2d_ext',
+-                module='basicsr.ops.upfirdn2d',
+-                sources=['src/upfirdn2d.cpp'],
+-                sources_cuda=['src/upfirdn2d_kernel.cu']),
+-        ]
+-        setup_kwargs = dict(cmdclass={'build_ext': BuildExtension})
+-    else:
+-        ext_modules = []
+-        setup_kwargs = dict()
+-
+     write_version_py()
+     setup(
+         name='basicsr',
+@@ -167,6 +111,4 @@ if __name__ == '__main__':
+         license='Apache License 2.0',
+         setup_requires=['cython', 'numpy', 'torch'],
+         install_requires=get_requirements(),
+-        ext_modules=ext_modules,
+-        zip_safe=False,
+-        **setup_kwargs)
++        zip_safe=False)