Recurrent Neural Networks

Update: 2023-10-26

A Recurrent Neural Network (RNN) [1] maintains hidden states of previous inputs and uses them to predict outputs, allowing it to model temporal dependencies in sequential data.

The hidden state is a vector representing the network's internal memory of the previous time step. It captures information from previous time steps and influences the predictions made at the current time step, often updated at each time step as the RNN processes a sequence of inputs.

RNN for Sequence Tagging

Given an input sequence X=[x1,,xn]X = [x_1, \ldots, x_n] where xiRd×1x_i \in \mathbb{R}^{d \times 1}, an RNN for sequence tagging defines two functions, ff and gg:

  • ff takes the current input xiXx_i \in X and the hidden state hi1h_{i-1} of the previous input xi1x_{i-1}, and returns a hidden state hiRe×1h_i \in \mathbb{R}^{e \times 1} such that f(xi,hi1)=α(Wxxi+Whhi1)=hif(x_i, h_{i-1}) = \alpha(W^x x_i + W^h h_{i-1}) = h_i, where WxRe×dW^x \in \mathbb{R}^{e \times d}, WhRe×eW^h \in \mathbb{R}^{e \times e}, and α\alpha is an activation function.

  • gg takes the hidden state hih_i and returns an output yiRo×1y_i \in \mathbb{R}^{o \times 1} such that g(hi)=Wohi=yig(h_i) = W^o h_i = y_i, where WoRo×eW^o \in \mathbb{R}^{o \times e}.

Figure 1 shows an example of an RNN for sequence tagging, such as part-of-speech tagging:

Notice that the output y1y_1 for the first input x1x_1 is predicted by considering only the input itself such that f(x1,0)=α(Wxx1)=h1f(x_1, \mathbf{0}) = \alpha(W^x x_1) = h_1 (e.g., the POS tag of the first word "I" is predicted solely using that word). However, the output yiy_i for every other input xix_i is predicted by considering both xix_i and hi1h_{i-1}, an intermediate representation created explicitly for the task. This enables RNNs to capture sequential information that Feedforward Neural Networks cannot.

What does each hidden state hih_i represent in the RNN for sequence tagging?

RNN for Text Classification

Unlike sequence tagging where the RNN predicts a sequence of output Y=[y1,,yn]Y = [y_1, \ldots, y_n] for the input X=[x1,,xn]X = [x_1, \ldots, x_n], an RNN designed for text classification predicts only one output yy for the entire input sequence such that:

  • Sequence TaggingRNNst(X)Y\text{RNN}_{st}(X) \rightarrow Y

  • Text Classification: RNNst(X)y\text{RNN}_{st}(X) \rightarrow y

To accomplish this, a common practice is to predict the output yy from the last hidden state hnh_n using the function gg. Figure 2 shows an example of an RNN for text classification, such as sentiment analysis:

What does the hidden state hnh_n represent in the RNN for text classification?

Bidirectional RNN

The RNN for sequence tagging above does not consider the words that follow the current word when predicting the output. This limitation can significantly impact model performance since contextual information following the current word can be crucial.

For example, let us consider the word "early" in the following two sentences:

  • They are early birds -> "early" is an adjective.

  • They are early today -> "early" is an adverb.

The POS tags of "early" depend on the following words, "birds" and "today", such that making the correct predictions becomes challenging without the following context.

To overcome this challenge, a Bidirectional RNN is suggested [2] that considers both forward and backward directions, creating twice as many hidden states to capture a more comprehensive context. Figure 3 illustrates a bidirectional RNN for sequence tagging:

For every xix_i, the hidden states hi\overrightarrow{h}_i and hi\overleftarrow{h}_i are created by considering hi1\overrightarrow{h}_{i-1} and hi+1\overleftarrow{h}_{i+1}, respectively. The function gg takes both hi\overrightarrow{h}_i and hi\overleftarrow{h}_i and returns an output yiRo×1y_i \in \mathbb{R}^{o \times 1} such that g(hi,hi)=Wo(hihi)=yig(\overrightarrow{h}_i, \overleftarrow{h}_i) = W^o (\overrightarrow{h}_i \oplus \overleftarrow{h}_i) = y_i, where (hihi)R2e×1(\overrightarrow{h}_i \oplus \overleftarrow{h}_i) \in \mathbb{R}^{2e \times 1} is a concatenation of the two hidden states and WoRo×2eW^o \in \mathbb{R}^{o \times 2e}.

Does it make sense to use bidirectional RNN for text classification? Explain your answer.

Advanced Topics

  • Long Short-Term Memory (LSTM) Networks [3-5]

  • Gated Recurrent Units (GRUs) [6-7]

References

  1. Finding Structure in Time, Elman, Cognitive Science, 14(2), 1990.

  2. Bidirectional Recurrent Neural Networks, Schuster and Paliwal, IEEE Transactions on Signal Processing, 45(11), 1997.

  3. Long Short-Term Memory, Hochreiter and Schmidhuber, Neural Computation, 9(8), 1997 (PDF available at ResearchGate).

  4. Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling, Chung et al., NeurIPS Workshop on Deep Learning and Representation Learning, 2014.*

Last updated

Copyright © 2023 All rights reserved