diff --git a/qa/python_models/torchvision/resnet50/model.py b/qa/python_models/torchvision/resnet50/model.py index 46f83bcafd..0e6405d302 100644 --- a/qa/python_models/torchvision/resnet50/model.py +++ b/qa/python_models/torchvision/resnet50/model.py @@ -25,9 +25,9 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import torch -import torchvision import triton_python_backend_utils as pb_utils from torch.utils.dlpack import to_dlpack +from torchvision import models class TritonPythonModel: @@ -37,9 +37,7 @@ def initialize(self, args): """ self.device = "cuda" if args["model_instance_kind"] == "GPU" else "cpu" self.model = ( - torchvision.models.resnet50( - weights=torchvision.models.ResNet50_Weights.DEFAULT - ) + models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2) .to(self.device) .eval() )