diff --git a/docs/docs/ProgrammingGuide/pytorch.md b/docs/docs/ProgrammingGuide/pytorch.md new file mode 100644 index 00000000000..1161b68d32e --- /dev/null +++ b/docs/docs/ProgrammingGuide/pytorch.md @@ -0,0 +1,155 @@ +Analytics-Zoo supports distributed Pytorch training and inferenceon on Apache Spark. User can +define their model and loss function with Pytorch API, and run it in a distributed environment +with the wrapper layers provided by Analytics Zoo. + +# System Requirement +Pytorch version: 1.1.0 +torchvision: 2.2.0 + +tested OS version (all 64-bit): __Ubuntu 16.04 or later__ . We expect it to +support a wide range of Operating Systems, yet other systems have not been fully tested with. +Please create issues on [issue page](https://github.com/intel-analytics/analytics-zoo/issues) +if any error is found. + + +# Pytorch API + +Two wrappers are defined in Analytics Zoo for Pytorch: + +1. TorchNet: TorchNet is a wrapper class for Pytorch model. +User may create a TorchNet by providing a Pytorch model and example input or expected size, e.g. +```python + from zoo.pipeline.api.net.torch_net import TorchNet + TorchNet.from_pytorch(torchvision.models.resnet18(pretrained=True).eval(), [1, 3, 224, 224]) +``` +The above line creates TorchNet wrapping a ResNet model, and user can use the TorchNet for +training or inference with Analytics Zoo. Internally, we create a sample input +from the input_shape provided, and use torch script module to trace the tensor operations +performed on the input sample. The result TorchNet extends from BigDL module, and can be used +with local or distributed data (RDD or DataFrame) just like other layers. For multi-input +models, please use tuple of tensors or tuple of expected tensor sizes as example input. + +2. TorchCriterion: TorchCriterion is a wrapper for loss functions defined by Pytorch. +User may create a TorchCriterion from a Pytorch Criterion, +```python + from torch import nn + from zoo.pipeline.api.net.torch_criterion import TorchCriterion + + az_criterion = TorchCriterion.from_pytorch(loss=nn.MSELoss(), + input=[1, 1], + label=[1, 1]) +``` +or from a custom loss function, which takes input and label as parameters + +```python + from torch import nn + from zoo.pipeline.api.net.torch_criterion import TorchCriterion + + criterion = nn.MSELoss() + + # this loss function is calculating loss for a multi-output model + def lossFunc(input, label): + loss1 = criterion(input[0], label[0]) + loss2 = criterion(input[1], label[1]) + loss = loss1 + 0.4 * loss2 + return loss + + az_criterion = TorchCriterion.from_pytorch(loss=lossFunc, + input=(torch.ones(2, 2), torch.ones(2, 1)), + label=(torch.ones(2, 2), torch.ones(2, 1))) +``` +Similar to TorchNet, we also need users to provide example input shape or example input data, +to trace the operations in the loss functions. The created TorchCriterion extends BigDL +criterion, and can be used similarly as other criterions. + +# Examples +Here we provide a simple end to end example, where we use TorchNet and TorchCriterion to +train a simple model with Spark DataFrame. +```python +# +# Copyright 2018 Analytics Zoo Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import torch +import torch.nn as nn +from bigdl.optim.optimizer import Adam +from zoo.common.nncontext import * +from zoo.pipeline.api.net.torch_net import TorchNet +from zoo.pipeline.api.net.torch_criterion import TorchCriterion +from zoo.pipeline.nnframes import * + +from pyspark.ml.linalg import Vectors +from pyspark.sql import SparkSession + + +# define model with Pytorch +class SimpleTorchModel(nn.Module): + def __init__(self): + super(SimpleTorchModel, self).__init__() + self.dense1 = nn.Linear(2, 4) + self.dense2 = nn.Linear(4, 1) + + def forward(self, x): + x = self.dense1(x) + x = torch.sigmoid(self.dense2(x)) + return x + +if __name__ == '__main__': + sparkConf = init_spark_conf().setAppName("example_pytorch").setMaster('local[1]') + sc = init_nncontext(sparkConf) + spark = SparkSession \ + .builder \ + .getOrCreate() + + df = spark.createDataFrame( + [(Vectors.dense([2.0, 1.0]), 1.0), + (Vectors.dense([1.0, 2.0]), 0.0), + (Vectors.dense([2.0, 1.0]), 1.0), + (Vectors.dense([1.0, 2.0]), 0.0)], + ["features", "label"]) + + torch_model = SimpleTorchModel() + torch_criterion = nn.MSELoss() + + az_model = TorchNet.from_pytorch(torch_model, [1, 2]) + az_criterion = TorchCriterion.from_pytorch(torch_criterion, [1, 1], [1, 1]) + + classifier = NNClassifier(az_model, az_criterion) \ + .setBatchSize(4) \ + .setOptimMethod(Adam()) \ + .setLearningRate(0.01) \ + .setMaxEpoch(10) + + nnClassifierModel = classifier.fit(df) + + print("After training: ") + res = nnClassifierModel.transform(df) + res.show(10, False) + +``` + +and we expects to see the output like: +```python ++---------+-----+----------+ +|features |label|prediction| ++---------+-----+----------+ +|[2.0,1.0]|1.0 |1.0 | +|[1.0,2.0]|0.0 |0.0 | +|[2.0,1.0]|1.0 |1.0 | +|[1.0,2.0]|0.0 |0.0 | ++---------+-----+----------+ +``` + +More Pytorch examples (ResNet, Lenet etc.) are available [here](../../../pyzoo/zoo/examples/pytorch). + diff --git a/pyzoo/dev/run-pytests b/pyzoo/dev/run-pytests index 23b90ce536d..b61833c73eb 100755 --- a/pyzoo/dev/run-pytests +++ b/pyzoo/dev/run-pytests @@ -22,7 +22,6 @@ cd "`dirname $0`" export PYSPARK_PYTHON=python export PYSPARK_DRIVER_PYTHON=python - py_version="$(python -V 2>&1)" python -m pytest -v --doctest-modules ../zoo \ @@ -30,7 +29,6 @@ python -m pytest -v --doctest-modules ../zoo \ --ignore=../zoo/tfpark/text \ --ignore=../zoo/examples \ --ignore=../zoo/ray/ - exit_status_1=$? if [ $exit_status_1 -ne 0 ]; then diff --git a/pyzoo/test/zoo/pipeline/api/keras/test_simple_integration.py b/pyzoo/test/zoo/pipeline/api/keras/test_simple_integration.py index fff7f94c5ea..e8c9438ca84 100644 --- a/pyzoo/test/zoo/pipeline/api/keras/test_simple_integration.py +++ b/pyzoo/test/zoo/pipeline/api/keras/test_simple_integration.py @@ -63,7 +63,8 @@ def test_training_with_tensorboard_checkpoint_gradientclipping(self): y_train = np.random.randint(4, size=(200, )) X_test = np.random.random([40, 32, 32]) y_test = np.random.randint(4, size=(40, )) - model.compile(optimizer="adam", + from zoo.pipeline.api.keras.optimizers import Adam, EpochStep + model.compile(optimizer=Adam(lr=0.003, schedule=EpochStep(1, 0.75)), loss="sparse_categorical_crossentropy", metrics=['accuracy']) tmp_log_dir = create_tmp_path() @@ -72,7 +73,7 @@ def test_training_with_tensorboard_checkpoint_gradientclipping(self): model.set_tensorboard(tmp_log_dir, "training_test") model.set_checkpoint(tmp_checkpoint_path) model.set_constant_gradient_clipping(0.01, 0.03) - model.fit(X_train, y_train, batch_size=112, nb_epoch=2, validation_data=(X_test, y_test)) + model.fit(X_train, y_train, batch_size=32, nb_epoch=20, validation_data=(X_test, y_test)) model.clear_gradient_clipping() model.fit(X_train, y_train, batch_size=112, nb_epoch=2, validation_data=(X_test, y_test)) model.set_gradient_clipping_by_l2_norm(0.2) diff --git a/pyzoo/test/zoo/pipeline/api/test_torch_net.py b/pyzoo/test/zoo/pipeline/api/test_torch_net.py index 6236c58b581..6059099d3cf 100644 --- a/pyzoo/test/zoo/pipeline/api/test_torch_net.py +++ b/pyzoo/test/zoo/pipeline/api/test_torch_net.py @@ -27,6 +27,55 @@ class TestTF(ZooTestCase): + + def test_torchnet_constructor(self): + # two inputs test + class TwoInputModel(nn.Module): + def __init__(self): + super(TwoInputModel, self).__init__() + self.dense1 = nn.Linear(2, 2) + self.dense2 = nn.Linear(3, 1) + + def forward(self, x1, x2): + x1 = self.dense1(x1) + x2 = self.dense2(x2) + return x1, x2 + + TorchNet.from_pytorch(TwoInputModel(), (torch.ones(2, 2), torch.ones(2, 3))) + TorchNet.from_pytorch(TwoInputModel(), ([2, 2], [2, 3])) + TorchNet.from_pytorch(TwoInputModel(), [torch.ones(2, 2), torch.ones(2, 3)]) + TorchNet.from_pytorch(TwoInputModel(), [[2, 2], [2, 3]]) + + # one input + input = [[0.5, 1.], [-0.3, 1.2]] + torch_input = torch.tensor(input) + model = nn.Linear(2, 1) + TorchNet.from_pytorch(model, torch_input) + TorchNet.from_pytorch(model, [1, 2]) + + def test_torchcriterion_constructor(self): + # two inputs test + criterion = nn.MSELoss() + + def lossFunc(input, label): + loss1 = criterion(input[0], label[0]) + loss2 = criterion(input[1], label[1]) + loss = loss1 + 0.4 * loss2 + return loss + + TorchCriterion.from_pytorch(lossFunc, + (torch.ones(2, 2), torch.ones(2, 3)), + (torch.ones(2, 2), torch.ones(2, 3))) + TorchCriterion.from_pytorch(lossFunc, ([2, 2], [2, 3]), ([2, 2], [2, 3])) + TorchCriterion.from_pytorch(lossFunc, + [torch.ones(2, 2), torch.ones(2, 3)], + [torch.ones(2, 2), torch.ones(2, 3)]) + TorchCriterion.from_pytorch(lossFunc, [[2, 2], [2, 3]], [[2, 2], [2, 3]]) + + # one inputs test + TorchCriterion.from_pytorch(criterion, [2, 1], [2, 1]) + TorchCriterion.from_pytorch(criterion, torch.ones(2, 2), torch.ones(2, 2)) + def test_torch_net_predict_resnet(self): model = torchvision.models.resnet18(pretrained=True).eval() net = TorchNet.from_pytorch(model, [1, 3, 224, 224]) @@ -54,8 +103,7 @@ def test_linear_gradient_match(self): # AZ part az_net = TorchNet.from_pytorch(model, [1, 2]) - az_criterion = TorchCriterion.from_pytorch(loss=criterion, input_shape=[1, 1], - label_shape=[1, 1]) + az_criterion = TorchCriterion.from_pytorch(criterion, [1, 1], [1, 1]) az_input = np.array(input) az_label = np.array(label) @@ -107,9 +155,7 @@ def forward(self, x): # AZ part az_net = TorchNet.from_pytorch(torch_model, [1, 2]) - az_criterion = TorchCriterion.from_pytorch(loss=torch_criterion.forward, - input_shape=[1, 1], - label_shape=[1, 1]) + az_criterion = TorchCriterion.from_pytorch(torch_criterion.forward, [1, 1], [1, 1]) az_input = np.array(input) az_label = np.array(label) @@ -142,8 +188,7 @@ def lossFunc(input, target): # AZ part az_net = TorchNet.from_pytorch(model, [1, 2]) - az_criterion = TorchCriterion.from_pytorch(loss=lossFunc, input_shape=[1, 10], - label_shape=[1, 1]) + az_criterion = TorchCriterion.from_pytorch(lossFunc, [1, 10], [1, 1]) az_input = np.array(input) az_label = np.array(label) @@ -198,13 +243,12 @@ def forward(self, x): torch_model.fc2.bias.grad.flatten().tolist() # AZ part - az_net = TorchNet.from_pytorch(torch_model, input_shape=[1, 1, 28, 28]) + az_net = TorchNet.from_pytorch(torch_model, [1, 1, 28, 28]) def lossFunc(input, target): return torch_criterion.forward(input, target.flatten().long()) - az_criterion = TorchCriterion.from_pytorch(loss=lossFunc, input_shape=[1, 10], - label_shape=[1, 1]) + az_criterion = TorchCriterion.from_pytorch(lossFunc, [1, 10], [1, 1]) az_input = np.array(input) az_label = np.array(label) @@ -267,9 +311,9 @@ def lossFunc(input, label): az_net = TorchNet.from_pytorch(model, [1, 2]) az_criterion = TorchCriterion.from_pytorch( - loss=lossFunc, - sample_input=(torch.ones(2, 2), torch.ones(2, 1)), - sample_label=(torch.ones(2, 2), torch.ones(2, 1))) + lossFunc, + (torch.ones(2, 2), torch.ones(2, 1)), + (torch.ones(2, 2), torch.ones(2, 1))) az_input = np.array(input) az_label = [np.ones([2, 2]), np.ones([2, 1])] @@ -283,37 +327,6 @@ def lossFunc(input, label): assert np.allclose(torch_loss.tolist(), az_loss_output) assert np.allclose(torch_grad, az_grad.tolist()) - def test_torchnet_constructor(self): - class TwoInputModel(nn.Module): - def __init__(self): - super(TwoInputModel, self).__init__() - self.dense1 = nn.Linear(2, 2) - self.dense2 = nn.Linear(3, 1) - - def forward(self, x1, x2): - x1 = self.dense1(x1) - x2 = self.dense2(x2) - return x1, x2 - - az_net = TorchNet.from_pytorch( - TwoInputModel(), sample_input=(torch.ones(2, 2), torch.ones(2, 3))) - az_net = TorchNet.from_pytorch(TwoInputModel(), ([2, 2], [2, 3])) - - def test_torchcriterion_constructor(self): - criterion = nn.MSELoss() - - def lossFunc(input, label): - loss1 = criterion(input[0], label[0]) - loss2 = criterion(input[1], label[1]) - loss = loss1 + 0.4 * loss2 - return loss - - az_criterion = TorchCriterion.from_pytorch( - lossFunc, - sample_input=(torch.ones(2, 2), torch.ones(2, 3)), - sample_label=(torch.ones(2, 2), torch.ones(2, 3))) - az_criterion = TorchCriterion.from_pytorch(lossFunc, ([2, 2], [2, 3]), ([2, 2], [2, 3])) - def test_model_train_with_multiple_input(self): class TwoInputModel(nn.Module): def __init__(self): @@ -349,11 +362,11 @@ def lossFunc(input, label): model.dense2.weight.grad.tolist()[0] + \ model.dense2.bias.grad.tolist() - az_net = TorchNet.from_pytorch(model, sample_input=(torch.ones(2, 2), torch.ones(2, 2))) + az_net = TorchNet.from_pytorch(model, (torch.ones(2, 2), torch.ones(2, 2))) az_criterion = TorchCriterion.from_pytorch( - loss=lossFunc, - sample_input=(torch.ones(2, 2), torch.ones(2, 1)), - sample_label=(torch.ones(2, 2), torch.ones(2, 1))) + lossFunc, + (torch.ones(2, 2), torch.ones(2, 1)), + (torch.ones(2, 2), torch.ones(2, 1))) az_input = [np.array(input), np.array(input)] az_label = [np.ones([2, 2]), np.ones([2, 1])] diff --git a/pyzoo/test/zoo/ray/test_ray_on_local.py b/pyzoo/test/zoo/ray/test_ray_on_local.py index 563421f53db..61dc3703d23 100644 --- a/pyzoo/test/zoo/ray/test_ray_on_local.py +++ b/pyzoo/test/zoo/ray/test_ray_on_local.py @@ -19,6 +19,7 @@ import psutil import pytest import ray +import time from zoo import init_spark_on_local from zoo.ray.util.raycontext import RayContext @@ -44,6 +45,7 @@ def test_local(self): print([ray.get(actor.hostname.remote()) for actor in actors]) ray_ctx.stop() sc.stop() + time.sleep(1) for process_info in ray_ctx.ray_processesMonitor.process_infos: for pid in process_info.pids: assert not psutil.pid_exists(pid) diff --git a/pyzoo/test/zoo/ray/test_util.py b/pyzoo/test/zoo/ray/test_util.py index ee2b4d70bfe..abbf5d55620 100644 --- a/pyzoo/test/zoo/ray/test_util.py +++ b/pyzoo/test/zoo/ray/test_util.py @@ -25,14 +25,14 @@ class TestUtil(TestCase): - def test_split(self): - vector = np.ones([10]) - result = rutils.split(vector, 4) - assert len(result) == 4 - assert len(result[0]) == 3 - assert len(result[1]) == 3 - assert len(result[2]) == 2 - assert len(result[3]) == 2 + # def test_split(self): + # vector = np.ones([10]) + # result = rutils.split(vector, 4) + # assert len(result) == 4 + # assert len(result[0]) == 3 + # assert len(result[1]) == 3 + # assert len(result[2]) == 2 + # assert len(result[3]) == 2 def test_resource_to_bytes(self): assert 10 == rutils.resourceToBytes("10b") diff --git a/pyzoo/test/zoo/tfpark/test_tfpark_model.py b/pyzoo/test/zoo/tfpark/test_tfpark_model.py index abd0a198ad2..13b50a4a3c0 100644 --- a/pyzoo/test/zoo/tfpark/test_tfpark_model.py +++ b/pyzoo/test/zoo/tfpark/test_tfpark_model.py @@ -316,6 +316,33 @@ def test_tf_optimizer_with_sparse_gradient_using_keras(self): optimizer = TFOptimizer.from_keras(model, dataset) optimizer.optimize() + def test_tf_optimizer_variable_length(self): + from random import randrange + ids = [np.random.randint(0, 10, size=[randrange(20)+1]) for i in range(0, 20)] + labels = [np.array([1]) for i in range(0, 20)] + id_rdd = self.sc.parallelize(ids) + label_rdd = self.sc.parallelize(labels) + training_rdd = id_rdd.zip(label_rdd) + dataset = TFDataset.from_rdd(training_rdd, + features=(tf.int32, [None]), + labels=(tf.int32, [1]), + names=["features", "labels"], + ) + # model = tf.keras.models.Sequential() + # model.add(tf.keras.layers.Dense(2, input_shape=(20, ), activation="softmax")) + # model.compile(optimizer="sgd", loss="sparse_categorical_crossentropy") + words_input = tf.keras.layers.Input(shape=(20, ), name='words_input') + embedding_layer = tf.keras.layers.Embedding(input_dim=10, + output_dim=5, name='word_embedding') + word_embeddings = embedding_layer(words_input) + embedding = tf.keras.layers.Flatten()(word_embeddings) + model = tf.keras.models.Model(inputs=[words_input], outputs=[embedding]) + model.compile(optimizer="sgd", loss="mse") + optimizer = TFOptimizer.from_keras(model, dataset) + optimizer.optimize() + print("111") + + def test_tensorflow_optimizer(self): data = tf.keras.layers.Input(shape=[10]) diff --git a/pyzoo/zoo/examples/pytorch/README.md b/pyzoo/zoo/examples/pytorch/inference/README.md similarity index 68% rename from pyzoo/zoo/examples/pytorch/README.md rename to pyzoo/zoo/examples/pytorch/inference/README.md index 214d4ac0268..1aadb0835a3 100644 --- a/pyzoo/zoo/examples/pytorch/README.md +++ b/pyzoo/zoo/examples/pytorch/inference/README.md @@ -1,25 +1,22 @@ ## Torch ResNet Prediction Example -TorchNet wraps a TorchScript model as a single layer, thus the Pytorch model can be used for -distributed inference. This example illustrates that a PyTorch program, with One line of change, +TorchNet wraps a Pytorch model as Analytics Zoo module, thus the Pytorch model can be used for +distributed inference. This example illustrates that a PyTorch program, with few lines of change, can be executed on Apache Spark. ## Install or download Analytics Zoo -Follow the instructions [here](https://analytics-zoo.github.io/master/#PythonUserGuide/install/) to install analytics-zoo via __pip__ or __download the prebuilt package__. +Follow the instructions [here](https://analytics-zoo.github.io/master/#PythonUserGuide/install/) +to install analytics-zoo via __pip__ or __download the prebuilt package__. ## Model and Data Preparation -1. Prepare the image dataset for inference. Put the images to do prediction in the same folder. - +We use ResNet 18 from torchvision and run inference on some images, e.g. images from ImageNet. ## Run this example after pip install ```bash python predict.py --image path_to_image_folder ``` -__Options:__ -* `--image` The path where the images are stored. - ## Run this example with prebuilt package ```bash export SPARK_HOME=the root directory of Spark diff --git a/pyzoo/zoo/examples/pytorch/inference/__init__.py b/pyzoo/zoo/examples/pytorch/inference/__init__.py new file mode 100644 index 00000000000..5976dc4df02 --- /dev/null +++ b/pyzoo/zoo/examples/pytorch/inference/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2018 Analytics Zoo Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/pyzoo/zoo/examples/pytorch/predict.py b/pyzoo/zoo/examples/pytorch/inference/predict.py similarity index 87% rename from pyzoo/zoo/examples/pytorch/predict.py rename to pyzoo/zoo/examples/pytorch/inference/predict.py index ff1a8c6edc7..affe2119669 100644 --- a/pyzoo/zoo/examples/pytorch/predict.py +++ b/pyzoo/zoo/examples/pytorch/inference/predict.py @@ -49,10 +49,6 @@ def predict(img_path): parser.add_option("--image", type=str, dest="img_path", help="The path where the images are stored, " "can be either a folder or an image path") - parser.add_option("--model", type=str, dest="model_path", - help="The path of the TensorFlow object detection model") - parser.add_option("--partition_num", type=int, dest="partition_num", default=4, - help="The number of partitions") (options, args) = parser.parse_args(sys.argv) sc = init_nncontext("Torch ResNet Prediction Example") diff --git a/pyzoo/zoo/examples/pytorch/train/Lenet_mnist.py b/pyzoo/zoo/examples/pytorch/train/Lenet_mnist.py index 2d495374eaf..428beaae0dd 100644 --- a/pyzoo/zoo/examples/pytorch/train/Lenet_mnist.py +++ b/pyzoo/zoo/examples/pytorch/train/Lenet_mnist.py @@ -49,8 +49,7 @@ def forward(self, x): if __name__ == '__main__': - sparkConf = init_spark_conf().setAppName("test_pytorch_lenet").setMaster("local[1]")\ - .set('spark.driver.memory', '10g') + sparkConf = init_spark_conf().setAppName("test_pytorch_lenet") sc = init_nncontext(sparkConf) spark = SparkSession.builder.config(conf=sparkConf).getOrCreate() @@ -70,9 +69,8 @@ def lossFunc(input, target): return nn.CrossEntropyLoss().forward(input, target.flatten().long()) torch_model = LeNet() - model = TorchNet.from_pytorch(module=torch_model, input_shape=[1, 1, 28, 28]) - criterion = TorchCriterion.from_pytorch(loss=lossFunc, input_shape=[1, 10], - sample_label=torch.LongTensor([5])) + model = TorchNet.from_pytorch(torch_model, [1, 1, 28, 28]) + criterion = TorchCriterion.from_pytorch(lossFunc, [1, 10], torch.LongTensor([5])) classifier = NNClassifier(model, criterion, SeqToTensor([1, 28, 28])) \ .setBatchSize(64) \ .setOptimMethod(Adam()) \ diff --git a/pyzoo/zoo/examples/pytorch/train/SimpleTrainingExample.py b/pyzoo/zoo/examples/pytorch/train/SimpleTrainingExample.py index 6643b107470..46e80bf72c7 100644 --- a/pyzoo/zoo/examples/pytorch/train/SimpleTrainingExample.py +++ b/pyzoo/zoo/examples/pytorch/train/SimpleTrainingExample.py @@ -13,32 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import torch import torch.nn as nn -import torch.nn.functional as F -from torch.nn.modules.loss import BCELoss -from bigdl.nn.criterion import * -from bigdl.nn.layer import * from bigdl.optim.optimizer import Adam -from pyspark.sql.types import * -from zoo import init_nncontext from zoo.common.nncontext import * -from zoo.pipeline.api.net.torch_net import TorchNet, TorchIdentityCriterion +from zoo.pipeline.api.net.torch_net import TorchNet +from zoo.pipeline.api.net.torch_criterion import TorchCriterion from zoo.pipeline.nnframes import * - -# create training data as Spark DataFrame -def get_df(sqlContext): - data = sc.parallelize([ - ((2.0, 1.0), 1.0), - ((1.0, 2.0), 0.0), - ((2.0, 1.0), 1.0), - ((1.0, 2.0), 0.0)]) - - schema = StructType([ - StructField("features", ArrayType(DoubleType(), False), False), - StructField("label", DoubleType(), False)]) - df = sqlContext.createDataFrame(data, schema) - return df +from pyspark.ml.linalg import Vectors +from pyspark.sql import SparkSession # define model with Pytorch @@ -46,33 +30,38 @@ class SimpleTorchModel(nn.Module): def __init__(self): super(SimpleTorchModel, self).__init__() self.dense1 = nn.Linear(2, 4) - self.dense2 = nn.Linear(4, 8) - self.dense3 = nn.Linear(8, 1) + self.dense2 = nn.Linear(4, 1) def forward(self, x): x = self.dense1(x) - x = self.dense2(x) - x = F.sigmoid(self.dense3(x)) + x = torch.sigmoid(self.dense2(x)) return x if __name__ == '__main__': sparkConf = init_spark_conf().setAppName("testNNClassifer").setMaster('local[1]') sc = init_nncontext(sparkConf) - sqlContext = SQLContext(sc) - df = get_df(sqlContext) + spark = SparkSession \ + .builder \ + .getOrCreate() + + df = spark.createDataFrame( + [(Vectors.dense([2.0, 1.0]), 1.0), + (Vectors.dense([1.0, 2.0]), 0.0), + (Vectors.dense([2.0, 1.0]), 1.0), + (Vectors.dense([1.0, 2.0]), 0.0)], + ["features", "label"]) torch_model = SimpleTorchModel() - becloss = BCELoss() + torch_criterion = nn.MSELoss() + + az_model = TorchNet.from_pytorch(torch_model, [1, 2]) + az_criterion = TorchCriterion.from_pytorch(torch_criterion, [1, 1], [1, 1]) - model = TorchNet.from_pytorch(module=torch_model, - input_shape=[1, 2], - lossFunc=becloss.forward, - pred_shape=[1, 1], label_shape=[1, 1]) - classifier = NNEstimator(model, TorchIdentityCriterion(), SeqToTensor([2])) \ - .setBatchSize(2) \ + classifier = NNClassifier(az_model, az_criterion) \ + .setBatchSize(4) \ .setOptimMethod(Adam()) \ - .setLearningRate(0.1) \ - .setMaxEpoch(20) + .setLearningRate(0.01) \ + .setMaxEpoch(10) nnClassifierModel = classifier.fit(df) diff --git a/pyzoo/zoo/examples/pytorch/train/resnet_finetune/resnet_finetune.py b/pyzoo/zoo/examples/pytorch/train/resnet_finetune/resnet_finetune.py index 7d21d5874a9..09c5f101456 100644 --- a/pyzoo/zoo/examples/pytorch/train/resnet_finetune/resnet_finetune.py +++ b/pyzoo/zoo/examples/pytorch/train/resnet_finetune/resnet_finetune.py @@ -60,8 +60,7 @@ def forward(self, x): def lossFunc(input, target): return nn.CrossEntropyLoss().forward(input, target.flatten().long()) - torchcriterion = TorchCriterion.from_pytorch(loss=lossFunc, input_shape=[1, 2], - sample_label=torch.LongTensor([1])) + torchcriterion = TorchCriterion.from_pytorch(lossFunc, [1, 2], torch.LongTensor([1])) # prepare training data as Spark DataFrame image_path = sys.argv[1] @@ -75,8 +74,7 @@ def lossFunc(input, target): # run training and evaluation featureTransformer = ChainedPreprocessing( [RowToImageFeature(), ImageCenterCrop(224, 224), - ImageChannelNormalize(0, 0, 0, 255.0, 255.0, 255.0), - ImageChannelNormalize(0.485, 0.456, 0.406, 0.229, 0.224, 0.225), + ImageChannelNormalize(123.0, 117.0, 104.0, 255.0, 255.0, 255.0), ImageMatToTensor(), ImageFeatureToTensor()]) classifier = NNClassifier(torchnet, torchcriterion, featureTransformer) \ diff --git a/pyzoo/zoo/examples/rayexample/__init__.py b/pyzoo/zoo/examples/rayexample/__init__.py new file mode 100644 index 00000000000..5976dc4df02 --- /dev/null +++ b/pyzoo/zoo/examples/rayexample/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2018 Analytics Zoo Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/pyzoo/zoo/examples/rayexample/parameter_server/__init__.py b/pyzoo/zoo/examples/rayexample/parameter_server/__init__.py new file mode 100644 index 00000000000..5976dc4df02 --- /dev/null +++ b/pyzoo/zoo/examples/rayexample/parameter_server/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2018 Analytics Zoo Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/pyzoo/zoo/examples/rayexample/parameter_server/async_parameter_server.py b/pyzoo/zoo/examples/rayexample/parameter_server/async_parameter_server.py new file mode 100644 index 00000000000..3d88558e379 --- /dev/null +++ b/pyzoo/zoo/examples/rayexample/parameter_server/async_parameter_server.py @@ -0,0 +1,121 @@ +# This file is adapted from https://github.com/ray-project/ray/blob +# /master/examples/parameter_server/async_parameter_server.py +# +# Copyright 2018 Analytics Zoo Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +import argparse +import time + +import ray +import model + +from zoo import init_spark_on_yarn, init_spark_on_local +from zoo.ray.util.raycontext import RayContext + +os.environ["LANG"] = "C.UTF-8" +parser = argparse.ArgumentParser(description="Run the asynchronous parameter " + "server example.") +parser.add_argument("--num-workers", default=4, type=int, + help="The number of workers to use.") +parser.add_argument("--iterations", default=10, type=int, + help="Iteration time.") +parser.add_argument("--hadoop_conf", type=str, + help="turn on yarn mode by passing the path to the hadoop" + "Configuration folder. Otherwise, turn on local mode.") + + +@ray.remote +class ParameterServer(object): + def __init__(self, keys, values): + # These values will be mutated, so we must create a copy that is not + # backed by the object store. + values = [value.copy() for value in values] + self.weights = dict(zip(keys, values)) + + def push(self, keys, values): + for key, value in zip(keys, values): + self.weights[key] += value + + def pull(self, keys): + return [self.weights[key] for key in keys] + + +@ray.remote +def worker_task(ps, worker_index, batch_size=50): + # Download MNIST. + print("Worker " + str(worker_index)) + mnist = model.download_mnist_retry(seed=worker_index) + + # Initialize the model. + net = model.SimpleCNN() + keys = net.get_weights()[0] + + while True: + # Get the current weights from the parameter server. + weights = ray.get(ps.pull.remote(keys)) + net.set_weights(keys, weights) + # Compute an update and push it to the parameter server. + xs, ys = mnist.train.next_batch(batch_size) + gradients = net.compute_update(xs, ys) + ps.push.remote(keys, gradients) + +if __name__ == "__main__": + args = parser.parse_args() + if args.hadoop_conf: + slave_num = 2 + sc = init_spark_on_yarn( + hadoop_conf=args.hadoop_conf, + conda_name="ray36", + num_executor=slave_num, + executor_cores=28, + executor_memory="10g", + driver_memory="2g", + driver_cores=4, + extra_executor_memory_for_ray="30g") + ray_ctx = RayContext(sc=sc, object_store_memory="25g") + else: + sc = init_spark_on_local(cores=4) + ray_ctx = RayContext(sc=sc, object_store_memory="4g") + ray_ctx.init() + + # Create a parameter server with some random weights. + net = model.SimpleCNN() + all_keys, all_values = net.get_weights() + ps = ParameterServer.remote(all_keys, all_values) + + # Start some training tasks. + worker_tasks = [worker_task.remote(ps, i) for i in range(args.num_workers)] + + # Download MNIST. + mnist = model.download_mnist_retry() + print("Begin iteration") + i = 0 + while i < args.iterations: + # Get and evaluate the current model. + print("-----Iteration" + str(i) + "------") + current_weights = ray.get(ps.pull.remote(all_keys)) + net.set_weights(all_keys, current_weights) + test_xs, test_ys = mnist.test.next_batch(1000) + accuracy = net.compute_accuracy(test_xs, test_ys) + print("Iteration {}: accuracy is {}".format(i, accuracy)) + i += 1 + time.sleep(1) + ray_ctx.stop() + sc.stop() diff --git a/pyzoo/zoo/examples/rayexample/parameter_server/model.py b/pyzoo/zoo/examples/rayexample/parameter_server/model.py new file mode 100644 index 00000000000..6c7f78b1cd8 --- /dev/null +++ b/pyzoo/zoo/examples/rayexample/parameter_server/model.py @@ -0,0 +1,218 @@ +# This file is adapted from https://github.com/ray-project/ray/blob +# /master/examples/parameter_server/model.py +# +# Copyright 2018 Analytics Zoo Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Most of the tensorflow code is adapted from Tensorflow's tutorial on using +# CNNs to train MNIST +# https://www.tensorflow.org/get_started/mnist/pros#build-a-multilayer-convolutional-network. # noqa: E501 + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import tensorflow as tf +from tensorflow.examples.tutorials.mnist import input_data + +import ray +import ray.experimental.tf_utils + + +def download_mnist_retry(seed=0, max_num_retries=20): + for _ in range(max_num_retries): + try: + return input_data.read_data_sets( + "MNIST_data", one_hot=True, seed=seed) + except tf.errors.AlreadyExistsError: + time.sleep(1) + raise Exception("Failed to download MNIST.") + + +class SimpleCNN(object): + def __init__(self, learning_rate=1e-4): + with tf.Graph().as_default(): + + # Create the model + self.x = tf.placeholder(tf.float32, [None, 784]) + + # Define loss and optimizer + self.y_ = tf.placeholder(tf.float32, [None, 10]) + + # Build the graph for the deep net + self.y_conv, self.keep_prob = deepnn(self.x) + + with tf.name_scope("loss"): + cross_entropy = tf.nn.softmax_cross_entropy_with_logits( + labels=self.y_, logits=self.y_conv) + self.cross_entropy = tf.reduce_mean(cross_entropy) + + with tf.name_scope("adam_optimizer"): + self.optimizer = tf.train.AdamOptimizer(learning_rate) + self.train_step = self.optimizer.minimize(self.cross_entropy) + + with tf.name_scope("accuracy"): + correct_prediction = tf.equal( + tf.argmax(self.y_conv, 1), tf.argmax(self.y_, 1)) + correct_prediction = tf.cast(correct_prediction, tf.float32) + self.accuracy = tf.reduce_mean(correct_prediction) + + self.sess = tf.Session( + config=tf.ConfigProto( + intra_op_parallelism_threads=1, + inter_op_parallelism_threads=1)) + self.sess.run(tf.global_variables_initializer()) + + # Helper values. + + self.variables = ray.experimental.tf_utils.TensorFlowVariables( + self.cross_entropy, self.sess) + + self.grads = self.optimizer.compute_gradients(self.cross_entropy) + self.grads_placeholder = [(tf.placeholder( + "float", shape=grad[1].get_shape()), grad[1]) + for grad in self.grads] + self.apply_grads_placeholder = self.optimizer.apply_gradients( + self.grads_placeholder) + + def compute_update(self, x, y): + # TODO(rkn): Computing the weights before and after the training step + # and taking the diff is awful. + weights = self.get_weights()[1] + self.sess.run( + self.train_step, + feed_dict={ + self.x: x, + self.y_: y, + self.keep_prob: 0.5 + }) + new_weights = self.get_weights()[1] + return [x - y for x, y in zip(new_weights, weights)] + + def compute_gradients(self, x, y): + return self.sess.run( + [grad[0] for grad in self.grads], + feed_dict={ + self.x: x, + self.y_: y, + self.keep_prob: 0.5 + }) + + def apply_gradients(self, gradients): + feed_dict = {} + for i in range(len(self.grads_placeholder)): + feed_dict[self.grads_placeholder[i][0]] = gradients[i] + self.sess.run(self.apply_grads_placeholder, feed_dict=feed_dict) + + def compute_accuracy(self, x, y): + return self.sess.run( + self.accuracy, + feed_dict={ + self.x: x, + self.y_: y, + self.keep_prob: 1.0 + }) + + def set_weights(self, variable_names, weights): + self.variables.set_weights(dict(zip(variable_names, weights))) + + def get_weights(self): + weights = self.variables.get_weights() + return list(weights.keys()), list(weights.values()) + + +def deepnn(x): + """deepnn builds the graph for a deep net for classifying digits. + Args: + x: an input tensor with the dimensions (N_examples, 784), where 784 is + the number of pixels in a standard MNIST image. + Returns: + A tuple (y, keep_prob). y is a tensor of shape (N_examples, 10), with + values equal to the logits of classifying the digit into one of 10 + classes (the digits 0-9). keep_prob is a scalar placeholder for the + probability of dropout. + """ + # Reshape to use within a convolutional neural net. + # Last dimension is for "features" - there is only one here, since images + # are grayscale -- it would be 3 for an RGB image, 4 for RGBA, etc. + with tf.name_scope("reshape"): + x_image = tf.reshape(x, [-1, 28, 28, 1]) + + # First convolutional layer - maps one grayscale image to 32 feature maps. + with tf.name_scope("conv1"): + W_conv1 = weight_variable([5, 5, 1, 32]) + b_conv1 = bias_variable([32]) + h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) + + # Pooling layer - downsamples by 2X. + with tf.name_scope("pool1"): + h_pool1 = max_pool_2x2(h_conv1) + + # Second convolutional layer -- maps 32 feature maps to 64. + with tf.name_scope("conv2"): + W_conv2 = weight_variable([5, 5, 32, 64]) + b_conv2 = bias_variable([64]) + h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) + + # Second pooling layer. + with tf.name_scope("pool2"): + h_pool2 = max_pool_2x2(h_conv2) + + # Fully connected layer 1 -- after 2 round of downsampling, our 28x28 image + # is down to 7x7x64 feature maps -- maps this to 1024 features. + with tf.name_scope("fc1"): + W_fc1 = weight_variable([7 * 7 * 64, 1024]) + b_fc1 = bias_variable([1024]) + + h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]) + h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) + + # Dropout - controls the complexity of the model, prevents co-adaptation of + # features. + with tf.name_scope("dropout"): + keep_prob = tf.placeholder(tf.float32) + h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) + + # Map the 1024 features to 10 classes, one for each digit + with tf.name_scope("fc2"): + W_fc2 = weight_variable([1024, 10]) + b_fc2 = bias_variable([10]) + + y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2 + return y_conv, keep_prob + + +def conv2d(x, W): + """conv2d returns a 2d convolution layer with full stride.""" + return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding="SAME") + + +def max_pool_2x2(x): + """max_pool_2x2 downsamples a feature map by 2X.""" + return tf.nn.max_pool( + x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME") + + +def weight_variable(shape): + """weight_variable generates a weight variable of a given shape.""" + initial = tf.truncated_normal(shape, stddev=0.1) + return tf.Variable(initial) + + +def bias_variable(shape): + """bias_variable generates a bias variable of a given shape.""" + initial = tf.constant(0.1, shape=shape) + return tf.Variable(initial) diff --git a/pyzoo/zoo/examples/rayexample/parameter_server/sync_parameter_server.py b/pyzoo/zoo/examples/rayexample/parameter_server/sync_parameter_server.py new file mode 100644 index 00000000000..4f1514cb2fe --- /dev/null +++ b/pyzoo/zoo/examples/rayexample/parameter_server/sync_parameter_server.py @@ -0,0 +1,118 @@ +# This file is adapted from https://github.com/ray-project/ray/blob +# /master/examples/parameter_server/sync_parameter_server.py +# +# Copyright 2018 Analytics Zoo Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +import argparse +import numpy as np + +import ray +import model + +from zoo import init_spark_on_yarn, init_spark_on_local +from zoo.ray.util.raycontext import RayContext + +os.environ["LANG"] = "C.UTF-8" +parser = argparse.ArgumentParser(description="Run the synchronous parameter " + "server example.") +parser.add_argument("--num-workers", default=4, type=int, + help="The number of workers to use.") +parser.add_argument("--iterations", default=50, type=int, + help="Iteration time.") +parser.add_argument("--hadoop_conf", type=str, + help="turn on yarn mode by passing the path to the hadoop" + "configuration folder. Otherwise, turn on local mode.") + + +@ray.remote +class ParameterServer(object): + def __init__(self, learning_rate): + self.net = model.SimpleCNN(learning_rate=learning_rate) + + def apply_gradients(self, *gradients): + self.net.apply_gradients(np.mean(gradients, axis=0)) + return self.net.variables.get_flat() + + def get_weights(self): + return self.net.variables.get_flat() + + +@ray.remote +class Worker(object): + def __init__(self, worker_index, batch_size=50): + self.worker_index = worker_index + self.batch_size = batch_size + self.mnist = model.download_mnist_retry(seed=worker_index) + self.net = model.SimpleCNN() + + def compute_gradients(self, weights): + self.net.variables.set_flat(weights) + xs, ys = self.mnist.train.next_batch(self.batch_size) + return self.net.compute_gradients(xs, ys) + + +if __name__ == "__main__": + args = parser.parse_args() + if args.hadoop_conf: + slave_num = 2 + sc = init_spark_on_yarn( + hadoop_conf=args.hadoop_conf, + conda_name="ray36", + num_executor=slave_num, + executor_cores=28, + executor_memory="10g", + driver_memory="2g", + driver_cores=4, + extra_executor_memory_for_ray="30g") + ray_ctx = RayContext(sc=sc, object_store_memory="25g") + else: + sc = init_spark_on_local(cores=4) + ray_ctx = RayContext(sc=sc, object_store_memory="4g") + + ray_ctx.init() + + # Create a parameter server. + net = model.SimpleCNN() + ps = ParameterServer.remote(1e-4 * args.num_workers) + + # Create workers. + workers = [Worker.remote(worker_index) + for worker_index in range(args.num_workers)] + + # Download MNIST. + mnist = model.download_mnist_retry() + + i = 0 + current_weights = ps.get_weights.remote() + print("Begin iteration") + while i < args.iterations: + # Compute and apply gradients. + gradients = [worker.compute_gradients.remote(current_weights) + for worker in workers] + current_weights = ps.apply_gradients.remote(*gradients) + + if i % 10 == 0: + # Evaluate the current model. + net.variables.set_flat(ray.get(current_weights)) + test_xs, test_ys = mnist.test.next_batch(1000) + accuracy = net.compute_accuracy(test_xs, test_ys) + print("Iteration {}: accuracy is {}".format(i, accuracy)) + i += 1 + ray_ctx.stop() + sc.stop() diff --git a/pyzoo/zoo/examples/rayexample/rl_pong/__init__.py b/pyzoo/zoo/examples/rayexample/rl_pong/__init__.py new file mode 100644 index 00000000000..5976dc4df02 --- /dev/null +++ b/pyzoo/zoo/examples/rayexample/rl_pong/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2018 Analytics Zoo Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/pyzoo/zoo/examples/rayexample/rl_pong/rl_pong.py b/pyzoo/zoo/examples/rayexample/rl_pong/rl_pong.py new file mode 100644 index 00000000000..a69bf146b1b --- /dev/null +++ b/pyzoo/zoo/examples/rayexample/rl_pong/rl_pong.py @@ -0,0 +1,265 @@ +# This file is adapted from https://github.com/ray-project/ray/blob/master +# /examples/rl_pong/driver.py +# +# Copyright 2018 Analytics Zoo Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# play Pong https://gist.github.com/karpathy/a4166c7fe253700972fcbc77e4ea32c5. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +import argparse +import numpy as np +import os +import ray +import time +import gym +from zoo import init_spark_on_yarn, init_spark_on_local +from zoo.ray.util.raycontext import RayContext + +os.environ["LANG"] = "C.UTF-8" +# Define some hyperparameters. + +# The number of hidden layer neurons. +H = 200 +learning_rate = 1e-4 +# Discount factor for reward. +gamma = 0.99 +# The decay factor for RMSProp leaky sum of grad^2. +decay_rate = 0.99 + +# The input dimensionality: 80x80 grid. +D = 80 * 80 + + +def sigmoid(x): + # Sigmoid "squashing" function to interval [0, 1]. + return 1.0 / (1.0 + np.exp(-x)) + + +def preprocess(img): + """Preprocess 210x160x3 uint8 frame into 6400 (80x80) 1D float vector.""" + # Crop the image. + img = img[35:195] + # Downsample by factor of 2. + img = img[::2, ::2, 0] + # Erase background (background type 1). + img[img == 144] = 0 + # Erase background (background type 2). + img[img == 109] = 0 + # Set everything else (paddles, ball) to 1. + img[img != 0] = 1 + return img.astype(np.float).ravel() + + +def discount_rewards(r): + """take 1D float array of rewards and compute discounted reward""" + discounted_r = np.zeros_like(r) + running_add = 0 + for t in reversed(range(0, r.size)): + # Reset the sum, since this was a game boundary (pong specific!). + if r[t] != 0: + running_add = 0 + running_add = running_add * gamma + r[t] + discounted_r[t] = running_add + return discounted_r + + +# defines the policy network +# x is a vector that holds the preprocessed pixel information +def policy_forward(x, model): + # neurons in the hidden layer (W1) can detect various game senarios + h = np.dot(model["W1"], x) # compute hidden layer neuron activations + h[h < 0] = 0 # ReLU nonlinearity. threhold at zero + # weights in W2 can then decide if each case we should go UP or DOWN + logp = np.dot(model["W2"], h) # compuate the log probability of going up + p = sigmoid(logp) + # Return probability of taking action 2, and hidden state. + return p, h + + +def policy_backward(eph, epx, epdlogp, model): + """backward pass. (eph is array of intermediate hidden states)""" +# the way to change the policy parameters is to +# do some rollouts, take the gradient of the sampled actions +# multiply it by the score and add everything + dW2 = np.dot(eph.T, epdlogp).ravel() + dh = np.outer(epdlogp, model["W2"]) + # Backprop relu. + dh[eph <= 0] = 0 + dW1 = np.dot(dh.T, epx) + return {"W1": dW1, "W2": dW2} + + +@ray.remote +class PongEnv(object): + def __init__(self): + # Tell numpy to only use one core. If we don't do this, each actor may + # try to use all of the cores and the resulting contention may result + # in no speedup over the serial version. Note that if numpy is using + # OpenBLAS, then you need to set OPENBLAS_NUM_THREADS=1, and you + # probably need to do it from the command line (so it happens before + # numpy is imported). + os.environ["MKL_NUM_THREADS"] = "1" + self.env = gym.make("Pong-v0") + + def compute_gradient(self, model): + # model = {'W1':W1, 'W2':W2} + # given a model, run for one episode and return the parameter + # to be updated and sum(reward) + # Reset the game. + observation = self.env.reset() + # Note that prev_x is used in computing the difference frame. + prev_x = None + xs, hs, dlogps, drs = [], [], [], [] + reward_sum = 0 + done = False + while not done: + cur_x = preprocess(observation) + x = cur_x - prev_x if prev_x is not None else np.zeros(D) + prev_x = cur_x + + # feed difference frames into the network + # so that it can detect motion + aprob, h = policy_forward(x, model) + # Sample an action. + action = 2 if np.random.uniform() < aprob else 3 + + # The observation. + xs.append(x) + # The hidden state. + hs.append(h) + y = 1 if action == 2 else 0 # A "fake label". + # The gradient that encourages the action that was taken to be + # taken (see http://cs231n.github.io/neural-networks-2/#losses if + # confused). + dlogps.append(y - aprob) + + observation, reward, done, info = self.env.step(action) + reward_sum += reward + + # Record reward (has to be done after we call step() to get reward + # for previous action). + drs.append(reward) + + epx = np.vstack(xs) + eph = np.vstack(hs) + epdlogp = np.vstack(dlogps) + epr = np.vstack(drs) + # Reset the array memory. + xs, hs, dlogps, drs = [], [], [], [] + + # Compute the discounted reward backward through time. + discounted_epr = discount_rewards(epr) + # Standardize the rewards to be unit normal (helps control the gradient + # estimator variance). + discounted_epr -= np.mean(discounted_epr) + discounted_epr /= np.std(discounted_epr) + # Modulate the gradient with advantage (the policy gradient magic + # happens right here). + epdlogp *= discounted_epr + return policy_backward(eph, epx, epdlogp, model), reward_sum + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Train an RL agent") + + parser.add_argument("--hadoop_conf", + type=str, + help="turn on yarn mode by passing the hadoop path" + "configuration folder. Otherwise, turn on local mode.") + parser.add_argument( + "--batch-size", + default=10, + type=int, + help="The number of rollouts to do per batch.") + parser.add_argument( + "--iterations", + default=-1, + type=int, + help="The number of model updates to perform. By " + "default, training will not terminate.") + + args = parser.parse_args() + if args.hadoop_conf: + slave_num = 2 + sc = init_spark_on_yarn( + hadoop_conf=args.hadoop_conf, + conda_name="ray36", + num_executor=slave_num, + executor_cores=28, + executor_memory="10g", + driver_memory="2g", + driver_cores=4, + extra_executor_memory_for_ray="30g") + ray_ctx = RayContext(sc=sc, object_store_memory="25g") + else: + sc = init_spark_on_local(cores=4) + ray_ctx = RayContext(sc=sc, object_store_memory="4g") + ray_ctx.init() + + batch_size = args.batch_size + # Run the reinforcement learning. + running_reward = None + batch_num = 1 + model = {} + # "Xavier" initialization. + model["W1"] = np.random.randn(H, D) / np.sqrt(D) + model["W2"] = np.random.randn(H) / np.sqrt(H) + # Update buffers that add up gradients over a batch. + grad_buffer = {k: np.zeros_like(v) for k, v in model.items()} + # Update the rmsprop memory. + rmsprop_cache = {k: np.zeros_like(v) for k, v in model.items()} + actors = [PongEnv.remote() for _ in range(batch_size)] + iteration = 0 + while iteration != args.iterations: + iteration += 1 + model_id = ray.put(model) + actions = [] + # Launch tasks to compute gradients from multiple rollouts in parallel. + start_time = time.time() + # run rall_out for batch_size times + for i in range(batch_size): + # compute_gradient returns two variables, so action_id is a list + action_id = actors[i].compute_gradient.remote(model_id) + actions.append(action_id) + for i in range(batch_size): + # wait for one actor to finish its operation + # action_id is the ready object id + action_id, actions = ray.wait(actions) + grad, reward_sum = ray.get(action_id[0]) + # Accumulate the gradient of each weight parameter over batch. + for k in model: + grad_buffer[k] += grad[k] + running_reward = (reward_sum if running_reward is None else + running_reward * 0.99 + reward_sum * 0.01) + end_time = time.time() + print("Batch {} computed {} rollouts in {} seconds, " + "running mean is {}".format(batch_num, batch_size, + end_time - start_time, + running_reward)) + # update gradient after one iteration + for k, v in model.items(): + g = grad_buffer[k] + rmsprop_cache[k] = ( + decay_rate * rmsprop_cache[k] + (1 - decay_rate) * g**2) + model[k] += learning_rate * g / (np.sqrt(rmsprop_cache[k]) + 1e-5) + # Reset the batch gradient buffer. + grad_buffer[k] = np.zeros_like(v) + batch_num += 1 + + ray_ctx.stop() + sc.stop() diff --git a/pyzoo/zoo/examples/rayexample/rllibexample/__init__.py b/pyzoo/zoo/examples/rayexample/rllibexample/__init__.py new file mode 100644 index 00000000000..5976dc4df02 --- /dev/null +++ b/pyzoo/zoo/examples/rayexample/rllibexample/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2018 Analytics Zoo Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/pyzoo/zoo/examples/rayexample/rllibexample/multiagent_two_trainers.py b/pyzoo/zoo/examples/rayexample/rllibexample/multiagent_two_trainers.py new file mode 100644 index 00000000000..8c289df352a --- /dev/null +++ b/pyzoo/zoo/examples/rayexample/rllibexample/multiagent_two_trainers.py @@ -0,0 +1,141 @@ +# This file is adapted from https://github.com/ray-project/ray/blob +# /master/python/ray/rllib/examples/multiagent_two_trainers.py +# +# Copyright 2018 Analytics Zoo Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +"""Example of using two different training methods at once in multi-agent. +Here we create a number of CartPole agents, some of which are trained with +DQN, and some of which are trained with PPO. We periodically sync weights +between the two trainers (note that no such syncing is needed when using just +a single training method). +For a simpler example, see also: multiagent_cartpole.py +""" + +import argparse +import gym +import os + +import ray +from ray.rllib.agents.dqn.dqn import DQNTrainer +from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy +from ray.rllib.agents.ppo.ppo import PPOTrainer +from ray.rllib.agents.ppo.ppo_policy import PPOTFPolicy +from ray.rllib.tests.test_multi_agent_env import MultiCartpole +from ray.tune.logger import pretty_print +from ray.tune.registry import register_env +from zoo import init_spark_on_yarn, init_spark_on_local +from zoo.ray.util.raycontext import RayContext +os.environ["LANG"] = "C.UTF-8" + +parser = argparse.ArgumentParser() +parser.add_argument("--num-iters", type=int, default=20) +parser.add_argument("--hadoop_conf", type=str, + help="turn on yarn mode by passing the path to the hadoop" + " configuration folder. Otherwise, turn on local mode.") + +if __name__ == "__main__": + args = parser.parse_args() + if args.hadoop_conf: + slave_num = 2 + sc = init_spark_on_yarn( + hadoop_conf=args.hadoop_conf, + conda_name="ray36", + num_executor=slave_num, + executor_cores=28, + executor_memory="10g", + driver_memory="2g", + driver_cores=4, + extra_executor_memory_for_ray="30g") + ray_ctx = RayContext( + sc=sc, + object_store_memory="25g") + else: + sc = init_spark_on_local(cores=4) + ray_ctx = RayContext(sc=sc, object_store_memory="4g") + ray_ctx.init() + + # Simple environment with 4 independent cartpole entities + register_env("multi_cartpole", lambda _: MultiCartpole(4)) + single_env = gym.make("CartPole-v0") + obs_space = single_env.observation_space + act_space = single_env.action_space + + # You can also have multiple policies per trainer, but here we just + # show one each for PPO and DQN. + policies = { + "ppo_policy": (PPOTFPolicy, obs_space, act_space, {}), + "dqn_policy": (DQNTFPolicy, obs_space, act_space, {}), + } + + def policy_mapping_fn(agent_id): + if agent_id % 2 == 0: + return "ppo_policy" + else: + return "dqn_policy" + + ppo_trainer = PPOTrainer( + env="multi_cartpole", + config={ + "multiagent": { + "policies": policies, + "policy_mapping_fn": policy_mapping_fn, + "policies_to_train": ["ppo_policy"], + }, + # disable filters, otherwise we would need to synchronize those + # as well to the DQN agent + "observation_filter": "NoFilter", + }) + + dqn_trainer = DQNTrainer( + env="multi_cartpole", + config={ + "multiagent": { + "policies": policies, + "policy_mapping_fn": policy_mapping_fn, + "policies_to_train": ["dqn_policy"], + }, + "gamma": 0.95, + "n_step": 3, + }) + + # disable DQN exploration when used by the PPO trainer + ppo_trainer.optimizer.foreach_worker( + lambda ev: ev.for_policy( + lambda pi: pi.set_epsilon(0.0), policy_id="dqn_policy")) + + # You should see both the printed X and Y approach 200 as this trains: + # info: + # policy_reward_mean: + # dqn_policy: X + # ppo_policy: Y + for i in range(args.num_iters): + print("== Iteration", i, "==") + + # improve the DQN policy + print("-- DQN --") + print(pretty_print(dqn_trainer.train())) + + # improve the PPO policy + print("-- PPO --") + print(pretty_print(ppo_trainer.train())) + + # swap weights to synchronize + dqn_trainer.set_weights(ppo_trainer.get_weights(["ppo_policy"])) + ppo_trainer.set_weights(dqn_trainer.get_weights(["dqn_policy"])) + ray_ctx.stop() + sc.stop() diff --git a/pyzoo/zoo/examples/run-example-test-ray.sh b/pyzoo/zoo/examples/run-example-test-ray.sh new file mode 100644 index 00000000000..c89b53f9ccc --- /dev/null +++ b/pyzoo/zoo/examples/run-example-test-ray.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +export SPARK_HOME=$SPARK_HOME +export MASTER=local[4] +export FTP_URI=$FTP_URI +export ANALYTICS_ZOO_ROOT=$ANALYTICS_ZOO_ROOT +export ANALYTICS_ZOO_HOME=$ANALYTICS_ZOO_ROOT/dist +export ANALYTICS_ZOO_JAR=`find ${ANALYTICS_ZOO_HOME}/lib -type f -name "analytics-zoo*jar-with-dependencies.jar"` +export ANALYTICS_ZOO_PYZIP=`find ${ANALYTICS_ZOO_HOME}/lib -type f -name "analytics-zoo*python-api.zip"` +export ANALYTICS_ZOO_CONF=${ANALYTICS_ZOO_HOME}/conf/spark-analytics-zoo.conf +export PYTHONPATH=${ANALYTICS_ZOO_PYZIP}:$PYTHONPATH +export BIGDL_JARS=`find ${ANALYTICS_ZOO_HOME}/lib -type f -name "analytics-zoo*jar-with-dependencies.jar"` + +set -e + +echo "Start ray exmples tests" +#start execute +echo "Start pong example" +start=$(date "+%s") +python ${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/rayexample/rl_pong/rl_pong.py --iterations 10 +now=$(date "+%s") +time1=$((now-start)) + +echo "Start async_parameter example" +start=$(date "+%s") +python ${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/rayexample/parameter_server/async_parameter_server.py --iterations 10 +now=$(date "+%s") +time2=$((now-start)) + +echo "Start sync_parameter example" +start=$(date "+%s") +python ${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/rayexample/parameter_server/sync_parameter_server.py --iterations 10 +now=$(date "+%s") +time3=$((now-start)) + +echo "Start multiagent example" +start=$(date "+%s") +python ${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/rayexample/rllibexample/multiagent_two_trainers.py +now=$(date "+%s") +time4=$((now-start)) + +echo "End ray example tests" +echo "#9 rl_pong time used:$time1 seconds" +echo "#10 sync_parameter_server time used:$time2 seconds" +echo "#11 async_parameter_server time used:$time3 seconds" +echo "#12 multiagent_two_trainers time used:$time3 seconds" \ No newline at end of file diff --git a/pyzoo/zoo/examples/run-example-tests-pip-ray.sh b/pyzoo/zoo/examples/run-example-tests-pip-ray.sh new file mode 100644 index 00000000000..738e20714ba --- /dev/null +++ b/pyzoo/zoo/examples/run-example-tests-pip-ray.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash +clear_up () { + echo "Clearing up environment. Uninstalling analytics-zoo" + pip uninstall -y analytics-zoo + pip uninstall -y bigdl + pip uninstall -y pyspark +} +#if image exist this two dependency, remove below +execute_ray_test(){ + echo "start example $1" + start=$(date "+%s") + python $2 + exit_status=$? + if [ $exit_status -ne 0 ]; + then + clear_up + echo "$1 failed" + exit $exit_status + fi + now=$(date "+%s") + return $((now-start)) +} + +execute_ray_test rl_pong ${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/rayexample/rl_pong/rl_pong.py --iterations 10 +time9=$? + +execute_ray_test sync_parameter_server ${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/rayexample/parameter_server/sync_parameter_server.py --iterations 10 +time10=$? + +execute_ray_test async_parameter_server ${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/rayexample/parameter_server/async_parameter_server.py --iterations 10 +time11=$? + +execute_ray_test multiagent_two_trainers ${ANALYTICS_ZOO_ROOT}/pyzoo/zoo/examples/rayexample/rllibexample/multiagent_two_trainers.py +time12=$? + +echo "#9 rl_pong time used:$time9 seconds" +echo "#10 sync_parameter_server time used:$time10 seconds" +echo "#11 async_parameter_server time used:$time11 seconds" +echo "#12 multiagent_two_trainers time used:$time12 seconds" + +clear_up \ No newline at end of file diff --git a/pyzoo/zoo/examples/run-example-tests-pip.sh b/pyzoo/zoo/examples/run-example-tests-pip.sh index df46041c914..d429682e2b6 100644 --- a/pyzoo/zoo/examples/run-example-tests-pip.sh +++ b/pyzoo/zoo/examples/run-example-tests-pip.sh @@ -1,5 +1,4 @@ #!/usr/bin/env bash - clear_up () { echo "Clearing up environment. Uninstalling analytics-zoo" pip uninstall -y analytics-zoo @@ -375,6 +374,5 @@ now=$(date "+%s") time9=$((now-start)) echo "qaranker time used:$time9 seconds" - # This should be done at the very end after all tests finish. clear_up diff --git a/pyzoo/zoo/examples/run-example-tests.sh b/pyzoo/zoo/examples/run-example-tests.sh index f15831dab27..fb813e8e68e 100644 --- a/pyzoo/zoo/examples/run-example-tests.sh +++ b/pyzoo/zoo/examples/run-example-tests.sh @@ -363,8 +363,6 @@ ${SPARK_HOME}/bin/spark-submit \ now=$(date "+%s") time8=$((now-start)) - - echo "#1 textclassification time used:$time1 seconds" echo "#2 customized loss and layer time used:$time2 seconds" echo "#3 image-classification time used:$time3 seconds" diff --git a/pyzoo/zoo/pipeline/api/keras/optimizers.py b/pyzoo/zoo/pipeline/api/keras/optimizers.py index af910c505cd..1801dd53057 100644 --- a/pyzoo/zoo/pipeline/api/keras/optimizers.py +++ b/pyzoo/zoo/pipeline/api/keras/optimizers.py @@ -105,3 +105,60 @@ def __init__(self, epsilon, weight_decay) self.bigdl_type = bigdl_type + + +from bigdl.optim.optimizer import DistriOptimizer as BDistriOptimizer, SGD + + +class DistriOptimizer(BDistriOptimizer): + def __init__(self, + model, + training_rdd, + criterion, + end_trigger=None, + batch_size=32, + optim_method=None, + bigdl_type="float"): + """ + Create an optimizer. + + + :param model: the neural net model + :param training_data: the training dataset + :param criterion: the loss function + :param optim_method: the algorithm to use for optimization, + e.g. SGD, Adagrad, etc. If optim_method is None, the default algorithm is SGD. + :param end_trigger: when to end the optimization + :param batch_size: training batch size + """ + if not optim_method: + optim_methods = {model.name(): SGD()} + elif isinstance(optim_method, OptimMethod): + optim_methods = {model.name(): optim_method} + elif isinstance(optim_method, JavaObject): + optim_methods = {model.name(): OptimMethod(optim_method, bigdl_type)} + else: + optim_methods = optim_method + if isinstance(training_rdd, RDD): + self.bigdl_type = bigdl_type + self.value = callBigDlFunc(self.bigdl_type, "createDistriOptimizerFromRDD", + model.value, training_rdd, criterion, + optim_methods, end_trigger, batch_size) + + def set_validation(self, batch_size, val_rdd, trigger, val_method=None): + """ + Configure validation settings. + + + :param batch_size: validation batch size + :param val_rdd: validation dataset + :param trigger: validation interval + :param val_method: the ValidationMethod to use,e.g. "Top1Accuracy", "Top5Accuracy", "Loss" + """ + callBigDlFunc(self.bigdl_type, "setValidationWithPaddingStrategy", self.value, batch_size, + trigger, val_rdd, to_list(val_method)) + + +class EpochStep(JavaValue): + def __init__(self, step_size, gamma, bigdl_type="float"): + JavaValue.__init__(self, None, bigdl_type, step_size, float(gamma)) diff --git a/pyzoo/zoo/pipeline/api/net/tf_dataset.py b/pyzoo/zoo/pipeline/api/net/tf_dataset.py index cd09d995711..6a2cf26bc6d 100644 --- a/pyzoo/zoo/pipeline/api/net/tf_dataset.py +++ b/pyzoo/zoo/pipeline/api/net/tf_dataset.py @@ -40,7 +40,7 @@ def _to_tensor_structure(tensors): elif isinstance(tensors, dict): tensor_structure = {} for key, value in tensors.items(): - tensor_structure[key] = TensorMeta(dtype=value[0], shape=value[1], name=key) + tensor_structure[key] = TensorMeta(dtype=value[0], shape=value[1], name="input_"+key) else: raise ValueError("In TFDataset.from_rdd, features and labels should be a tuple, " "a list of tuples or a dict of tuples") diff --git a/pyzoo/zoo/pipeline/api/net/tf_optimizer.py b/pyzoo/zoo/pipeline/api/net/tf_optimizer.py index 42cf7aba309..937e53736e4 100644 --- a/pyzoo/zoo/pipeline/api/net/tf_optimizer.py +++ b/pyzoo/zoo/pipeline/api/net/tf_optimizer.py @@ -23,9 +23,10 @@ from bigdl.nn.criterion import Criterion from bigdl.nn.layer import Layer -from bigdl.util.common import to_list, JavaValue -from bigdl.optim.optimizer import EveryEpoch, MaxEpoch, Optimizer +from bigdl.util.common import to_list, JavaValue, callBigDlFunc +from bigdl.optim.optimizer import EveryEpoch, MaxEpoch, SeveralIteration from zoo.pipeline.api.keras.engine.topology import to_bigdl_metric +from zoo.pipeline.api.keras.optimizers import DistriOptimizer from zoo.pipeline.api.net.utils import _find_placeholders, _check_the_same from zoo.util import nest @@ -53,6 +54,12 @@ def __init__(self, path, configProto): byte_arr = None super(TFTrainingHelper, self).__init__(None, "float", path, byte_arr) + def evaluate(self, dataset, batch_size, val_methods): + return callBigDlFunc(self.bigdl_type, + "tfEvaluate", + self.value, + dataset, batch_size, val_methods) + class TFOptimizer: def __init__(self, loss, optim_method, sess=None, dataset=None, inputs=None, @@ -181,10 +188,12 @@ def to_floats(vs): if val_outputs is not None and val_labels is not None: val_rdd = self.dataset.get_validation_data() + self.val_rdd = val_rdd if val_rdd is not None: val_method = [TFValidationMethod(m, len(val_outputs), len(val_labels)) for m in to_list(val_method)] training_rdd = sample_rdd + self.val_method = val_method elif val_split != 0.0: training_rdd, val_rdd = sample_rdd.randomSplit([1 - val_split, val_split]) @@ -194,18 +203,18 @@ def to_floats(vs): raise ValueError("Validation data is not specified. Please set " + "val rdd in TFDataset, or set val_split larger than zero") - self.optimizer = Optimizer.create(self.training_helper_layer, + self.optimizer = DistriOptimizer(self.training_helper_layer, training_rdd, IdentityCriterion(), batch_size=batch_size, optim_method=self.optim_method) self.optimizer.set_validation(self.dataset.batch_size, val_rdd, - EveryEpoch(), + SeveralIteration(50), val_method) else: training_rdd = sample_rdd - self.optimizer = Optimizer.create(self.training_helper_layer, + self.optimizer = DistriOptimizer(self.training_helper_layer, training_rdd, IdentityCriterion(), batch_size=batch_size, diff --git a/pyzoo/zoo/pipeline/api/net/torch_criterion.py b/pyzoo/zoo/pipeline/api/net/torch_criterion.py index 75dae2e34d3..2d2f532e77f 100644 --- a/pyzoo/zoo/pipeline/api/net/torch_criterion.py +++ b/pyzoo/zoo/pipeline/api/net/torch_criterion.py @@ -50,35 +50,37 @@ def __init__(self, path, bigdl_type="float"): super(TorchCriterion, self).__init__(None, bigdl_type, path) @staticmethod - def from_pytorch(loss, input_shape=None, label_shape=None, - sample_input=None, sample_label=None): + def from_pytorch(loss, input, label=None): """ - Create a TorchCriterion directly from PyTorch function. We need user to provide a sample - input and label to trace the loss function. User may just specify the input and label shape. - For specific data type or multiple input models, users can send sample_input and - sample_label. + Create a TorchCriterion directly from PyTorch function. We need users to provide example + input and label (or just their sizes) to trace the loss function. + :param loss: this can be a torch loss (e.g. nn.MSELoss()) or - a function that take two Tensor parameters: input and label. E.g. + a function that takes two Tensor parameters: input and label. E.g. def lossFunc(input, target): return nn.CrossEntropyLoss().forward(input, target.flatten().long()) - - :param input_shape: list of integers. - :param label_shape: list of integers. If not specified, it will be set equal to input_shape - :param sample_input: a sample of input. - :param sample_label: a sample of label. + :param input: example input. It can be: + 1. a torch tensor, or tuple of torch tensors for multi-input models + 2. list of integers, or tuple of int list for multi-input models. E.g. For + ResNet, this can be [1, 3, 224, 224]. A random tensor with the + specified size will be used as the example input. + :param label: example label. It can be: + 1. a torch tensor, or tuple of torch tensors for multi-input models + 2. list of integers, or tuple of int list for multi-input models. E.g. For + ResNet, this can be [1, 3, 224, 224]. A random tensor with the + specified size will be used as the example input. + When label is None, input will also be used as label. """ - if not input_shape and not label_shape and not sample_input and not sample_label: - raise Exception("please specify input_shape and label_shape, or sample_input" - " and sample_label") + if input is None: + raise Exception("please specify input and label") temp = tempfile.mkdtemp() - # use input_shape as label shape when label_shape is not specified - if not label_shape: - label_shape = input_shape + if label is None: + label = input - sample_input = TorchNet.get_sample_input(input_shape, sample_input) - sample_label = TorchNet.get_sample_input(label_shape, sample_label) + sample_input = TorchNet.get_sample_input(input) + sample_label = TorchNet.get_sample_input(label) traced_script_loss = torch.jit.trace(LossWrapperModule(loss), (sample_input, sample_label)) diff --git a/pyzoo/zoo/pipeline/api/net/torch_net.py b/pyzoo/zoo/pipeline/api/net/torch_net.py index 6059df374d0..7462a21125e 100644 --- a/pyzoo/zoo/pipeline/api/net/torch_net.py +++ b/pyzoo/zoo/pipeline/api/net/torch_net.py @@ -43,23 +43,31 @@ def __init__(self, path, bigdl_type="float"): super(TorchNet, self).__init__(None, bigdl_type, path) @staticmethod - def from_pytorch(module, input_shape=None, sample_input=None): + def from_pytorch(module, input, check_trace=True): """ Create a TorchNet directly from PyTorch model, e.g. model in torchvision.models. - Users need to specify sample_input or input_shape. + Users need to provide an example input or the input tensor shape. :param module: a PyTorch model - :param input_shape: list of integers, or tuple of list for multiple inputs models. E.g. - for ResNet, this may be [1, 3, 224, 224] - :param sample_input. A sample of Torch Tensor or tuple to trace the model. + :param input: To trace the tensor operations, torch jit trace requires users to + provide an example input. Here the input parameter can be: + 1. a torch tensor, or tuple of torch tensors for multi-input models + 2. list of integers, or tuple of int list for multi-input models. E.g. For + ResNet, this can be [1, 3, 224, 224]. A random tensor with the + specified size will be used as the example input. + :param check_trace: boolean value, optional. check if the same inputs run through + traced module produce the same outputs. Default: ``True``. You + might want to disable this if, for example, your network contains + non-deterministic ops or if you are sure that the network is + correct despite a checker failure. """ - if not input_shape and not sample_input: - raise Exception("please specify input_shape or sample_input") + if input is None: + raise Exception("please provide an example input or input Tensor size") - sample = TorchNet.get_sample_input(input_shape, sample_input) + sample = TorchNet.get_sample_input(input) temp = tempfile.mkdtemp() # save model - traced_script_module = torch.jit.trace(module, sample) + traced_script_module = torch.jit.trace(module, sample, check_trace=check_trace) path = os.path.join(temp, "model.pt") traced_script_module.save(path) @@ -69,15 +77,20 @@ def from_pytorch(module, input_shape=None, sample_input=None): return net @staticmethod - def get_sample_input(shape, sample): - if sample: - return sample - elif isinstance(shape, list): - return torch.rand(shape) - elif isinstance(shape, tuple): - return tuple(map(lambda s: torch.rand(s), shape)) - else: - raise Exception("please specify shape as list of ints or tuples of int lists") + def get_sample_input(input): + if isinstance(input, torch.Tensor): + return input + + elif isinstance(input, (list, tuple)) and len(input) > 0: + if all(isinstance(x, torch.Tensor) for x in input): # tensors + return tuple(input) + elif all(isinstance(x, int) for x in input): # ints + return torch.rand(input) + elif all(isinstance(x, (list, tuple)) for x in input) and \ + all(isinstance(y, int) for x in input for y in x): # nested int list (tuple) + return tuple(map(lambda size: torch.rand(size), input)) + + raise Exception("Unsupported input type: " + str(input)) def predict(self, x, batch_per_thread=1, distributed=True): """ diff --git a/zoo/src/main/scala/com/intel/analytics/zoo/pipeline/api/keras/python/PythonZooKeras.scala b/zoo/src/main/scala/com/intel/analytics/zoo/pipeline/api/keras/python/PythonZooKeras.scala index 956836e0293..006b29ddddf 100644 --- a/zoo/src/main/scala/com/intel/analytics/zoo/pipeline/api/keras/python/PythonZooKeras.scala +++ b/zoo/src/main/scala/com/intel/analytics/zoo/pipeline/api/keras/python/PythonZooKeras.scala @@ -18,11 +18,11 @@ package com.intel.analytics.zoo.pipeline.api.keras.python import java.util.{List => JList, Map => JMap} -import com.intel.analytics.bigdl.{Criterion, Module} -import com.intel.analytics.bigdl.dataset.{DataSet, LocalDataSet, MiniBatch} +import com.intel.analytics.bigdl.{Criterion, DataSet, Module} +import com.intel.analytics.bigdl.dataset.{Identity => DIdentity, Sample => JSample, _} import scala.collection.JavaConverters._ -import com.intel.analytics.bigdl.optim.{_} +import com.intel.analytics.bigdl.optim._ import com.intel.analytics.bigdl.python.api.{EvaluatedResult, JTensor, Sample} import com.intel.analytics.bigdl.tensor.Tensor import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric @@ -36,7 +36,7 @@ import com.intel.analytics.zoo.feature.image.ImageSet import com.intel.analytics.zoo.pipeline.api.autograd.{Constant, _} import com.intel.analytics.zoo.pipeline.api.keras.layers.{KerasLayerWrapper, _} import com.intel.analytics.zoo.pipeline.api.keras.layers.utils.KerasUtils -import com.intel.analytics.zoo.pipeline.api.keras.models.{KerasNet, Model, Sequential} +import com.intel.analytics.zoo.pipeline.api.keras.models.{InternalDistriOptimizer, KerasNet, Model, Sequential} import com.intel.analytics.zoo.pipeline.api.keras.objectives._ import com.intel.analytics.zoo.pipeline.api.keras.optimizers.{Adam, AdamWeightDecay} import org.apache.spark.api.java.JavaRDD @@ -1355,4 +1355,82 @@ class PythonZooKeras[T: ClassTag](implicit ev: TensorNumeric[T]) extends PythonZ new AdamWeightDecay[T](learningRate, warmupPortion, total, schedule, beta1, beta2, epsilon, weightDecay) } + + def batchingWithPaddingStrategy(dataset: DataSet[JSample[T]], batchSize: Int, featureSize: Int) + : DataSet[MiniBatch[T]] = { + println("Using Feature Padding Strategy") + val paddingTensor = Tensor[T](1).fill(ev.fromType(-1.0)) + val widePaddingTensor = Tensor[T](1).fill(ev.fromType(0.0)) + val paddingArray = Array.fill[Tensor[T]](featureSize-1)(paddingTensor) ++ Array(widePaddingTensor) + val featurePaddingParam = PaddingParam[T](Some(paddingArray)) + dataset.transform(SampleToMiniBatch( + batchSize = batchSize, featurePaddingParam = Some(featurePaddingParam))) + } + + def createDistriOptimizerFromRDD(model: AbstractModule[Activity, Activity, T], + trainingRdd: JavaRDD[Sample], + criterion: Criterion[T], + optimMethod: JMap[String, OptimMethod[T]], + endTrigger: Trigger, + batchSize: Int): Optimizer[T, MiniBatch[T]] = { + val sampleRDD = toJSample(trainingRdd) + val featureSize = sampleRDD.first().numFeature() + + val optimizer = new InternalDistriOptimizer( + _model = model, + _dataset = batchingWithPaddingStrategy(DataSet.rdd(sampleRDD), batchSize, featureSize) + .asInstanceOf[DistributedDataSet[MiniBatch[T]]], + _criterion = criterion + ).asInstanceOf[Optimizer[T, MiniBatch[T]]] + enrichOptimizer(optimizer, endTrigger, optimMethod.asScala.toMap) + } + + private def enrichOptimizer[T]( + optimizer: Optimizer[T, MiniBatch[T]], + endTrigger: Trigger, + optimMethod: Map[String, OptimMethod[T]]): Optimizer[T, MiniBatch[T]] = { + optimizer.setEndWhen(endTrigger) + + optimizer.setOptimMethods(optimMethod) + + // TODO: remove this + optimizer.disableCheckSingleton() + + optimizer + } + + def setValidationWithPaddingStrategy(optimizer: Optimizer[T, MiniBatch[T]], + batchSize: Int, + trigger: Trigger, + valRdd: JavaRDD[Sample], + vMethods: JList[ValidationMethod[T]]): Unit = { + val sampleRDD = toJSample(valRdd) + val featureSize = sampleRDD.first().numFeature() + optimizer.setValidation(trigger, + batchingWithPaddingStrategy(DataSet.rdd(sampleRDD), batchSize, featureSize), + vMethods.asScala.toArray) + } + + def createEpochStep(stepSize: Int, gamma: Double): SGD.EpochStep = { + SGD.EpochStep(stepSize, gamma) + } + + + def tfEvaluate(model: AbstractModule[Activity, Activity, T], + valRDD: JavaRDD[Sample], + batchSize: Int, + valMethods: JList[ValidationMethod[T]]) + : JList[EvaluatedResult] = { + val sampleRDD = toJSample(valRDD) + val featureSize = sampleRDD.first().numFeature() + val dataSet = batchingWithPaddingStrategy(DataSet.rdd(sampleRDD), batchSize, featureSize) + val rdd = dataSet.toDistributed().data(train = false) + val resultArray = model.evaluate(rdd, + valMethods.asScala.toArray) + val testResultArray = resultArray.map { result => + EvaluatedResult(result._1.result()._1, result._1.result()._2, + result._2.toString()) + } + testResultArray.toList.asJava + } } diff --git a/zoo/src/main/scala/com/intel/analytics/zoo/pipeline/api/net/TFTrainingHelper.scala b/zoo/src/main/scala/com/intel/analytics/zoo/pipeline/api/net/TFTrainingHelper.scala index 91fbb3d5df1..600755feac0 100644 --- a/zoo/src/main/scala/com/intel/analytics/zoo/pipeline/api/net/TFTrainingHelper.scala +++ b/zoo/src/main/scala/com/intel/analytics/zoo/pipeline/api/net/TFTrainingHelper.scala @@ -57,6 +57,27 @@ private[zoo] class TFTrainingHelper(tfnet: TFNet, setWeights(ws) } + private val weightsMap = { + val map = collection.mutable.Map[String, Tensor[Float]]() + var i = 0 + while (i < variables.length) { + map(variables(i)) = weights(i) + i +=1 + } + map + } + + override def apply(name: String): Option[AbstractModule[Activity, Activity, Float]] = { + val targetVariables = if (name == getName()) variables else variables.filter(_.contains(name)) + if (targetVariables == null) { + None + } + else { + val targetWeights = targetVariables.map(weightsMap) + Some(new TFSubGraph(targetWeights)) + } + } + private val gradWeights = variables.map(_ => Tensor[Float]()) @@ -144,6 +165,19 @@ private[zoo] class TFTrainingHelper(tfnet: TFNet, } } +private class TFSubGraph(weights: Array[Tensor[Float]]) extends AbstractModule[Activity, Activity, Float] { + override def updateOutput(input: Activity): Activity = { + input + } + override def updateGradInput(input: Activity, gradOutput: Activity): Activity = { + gradInput + } + + override def parameters(): (Array[Tensor[Float]], Array[Tensor[Float]]) = { + (weights, weights.map(_ => Tensor[Float]())) + } +} + object TFTrainingHelper { def apply(modelPath: String, sessionConfig: Array[Byte] = null): TFTrainingHelper = { diff --git a/zoo/src/test/scala/com/intel/analytics/zoo/pipeline/api/net/TFNetSpec.scala b/zoo/src/test/scala/com/intel/analytics/zoo/pipeline/api/net/TFNetSpec.scala index 3d5178279cf..b51e8ee37d1 100644 --- a/zoo/src/test/scala/com/intel/analytics/zoo/pipeline/api/net/TFNetSpec.scala +++ b/zoo/src/test/scala/com/intel/analytics/zoo/pipeline/api/net/TFNetSpec.scala @@ -16,19 +16,24 @@ package com.intel.analytics.zoo.pipeline.api.net -import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule +import com.intel.analytics.bigdl.dataset._ +import com.intel.analytics.bigdl.optim.{DistriOptimizer, SGD} import com.intel.analytics.bigdl.tensor.Tensor -import com.intel.analytics.bigdl.utils.{LayerException, T} -import com.intel.analytics.zoo.pipeline.api.Net -import com.intel.analytics.zoo.pipeline.api.keras.ZooSpecHelper -import com.intel.analytics.zoo.pipeline.api.keras.serializer.ModuleSerializationTest +import com.intel.analytics.bigdl.utils.{Engine, LayerException, LoggerFilter, T} +import com.intel.analytics.zoo.common.MaxEpoch +import com.intel.analytics.zoo.pipeline.api.keras.optimizers.Adam +import org.apache.log4j.{Level, Logger} import org.apache.spark.serializer.KryoSerializer -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkContext} import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers} -import scala.util.Random class TFNetSpec extends FlatSpec with Matchers with BeforeAndAfter { + LoggerFilter.redirectSparkInfoLogs() + Logger.getLogger("org").setLevel(Level.WARN) + Logger.getLogger("akka").setLevel(Level.WARN) + Logger.getLogger("com.intel.analytics.bigdl.optim").setLevel(Level.INFO) + Logger.getLogger("com.intel.analytics.bigdl").setLevel(Level.INFO) "TFNet " should "work with different data types" in { @@ -137,4 +142,33 @@ class TFNetSpec extends FlatSpec with Matchers with BeforeAndAfter { gradInput.size() should be (input.size()) } + "TFTraningHelper " should "work properly" in { + val conf = Engine.createSparkConf() + .setAppName("Optimizer test") + .setMaster("local[4]") + val sc = new SparkContext(conf) + Engine.init + val layer = TFTrainingHelper("/home/kai/mnist_keras") + val data = new Array[Sample[Float]](500) + var i = 0 + while (i < data.length) { + val input = Tensor[Float](28, 28, 1).rand() + val label = Tensor[Float](1).fill(0.0f) + data(i) = Sample(Array(input, label), label) + i += 1 + } + + val rdd = sc.parallelize(data) + val dataSet = DataSet.rdd(rdd) -> SampleToMiniBatch[Float](128) + val optimizer = new DistriOptimizer[Float]( + layer, + dataSet.asInstanceOf[DistributedDataSet[MiniBatch[Float]]], + new IdentityCriterion()) + .setOptimMethods( + Map("dense/" -> new SGD[Float](0.001), "dense_" -> new Adam[Float](0.02))) + .setEndWhen(MaxEpoch(2)) + optimizer.optimize() + println("1111") + } + }