pytorch构造GRUCell

CLASStorch.nn.GRUCell(input_size, hidden_size, bias=True)

import torch
import numpy as np
import torch.nn as nn

BATCH_SIZE = 3
SEQ_LEN = 6
INPUT_SIZE = 10
HIDDEN_SIZE = 20


gru_cell = nn.GRUCell(INPUT_SIZE, HIDDEN_SIZE)
input = torch.randn(SEQ_LEN, BATCH_SIZE, INPUT_SIZE)
hx = torch.randn(BATCH_SIZE, HIDDEN_SIZE)

# shape:(batch_size, hidden_size)
output = []
for i in range(SEQ_LEN):
    hx = gru_cell(input[i], hx)
    output.append(hx)

print(output[0].shape)

 

hmoban主题是根据ripro二开的主题,极致后台体验,无插件,集成会员系统
自学咖网 » pytorch构造GRUCell