/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tez.dag.library.vertexmanager;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.dag.library.vertexmanager.FairShuffleUserPayloads.FairShuffleEdgeManagerConfigPayloadProto;
import org.apache.tez.dag.library.vertexmanager.FairShuffleUserPayloads.FairShuffleEdgeManagerDestinationTaskPropProto;
import org.apache.tez.dag.library.vertexmanager.FairShuffleUserPayloads.RangeProto;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map.Entry;
/**
* Handles edge configuration serialization and de-serialization between
* {@link FairShuffleVertexManager} and {@link FairShuffleEdgeManager}.
*/
class FairEdgeConfiguration {
private final int numBuckets;
private final HashMap<Integer, DestinationTaskInputsProperty>
destinationInputsProperties;
public FairEdgeConfiguration(int numBuckets,
HashMap<Integer, DestinationTaskInputsProperty> routingTable) {
this.destinationInputsProperties = routingTable;
this.numBuckets = numBuckets;
}
private FairShuffleEdgeManagerConfigPayloadProto getConfigPayload() {
FairShuffleEdgeManagerConfigPayloadProto.Builder builder =
FairShuffleEdgeManagerConfigPayloadProto.newBuilder();
builder.setNumBuckets(numBuckets);
if (destinationInputsProperties != null) {
for (Entry<Integer, DestinationTaskInputsProperty> entry :
destinationInputsProperties.entrySet()) {
FairShuffleEdgeManagerDestinationTaskPropProto.Builder taskBuilder =
FairShuffleEdgeManagerDestinationTaskPropProto.newBuilder();
taskBuilder.
setDestinationTaskIndex(entry.getKey()).
setPartitions(newRange(entry.getValue().getFirstPartitionId(),
entry.getValue().getNumOfPartitions())).
setSourceTasks(newRange(entry.getValue().
getFirstSourceTaskIndex(), entry.getValue().getNumOfSourceTasks()));
builder.addDestinationTaskProps(taskBuilder.build());
}
}
return builder.build();
}
private RangeProto newRange(int firstIndex, int numOfIndexes) {
return RangeProto.newBuilder().
setFirstIndex(firstIndex).setNumOfIndexes(numOfIndexes).build();
}
static FairEdgeConfiguration fromUserPayload(UserPayload payload)
throws InvalidProtocolBufferException {
HashMap<Integer, DestinationTaskInputsProperty> routingTable = new HashMap<>();
FairShuffleEdgeManagerConfigPayloadProto proto =
FairShuffleEdgeManagerConfigPayloadProto.parseFrom(
ByteString.copyFrom(payload.getPayload()));
int numBuckets = proto.getNumBuckets();
if (proto.getDestinationTaskPropsList() != null) {
for (int i = 0; i < proto.getDestinationTaskPropsList().size(); i++) {
FairShuffleEdgeManagerDestinationTaskPropProto propProto =
proto.getDestinationTaskPropsList().get(i);
routingTable.put(
propProto.getDestinationTaskIndex(),
new DestinationTaskInputsProperty(
propProto.getPartitions().getFirstIndex(),
propProto.getPartitions().getNumOfIndexes(),
propProto.getSourceTasks().getFirstIndex(),
propProto.getSourceTasks().getNumOfIndexes()));
}
}
return new FairEdgeConfiguration(numBuckets, routingTable);
}
public HashMap<Integer, DestinationTaskInputsProperty> getRoutingTable() {
return destinationInputsProperties;
}
// The number of partitions used by source vertex.
int getNumBuckets() {
return numBuckets;
}
UserPayload getBytePayload() {
return UserPayload.create(ByteBuffer.wrap(
getConfigPayload().toByteArray()));
}
}