Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 80 additions & 52 deletions tensorflow/core/kernels/save_restore_v2_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.

#include <string>
#include <vector>
#include <iostream>

#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op_kernel.h"
Expand All @@ -31,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/tensor_bundle/db_writer.h"
#include "tensorflow/core/util/saved_tensor_slice_util.h"
#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
#include "tensorflow/core/util/tensor_slice_reader.h"
Expand Down Expand Up @@ -106,25 +108,22 @@ class SaveV2 : public OpKernel {
}

template <typename TKey, typename TValue>
void DumpEvWithGlobalStep(OpKernelContext* context, int variable_index,
const string& tensor_name, BundleWriter& writer,
DataType global_step_type) {
void DumpEvWithGlobalStep(OpKernelContext* context, int variable_index,
const string& tensor_name, BundleWriter& writer,
DataType global_step_type) {
if (global_step_type == DT_INT32) {
DumpEv<TKey, TValue, int32>(context, variable_index,
tensor_name, writer);
DumpEv<TKey, TValue, int32>(context, variable_index, tensor_name, writer);
} else {
DumpEv<TKey, TValue, int64>(context, variable_index,
tensor_name, writer);
DumpEv<TKey, TValue, int64>(context, variable_index, tensor_name, writer);
}
}

template <typename TKey, typename TValue, typename TGlobalStep>
void DumpEv(OpKernelContext* context, int variable_index,
const string& tensor_name, BundleWriter& writer) {
void DumpEv(OpKernelContext* context, int variable_index,
const string& tensor_name, BundleWriter& writer) {
EmbeddingVar<TKey, TValue>* variable = nullptr;
OP_REQUIRES_OK(context,
LookupResource(context,
HandleFromInput(context, variable_index), &variable));
LookupResource(context, HandleFromInput(context, variable_index), &variable));
const Tensor& global_step = context->input(3);
Tensor part_offset_tensor;
context->allocate_temp(DT_INT32,
Expand All @@ -136,8 +135,7 @@ class SaveV2 : public OpKernel {
OP_REQUIRES_OK(context, variable->Shrink());
else
OP_REQUIRES_OK(context, variable->Shrink(global_step_scalar));
OP_REQUIRES_OK(context, DumpEmbeddingValues(variable, tensor_name,
&writer, &part_offset_tensor));
OP_REQUIRES_OK(context, DumpEmbeddingValues(variable, tensor_name, &writer, &part_offset_tensor));
}

void Compute(OpKernelContext* context) override {
Expand All @@ -146,36 +144,53 @@ class SaveV2 : public OpKernel {
const Tensor& shape_and_slices = context->input(2);
ValidateInputs(true /* is save op */, context, prefix, tensor_names,
shape_and_slices);
if (!context->status().ok()) return;

const int kFixedInputs = 3; // Prefix, tensor names, shape_and_slices.
const int num_tensors = static_cast<int>(tensor_names.NumElements());
const int num_tensors = static_cast<int>(tensor_names.NumElements());
const string& prefix_string = prefix.scalar<tstring>()();
const auto& tensor_names_flat = tensor_names.flat<tstring>();
const auto& shape_and_slices_flat = shape_and_slices.flat<tstring>();

BundleWriter writer(Env::Default(), prefix_string);
const int nosql_marker = 0;
auto tempstate = random::New64();
string db_prefix_tmp = strings::StrCat(prefix_string,"--temp",tempstate);
DBWriter dbwriter(Env::Default(), prefix_string,db_prefix_tmp);
OP_REQUIRES_OK(context, dbwriter.status());

BundleWriter writer(Env::Default(), prefix_string,db_prefix_tmp);
OP_REQUIRES_OK(context, writer.status());
VLOG(1) << "BundleWriter, prefix_string: " << prefix_string;


int start_index = 0;
if (has_ev_) {
start_index = 1;
}


int start_ev_key_index = 0;


for (int i = start_index; i < num_tensors; ++i) {
const string& tensor_name = tensor_names_flat(i);
const string& tensor_name = tensor_names_flat(i);


if (tensor_types_[i] == DT_RESOURCE) {
auto& handle = HandleFromInput(context, i + kFixedInputs);
if (IsHandle<EmbeddingVar<int64, float>>(handle)) {
EmbeddingVar<int64, float>* variable = nullptr;
OP_REQUIRES_OK(context,
LookupResource(context, HandleFromInput(context, i + kFixedInputs), &variable));
core::ScopedUnref unref_variable(variable);
const Tensor& global_step = context->input(3);
Tensor part_offset_tensor;
context->allocate_temp(DT_INT32,
TensorShape({kSavedPartitionNum + 1}),
&part_offset_tensor);

if (ev_key_types_[start_ev_key_index] == DT_INT32) {
DumpEvWithGlobalStep<int32, float>(context,
i + kFixedInputs, tensor_name, writer, tensor_types_[0]);
DumpEvWithGlobalStep<int32, float>(context, i + kFixedInputs, tensor_name, writer, tensor_types_[0]);
} else if (ev_key_types_[start_ev_key_index] == DT_INT64) {
DumpEvWithGlobalStep<int64, float>(context,
i + kFixedInputs, tensor_name, writer, tensor_types_[0]);
DumpEvWithGlobalStep<int64, float>(context, i + kFixedInputs, tensor_name, writer, tensor_types_[0]);
}
} else if (IsHandle<HashTableResource>(handle)) {
auto handles = context->input(i + kFixedInputs).flat<ResourceHandle>();
Expand Down Expand Up @@ -205,7 +220,6 @@ class SaveV2 : public OpKernel {

OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
shape_spec, &shape, &slice, &slice_shape));

std::vector<string> names_lst = str_util::Split(tensor_name, '|');
for (auto&& name : names_lst) {
std::vector<string> tensor_name_x =
Expand All @@ -218,15 +232,14 @@ class SaveV2 : public OpKernel {
OP_REQUIRES_OK(context, SaveHashTable(
&writer, hashtable, tensibles, table_name, tensible_name,
slice.start(0), slice.length(0), slice_shape.dim_size(0)));

}
} else if (IsHandle<HashTableAdmitStrategyResource>(handle)) {
HashTableAdmitStrategyResource* resource;
OP_REQUIRES_OK(context,
LookupResource(context,
HandleFromInput(context, i + kFixedInputs), &resource));
LookupResource(context, HandleFromInput(context, i + kFixedInputs), &resource));
HashTableAdmitStrategy* strategy = resource->Internal();
BloomFilterAdmitStrategy* bf =
dynamic_cast<BloomFilterAdmitStrategy*>(strategy);
BloomFilterAdmitStrategy* bf = dynamic_cast<BloomFilterAdmitStrategy*>(strategy);
CHECK(bf != nullptr) << "Cannot save Non-BloomFilterAdmitStrategy!";

string shape_spec = shape_and_slices_flat(i);
Expand All @@ -240,33 +253,54 @@ class SaveV2 : public OpKernel {
&writer, bf, tensor_name, slice.start(0),
slice.length(0), slice_shape.dim_size(0)));
}

start_ev_key_index++;
} else {
} else {
const Tensor& tensor = context->input(i + kFixedInputs);

if (!shape_and_slices_flat(i).empty()) {
const string& shape_spec = shape_and_slices_flat(i);
TensorShape shape;
TensorSlice slice(tensor.dims());
TensorShape slice_shape;


OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
shape_spec, &shape, &slice, &slice_shape));
shape_spec, &shape, &slice, &slice_shape));
OP_REQUIRES(context, slice_shape.IsSameSize(tensor.shape()),
errors::InvalidArgument("Slice in shape_and_slice "
"specification does not match the "
"shape of the tensor to save: ",
shape_spec, ", tensor: ",
tensor.shape().DebugString()));

OP_REQUIRES_OK(context,
writer.AddSlice(tensor_name, shape, slice, tensor));
errors::InvalidArgument("Slice in shape_and_slice "
"specification does not match the "
"shape of the tensor to save: ",
shape_spec, ", tensor: ",
tensor.shape().DebugString()));

if(nosql_marker==1){

OP_REQUIRES_OK(context,
dbwriter.AddSlice(tensor_name, shape, slice, tensor,"slice_tensor"));
} else{

OP_REQUIRES_OK(context,
writer.AddSlice(tensor_name, shape, slice, tensor));
}
} else {
OP_REQUIRES_OK(context, writer.Add(tensor_name, tensor));
if(nosql_marker==1){
OP_REQUIRES_OK(context,
dbwriter.Add(tensor_name, tensor,"normal_tensor"));
} else{
string tmp_dbfile_prefix_string =
strings::StrCat(prefix_string,"--temp",tempstate,"--data--0--1","--tensor--",tensor_name);
OP_REQUIRES_OK(context, writer.Add(tensor_name, tensor,tmp_dbfile_prefix_string));
}
}
}
}
OP_REQUIRES_OK(context, writer.Finish());
if(nosql_marker==1){

OP_REQUIRES_OK(context, dbwriter.Finish());
} else{

OP_REQUIRES_OK(context, writer.Finish());
}
}
private:
DataTypeVector tensor_types_;
Expand All @@ -278,8 +312,7 @@ REGISTER_KERNEL_BUILDER(Name("SaveV2").Device(DEVICE_CPU), SaveV2);
// Restores a list of named tensors from a tensor bundle (V2 checkpoint format).
class RestoreHashTableOp : public AsyncOpKernel {
public:
explicit RestoreHashTableOp(OpKernelConstruction* context)
: AsyncOpKernel(context) {
explicit RestoreHashTableOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("clear", &clear_));
}

