allow capabilities to be sent when addresses resolved async

queue addresses that need resolution, use async dns to resolve them
and when final result known dispatch the capabilities to be sent

Signed-off-by: Caolán McNamara <caolan.mcnamara@collabora.com>
Change-Id: I13b6d0c4d47e6e8ecd06f7a449c8f808a41e5e7a
pull/9095/head
Caolán McNamara 2024-05-16 09:47:17 +01:00 committed by Caolán McNamara
parent fc41cf7694
commit 694f0488ad
2 changed files with 80 additions and 9 deletions

View File

@ -33,6 +33,7 @@
#include <Socket.hpp>
#include <UserMessages.hpp>
#include <Util.hpp>
#include <net/AsyncDNS.hpp>
#include <net/HttpHelper.hpp>
#if !MOBILEAPP
#include <HostUtil.hpp>
@ -371,15 +372,18 @@ getConvertToBrokerImplementation(const std::string& requestType, const std::stri
return nullptr;
}
class ConvertToAddressResolver
class ConvertToAddressResolver : public std::enable_shared_from_this<ConvertToAddressResolver>
{
std::shared_ptr<ConvertToAddressResolver> _selfLifecycle;
std::queue<std::string> _addressesToResolve;
ClientRequestDispatcher::AsyncFn _asyncCb;
bool _allow;
public:
ConvertToAddressResolver(std::queue<std::string> addressesToResolve)
ConvertToAddressResolver(std::queue<std::string> addressesToResolve, ClientRequestDispatcher::AsyncFn asyncCb)
: _addressesToResolve(addressesToResolve)
, _asyncCb(asyncCb)
, _allow(true)
{
}
@ -392,6 +396,7 @@ public:
// synchronous case
bool syncProcess()
{
assert(!_asyncCb);
while (!_addressesToResolve.empty())
{
const std::string& addressToCheck = _addressesToResolve.front();
@ -423,6 +428,56 @@ public:
}
return _allow;
}
// asynchronous case
void startAsyncProcessing()
{
assert(_asyncCb);
_selfLifecycle = shared_from_this();
dispatchNextLookup();
}
void dispatchNextLookup()
{
net::AsyncDNS::DNSThreadFn pushHostnameResolvedToPoll = [this](const std::string& hostname,
const std::string& exception) {
COOLWSD::getWebServerPoll()->addCallback([this, hostname, exception]() {
hostnameResolved(hostname, exception);
});
};
const std::string& addressToCheck = _addressesToResolve.front();
net::AsyncDNS::canonicalHostName(addressToCheck, pushHostnameResolvedToPoll);
}
void hostnameResolved(const std::string& hostToCheck, const std::string& exception)
{
if (!exception.empty())
{
LOG_ERR_S(exception);
// We can't find out the hostname, and it already failed the IP check
_allow = false;
}
else
testHostName(hostToCheck);
const std::string& addressToCheck = _addressesToResolve.front();
if (_allow)
LOG_INF_S("convert-to: Requesting address is allowed: " << addressToCheck);
else
LOG_WRN_S("convert-to: Requesting address is denied: " << addressToCheck);
_addressesToResolve.pop();
// If hostToCheck is not allowed, or there are no addresses
// left to check, then do callback and end
if (!_allow || _addressesToResolve.empty())
{
_asyncCb(_allow);
_selfLifecycle.reset();
return;
}
dispatchNextLookup();
}
};
bool ClientRequestDispatcher::allowPostFrom(const std::string& address)
@ -455,13 +510,15 @@ bool ClientRequestDispatcher::allowPostFrom(const std::string& address)
}
bool ClientRequestDispatcher::allowConvertTo(const std::string& address,
const Poco::Net::HTTPRequest& request)
const Poco::Net::HTTPRequest& request,
AsyncFn asyncCb)
{
const bool allow = allowPostFrom(address) || HostUtil::allowedWopiHost(request.getHost());
if (!allow)
{
LOG_WRN_S("convert-to: Requesting address is denied: " << address);
if (asyncCb)
asyncCb(false);
return false;
}
@ -490,9 +547,18 @@ bool ClientRequestDispatcher::allowConvertTo(const std::string& address,
}
if (addressesToResolve.empty())
{
if (asyncCb)
asyncCb(true);
return true;
}
std::shared_ptr<ConvertToAddressResolver> resolver = std::make_shared<ConvertToAddressResolver>(addressesToResolve);
std::shared_ptr<ConvertToAddressResolver> resolver = std::make_shared<ConvertToAddressResolver>(addressesToResolve, asyncCb);
if (asyncCb)
{
resolver->startAsyncProcessing();
return false;
}
return resolver->syncProcess();
}
@ -1311,7 +1377,7 @@ void ClientRequestDispatcher::handlePostRequest(const RequestDetails& requestDet
requestDetails.equals(1, "get-thumbnail"))
{
// Validate sender - FIXME: should do this even earlier.
if (!allowConvertTo(socket->clientAddress(), request))
if (!allowConvertTo(socket->clientAddress(), request, nullptr))
{
LOG_WRN(
"Conversion requests not allowed from this address: " << socket->clientAddress());
@ -1937,8 +2003,11 @@ void ClientRequestDispatcher::handleCapabilitiesRequest(const Poco::Net::HTTPReq
LOG_DBG("Wopi capabilities request: " << request.getURI());
const bool convertToAvailable = allowConvertTo(socket->clientAddress(), request);
sendCapabilities(convertToAvailable, socket);
AsyncFn convertToAllowedCb = [socket](bool allowedConvert){
COOLWSD::getWebServerPoll()->addCallback([socket, allowedConvert]() { sendCapabilities(allowedConvert, socket); });
};
allowConvertTo(socket->clientAddress(), request, convertToAllowedCb);
}
#endif

View File

@ -32,6 +32,8 @@ public:
StaticFileContentCache["discovery.xml"] = getDiscoveryXML();
}
typedef std::function<void(bool)> AsyncFn;
private:
/// Set the socket associated with this ResponseClient.
void onConnect(const std::shared_ptr<StreamSocket>& socket) override;
@ -51,7 +53,7 @@ private:
/// Does this address feature in the allowed hosts list.
static bool allowPostFrom(const std::string& address);
static bool allowConvertTo(const std::string& address, const Poco::Net::HTTPRequest& request);
static bool allowConvertTo(const std::string& address, const Poco::Net::HTTPRequest& request, AsyncFn asyncCb);
void handleRootRequest(const RequestDetails& requestDetails,
const std::shared_ptr<StreamSocket>& socket);