Source code

Revision control

Copy as Markdown

Other Tools

commit 4ad63eb3aa65ce7baa08190aac2770540dc25f43
Author: Greg Stoll <gstoll@mozilla.com>
Date: Wed, 27 Mar 2024 12:13:56 -0500
Mozilla improvements to content_analysis_sdk
- add ability for demo agent to block/warn/report specific regexes
- add ability for demo agent to chose a sequence of delays to apply
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 39477223f031c..5dacc81031117 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -203,6 +203,7 @@ add_executable(agent
./demo/agent.cc
./demo/handler.h
)
+target_compile_features(agent PRIVATE cxx_std_17)
target_include_directories(agent PRIVATE ${AGENT_INCLUDES})
target_link_libraries(agent PRIVATE cac_agent)
diff --git a/agent/src/event_win.h b/agent/src/event_win.h
index 9f8b6903566f2..f631f693dcd9c 100644
--- a/agent/src/event_win.h
+++ b/agent/src/event_win.h
@@ -28,6 +28,12 @@ class ContentAnalysisEventWin : public ContentAnalysisEventBase {
ResultCode Close() override;
ResultCode Send() override;
std::string DebugString() const override;
+ std::string SerializeStringToSendToBrowser() {
+ return agent_to_chrome()->SerializeAsString();
+ }
+ void SetResponseSent() { response_sent_ = true; }
+
+ HANDLE Pipe() const { return hPipe_; }
private:
void Shutdown();
diff --git a/browser/src/client_win.cc b/browser/src/client_win.cc
index 9d3d7e8c52662..039946d131398 100644
--- a/browser/src/client_win.cc
+++ b/browser/src/client_win.cc
@@ -418,7 +418,11 @@ DWORD ClientWin::ConnectToPipe(const std::string& pipename, HANDLE* handle) {
void ClientWin::Shutdown() {
if (hPipe_ != INVALID_HANDLE_VALUE) {
- FlushFileBuffers(hPipe_);
+ // TODO: This trips the LateWriteObserver. We could move this earlier
+ // (before the LateWriteObserver is created) or just remove it, although
+ // the later could mean an ACK message is not processed by the agent
+ // in time.
+ // FlushFileBuffers(hPipe_);
CloseHandle(hPipe_);
hPipe_ = INVALID_HANDLE_VALUE;
}
diff --git a/demo/agent.cc b/demo/agent.cc
index ff8b93f647ebd..3e168b0915a0c 100644
--- a/demo/agent.cc
+++ b/demo/agent.cc
@@ -2,12 +2,18 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
+#include <algorithm>
#include <fstream>
#include <iostream>
#include <string>
+#include <regex>
+#include <vector>
#include "content_analysis/sdk/analysis_agent.h"
#include "demo/handler.h"
+#include "demo/handler_misbehaving.h"
+
+using namespace content_analysis::sdk;
// Different paths are used depending on whether this agent should run as a
// use specific agent or not. These values are chosen to match the test
@@ -19,19 +25,50 @@ constexpr char kPathSystem[] = "brcm_chrm_cas";
std::string path = kPathSystem;
bool use_queue = false;
bool user_specific = false;
-unsigned long delay = 0; // In seconds.
+std::vector<unsigned long> delays = {0}; // In seconds.
unsigned long num_threads = 8u;
std::string save_print_data_path = "";
+RegexArray toBlock, toWarn, toReport;
+static bool useMisbehavingHandler = false;
+static std::string modeStr;
// Command line parameters.
-constexpr const char* kArgDelaySpecific = "--delay=";
+constexpr const char* kArgDelaySpecific = "--delays=";
constexpr const char* kArgPath = "--path=";
constexpr const char* kArgQueued = "--queued";
constexpr const char* kArgThreads = "--threads=";
constexpr const char* kArgUserSpecific = "--user";
+constexpr const char* kArgToBlock = "--toblock=";
+constexpr const char* kArgToWarn = "--towarn=";
+constexpr const char* kArgToReport = "--toreport=";
+constexpr const char* kArgMisbehave = "--misbehave=";
constexpr const char* kArgHelp = "--help";
constexpr const char* kArgSavePrintRequestDataTo = "--save-print-request-data-to=";
+std::map<std::string, Mode> sStringToMode = {
+#define AGENT_MODE(name) {#name, Mode::Mode_##name},
+#include "modes.h"
+#undef AGENT_MODE
+};
+
+std::map<Mode, std::string> sModeToString = {
+#define AGENT_MODE(name) {Mode::Mode_##name, #name},
+#include "modes.h"
+#undef AGENT_MODE
+};
+
+std::vector<std::pair<std::string, std::regex>>
+ParseRegex(const std::string str) {
+ std::vector<std::pair<std::string, std::regex>> ret;
+ for (auto it = str.begin(); it != str.end(); /* nop */) {
+ auto it2 = std::find(it, str.end(), ',');
+ ret.push_back(std::make_pair(std::string(it, it2), std::regex(it, it2)));
+ it = it2 == str.end() ? it2 : it2 + 1;
+ }
+
+ return ret;
+}
+
bool ParseCommandLine(int argc, char* argv[]) {
for (int i = 1; i < argc; ++i) {
const std::string arg = argv[i];
@@ -44,16 +81,38 @@ bool ParseCommandLine(int argc, char* argv[]) {
path = kPathUser;
user_specific = true;
} else if (arg.find(kArgDelaySpecific) == 0) {
- delay = std::stoul(arg.substr(strlen(kArgDelaySpecific)));
+ std::string delaysStr = arg.substr(strlen(kArgDelaySpecific));
+ delays.clear();
+ size_t posStart = 0, posEnd;
+ unsigned long delay;
+ while ((posEnd = delaysStr.find(',', posStart)) != std::string::npos) {
+ delay = std::stoul(delaysStr.substr(posStart, posEnd - posStart));
+ if (delay > 30) {
+ delay = 30;
+ }
+ delays.push_back(delay);
+ posStart = posEnd + 1;
+ }
+ delay = std::stoul(delaysStr.substr(posStart));
if (delay > 30) {
delay = 30;
}
+ delays.push_back(delay);
} else if (arg.find(kArgPath) == 0) {
path = arg.substr(strlen(kArgPath));
} else if (arg.find(kArgQueued) == 0) {
use_queue = true;
} else if (arg.find(kArgThreads) == 0) {
num_threads = std::stoul(arg.substr(strlen(kArgThreads)));
+ } else if (arg.find(kArgToBlock) == 0) {
+ toBlock = ParseRegex(arg.substr(strlen(kArgToBlock)));
+ } else if (arg.find(kArgToWarn) == 0) {
+ toWarn = ParseRegex(arg.substr(strlen(kArgToWarn)));
+ } else if (arg.find(kArgToReport) == 0) {
+ toReport = ParseRegex(arg.substr(strlen(kArgToReport)));
+ } else if (arg.find(kArgMisbehave) == 0) {
+ modeStr = arg.substr(strlen(kArgMisbehave));
+ useMisbehavingHandler = true;
} else if (arg.find(kArgHelp) == 0) {
return false;
} else if (arg.find(kArgSavePrintRequestDataTo) == 0) {
@@ -72,13 +131,17 @@ void PrintHelp() {
<< "A simple agent to process content analysis requests." << std::endl
<< "Data containing the string 'block' blocks the request data from being used." << std::endl
<< std::endl << "Options:" << std::endl
- << kArgDelaySpecific << "<delay> : Add a delay to request processing in seconds (max 30)." << std::endl
+ << kArgDelaySpecific << "<delay1,delay2,...> : Add delays to request processing in seconds. Delays are limited to 30 seconds and are applied round-robin to requests. Default is 0." << std::endl
<< kArgPath << " <path> : Used the specified path instead of default. Must come after --user." << std::endl
<< kArgQueued << " : Queue requests for processing in a background thread" << std::endl
<< kArgThreads << " : When queued, number of threads in the request processing thread pool" << std::endl
<< kArgUserSpecific << " : Make agent OS user specific." << std::endl
<< kArgHelp << " : prints this help message" << std::endl
- << kArgSavePrintRequestDataTo << " : saves the PDF data to the given file path for print requests";
+ << kArgSavePrintRequestDataTo << " : saves the PDF data to the given file path for print requests" << std::endl
+ << kArgToBlock << "<regex> : Regular expression matching file and text content to block." << std::endl
+ << kArgToWarn << "<regex> : Regular expression matching file and text content to warn about." << std::endl
+ << kArgToReport << "<regex> : Regular expression matching file and text content to report." << std::endl
+ << kArgMisbehave << "<mode> : Use 'misbehaving' agent in given mode for testing purposes." << std::endl;
}
int main(int argc, char* argv[]) {
@@ -87,9 +150,17 @@ int main(int argc, char* argv[]) {
return 1;
}
- auto handler = use_queue
- ? std::make_unique<QueuingHandler>(num_threads, delay, save_print_data_path)
- : std::make_unique<Handler>(delay, save_print_data_path);
+ auto handler =
+ useMisbehavingHandler
+ ? MisbehavingHandler::Create(modeStr, std::move(delays), save_print_data_path, std::move(toBlock), std::move(toWarn), std::move(toReport))
+ : use_queue
+ ? std::make_unique<QueuingHandler>(num_threads, std::move(delays), save_print_data_path, std::move(toBlock), std::move(toWarn), std::move(toReport))
+ : std::make_unique<Handler>(std::move(delays), save_print_data_path, std::move(toBlock), std::move(toWarn), std::move(toReport));
+
+ if (!handler) {
+ std::cout << "[Demo] Failed to construct handler." << std::endl;
+ return 1;
+ }
// Each agent uses a unique name to identify itself with Google Chrome.
content_analysis::sdk::ResultCode rc;
diff --git a/demo/handler.h b/demo/handler.h
index 9d1ccfdf9857a..88599963c51b0 100644
--- a/demo/handler.h
+++ b/demo/handler.h
@@ -7,31 +7,51 @@
#include <time.h>
+#include <algorithm>
+#include <atomic>
#include <chrono>
#include <cstdio>
#include <fstream>
#include <iostream>
+#include <optional>
#include <thread>
#include <utility>
+#include <regex>
#include <vector>
#include "content_analysis/sdk/analysis_agent.h"
#include "demo/atomic_output.h"
#include "demo/request_queue.h"
+using RegexArray = std::vector<std::pair<std::string, std::regex>>;
+
// An AgentEventHandler that dumps requests information to stdout and blocks
// any requests that have the keyword "block" in their data
class Handler : public content_analysis::sdk::AgentEventHandler {
public:
using Event = content_analysis::sdk::ContentAnalysisEvent;
- Handler(unsigned long delay, const std::string& print_data_file_path) :
- delay_(delay), print_data_file_path_(print_data_file_path) {
- }
+ Handler(std::vector<unsigned long>&& delays, const std::string& print_data_file_path,
+ RegexArray&& toBlock = RegexArray(),
+ RegexArray&& toWarn = RegexArray(),
+ RegexArray&& toReport = RegexArray()) :
+ toBlock_(std::move(toBlock)), toWarn_(std::move(toWarn)), toReport_(std::move(toReport)),
+ delays_(std::move(delays)), print_data_file_path_(print_data_file_path) {}
- unsigned long delay() { return delay_; }
+ const std::vector<unsigned long> delays() { return delays_; }
+ size_t nextDelayIndex() const { return nextDelayIndex_; }
protected:
+ // subclasses can override this
+ // returns whether the response has been set
+ virtual bool SetCustomResponse(AtomicCout& aout, std::unique_ptr<Event>& event) {
+ return false;
+ }
+ // subclasses can override this
+ // returns whether the response has been sent
+ virtual bool SendCustomResponse(std::unique_ptr<Event>& event) {
+ return false;
+ }
// Analyzes one request from Google Chrome and responds back to the browser
// with either an allow or block verdict.
void AnalyzeContent(AtomicCout& aout, std::unique_ptr<Event> event) {
@@ -43,29 +63,25 @@ class Handler : public content_analysis::sdk::AgentEventHandler {
DumpEvent(aout.stream(), event.get());
- bool block = false;
bool success = true;
- unsigned long delay = delay_;
-
- if (event->GetRequest().has_text_content()) {
- block = ShouldBlockRequest(
- event->GetRequest().text_content());
- GetFileSpecificDelay(event->GetRequest().text_content(), &delay);
- } else if (event->GetRequest().has_file_path()) {
- std::string content;
- success =
- ReadContentFromFile(event->GetRequest().file_path(),
- &content);
- if (success) {
- block = ShouldBlockRequest(content);
- GetFileSpecificDelay(content, &delay);
+ std::optional<content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action> caResponse;
+ bool setResponse = SetCustomResponse(aout, event);
+ if (!setResponse) {
+ caResponse = content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action_BLOCK;
+ if (event->GetRequest().has_text_content()) {
+ caResponse = DecideCAResponse(
+ event->GetRequest().text_content(), aout.stream());
+ } else if (event->GetRequest().has_file_path()) {
+ // TODO: Fix downloads to store file *first* so we can check contents.
+ // Until then, just check the file name:
+ caResponse = DecideCAResponse(
+ event->GetRequest().file_path(), aout.stream());
+ } else if (event->GetRequest().has_print_data()) {
+ // In the case of print request, normally the PDF bytes would be parsed
+ // for sensitive data violations. To keep this class simple, only the
+ // URL is checked for the word "block".
+ caResponse = DecideCAResponse(event->GetRequest().request_data().url(), aout.stream());
}
- } else if (event->GetRequest().has_print_data()) {
- // In the case of print request, normally the PDF bytes would be parsed
- // for sensitive data violations. To keep this class simple, only the
- // URL is checked for the word "block".
- block = ShouldBlockRequest(event->GetRequest().request_data().url());
- GetFileSpecificDelay(event->GetRequest().request_data().url(), &delay);
}
if (!success) {
@@ -75,22 +91,44 @@ class Handler : public content_analysis::sdk::AgentEventHandler {
content_analysis::sdk::ContentAnalysisResponse::Result::FAILURE);
aout.stream() << " Verdict: failed to reach verdict: ";
aout.stream() << event->DebugString() << std::endl;
- } else if (block) {
- auto rc = content_analysis::sdk::SetEventVerdictToBlock(event.get());
- aout.stream() << " Verdict: block";
- if (rc != content_analysis::sdk::ResultCode::OK) {
- aout.stream() << " error: "
- << content_analysis::sdk::ResultCodeToString(rc) << std::endl;
- aout.stream() << " " << event->DebugString() << std::endl;
+ } else {
+ aout.stream() << " Verdict: ";
+ if (caResponse) {
+ switch (caResponse.value()) {
+ case content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action_BLOCK:
+ aout.stream() << "BLOCK";
+ break;
+ case content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action_WARN:
+ aout.stream() << "WARN";
+ break;
+ case content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action_REPORT_ONLY:
+ aout.stream() << "REPORT_ONLY";
+ break;
+ case content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action_ACTION_UNSPECIFIED:
+ aout.stream() << "ACTION_UNSPECIFIED";
+ break;
+ default:
+ aout.stream() << "<error>";
+ break;
+ }
+ auto rc =
+ content_analysis::sdk::SetEventVerdictTo(event.get(), caResponse.value());
+ if (rc != content_analysis::sdk::ResultCode::OK) {
+ aout.stream() << " error: "
+ << content_analysis::sdk::ResultCodeToString(rc) << std::endl;
+ aout.stream() << " " << event->DebugString() << std::endl;
+ }
+ aout.stream() << std::endl;
+ } else {
+ aout.stream() << " Verdict: allow" << std::endl;
}
aout.stream() << std::endl;
- } else {
- aout.stream() << " Verdict: allow" << std::endl;
}
-
aout.stream() << std::endl;
// If a delay is specified, wait that much.
+ size_t nextDelayIndex = nextDelayIndex_.fetch_add(1);
+ unsigned long delay = delays_[nextDelayIndex % delays_.size()];
if (delay > 0) {
aout.stream() << "Delaying response to " << event->GetRequest().request_token()
<< " for " << delay << "s" << std::endl<< std::endl;
@@ -99,16 +137,19 @@ class Handler : public content_analysis::sdk::AgentEventHandler {
}
// Send the response back to Google Chrome.
- auto rc = event->Send();
- if (rc != content_analysis::sdk::ResultCode::OK) {
- aout.stream() << "[Demo] Error sending response: "
- << content_analysis::sdk::ResultCodeToString(rc)
- << std::endl;
- aout.stream() << event->DebugString() << std::endl;
+ bool sentCustomResponse = SendCustomResponse(event);
+ if (!sentCustomResponse) {
+ auto rc = event->Send();
+ if (rc != content_analysis::sdk::ResultCode::OK) {
+ aout.stream() << "[Demo] Error sending response: "
+ << content_analysis::sdk::ResultCodeToString(rc)
+ << std::endl;
+ aout.stream() << event->DebugString() << std::endl;
+ }
}
}
- private:
+ protected:
void OnBrowserConnected(
const content_analysis::sdk::BrowserInfo& info) override {
AtomicCout aout;
@@ -362,21 +403,40 @@ class Handler : public content_analysis::sdk::AgentEventHandler {
return true;
}
- bool ShouldBlockRequest(const std::string& content) {
- // Determines if the request should be blocked. For this simple example
- // the content is blocked if the string "block" is found. Otherwise the
- // content is allowed.
- return content.find("block") != std::string::npos;
- }
-
- void GetFileSpecificDelay(const std::string& content, unsigned long* delay) {
- auto pos = content.find("delay=");
- if (pos != std::string::npos) {
- std::sscanf(content.substr(pos).c_str(), "delay=%lu", delay);
+ std::optional<content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action>
+ DecideCAResponse(const std::string& content, std::stringstream& stream) {
+ for (auto& r : toBlock_) {
+ if (std::regex_search(content, r.second)) {
+ stream << "'" << content << "' matches BLOCK regex '"
+ << r.first << "'" << std::endl;
+ return content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action_BLOCK;
+ }
}
+ for (auto& r : toWarn_) {
+ if (std::regex_search(content, r.second)) {
+ stream << "'" << content << "' matches WARN regex '"
+ << r.first << "'" << std::endl;
+ return content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action_WARN;
+ }
+ }
+ for (auto& r : toReport_) {
+ if (std::regex_search(content, r.second)) {
+ stream << "'" << content << "' matches REPORT_ONLY regex '"
+ << r.first << "'" << std::endl;
+ return content_analysis::sdk::ContentAnalysisResponse_Result_TriggeredRule_Action_REPORT_ONLY;
+ }
+ }
+ stream << "'" << content << "' was ALLOWed\n";
+ return {};
}
- unsigned long delay_;
+ // For the demo, block any content that matches these wildcards.
+ RegexArray toBlock_;
+ RegexArray toWarn_;
+ RegexArray toReport_;
+
+ std::vector<unsigned long> delays_;
+ std::atomic<size_t> nextDelayIndex_;
std::string print_data_file_path_;
};
@@ -384,8 +444,11 @@ class Handler : public content_analysis::sdk::AgentEventHandler {
// any requests that have the keyword "block" in their data
class QueuingHandler : public Handler {
public:
- QueuingHandler(unsigned long threads, unsigned long delay, const std::string& print_data_file_path)
- : Handler(delay, print_data_file_path) {
+ QueuingHandler(unsigned long threads, std::vector<unsigned long>&& delays, const std::string& print_data_file_path,
+ RegexArray&& toBlock = RegexArray(),
+ RegexArray&& toWarn = RegexArray(),
+ RegexArray&& toReport = RegexArray())
+ : Handler(std::move(delays), print_data_file_path, std::move(toBlock), std::move(toWarn), std::move(toReport)) {
StartBackgroundThreads(threads);
}
@@ -421,6 +484,8 @@ class QueuingHandler : public Handler {
aout.stream() << std::endl << "----------" << std::endl;
aout.stream() << "Thread: " << std::this_thread::get_id()
<< std::endl;
+ aout.stream() << "Delaying request processing for "
+ << handler->delays()[handler->nextDelayIndex() % handler->delays().size()] << "s" << std::endl << std::endl;
aout.flush();
handler->AnalyzeContent(aout, std::move(event));
--
2.42.0.windows.2