aboutsummaryrefslogtreecommitdiffstats
path: root/src/lib/x509/x509_ext.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/x509/x509_ext.cpp')
-rw-r--r--src/lib/x509/x509_ext.cpp67
1 files changed, 67 insertions, 0 deletions
diff --git a/src/lib/x509/x509_ext.cpp b/src/lib/x509/x509_ext.cpp
index c22e9ebcb..f475c50c2 100644
--- a/src/lib/x509/x509_ext.cpp
+++ b/src/lib/x509/x509_ext.cpp
@@ -90,10 +90,52 @@ void Certificate_Extension::validate(const X509_Certificate&, const X509_Certifi
void Extensions::add(Certificate_Extension* extn, bool critical)
{
+ // sanity check: we don't want to have the same extension more than once
+ for(const auto& ext : m_extensions)
+ {
+ if(ext.first->oid_of() == extn->oid_of())
+ {
+ throw Invalid_Argument(extn->oid_name() + " extension already present");
+ }
+ }
+
+ if(m_extensions_raw.count(extn->oid_of()) > 0)
+ {
+ throw Invalid_Argument(extn->oid_name() + " extension already present");
+ }
+
m_extensions.push_back(std::make_pair(std::unique_ptr<Certificate_Extension>(extn), critical));
m_extensions_raw.emplace(extn->oid_of(), std::make_pair(extn->encode_inner(), critical));
}
+void Extensions::replace(Certificate_Extension* extn, bool critical)
+ {
+ for(auto it = m_extensions.begin(); it != m_extensions.end(); ++it)
+ {
+ if(it->first->oid_of() == extn->oid_of())
+ {
+ m_extensions.erase(it);
+ break;
+ }
+ }
+
+ m_extensions.push_back(std::make_pair(std::unique_ptr<Certificate_Extension>(extn), critical));
+ m_extensions_raw[extn->oid_of()] = std::make_pair(extn->encode_inner(), critical);
+ }
+
+Certificate_Extension* Extensions::get(const OID& oid) const
+ {
+ for(auto& ext : m_extensions)
+ {
+ if(ext.first->oid_of() == oid)
+ {
+ return ext.first.get();
+ }
+ }
+
+ return nullptr;
+ }
+
std::vector<std::pair<std::unique_ptr<Certificate_Extension>, bool>> Extensions::extensions() const
{
std::vector<std::pair<std::unique_ptr<Certificate_Extension>, bool>> exts;
@@ -114,6 +156,7 @@ std::map<OID, std::pair<std::vector<byte>, bool>> Extensions::extensions_raw() c
*/
void Extensions::encode_into(DER_Encoder& to_object) const
{
+ // encode any known extensions
for(size_t i = 0; i != m_extensions.size(); ++i)
{
const Certificate_Extension* ext = m_extensions[i].first.get();
@@ -130,6 +173,30 @@ void Extensions::encode_into(DER_Encoder& to_object) const
.end_cons();
}
}
+
+ // encode any unknown extensions
+ for(const auto& ext_raw : m_extensions_raw)
+ {
+ const bool is_critical = ext_raw.second.second;
+ const OID oid = ext_raw.first;
+ const std::vector<uint8_t> value = ext_raw.second.first;
+
+ auto pos = std::find_if(std::begin(m_extensions), std::end(m_extensions),
+ [&oid](const std::pair<std::unique_ptr<Certificate_Extension>, bool>& ext) -> bool
+ {
+ return ext.first->oid_of() == oid;
+ });
+
+ if(pos == std::end(m_extensions))
+ {
+ // not found in m_extensions, must be unknown
+ to_object.start_cons(SEQUENCE)
+ .encode(oid)
+ .encode_optional(is_critical, false)
+ .encode(value, OCTET_STRING)
+ .end_cons();
+ }
+ }
}
/*