強化学習シミュレーション

2022年10月1日

/* 
   gcc -g ql_test.cpp -o ql_test
 
   強化学習(Q-Learning)を理解する為に、中学→高校→大学の学歴を使ってみた
 
*/
 
#include <stdio.h>
#include <stdlib.h>
 
typedef enum period{
  BIRTH = 0, JUNIOR_HIGH = 1, HIGH = 2, COLLEGE = 3, SUPER_COLLEGE = 4
}PERIOD;
 
typedef struct state{
  struct state* future_state[2]; // 未来へのパス(取り敢えず2つほど)
  PERIOD period;
  int q;
}STATE;
 
 
STATE* change_state(STATE* p_state)
{
  if ((double)rand()/RAND_MAX < 0.3){  // ε:0.3
	if ((double)rand()/RAND_MAX < 0.5){ // 半々
	  return p_state->future_state[0];
	}
	else{
	  return p_state->future_state[1];
	}
  }
  else {
	if (p_state->future_state[0]->q > p_state->future_state[1]->q){
	  return p_state->future_state[0];
	}
	else{
	  return p_state->future_state[1];
	}
  }
}
	
void  q_renewal(STATE* p_state)
{
  int dummy_q;
 
  if (p_state->period  == SUPER_COLLEGE){
	p_state->q += 0.1 * (1000- p_state->q); // α:0.1 報酬の源泉:年収1000万円
  }
  else if (p_state->period !=  COLLEGE){
	if (p_state->future_state[0]->q > p_state->future_state[1]->q){
	  dummy_q = p_state->future_state[0]->q;
	}
	else {
	  dummy_q = p_state->future_state[1]->q;
	}
	p_state->q += 0.1 * (0.9 * dummy_q - p_state->q); // α:0.1 γ:0.9
  }
 
  return;
}
	
void q_display(STATE* p_state)
{
  for (int i =0; i < 15 ; i++){
	printf("%d,", p_state->q);
	p_state++;
  }
  printf("\n");
  return;
}
 
 
 
int main()
{
  srand(13);
 
 
  // 初期設定
  //STATE* state;
  STATE state[15];
 
  state[0].period = BIRTH;
  state[0].future_state[0] = &(state[1]);
  state[0].future_state[1] = &(state[2]);
 
  state[1].period = JUNIOR_HIGH;
  state[1].future_state[0] = &(state[3]);
  state[1].future_state[1] = &(state[4]);
 
  state[2].period = JUNIOR_HIGH;
  state[2].future_state[0] = &(state[5]);
  state[2].future_state[1] = &(state[6]);
 
  state[3].period = HIGH;
  state[3].future_state[0] = &(state[7]);
  state[3].future_state[1] = &(state[8]);
 
  state[4].period = HIGH;
  state[4].future_state[0] = &(state[9]);
  state[4].future_state[1] = &(state[10]);
 
  state[5].period = HIGH;
  state[5].future_state[0] = &(state[11]);
  state[5].future_state[1] = &(state[12]);
 
  state[6].period = HIGH;
  state[6].future_state[0] = &(state[13]);
  state[6].future_state[1] = &(state[14]);
 
  state[7].period = COLLEGE;
  state[8].period = COLLEGE;
  state[9].period = COLLEGE;
  state[10].period = SUPER_COLLEGE;
  state[11].period = COLLEGE;
  state[12].period = COLLEGE;
  state[13].period = COLLEGE;
  state[14].period = COLLEGE;
  
  for (int i = 0; i < 15; i++){
	state[i].q = (int)rand() % 100;
  }
 
  printf("誕生,A中学,B中学,C高校,D高校,E高校,F高校,G大学,H大学,I大学,J大学,K大学,L大学,M大学,N大学\n");

  STATE* s = state;
  //q_display(s);
  q_display(state);
 
  for (int i = 0; i < 1000; i++){  // 300:学習回数
	STATE* s = state; // 初期値に戻しているだけ
	
	do{ 
	  s = change_state(s);
	  q_renewal(s);
	}while( (s->period != COLLEGE) && (s->period != SUPER_COLLEGE));

	q_display(state);

  }
 
  printf("\n[after]\n");
  //q_display(s);
  q_display(state);
 
}

2022年10月1日2017,江端さんの技術メモ

Posted by ebata