RL4J – Reinforcement Learning usando Java

Aparte de mis labores diarias como desarrollador de software, práctico el Ajedrez. No soy muy bueno en mi hobby pero me entretiene bastante.

Desde que la supercomputadora Deep Blue vence al campeón mundial Garry Kasparov en el match disputado en mayo de 1997, las máquinas han conseguido ser invencibles incluso para los mejores jugadores del mundo. En diciembre de 2017 DeepMind introdujo a AlphaZero, un programa de computadoras desarrollado con Inteligencia Artificial (IA) que logró vencer al por entonces mejor modulo de análisis del Ajedrez, Stockfish.

Stockfish representa el clásico enfoque de calcular y buscar la mejor jugada usando algoritmos de fuerza bruta y la capacidad computacional disponible contra el enfoque de la IA. Fue tal el impacto de ese match que surgieron otros motores de análisis como Lc0 y para junio de 2020 se anunció el desarrollo de Stockfish NNUE el cual incorpora redes neuronales a su poderoso módulo de análisis.

Inteligencia Artificial, Software y Ajedrez: era evidente que todo esto junto iba a despertar mi interés. De IA conocía menos que el ajedrez, comencé estudiando Python para continuar con Tensor Flow y Keras, entre otros. A parte de la teoría sobre IA en diversos libros, tengo mayor experiencia como desarrollador Java, por lo cual mezclar la teoría con la práctica en Python no estaba siendo una tarea sencilla. Entonces decidí dejar de momento Python, para enfocar mi práctica utilizando Java.

Tres enfoques básicos para Machine Learning son: Supervised Learning, Unsupervised Learning y Reinforcement Learning (RL). RL fue el enfoque utilizado por DeepMind para desarrollar a AlphaZero por lo que RL fue lo que más me interesó indagar. 

¿Qué es Reinforcement Learning?

Es un enfoque computacional para entender y automatizar aprendizaje guiado por objetivos y la toma de decisiones. RL usa el framework de Markov decision processes para definir la interacción entre un agente que aprende y su entorno, en términos de estados, acciones y recompensas. Puede parecer sencillo el resumen, pero no lo es, y como soy más de práctica que de teoría, a continuación presentaré un pequeño ejercicio utilizando RL4J. Este es un framework para RL que hace parte de Eclipse Deeplearning4j, una librería de código abierto para Java.

¿En qué consiste el ejercicio?

El Ajedrez es demasiado complejo para este ejercicio, por lo cual lo cambiare por un juego más sencillo, el Tic Tac Toe o tres en línea. Este consiste de un tablero de 3×3 y dos jugadores que intentan formar una línea en el tablero en su turno.

Conociendo el problema y teniendo en cuenta que el jugador debe tomar una decisión podemos aplicar RL para solucionarlo.

Primero necesitamos definir qué estructura se va a utilizar para el aprendizaje. Podríamos usar tablas de valores, o cualquier otro algoritmo de búsqueda para definir lo que se quiere aprender. En estos momentos lo mejor que se conoce son las redes neuronales y la siguiente puede ser una configuración apropiada con RL4J:

public static QLearningConfiguration TTT_QL = QLearningConfiguration.builder().doubleDQN(true)     .epsilonNbStep(10000).minEpsilon(0.3f).errorClamp(10.0).gamma(0.99).rewardFactor(0.05)     .targetDqnUpdateFreq(100).batchSize(1).expRepMaxSize(150000)     .maxStep(EPOCH_STEP * MAX_STEP)     .maxEpochStep(EPOCH_STEP).seed(123L).build();
public static DQNDenseNetworkConfiguration TTT_NET = DQNDenseNetworkConfiguration.builder()      .l2(0.00).updater(new Adam(0.001)).numHiddenNodes(20).numLayers(3).build();

Podemos dejar la mayoría de los atributos por defecto. Los atributos por resaltar son: maxStep y maxEpochStep, los cuales son el número de iteraciones que se ejecutarán con los datos disponibles definiendo episodios. Por ejemplo, en el ajedrez desde la apertura hasta el jaque mate o las tablas que dan el empate y en nuestro caso cuando el tablero de 3×3 esté lleno, o cuando uno de los jugadores forme las 3 en línea.

Con respecto a la configuración de la red neuronal se definieron 3 capas de 20 nodos. Tomé estos valores teniendo en cuenta que pueden haber más de 26.000 juegos posibles. La configuración no es camisa de fuerza y los valores seleccionados dependen más de la práctica que de alguna teoría.

Luego de definir de donde vamos a aprender, proseguimos con cómo se va a aprender, o lo que es lo mismo, la forma de entrenar esa red neuronal. El framework nos facilita las cosas con los siguientes sencillos pasos:

  private static void trainTTTDemo() throws IOException {    TTTEnv mdp = createTTTEnvironment();
    QLearningDiscreteDense<Observation> dql =        new QLearningDiscreteDense<Observation>(mdp, TTT_NET, TTT_QL);
    dql.train();
    DQNPolicy<Observation> policy = dql.getPolicy();
    policy.save(“E:\\Dev\\tmp\\ttt.policy”);
    mdp.close();  }

Primero definimos nuestro Markov Decision Process (MDP).En este representaremos los actores principales: el agente, la recompensa, las acciones y los estados que ayudarán a nuestra red neuronal a aprender a tomar sus decisiones. Luego el objeto de tipo QLearningDiscreteDense que realizará el entrenamiento, el cual luego de terminar de entrenar será el encargado de devolver lo aprendido como una política en un objeto de tipo DQNPolicy. Esta política la podemos guardar para luego ser utilizada en la aplicación de toma de decisiones del juego.

