We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent b1bacfc commit 1307aceCopy full SHA for 1307ace
1 file changed
qa/python_models/torchvision/resnet50/model.py
@@ -25,9 +25,9 @@
25
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
27
import torch
28
-import torchvision
29
import triton_python_backend_utils as pb_utils
30
from torch.utils.dlpack import to_dlpack
+from torchvision import models
31
32
33
class TritonPythonModel:
@@ -37,9 +37,7 @@ def initialize(self, args):
37
"""
38
self.device = "cuda" if args["model_instance_kind"] == "GPU" else "cpu"
39
self.model = (
40
- torchvision.models.resnet50(
41
- weights=torchvision.models.ResNet50_Weights.DEFAULT
42
- )
+ models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
43
.to(self.device)
44
.eval()
45
)
0 commit comments