2022年10月1日
/*
gcc -g ql_test.cpp -o ql_test
強化学習(Q-Learning)を理解する為に、中学→高校→大学の学歴を使ってみた
*/
#include <stdio.h>
#include <stdlib.h>
typedef enum period{
BIRTH = 0, JUNIOR_HIGH = 1, HIGH = 2, COLLEGE = 3, SUPER_COLLEGE = 4
}PERIOD;
typedef struct state{
struct state* future_state[2]; // 未来へのパス(取り敢えず2つほど)
PERIOD period;
int q;
}STATE;
STATE* change_state(STATE* p_state)
{
if ((double)rand()/RAND_MAX < 0.3){ // ε:0.3
if ((double)rand()/RAND_MAX < 0.5){ // 半々
return p_state->future_state[0];
}
else{
return p_state->future_state[1];
}
}
else {
if (p_state->future_state[0]->q > p_state->future_state[1]->q){
return p_state->future_state[0];
}
else{
return p_state->future_state[1];
}
}
}
void q_renewal(STATE* p_state)
{
int dummy_q;
if (p_state->period == SUPER_COLLEGE){
p_state->q += 0.1 * (1000- p_state->q); // α:0.1 報酬の源泉:年収1000万円
}
else if (p_state->period != COLLEGE){
if (p_state->future_state[0]->q > p_state->future_state[1]->q){
dummy_q = p_state->future_state[0]->q;
}
else {
dummy_q = p_state->future_state[1]->q;
}
p_state->q += 0.1 * (0.9 * dummy_q - p_state->q); // α:0.1 γ:0.9
}
return;
}
void q_display(STATE* p_state)
{
for (int i =0; i < 15 ; i++){
printf("%d,", p_state->q);
p_state++;
}
printf("\n");
return;
}
int main()
{
srand(13);
// 初期設定
//STATE* state;
STATE state[15];
state[0].period = BIRTH;
state[0].future_state[0] = &(state[1]);
state[0].future_state[1] = &(state[2]);
state[1].period = JUNIOR_HIGH;
state[1].future_state[0] = &(state[3]);
state[1].future_state[1] = &(state[4]);
state[2].period = JUNIOR_HIGH;
state[2].future_state[0] = &(state[5]);
state[2].future_state[1] = &(state[6]);
state[3].period = HIGH;
state[3].future_state[0] = &(state[7]);
state[3].future_state[1] = &(state[8]);
state[4].period = HIGH;
state[4].future_state[0] = &(state[9]);
state[4].future_state[1] = &(state[10]);
state[5].period = HIGH;
state[5].future_state[0] = &(state[11]);
state[5].future_state[1] = &(state[12]);
state[6].period = HIGH;
state[6].future_state[0] = &(state[13]);
state[6].future_state[1] = &(state[14]);
state[7].period = COLLEGE;
state[8].period = COLLEGE;
state[9].period = COLLEGE;
state[10].period = SUPER_COLLEGE;
state[11].period = COLLEGE;
state[12].period = COLLEGE;
state[13].period = COLLEGE;
state[14].period = COLLEGE;
for (int i = 0; i < 15; i++){
state[i].q = (int)rand() % 100;
}
printf("誕生,A中学,B中学,C高校,D高校,E高校,F高校,G大学,H大学,I大学,J大学,K大学,L大学,M大学,N大学\n");
STATE* s = state;
//q_display(s);
q_display(state);
for (int i = 0; i < 1000; i++){ // 300:学習回数
STATE* s = state; // 初期値に戻しているだけ
do{
s = change_state(s);
q_renewal(s);
}while( (s->period != COLLEGE) && (s->period != SUPER_COLLEGE));
q_display(state);
}
printf("\n[after]\n");
//q_display(s);
q_display(state);
}