Импорт обученной и сохраненной в Python модели нейронной сети в Java

0
01 июн 2020 17:02
Здравствуйте!

Пытаюсь импортировать в Java обученную и сохраненную в Python модель нейронной сети.
Вы дает следующее исключение:
Quote:
Exception in thread “main” java.lang.NoClassDefFoundError: org/deeplearning4j/nn/weights/IWeightInit
at org.deeplearning4j.nn.modelimport.keras.layers.core.KerasDense.<init>(KerasDense.java:96)
at org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.getKerasLayerFromConfig(KerasLayerUtils.java:220)
at org.deeplearning4j.nn.modelimport.keras.KerasModel.prepareLayers(KerasModel.java:218)
at org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel.<init>(KerasSequentialModel.java:110)
at org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel.<init>(KerasSequentialModel.java:57)
at org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder.buildSequential(KerasModelBuilder.java:322)
at org.deeplearning4j.nn.modelimport.keras.KerasModelImport.importKerasSequentialModelAndWeights(KerasModelImport.java:223)
at NeuralNetwork.main(NeuralNetwork.java:21)
Caused by: java.lang.ClassNotFoundException: org.deeplearning4j.nn.weights.IWeightInit
at java.net.URLClassLoader.findClass(URLClassLoader.java:382)
at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
at sun.misc.Launcher$AppClassLoader.loadClass(Launcher.java:349)
at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
… 8 more


При чем если смотреть через отладчик, то он вроде как считывает файл H5 (см. рисунок).

Модель нейронной сети, построенная и сохраненная в Python:
model_fully_connected = Sequential()
model_fully_connected.add(keras.layers.Dense(17, activation='tanh', input_shape=(x_train.shape[1],), W_regularizer=l2(l2_lambda)))
model_fully_connected.add(keras.layers.Dense(17, activation='tanh', W_regularizer=l2(l2_lambda)))
model_fully_connected.add(keras.layers.LeakyReLU (alpha=0.1))
model_fully_connected.add(keras.layers.Dense(17, activation='tanh', W_regularizer=l2(l2_lambda)))
model_fully_connected.add(keras.layers.LeakyReLU (alpha=0.1))
model_fully_connected.add(keras.layers.Dense(17, activation='tanh', W_regularizer=l2(l2_lambda)))
model_fully_connected.add(keras.layers.Dense(1))
model_fully_connected.compile(optimizer='adam', loss='mse', metrics=["mae", "mse"])
history=model_fully_connected.fit(x_train, y_train, epochs=10, batch_size=1, verbose=2, validation_data=(x_test, y_test))
# #Сохранение обученной нейронной сети
model_fully_connected.save("trained _neural_network.H5",True,True)


Код импорта в Java:
MultiLayerNetwork modelMultiLayer=null;
        KerasModelImport kerasModelImport=new KerasModelImport();
        try {            modelMultiLayer=kerasModelImport.importKerasSequentialModelAndWeights("E:\\Java\\neuralwork\\trained _neural_network.H5");
        } catch (IOException e) {
            e.printStackTrace();
        } catch (InvalidKerasConfigurationException e) {
            e.printStackTrace();
        } catch (UnsupportedKerasConfigurationException e) {
            e.printStackTrace();
        }
        System.out.println(modelMultiLayer.conf());


Библиотеки, которые использую в Java для импорта:
<dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-core</artifactId>
            <version>1.0.0-beta2</version>
        </dependency>
        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>nd4j-native-platform</artifactId>
            <version>1.0.0-beta2</version>
        </dependency>
        <dependency>
            <groupId>com.google.cloud.dataflow</groupId>
            <artifactId>google-cloud-dataflow-java-sdk-all</artifactId>
            <version>2.2.0</version>
        </dependency>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-modelimport</artifactId>
            <version>1.0.0-beta7</version>
        </dependency>


В чем здесь проблема может быть?

Ответов: 1

1
02 июн 2020 05:27
Вероятнее всего не может найти нужный класс.
Caused by: java.lang.ClassNotFoundException: org.deeplearning4j.nn.weights.IWeightInit


<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-nn</artifactId>
    <version>1.0.0-beta7</version>
</dependency>
Модераторы: Нет
Сейчас эту тему просматривают: Нет