pytorch构造GRUCell
CLASStorch.nn.GRUCell(input_size, hidden_size, bias=True)
import torch
import numpy as np
import torch.nn as nn
BATCH_SIZE = 3
SEQ_LEN = 6
INPUT_SIZE = 10
HIDDEN_SIZE = 20
gru_cell = nn.GRUCell(INPUT_SIZE, HIDDEN_SIZE)
input = torch.randn(SEQ_LEN, BATCH_SIZE, INPUT_SIZE)
hx = torch.randn(BATCH_SIZE, HIDDEN_SIZE)
# shape:(batch_size, hidden_size)
output = []
for i in range(SEQ_LEN):
hx = gru_cell(input[i], hx)
output.append(hx)
print(output[0].shape)