/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.cassandra.transport; import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.UnknownHostException; import java.util.EnumMap; import java.util.concurrent.Callable; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.cassandra.auth.IAuthenticator; import org.apache.cassandra.auth.ISaslAwareAuthenticator; import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.config.EncryptionOptions; import org.apache.cassandra.metrics.ClientMetrics; import org.apache.cassandra.security.SSLFactory; import org.apache.cassandra.service.*; import org.apache.cassandra.transport.messages.EventMessage; import org.jboss.netty.bootstrap.ServerBootstrap; import org.jboss.netty.channel.*; import org.jboss.netty.channel.group.ChannelGroup; import org.jboss.netty.channel.group.DefaultChannelGroup; import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory; import org.jboss.netty.handler.execution.ExecutionHandler; import org.jboss.netty.handler.ssl.SslHandler; import org.jboss.netty.logging.InternalLoggerFactory; import org.jboss.netty.logging.Slf4JLoggerFactory; public class Server implements CassandraDaemon.Server { static { InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()); } private static final Logger logger = LoggerFactory.getLogger(Server.class); /** current version of the native protocol we support */ public static final int CURRENT_VERSION = 2; private final ConnectionTracker connectionTracker = new ConnectionTracker(); private final Connection.Factory connectionFactory = new Connection.Factory() { public Connection newConnection(Channel channel, int version) { return new ServerConnection(channel, version, connectionTracker); } }; public final InetSocketAddress socket; private final AtomicBoolean isRunning = new AtomicBoolean(false); private ChannelFactory factory; private ExecutionHandler executionHandler; public Server(InetSocketAddress socket) { this.socket = socket; EventNotifier notifier = new EventNotifier(this); StorageService.instance.register(notifier); MigrationManager.instance.register(notifier); registerMetrics(); } public Server(String hostname, int port) { this(new InetSocketAddress(hostname, port)); } public Server(InetAddress host, int port) { this(new InetSocketAddress(host, port)); } public Server(int port) { this(new InetSocketAddress(port)); } public void start() { if(!isRunning()) { run(); isRunning.set(true); } } public void stop() { if (isRunning.compareAndSet(true, false)) close(); } public boolean isRunning() { return isRunning.get(); } private void run() { // Check that a SaslAuthenticator can be provided by the configured // IAuthenticator. If not, don't start the server. IAuthenticator authenticator = DatabaseDescriptor.getAuthenticator(); if (authenticator.requireAuthentication() && !(authenticator instanceof ISaslAwareAuthenticator)) { logger.error("Not starting native transport as the configured IAuthenticator is not capable of SASL authentication"); isRunning.compareAndSet(true, false); return; } // Configure the server. executionHandler = new ExecutionHandler(new RequestThreadPoolExecutor()); factory = new NioServerSocketChannelFactory(Executors.newCachedThreadPool(), Executors.newCachedThreadPool()); ServerBootstrap bootstrap = new ServerBootstrap(factory); bootstrap.setOption("child.tcpNoDelay", true); // Set up the event pipeline factory. final EncryptionOptions.ClientEncryptionOptions clientEnc = DatabaseDescriptor.getClientEncryptionOptions(); if (clientEnc.enabled) { logger.info("Enabling encrypted CQL connections between client and server"); bootstrap.setPipelineFactory(new SecurePipelineFactory(this, clientEnc)); } else { bootstrap.setPipelineFactory(new PipelineFactory(this)); } // Bind and start to accept incoming connections. logger.info("Starting listening for CQL clients on {}...", socket); Channel channel = bootstrap.bind(socket); connectionTracker.allChannels.add(channel); } private void registerMetrics() { ClientMetrics.instance.addCounter("connectedNativeClients", new Callable<Integer>() { @Override public Integer call() throws Exception { return connectionTracker.getConnectedClients(); } }); } private void close() { // Close opened connections connectionTracker.closeAll(); factory.releaseExternalResources(); factory = null; executionHandler.releaseExternalResources(); executionHandler = null; logger.info("Stop listening for CQL clients"); } public static class ConnectionTracker implements Connection.Tracker { public final ChannelGroup allChannels = new DefaultChannelGroup(); private final EnumMap<Event.Type, ChannelGroup> groups = new EnumMap<Event.Type, ChannelGroup>(Event.Type.class); public ConnectionTracker() { for (Event.Type type : Event.Type.values()) groups.put(type, new DefaultChannelGroup(type.toString())); } public void addConnection(Channel ch, Connection connection) { allChannels.add(ch); } public void register(Event.Type type, Channel ch) { groups.get(type).add(ch); } public void unregister(Channel ch) { for (ChannelGroup group : groups.values()) group.remove(ch); } public void send(Event event) { groups.get(event.type).write(new EventMessage(event)); } public void closeAll() { allChannels.close().awaitUninterruptibly(); } public int getConnectedClients() { /* - When server is running: allChannels contains all clients' connections (channels) plus one additional channel used for the server's own bootstrap. - When server is stopped: the size is 0 */ return allChannels.size() != 0 ? allChannels.size() - 1 : 0; } } private static class PipelineFactory implements ChannelPipelineFactory { // Stateless handlers private static final Message.ProtocolDecoder messageDecoder = new Message.ProtocolDecoder(); private static final Message.ProtocolEncoder messageEncoder = new Message.ProtocolEncoder(); private static final Frame.Decompressor frameDecompressor = new Frame.Decompressor(); private static final Frame.Compressor frameCompressor = new Frame.Compressor(); private static final Frame.Encoder frameEncoder = new Frame.Encoder(); private static final Message.Dispatcher dispatcher = new Message.Dispatcher(); private final Server server; public PipelineFactory(Server server) { this.server = server; } public ChannelPipeline getPipeline() throws Exception { ChannelPipeline pipeline = Channels.pipeline(); //pipeline.addLast("debug", new LoggingHandler()); pipeline.addLast("frameDecoder", new Frame.Decoder(server.connectionFactory)); pipeline.addLast("frameEncoder", frameEncoder); pipeline.addLast("frameDecompressor", frameDecompressor); pipeline.addLast("frameCompressor", frameCompressor); pipeline.addLast("messageDecoder", messageDecoder); pipeline.addLast("messageEncoder", messageEncoder); pipeline.addLast("executor", server.executionHandler); pipeline.addLast("dispatcher", dispatcher); return pipeline; } } private static class SecurePipelineFactory extends PipelineFactory { private final SSLContext sslContext; private final EncryptionOptions encryptionOptions; public SecurePipelineFactory(Server server, EncryptionOptions encryptionOptions) { super(server); this.encryptionOptions = encryptionOptions; try { this.sslContext = SSLFactory.createSSLContext(encryptionOptions, false); } catch (IOException e) { throw new RuntimeException("Failed to setup secure pipeline", e); } } public ChannelPipeline getPipeline() throws Exception { SSLEngine sslEngine = sslContext.createSSLEngine(); sslEngine.setUseClientMode(false); sslEngine.setEnabledCipherSuites(encryptionOptions.cipher_suites); sslEngine.setNeedClientAuth(encryptionOptions.require_client_auth); SslHandler sslHandler = new SslHandler(sslEngine); sslHandler.setIssueHandshake(true); ChannelPipeline pipeline = super.getPipeline(); pipeline.addFirst("ssl", sslHandler); return pipeline; } } private static class EventNotifier implements IEndpointLifecycleSubscriber, IMigrationListener { private final Server server; private static final InetAddress bindAll; static { try { bindAll = InetAddress.getByAddress(new byte[4]); } catch (UnknownHostException e) { throw new AssertionError(e); } } private EventNotifier(Server server) { this.server = server; } private InetAddress getRpcAddress(InetAddress endpoint) { try { InetAddress rpcAddress = InetAddress.getByName(StorageService.instance.getRpcaddress(endpoint)); // If rpcAddress == 0.0.0.0 (i.e. bound on all addresses), returning that is not very helpful, // so return the internal address (which is ok since "we're bound on all addresses"). return rpcAddress.equals(bindAll) ? endpoint : rpcAddress; } catch (UnknownHostException e) { // That should not happen, so log an error, but return the // endpoint address since there's a good change this is right logger.error("Problem retrieving RPC address for {}", endpoint, e); return endpoint; } } public void onJoinCluster(InetAddress endpoint) { server.connectionTracker.send(Event.TopologyChange.newNode(getRpcAddress(endpoint), server.socket.getPort())); } public void onLeaveCluster(InetAddress endpoint) { server.connectionTracker.send(Event.TopologyChange.removedNode(getRpcAddress(endpoint), server.socket.getPort())); } public void onMove(InetAddress endpoint) { server.connectionTracker.send(Event.TopologyChange.movedNode(getRpcAddress(endpoint), server.socket.getPort())); } public void onUp(InetAddress endpoint) { server.connectionTracker.send(Event.StatusChange.nodeUp(getRpcAddress(endpoint), server.socket.getPort())); } public void onDown(InetAddress endpoint) { server.connectionTracker.send(Event.StatusChange.nodeDown(getRpcAddress(endpoint), server.socket.getPort())); } public void onCreateKeyspace(String ksName) { server.connectionTracker.send(new Event.SchemaChange(Event.SchemaChange.Change.CREATED, ksName)); } public void onCreateColumnFamily(String ksName, String cfName) { server.connectionTracker.send(new Event.SchemaChange(Event.SchemaChange.Change.CREATED, ksName, cfName)); } public void onUpdateKeyspace(String ksName) { server.connectionTracker.send(new Event.SchemaChange(Event.SchemaChange.Change.UPDATED, ksName)); } public void onUpdateColumnFamily(String ksName, String cfName) { server.connectionTracker.send(new Event.SchemaChange(Event.SchemaChange.Change.UPDATED, ksName, cfName)); } public void onDropKeyspace(String ksName) { server.connectionTracker.send(new Event.SchemaChange(Event.SchemaChange.Change.DROPPED, ksName)); } public void onDropColumnFamily(String ksName, String cfName) { server.connectionTracker.send(new Event.SchemaChange(Event.SchemaChange.Change.DROPPED, ksName, cfName)); } } }