์ผ | ์ | ํ | ์ | ๋ชฉ | ๊ธ | ํ |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 |
- ๋ฐ์ดํฐ
- ๋ถ์
- ์๊ณ ๋ฆฌ์ฆ
- ์ธํ๋ฐ
- ๋ฆฌ์กํธ
- ์ ํ๋์ํ
- Titanic
- native
- nlp
- ๋จธ์ ๋ฌ๋
- ๊ฒฐ์ ํธ๋ฆฌ
- AI
- react
- ํ์ดํ๋
- c++
- cs231n
- ์๋๋ก์ด๋์คํ๋์ค
- ๋ฐ์ดํฐ์๊ฐํ
- Kaggle
- ํ๊ตญ์ด์๋ฒ ๋ฉ
- ๋ฐฑ์ค
- ๋ค์ดํฐ๋ธ
- ์๋ฒ ๋ฉ
- ๋์
- ๋ฅ๋ฌ๋
- Git
- linearalgebra
- ๋ฐ์ดํฐ๋ถ์
- ์ํ์ฝ๋ฉ
- ๊นํ
- Today
- Total
yeon's ๐ฉ๐ป๐ป
RNN์ด๋? ๋ณธ๋ฌธ
https://velog.io/@sksmslhy/RNN
Recurrent Neural Network, RNN(์ํ์ ๊ฒฝ๋ง)์ด๋?
RNN์ Recurrent ๋จ์ด ๊ทธ๋๋ก ๋ฐ๋ณต๋๋ ์ ๊ฒฝ๋ง์ด๋ค. ์ฆ, ์ค์ค๋ก๋ฅผ ๋ฐ๋ณตํ๋ฉด์ ์ด์ ๋จ๊ณ์์ ์ป์ ์ ๋ณด๊ฐ ์ง์๋๋๋ก ํ๋ค.
velog.io
์ ๋ธ๋ก๊ทธ๋ฅผ ํ์ฌํ๋ฉฐ ๊ณต๋ถ
* ๋ชจ๋ ํ ์คํธ์ ์ด๋ฏธ์ง์ ์ถ์ฒ๋ ์ ๋ธ๋ก๊ทธ์ ๋๋ค.
RNN (Recurrent Neural Network, ์ํ ์ ๊ฒฝ๋ง)
RNN์ด๋?
- RNN์ Recurrent ๋จ์ด ๊ทธ๋๋ก ๋ฐ๋ณต๋๋ ์ ๊ฒฝ๋ง์ด๋ค.
- ์ฆ, ์ค์ค๋ก๊ฐ ๋ฐ๋ณตํ๋ฉด์ ์ด์ ๋จ๊ณ์์ ์ป์ ์ ๋ณด๊ฐ ์ง์๋๋๋ก ํ๋ค.
- RNN์ ๊ธฐ์กด Neural Network์ ๊ตฌ์กฐ๊ฐ ์๋นํ ๋น์ทํ๋ค.
- CNN๊ณผ ๊ฐ์ ์ ๊ฒฝ๋ง๋ค์ ์ ๋ถ hidden layer์์ activation function์ ์ง๋ ๊ฐ์ ์ค์ง ์ถ๋ ฅ์ธต์ผ๋ก๋ง ํฅํ๋ค.
(์ด๋ฐ ์ ๊ฒฝ๋ง์ Feed Forward Neural Network๋ผ๊ณ ํจ)
- ๊ทธ๋ฌ๋ RNN์ hidden node์์ activation function์ ํตํด ๋์จ ์ถ๋ ฅ์ ์ถ๋ ฅ์ธต์ผ๋ก๋ ๋ด๋ณด๋ด๊ณ ,
hidden node์ ๋ค์ ์ฐ์ฐ์ ์ ๋ ฅ์ผ๋ก๋ ๋ด๋ณด๋ด๋ ํน์ง์ ๊ฐ์ง๊ณ ์๋ค.
# RNN์ ๊ตฌ์กฐ
- A๋ ์ ๋ ฅ์ผ๋ก Xt๋ฅผ ๋ฐ์ ht๋ฅผ ์ถ๋ ฅํ๋ค.
- A๋ฅผ ๋๋ฌ์ผ ๋ฐ๋ณต์ ๋ค์ ๋จ๊ณ์์์ network๊ฐ ์ด์ ๋จ๊ณ์ ์ ๋ณด๋ฅผ ๋ฐ๋๋ค๋ ๊ฒ์ ์๋ฏธํ๋ค.
- ์ผ์ชฝ์ ๋ฐ๋ณต์ ํ์ด์ ๋ณด๋ฉด ์ค๋ฅธ์ชฝ์ด ๋๋ค.
- ์ด์ ๋จ๊ณ์์์ ์ ๋ณด๊ฐ ๋ค์ ๋จ๊ณ์์ ์ฌ์ฉ๋๋ค.
- CNN์์์ ๋ง์ฐฌ๊ฐ์ง๋ก, bias๋ ์ ๋ ฅ์ผ๋ก ์กด์ฌํ ์ ์๋ค.
- RNN์์ hidden layer์์ activation function์ ํตํด ์ถ๋ ฅ์ ๋ด๋ณด๋ด๋ ์ญํ ์ ํ๋ node๋ฅผ cell์ด๋ผ๊ณ ํ๋ค
- (์ ๊ทธ๋ฆผ์์ A) ์ด cell์ ์ด์ ์ ๊ฐ์ ๊ธฐ์ตํ๋ ์ญํ ์ ํ๋ฏ๋ก memory cell ๋๋ RNN cell์ด๋ผ๊ณ ๋ถ๋ฅธ๋ค.
- Memory cell์ด ์ถ๋ ฅ์ธต ๋ฐฉํฅ์ผ๋ก ๋๋ ๋ค์ ์์ ์ธ t+1์ ์์ ์๊ฒ ๋ณด๋ด๋ ๊ฐ์ hidden state๋ผ๊ณ ํ๋ค.
- ์ฆ, t ์์ ์ memory cell์ t-1 ์์ ์ memory cell์ด ๋ณด๋ธ hidden state ๊ฐ์ t ์์ ์ hidden state ๊ณ์ฐ์ ์ํ ์ ๋ ฅ๊ฐ์ผ๋ก ์ฌ์ฉํ๋ค.
- RNN์์ Xt์ ht๋ ๊ธฐ๋ณธ์ ์ผ๋ก ๋ฒกํฐ ๋จ์์ด๋ค.
- ์ด๋ฌํ ๊ตฌ์กฐ๋ก ์ธํด RNN์ ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ ๋ ์ ๋ฆฌํ๋ค.
- RNN์ ์ ๋ ค๊ณ ๊ฐ ์ถ๋ ฅ์ ๊ธธ์ด(Xt, ht์ ๊ฐ์)๋ฅผ ๋ค๋ฅด๊ฒ ์ค๊ณ ๊ฐ๋ฅํ๋ค.
- ์ ๋ ฅ๊ณผ ์ถ๋ ฅ์ ๊ธธ์ด์ ๋ฐ๋ผ ์ผ๋์ผ, ๋ค๋์ผ, ๋ค๋๋ค ๋ฑ ์กด์ฌ ๊ฐ๋ฅ
- 'pytorch'๋ผ๋ ๋จ์ด๋ฅผ ์์ธกํ๋ ๋ชจ๋ธ์ ์๊ฐํด๋ณด์
- p๋ผ๋ ์ ๋ ฅ์ด ๋ค์ด์์ ๋ ๋ค์์ ๋์ฌ ์ํ๋ฒณ์ผ๋ก y๋ฅผ ๊ธฐ๋ํ๊ฒ ๋๋ค.
- ์ด๋ฌํ ๊ฒฝ์ฐ, ์์ ๋ค์ด์๋ ์ ๋ ฅ๊ฐ์ ๋ํ ์ ๋ณด๊ฐ ์๋ค๋ฉด ์ ๋๋ก ์์ธก x
- ์ด๋ฌํ ๊ณผ์ ์ด ์ง์ ํ ์๊ฐ๋งํผ ๋ฐ๋ณต๋จ
- ์ผ์ ์๊ฐ ๋์ ๋ชจ๋ ๊ฐ์ด ๊ณ์ฐ๋๋ฉด, ๋ชจ๋ธ์ ํ์ตํ๊ธฐ ์ํด ๊ฒฐ๊ด๊ฐ๊ณผ ๋ชฉํฏ๊ฐ์ ์ฐจ์ด๋ฅผ loss function์ ํตํด ๊ณ์ฐํ๊ณ ์ญ์ ํ(back propagation)ํด์ผ ํ๋ค.
- ๊ธฐ์กด์ ์ญ์ ํ์๋ ๋ฌ๋ฆฌ RNN์ ๊ณ์ฐ์ ์ฌ์ฉ๋ ์์ ์ ์์ ์ํฅ์ ๋ฐ๋๋ค.
ex) t=0 ~ t=2 ๊ฐ ๊ณ์ฐ์ ์ฌ์ฉ๋์๋ค๋ฉด, ๊ทธ ์๊ฐ ์ ์ฒด์ ๋ํด ์ญ์ ํ๋ฅผ ํด์ผ ํจ
: BPTT (Backpropogation Trough Time), ์๊ฐ์ ๋ฐ๋ฅธ ์ญ์ ํ
- ์ ๊ทธ๋ฆผ์์ t=0, 1, 2์ธ ์์ ์์ ๊ฐ๊ฐ ๊ฒฐ๊ณผ๊ฐ์ด ๋์ค๊ณ ๋ชฉํ๊ฐ๊ณผ ๋น๊ต๋๋ ๊ฒ์ ๋ณผ ์ ์๋ค.
- ๋ค์ ๋จ์ด 'pytorch'๋ก ์๋ฅผ ๋ค๋ฉด, t=0์์ p๊ฐ input์ผ๋ก ๋ค์ด๊ฐ๊ณ , t=0์์ output์ผ๋ก y๊ฐ ์ถ๋ ฅ๋๊ธฐ๋ฅผ ๊ธฐ๋ํ๋ค.
- ๋ฐ๋ผ์ target_0์๋ y๊ฐ ๋ค์ด๊ฐ๊ณ , ๊ฒฐ๊ณผ๊ฐ y์ ๊ฐใ ์ง ์๋ค๋ฉด loss๊ฐ ์๊ธธ ๊ฒ์ด๋ค.
- t=1์์ y๊ฐ ๋ค์ด๊ฐ๋ฉด target_1์ t์ด๊ณ , t=2์์๋ input์ด t, target_2๋ 0์ด ๋๋ค.
- ์ด ๋, ๋ชจ๋ธ์ ํ์ตํ๋ ค๋ฉด t=2 ์์ ์์ ๋ฐ์ํ loss๋ฅผ back propagation ํ๊ธฐ ์ํด
loss๋ฅผ input๊ณผ hidden layer ์ฌ์ด์ ๊ฐ์ค์น๋ก ๋ฏธ๋ถํ์ฌ loss์ ๋ํ ๊ฐ ๋น์ค์ ๊ตฌํด ์ ๋ฐ์ดํธ ํ๋ฉด ๋๋ค.
- ์ด ๊ณผ์ ์์ ์ด์ ์์ ์ ๊ฐ๋ค์ด ์ฐ์ฐ์ ํฌํจํ๊ฒ ๋๋๋ฐ, ์ด์ ์์ ์ ๊ฐ๋ค์ ๋ค์ ๊ฐ์ค์น, input, ๊ทธ ์ ์์ ์ ๊ฐ์ ์กฐํฉ์ด๋ค.
- RNN์ ๊ฐ ์์น๋ณ๋ก ๊ฐ์ ๊ฐ์ค์น๋ฅผ ๊ณต์ ํ๋ฏ๋ก
t=2 ์์ ์ loss๋ฅผ back propagationํ๋ ค๋ฉด ๊ฒฐ๊ณผ์ ์ผ๋ก t=0 ์์ ์ ๋ ธ๋ ๊ฐ๋ค์๋ ๋ชจ๋ ์ํฅ์ ์ฃผ์ด์ผ ํ๋ค.
- ์ฆ, ์๊ฐ์ ์ญ์ผ๋ก ๊ฑฐ์ฌ๋ฌ ์ฌ๋ผ๊ฐ๋ ๋ฐฉ์์ผ๋ก ๊ฐ ๊ฐ์ค์น๋ค์ ์ ๋ฐ์ดํธ ํด์ผ ํ๋ ๊ฒ์ด๋ค.
- ๊ฐ์ค์น๋ฅผ ํ์ํ๋ฉด ๋ค์๊ณผ ๊ฐ๋ค.
- ์ด ๊ทธ๋ฆผ์์ t=2์ธ ์์ ๋ง ๋ณด๋ฉด ๋ค์๊ณผ ๊ฐ๋ค.
- ๊ธฐ๋ณธ์ ์ผ๋ก RNN์์๋ activation function์ผ๋ก tanh ํจ์๋ฅผ ์ฌ์ฉ
'Computer ๐ป > Deep Learning' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
RNN in TensorFlow(ํ ์ํ๋ก์ฐ) (0) | 2021.10.28 |
---|---|
RNN์ด๋? (2) (0) | 2021.10.28 |
Model Subclassing API (CNN) (0) | 2021.10.27 |
Functional API (CNN) (0) | 2021.10.27 |
Sequential API (CNN) (0) | 2021.10.27 |