Answers for "pytorch cuda tensor in module"

0

pytorch cuda tensor in module

import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.weight = torch.nn.Parameter(torch.zeros(2, 1))
        self.bias = torch.nn.Parameter(torch.zeros(1))
        self.register_buffer('a_constant_tensor', torch.tensor([0.5]))

    def forward(self, x):
        # linear regression completely from scratch,
        # using parameters created in __init__
        x = torch.mm(x, self.weight) + self.bias + self.a_constant_tensor
        return x


model = Model().cuda()
Posted by: Guest on September-03-2021

Python Answers by Framework

Browse Popular Code Answers by Language