Skip to content

Commit 1307ace

Browse files
authored
fix: Use correct torchvision model weights (#8628)
1 parent b1bacfc commit 1307ace

1 file changed

Lines changed: 2 additions & 4 deletions

File tree

  • qa/python_models/torchvision/resnet50

qa/python_models/torchvision/resnet50/model.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

2727
import torch
28-
import torchvision
2928
import triton_python_backend_utils as pb_utils
3029
from torch.utils.dlpack import to_dlpack
30+
from torchvision import models
3131

3232

3333
class TritonPythonModel:
@@ -37,9 +37,7 @@ def initialize(self, args):
3737
"""
3838
self.device = "cuda" if args["model_instance_kind"] == "GPU" else "cpu"
3939
self.model = (
40-
torchvision.models.resnet50(
41-
weights=torchvision.models.ResNet50_Weights.DEFAULT
42-
)
40+
models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
4341
.to(self.device)
4442
.eval()
4543
)

0 commit comments

Comments
 (0)