Expand All @@ -289,8 +322,7 @@ class RestoreHashTableOp : public AsyncOpKernel {
const Tensor& shape_and_slices = context->input(2);
const Tensor& handles = context->input(3);
const string& prefix_string = prefix.scalar<string>()();
const string& shape_and_slices_string =
shape_and_slices.scalar<string>()();
const string& shape_and_slices_string = shape_and_slices.scalar<string>()();
auto tensor_names_flat = tensor_names.flat<string>();
auto handles_flat = handles.flat<ResourceHandle>();

Expand Down Expand Up @@ -376,8 +408,7 @@ class RestoreHashTableOp : public AsyncOpKernel {
private:
bool clear_;
};
REGISTER_KERNEL_BUILDER(Name("RestoreHashTable").Device(DEVICE_CPU),
RestoreHashTableOp);
REGISTER_KERNEL_BUILDER(Name("RestoreHashTable").Device(DEVICE_CPU), RestoreHashTableOp);

class RestoreBloomFilterOp : public AsyncOpKernel {
public:
Expand Down Expand Up @@ -408,8 +439,7 @@ class RestoreBloomFilterOp : public AsyncOpKernel {
OP_REQUIRES_OK_ASYNC(
context, LookupResource(context, handle_flat, &resource), done);
strategy = dynamic_cast<BloomFilterAdmitStrategy*>(resource->Internal());
CHECK(strategy != nullptr)
<< "Cannot restore BloomFilter from another strategy";
CHECK(strategy != nullptr) << "Cannot restore BloomFilter from another strategy";
}
Status st = RestoreBloomFilter(
reader.get(), strategy, tensor_name_flat, slice.start(0),
Expand All @@ -418,8 +448,7 @@ class RestoreBloomFilterOp : public AsyncOpKernel {
done();
}
};
REGISTER_KERNEL_BUILDER(Name("RestoreBloomFilter").Device(DEVICE_CPU),
RestoreBloomFilterOp);
REGISTER_KERNEL_BUILDER(Name("RestoreBloomFilter").Device(DEVICE_CPU), RestoreBloomFilterOp);

// Restores a list of named tensors from a tensor bundle (V2 checkpoint format).
class RestoreV2 : public OpKernel {
Expand All @@ -438,7 +467,6 @@ class RestoreV2 : public OpKernel {
" expected dtypes."));
ValidateInputs(false /* not save op */, context, prefix, tensor_names,
shape_and_slices);
if (!context->status().ok()) return;

const string& prefix_string = prefix.scalar<tstring>()();

Expand Down Expand Up @@ -501,7 +529,7 @@ class MergeV2Checkpoints : public OpKernel {
const string& merged_prefix = destination_prefix.scalar<tstring>()();
OP_REQUIRES_OK(
context, tensorflow::MergeBundles(env, input_prefixes, merged_prefix));

if (delete_old_dirs_) {
const string merged_dir(io::Dirname(merged_prefix));
for (const string& input_prefix : input_prefixes) {
Expand Down
Loading