package org.apache.cassandra.tools;
/*
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*/
import java.io.Closeable;
import java.io.IOException;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import javax.management.JMX;
import javax.management.MBeanServerConnection;
import javax.management.MalformedObjectNameException;
import javax.management.ObjectName;
import org.apache.cassandra.serializers.TimestampSerializer;
import org.apache.cassandra.dht.IPartitioner;
import org.apache.cassandra.dht.Token;
import org.apache.cassandra.locator.EndpointSnitchInfoMBean;
import org.apache.cassandra.service.StorageServiceMBean;
import org.apache.cassandra.thrift.Cassandra;
import org.apache.cassandra.thrift.Compression;
import org.apache.cassandra.thrift.ConsistencyLevel;
import org.apache.cassandra.thrift.CqlResult;
import org.apache.cassandra.thrift.CqlRow;
import org.apache.cassandra.thrift.InvalidRequestException;
import org.apache.cassandra.thrift.TimedOutException;
import org.apache.cassandra.thrift.TokenRange;
import org.apache.cassandra.thrift.UnavailableException;
import org.apache.cassandra.utils.ByteBufferUtil;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.MissingArgumentException;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.transport.TFastFramedTransport;
import org.apache.thrift.transport.TSocket;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
public class Shuffle extends AbstractJmxClient
{
private static final String ssObjName = "org.apache.cassandra.db:type=StorageService";
private static final String epSnitchObjName = "org.apache.cassandra.db:type=EndpointSnitchInfo";
private StorageServiceMBean ssProxy = null;
private Random rand = new Random(System.currentTimeMillis());
private final String thriftHost;
private final int thriftPort;
private final boolean thriftFramed;
static
{
addCmdOption("th", "thrift-host", true, "Thrift hostname or IP address (Default: JMX host)");
addCmdOption("tp", "thrift-port", true, "Thrift port number (Default: 9160)");
addCmdOption("tf", "thrift-framed", false, "Enable framed transport for Thrift (Default: false)");
addCmdOption("en", "and-enable", true, "Immediately enable shuffling (create only)");
addCmdOption("dc", "only-dc", true, "Apply only to named DC (create only)");
}
public Shuffle(String host, int port) throws IOException
{
this(host, port, host, 9160, false, null, null);
}
public Shuffle(String host, int port, String thriftHost, int thriftPort, boolean thriftFramed, String username, String password)
throws IOException
{
super(host, port, username, password);
this.thriftHost = thriftHost;
this.thriftPort = thriftPort;
this.thriftFramed = thriftFramed;
// Setup the StorageService proxy.
ssProxy = getSSProxy(jmxConn.getMbeanServerConn());
}
public StorageServiceMBean getSSProxy(MBeanServerConnection mbeanConn)
{
StorageServiceMBean proxy = null;
try
{
ObjectName name = new ObjectName(ssObjName);
proxy = JMX.newMBeanProxy(mbeanConn, name, StorageServiceMBean.class);
}
catch (MalformedObjectNameException e)
{
throw new RuntimeException(e);
}
return proxy;
}
public EndpointSnitchInfoMBean getEpSnitchProxy(MBeanServerConnection mbeanConn)
{
EndpointSnitchInfoMBean proxy = null;
try
{
ObjectName name = new ObjectName(epSnitchObjName);
proxy = JMX.newMBeanProxy(mbeanConn, name, EndpointSnitchInfoMBean.class);
}
catch (MalformedObjectNameException e)
{
throw new RuntimeException(e);
}
return proxy;
}
/**
* Given a Multimap of endpoint to tokens, return a new randomized mapping.
*
* @param endpointMap current mapping of endpoint to tokens
* @return a new mapping of endpoint to tokens
*/
public Multimap<String, String> calculateRelocations(Multimap<String, String> endpointMap)
{
Multimap<String, String> relocations = HashMultimap.create();
Set<String> endpoints = new HashSet<String>(endpointMap.keySet());
Map<String, Integer> endpointToNumTokens = new HashMap<String, Integer>(endpoints.size());
Map<String, Iterator<String>> iterMap = new HashMap<String, Iterator<String>>(endpoints.size());
// Create maps of endpoint to token iterators, and endpoint to number of tokens.
for (String endpoint : endpoints)
{
try
{
endpointToNumTokens.put(endpoint, ssProxy.getTokens(endpoint).size());
}
catch (UnknownHostException e)
{
throw new RuntimeException("What that...?", e);
}
iterMap.put(endpoint, endpointMap.get(endpoint).iterator());
}
int epsToComplete = endpoints.size();
Set<String> endpointsCompleted = new HashSet<String>();
outer:
while (true)
{
endpoints.removeAll(endpointsCompleted);
for (String endpoint : endpoints)
{
boolean choiceMade = false;
if (!iterMap.get(endpoint).hasNext())
{
endpointsCompleted.add(endpoint);
continue;
}
String token = iterMap.get(endpoint).next();
List<String> subSet = new ArrayList<String>(endpoints);
subSet.remove(endpoint);
Collections.shuffle(subSet, rand);
for (String choice : subSet)
{
if (relocations.get(choice).size() < endpointToNumTokens.get(choice))
{
relocations.put(choice, token);
choiceMade = true;
break;
}
}
if (!choiceMade)
relocations.put(endpoint, token);
}
// We're done when we've exhausted all of the token iterators
if (endpointsCompleted.size() == epsToComplete)
break outer;
}
return relocations;
}
/**
* Enable relocations.
*
* @param endpoints sequence of hostname or IP strings
*/
public void enableRelocations(String...endpoints)
{
enableRelocations(Arrays.asList(endpoints));
}
/**
* Enable relocations.
*
* @param endpoints Collection of hostname or IP strings
*/
public void enableRelocations(Collection<String> endpoints)
{
for (String endpoint : endpoints)
{
try
{
JMXConnection conn = new JMXConnection(endpoint, port, username, password);
getSSProxy(conn.getMbeanServerConn()).enableScheduledRangeXfers();
conn.close();
}
catch (IOException e)
{
writeln("Failed to enable shuffling on %s!", endpoint);
}
}
}
/**
* Disable relocations.
*
* @param endpoints sequence of hostname or IP strings
*/
public void disableRelocations(String...endpoints)
{
disableRelocations(Arrays.asList(endpoints));
}
/**
* Disable relocations.
*
* @param endpoints Collection of hostname or IP strings
*/
public void disableRelocations(Collection<String> endpoints)
{
for (String endpoint : endpoints)
{
try
{
JMXConnection conn = new JMXConnection(endpoint, port, username, password);
getSSProxy(conn.getMbeanServerConn()).disableScheduledRangeXfers();
conn.close();
}
catch (IOException e)
{
writeln("Failed to enable shuffling on %s!", endpoint);
}
}
}
/**
* Return a list of the live nodes (using JMX).
*
* @return String endpoint names
* @throws ShuffleError
*/
public Collection<String> getLiveNodes() throws ShuffleError
{
try
{
JMXConnection conn = new JMXConnection(host, port, username, password);
return getSSProxy(conn.getMbeanServerConn()).getLiveNodes();
}
catch (IOException e)
{
throw new ShuffleError("Error retrieving list of nodes from JMX interface");
}
}
/**
* Create and distribute a new, randomized token to endpoint mapping.
*
* @throws ShuffleError on handled exceptions
*/
public void shuffle(boolean enable, String onlyDc) throws ShuffleError
{
CassandraClient seedClient = null;
Map<String, String> tokenMap = null;
IPartitioner<?> partitioner = null;
Multimap<String, String> endpointMap = HashMultimap.create();
EndpointSnitchInfoMBean epSnitchProxy = getEpSnitchProxy(jmxConn.getMbeanServerConn());
try
{
seedClient = getThriftClient(thriftHost, thriftPort, thriftFramed);
tokenMap = seedClient.describe_token_map();
for (Map.Entry<String, String> entry : tokenMap.entrySet())
{
String endpoint = entry.getValue(), token = entry.getKey();
try
{
if (onlyDc != null)
{
if (onlyDc.equals(epSnitchProxy.getDatacenter(endpoint)))
endpointMap.put(endpoint, token);
}
else
endpointMap.put(endpoint, token);
}
catch (UnknownHostException e)
{
writeln("Warning: %s unknown to EndpointSnitch!", endpoint);
}
}
}
catch (InvalidRequestException ire)
{
throw new RuntimeException("What that...?", ire);
}
catch (TException e)
{
throw new ShuffleError(
String.format("Thrift request to %s:%d failed: %s", thriftHost, thriftPort, e.getMessage()));
}
partitioner = getPartitioner(thriftHost, thriftPort, thriftFramed);
Multimap<String, String> relocations = calculateRelocations(endpointMap);
writeln("%-42s %-15s %-15s", "Token", "From", "To");
writeln("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+~~~~~~~~~~~~~~~+~~~~~~~~~~~~~~~");
// Store relocations on remote nodes.
for (String endpoint : relocations.keySet())
{
for (String tok : relocations.get(endpoint))
writeln("%-42s %-15s %-15s", tok, tokenMap.get(tok), endpoint);
String cqlQuery = createShuffleBatchInsert(relocations.get(endpoint), partitioner);
executeCqlQuery(endpoint, thriftPort, thriftFramed, cqlQuery);
}
if (enable)
enableRelocations(relocations.keySet());
}
/**
* Print a list of pending token relocations for all nodes.
*
* @throws ShuffleError
*/
public void ls() throws ShuffleError
{
Map<String, List<CqlRow>> queuedRelocations = listRelocations();
IPartitioner<?> partitioner = getPartitioner(thriftHost, thriftPort, thriftFramed);
boolean justOnce = false;
for (String host : queuedRelocations.keySet())
{
for (CqlRow row : queuedRelocations.get(host))
{
assert row.getColumns().size() == 2;
if (!justOnce)
{
writeln("%-42s %-15s %s", "Token", "Endpoint", "Requested at");
writeln("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+~~~~~~~~~~~~~~~+~~~~~~~~~~~~~~~~~~~~~~~~~~~~");
justOnce = true;
}
ByteBuffer tokenBytes = ByteBuffer.wrap(row.getColumns().get(0).getValue());
ByteBuffer requestedAt = ByteBuffer.wrap(row.getColumns().get(1).getValue());
Date time = TimestampSerializer.instance.deserialize(requestedAt);
Token<?> token = partitioner.getTokenFactory().fromByteArray(tokenBytes);
writeln("%-42s %-15s %s", token.toString(), host, time.toString());
}
}
}
/**
* List pending token relocations for all nodes.
*
* @return
* @throws ShuffleError
*/
private Map<String, List<CqlRow>> listRelocations() throws ShuffleError
{
String cqlQuery = "SELECT token_bytes,requested_at FROM system.range_xfers";
Map<String, List<CqlRow>> results = new HashMap<String, List<CqlRow>>();
for (String host : getLiveNodes())
{
CqlResult result = executeCqlQuery(host, thriftPort, thriftFramed, cqlQuery);
results.put(host, result.getRows());
}
return results;
}
/**
* Clear pending token relocations on all nodes.
*
* @throws ShuffleError
*/
public void clear() throws ShuffleError
{
Map<String, List<CqlRow>> queuedRelocations = listRelocations();
for (String host : queuedRelocations.keySet())
{
for (CqlRow row : queuedRelocations.get(host))
{
assert row.getColumns().size() == 2;
ByteBuffer tokenBytes = ByteBuffer.wrap(row.getColumns().get(0).getValue());
String query = String.format("DELETE FROM system.range_xfers WHERE token_bytes = 0x%s",
ByteBufferUtil.bytesToHex(tokenBytes));
executeCqlQuery(host, thriftPort, thriftFramed, query);
}
}
}
/**
* Enable shuffling on all nodes in the cluster.
*
* @throws ShuffleError
*/
public void enable() throws ShuffleError
{
enableRelocations(getLiveNodes());
}
/**
* Disable shuffling on all nodes in the cluster.
*
* @throws ShuffleError
*/
public void disable() throws ShuffleError
{
disableRelocations(getLiveNodes());
}
/**
* Setup and return a new Thrift RPC connection.
*
* @param hostName hostname or address to connect to
* @param port port number to connect to
* @param framed wrap with framed transport if true
* @return a CassandraClient instance
* @throws ShuffleError
*/
public static CassandraClient getThriftClient(String hostName, int port, boolean framed) throws ShuffleError
{
try
{
return new CassandraClient(hostName, port, framed);
}
catch (TTransportException e)
{
throw new ShuffleError(String.format("Unable to connect to %s/%d: %s", hostName, port, e.getMessage()));
}
}
/**
* Execute a CQL v3 query.
*
* @param hostName hostname or address to connect to
* @param port port number to connect to
* @param isFramed wrap with framed transport if true
* @param cqlQuery CQL query string
* @return a Thrift CqlResult instance
* @throws ShuffleError
*/
public static CqlResult executeCqlQuery(String hostName, int port, boolean isFramed, String cqlQuery) throws ShuffleError
{
CassandraClient client = null;
try
{
client = getThriftClient(hostName, port, isFramed);
return client.execute_cql_query(ByteBuffer.wrap(cqlQuery.getBytes()), Compression.NONE);
}
catch (UnavailableException e)
{
throw new ShuffleError(
String.format("Unable to write shuffle entries to %s. Reason: UnavailableException", hostName));
}
catch (TimedOutException e)
{
throw new ShuffleError(
String.format("Unable to write shuffle entries to %s. Reason: TimedOutException", hostName));
}
catch (Exception e)
{
throw new RuntimeException(e);
}
finally
{
if (client != null)
client.close();
}
}
/**
* Return a partitioner instance for remote host.
*
* @param hostName hostname or address to connect to
* @param port port number to connect to
* @param framed wrap with framed transport if true
* @return an IPartitioner instance
* @throws ShuffleError
*/
public static IPartitioner<?> getPartitioner(String hostName, int port, boolean framed) throws ShuffleError
{
String partitionerName = null;
try
{
partitionerName = getThriftClient(hostName, port, framed).describe_partitioner();
}
catch (InvalidRequestException e)
{
throw new RuntimeException("Error calling describe_partitioner() defies explanation", e);
}
catch (TException e)
{
throw new ShuffleError(
String.format("Thrift request to %s:%d failed: %s", hostName, port, e.getMessage()));
}
try
{
Class<?> partitionerClass = Class.forName(partitionerName);
return (IPartitioner<?>)partitionerClass.newInstance();
}
catch (ClassNotFoundException e)
{
throw new ShuffleError("Unable to locate class for partitioner: " + partitionerName);
}
catch (Exception e)
{
throw new RuntimeException(e);
}
}
/**
* Create and return a CQL batch insert statement for a set of token relocations.
*
* @param tokens tokens to be relocated
* @param partitioner an instance of the IPartitioner in use
* @return a query string
*/
public static String createShuffleBatchInsert(Collection<String> tokens, IPartitioner<?> partitioner)
{
StringBuilder query = new StringBuilder();
query.append("BEGIN BATCH").append("\n");
for (String tokenStr : tokens)
{
Token<?> token = partitioner.getTokenFactory().fromString(tokenStr);
String hexToken = ByteBufferUtil.bytesToHex(partitioner.getTokenFactory().toByteArray(token));
query.append("INSERT INTO system.range_xfers (token_bytes, requested_at) ")
.append("VALUES (").append("0x").append(hexToken).append(", 'now');").append("\n");
}
query.append("APPLY BATCH").append("\n");
return query.toString();
}
/** Print usage information. */
private static void printShuffleHelp()
{
StringBuilder sb = new StringBuilder();
sb.append("Sub-commands:").append(String.format("%n"));
sb.append(" create Initialize a new shuffle operation").append(String.format("%n"));
sb.append(" ls List pending relocations").append(String.format("%n"));
sb.append(" clear Clear pending relocations").append(String.format("%n"));
sb.append(" en[able] Enable shuffling").append(String.format("%n"));
sb.append(" dis[able] Disable shuffling").append(String.format("%n%n"));
printHelp("shuffle [options] <sub-command>", sb.toString());
}
/**
* Execute.
*
* @param args arguments passed on the command line
* @throws Exception when face meets palm
*/
public static void main(String[] args) throws Exception
{
CommandLine cmd = null;
try
{
cmd = processArguments(args);
}
catch (MissingArgumentException e)
{
System.err.println(e.getMessage());
System.exit(1);
}
// Sub command argument.
if (cmd.getArgList().size() < 1)
{
System.err.println("Missing sub-command argument.");
printShuffleHelp();
System.exit(1);
}
String subCommand = (String)(cmd.getArgList()).get(0);
String hostName = (cmd.getOptionValue("host") != null) ? cmd.getOptionValue("host") : DEFAULT_HOST;
String port = (cmd.getOptionValue("port") != null) ? cmd.getOptionValue("port") : Integer.toString(DEFAULT_JMX_PORT);
String username = cmd.getOptionValue("username");
String password = cmd.getOptionValue("password");
String thriftHost = (cmd.getOptionValue("thrift-host") != null) ? cmd.getOptionValue("thrift-host") : hostName;
String thriftPort = (cmd.getOptionValue("thrift-port") != null) ? cmd.getOptionValue("thrift-port") : "9160";
String onlyDc = cmd.getOptionValue("only-dc");
boolean thriftFramed = cmd.hasOption("thrift-framed") ? true : false;
boolean andEnable = cmd.hasOption("and-enable") ? true : false;
int portNum = -1, thriftPortNum = -1;
// Parse JMX port number
if (port != null)
{
try
{
portNum = Integer.parseInt(port);
}
catch (NumberFormatException ferr)
{
System.err.printf("%s is not a valid JMX port number.%n", port);
System.exit(1);
}
}
else
portNum = DEFAULT_JMX_PORT;
// Parse Thrift port number
if (thriftPort != null)
{
try
{
thriftPortNum = Integer.parseInt(thriftPort);
}
catch (NumberFormatException ferr)
{
System.err.printf("%s is not a valid port number.%n", thriftPort);
System.exit(1);
}
}
else
thriftPortNum = 9160;
Shuffle shuffler = new Shuffle(hostName, portNum, thriftHost, thriftPortNum, thriftFramed,
username, password);
try
{
if (subCommand.equals("create"))
shuffler.shuffle(andEnable, onlyDc);
else if (subCommand.equals("ls"))
shuffler.ls();
else if (subCommand.startsWith("en"))
shuffler.enable();
else if (subCommand.startsWith("dis"))
shuffler.disable();
else if (subCommand.equals("clear"))
shuffler.clear();
else
{
System.err.println("Unknown subcommand: " + subCommand);
printShuffleHelp();
System.exit(1);
}
}
catch (ShuffleError err)
{
shuffler.writeln(err);
System.exit(1);
}
finally
{
shuffler.close();
}
System.exit(0);
}
}
/** A self-contained Cassandra.Client; Closeable. */
class CassandraClient implements Closeable
{
TTransport transport;
Cassandra.Client client;
CassandraClient(String hostName, int port, boolean framed) throws TTransportException
{
TSocket socket = new TSocket(hostName, port);
transport = (framed) ? socket : new TFastFramedTransport(socket);
transport.open();
client = new Cassandra.Client(new TBinaryProtocol(transport));
try
{
client.set_cql_version("3.0.0");
}
catch (Exception e)
{
throw new RuntimeException(e);
}
}
CqlResult execute_cql_query(ByteBuffer cqlQuery, Compression compression) throws Exception
{
return client.execute_cql3_query(cqlQuery, compression, ConsistencyLevel.ONE);
}
String describe_partitioner() throws TException, InvalidRequestException
{
return client.describe_partitioner();
}
List<TokenRange> describe_ring(String keyspace) throws InvalidRequestException, TException
{
return client.describe_ring(keyspace);
}
Map<String, String> describe_token_map() throws InvalidRequestException, TException
{
return client.describe_token_map();
}
public void close()
{
transport.close();
}
}
@SuppressWarnings("serial")
class ShuffleError extends Exception
{
ShuffleError(String msg)
{
super(msg);
}
}