En la creación del MDP se definen las clases ActionSpace y ObservationSpace, las cuales representan las posibles acciones que el agente puede decidir y los estados que el agente puede enfrentar para tomar esas decisiones.

private static TTTEnv createTTTEnvironment() {    TTTActionSpaceImpl actionSpace = new TTTActionSpaceImpl(TTTActionSpace.NORTH_WEST,        TTTActionSpace.NORTH, TTTActionSpace.NORTH_EAST,         TTTActionSpace.WEST, TTTActionSpace.CENTER,        TTTActionSpace.EAST, TTTActionSpace.SOUTH_WEST, TTTActionSpace.SOUTH,        TTTActionSpace.SOUTH_EAST);    actionSpace.setRandomSeed(123);    TTTObservationSpaceImpl observationSpace = new TTTObservationSpaceImpl();
    return new TTTEnv(observationSpace, actionSpace);  }

La clase TTTObservationSpace se encarga de tomar los datos relevantes del problema y representarlos de una manera apropiada para que sean consumidos por la red neuronal. Para hacer esto contamos con la libreria para computación científica llamada Nd4j. En nuestro caso utilizamos una matriz de 3×3 para representar el tablero de juego y lo transformamos de una manera apropiada para la observación.

  @Override  public Observation getObservation(int[][] tttState) {    int[] flattened = ArrayUtil.flatten(tttState);    double[] doubles = ArrayUtil.toDoubles(flattened);    INDArray data = Nd4j.create(doubles, getShape());
    return new Observation(data);  }

Con las acciones y las observaciones configuradas, continuamos con la representación del ambiente de juego o la clase MDP, que se encarga de procesar la información o generarla. En nuestro caso, eligiendo de manera aleatoria las jugadas en el tablero y calculando la recompensa entregada que servirá de señal a la red neuronal indicando el objetivo a optimizar con las decisiones del agente.

  public StepReply<Observation> step(Integer action) {    int[][] lastBoardState = this.tttBoard.getBoard();    boolean played = this.tttBoard.play(this.actionSpace.encode(action));    double reward = 0.0;    if (this.tttBoard.won()) {      reward = 50;    } else if (this.tttBoard.fullBoard()) {      reward = 15;    } else if (this.tttBoard.lost() || !played) {      reward = -10;    }
    return new StepReply<Observation>(this.observationSpace.getObservation(lastBoardState), reward,        isDone(), null);  }

Luego de realizar el entrenamiento podemos cargar la política aprendida en nuestra aplicación de la siguiente manera:

public TicTacToeAgent() throws IOException {    this.policy = DQNPolicy.load(“E:\\Dev\\tmp\\ttt.policy”);    this.observationSpace = new TTTObservationSpaceImpl();  }

y finalmente utilizar lo aprendido en nuestra aplicación, en este caso para jugar Tic Tac Toe:

public int move(int[][] board) {    return this.policy.nextAction(this.observationSpace.getObservation(board));  }
  public static void main(String[] args) throws IOException {    TicTacToeAgent ticTacToeAgent = new TicTacToeAgent();    in = new Scanner(System.in);    board = new int[3][3];    turn = 1;
    log.info(“Welcome to Tic Tac Toe.”);    printBoard();    log.info(“X’s turn [1-9]: “);
    while (winner == null) {      int numInput;      try {        if (turn == -1) {          numInput = in.nextInt();        } else {          numInput = ticTacToeAgent.move(board);          numInput += 1;        }
        if (!(numInput > 0 && numInput <= 9)) {          log.info(“Invalid move. Try again [1-9]: “);          continue;        }      } catch (InputMismatchException e) {        log.info(“Invalid move. Try again [1-9]: “);        continue;      }      if (move(numInput)) {        printBoard();        turn = turn * -1;        winner = gameOver();      } else {        continue;      }    }    if (winner.equalsIgnoreCase(“draw”)) {      log.info(“It’s a draw!”);    } else {      log.info(“Congratulations ” + winner + ” player you’re the winner!.”);    }  }

Conclusión

RL4J es un framework de Reinforcement Learning integrado con DL4J, por lo tanto, es open-source distribuido usando Apache Spark y Hadoop. Puede entrenarse utilizando múltiples GPUs de manera sencilla cambiando su dependencia con Maven. Es compatible con cualquier lenguaje de la JVM, Scala, Clojure, Kotlin. RL4J soporta los algoritmos deep Q-Learning  (usado en el ejercicio) y A3C. 

Aunque Python es el lenguaje dominante en las área de IA, utilizar las librerías de DL4J me ayudó a comprender más las librerías de Python. Como desarrollador Java esta herramienta simplifica mucho el desarrollo de servicios que implementen IA y podría hacer uso de la gran cantidad de código Java. Por ejemplo, utilizar el código de módulos de ajedrez ya escritos en Java para empoderarlo con redes neuronales, o ser usado con data suministrada desde otros servicios. Es innegable la importancia de la plataforma Java en sus 25 años de existencia así que las posibilidades son innumerables.

Para más detalle de la implementación, el código fuente está acá.

También puedes leer cómo la IA ayuda a las empresas a acelerar la colaboración ante la continuidad del trabajo remoto aquí.

Suscríbete a nuestro newsletter

Recibe nuestras últimas noticias, publicaciones seleccionadas y aspectos destacados. Nunca enviaremos spam, lo prometemos.

Más de

Nuestro equipo del Engineering Studio se encarga de diseñar, construir y desarrollar soluciones digitales integrales de primera clase. Desde el diseño de interfaces humanas hasta plataformas escalables, nuestras capacidades full-stack generan mejores experiencias de cliente y más personalizadas.